"examples/basics/vscode:/vscode.git/clone" did not exist on "6720dfb65820785bafe5ad9745f9a1a430fe36eb"
Unverified Commit 008683d6 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding kvbm-engine (#6773)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent cf79c4fc
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! S3-based distributed lock manager implementation.
//!
//! This module provides [`S3LockManager`], an implementation of [`ObjectLockManager`]
//! that uses S3 conditional PUT operations for atomic lock acquisition.
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::future::BoxFuture;
use super::S3ObjectBlockClient;
use crate::SequenceHash;
use crate::object::{LockFileContent, ObjectLockManager};
/// S3-based implementation of [`ObjectLockManager`].
///
/// Uses conditional PUT (If-None-Match: *) for atomic lock acquisition.
/// Lock files contain instance_id and deadline; stale locks (past deadline)
/// can be overwritten.
///
/// # Lock File Format
///
/// Lock files are stored at `{hash}.lock` as JSON:
/// ```json
/// {
/// "instance_id": "uuid-of-leader-instance",
/// "acquired_at": "2025-12-14T10:30:00Z",
/// "deadline": "2025-12-14T10:35:00Z"
/// }
/// ```
///
/// # Meta File Format
///
/// Meta files are stored at `{hash}.meta` as empty objects (presence-only).
pub struct S3LockManager {
client: Arc<S3ObjectBlockClient>,
instance_id: String,
lock_timeout: Duration,
}
impl S3LockManager {
/// Default lock timeout: 300 seconds (5 minutes).
pub const DEFAULT_LOCK_TIMEOUT: Duration = Duration::from_secs(300);
/// Create a new S3 lock manager.
///
/// # Arguments
/// * `client` - S3 client for object operations
/// * `instance_id` - Unique identifier for this instance (e.g., UUID)
pub fn new(client: Arc<S3ObjectBlockClient>, instance_id: String) -> Self {
Self {
client,
instance_id,
lock_timeout: Self::DEFAULT_LOCK_TIMEOUT,
}
}
/// Create with a custom lock timeout.
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.lock_timeout = timeout;
self
}
/// Format the lock key for a given hash.
fn lock_key(&self, hash: &SequenceHash) -> String {
format!("{}.lock", hash)
}
/// Format the meta key for a given hash.
fn meta_key(&self, hash: &SequenceHash) -> String {
format!("{}.meta", hash)
}
/// Create lock file content with current timestamp.
fn create_lock_content(&self) -> LockFileContent {
let now = chrono::Utc::now();
let deadline = now + chrono::Duration::from_std(self.lock_timeout).unwrap_or_default();
LockFileContent {
instance_id: self.instance_id.clone(),
acquired_at: now.to_rfc3339(),
deadline: deadline.to_rfc3339(),
}
}
/// Check if a lock's deadline has been breached.
fn is_lock_expired(lock: &LockFileContent) -> bool {
if let Ok(deadline) = chrono::DateTime::parse_from_rfc3339(&lock.deadline) {
let now = chrono::Utc::now();
now > deadline.with_timezone(&chrono::Utc)
} else {
// If we can't parse the deadline, consider it expired
true
}
}
}
impl ObjectLockManager for S3LockManager {
fn has_meta(&self, hash: SequenceHash) -> BoxFuture<'static, Result<bool>> {
let client = self.client.clone();
let meta_key = self.meta_key(&hash);
Box::pin(async move { client.has_object(&meta_key).await })
}
fn try_acquire_lock(&self, hash: SequenceHash) -> BoxFuture<'static, Result<bool>> {
let client = self.client.clone();
let lock_key = self.lock_key(&hash);
let lock_content = self.create_lock_content();
let our_instance_id = self.instance_id.clone();
Box::pin(async move {
// Serialize lock content
let lock_data = serde_json::to_vec(&lock_content)
.map_err(|e| anyhow::anyhow!("failed to serialize lock content: {}", e))?;
// Try conditional PUT (If-None-Match: *)
match client
.put_if_not_exists(&lock_key, bytes::Bytes::from(lock_data.clone()))
.await
{
Ok(true) => {
// Successfully acquired lock
tracing::debug!(lock_key = %lock_key, "Acquired lock");
Ok(true)
}
Ok(false) => {
// Lock exists, read it with ETag for CAS-style takeover
tracing::debug!(lock_key = %lock_key, "Lock exists, checking deadline");
match client.get_object_with_etag(&lock_key).await? {
Some((existing_data, etag)) => {
match serde_json::from_slice::<LockFileContent>(&existing_data) {
Ok(existing_lock) => {
// Check if we own the lock
if existing_lock.instance_id == our_instance_id {
tracing::debug!(lock_key = %lock_key, "We own this lock");
return Ok(true);
}
// Check if the lock is expired
if Self::is_lock_expired(&existing_lock) {
tracing::debug!(
lock_key = %lock_key,
old_instance = %existing_lock.instance_id,
deadline = %existing_lock.deadline,
"Lock expired, attempting atomic takeover"
);
// Atomically overwrite the expired lock using ETag
if let Some(etag) = etag {
let won = client
.put_object_if_match(
&lock_key,
bytes::Bytes::from(lock_data),
&etag,
)
.await?;
if !won {
tracing::debug!(
lock_key = %lock_key,
"Lost race for expired lock takeover"
);
}
Ok(won)
} else {
// No ETag available, fall back to unconditional put
tracing::warn!(
lock_key = %lock_key,
"No ETag on expired lock, falling back to unconditional overwrite"
);
client
.put_object(
&lock_key,
bytes::Bytes::from(lock_data),
)
.await?;
Ok(true)
}
} else {
tracing::debug!(
lock_key = %lock_key,
owner = %existing_lock.instance_id,
deadline = %existing_lock.deadline,
"Lock held by another instance"
);
Ok(false)
}
}
Err(e) => {
// Malformed lock file, attempt atomic overwrite
tracing::warn!(
lock_key = %lock_key,
error = %e,
"Malformed lock file, attempting atomic overwrite"
);
if let Some(etag) = etag {
let won = client
.put_object_if_match(
&lock_key,
bytes::Bytes::from(lock_data),
&etag,
)
.await?;
if !won {
tracing::debug!(
lock_key = %lock_key,
"Lost race for malformed lock takeover"
);
}
Ok(won)
} else {
tracing::warn!(
lock_key = %lock_key,
"No ETag on malformed lock, falling back to unconditional overwrite"
);
client
.put_object(&lock_key, bytes::Bytes::from(lock_data))
.await?;
Ok(true)
}
}
}
}
None => {
// Lock was deleted between checks, try to acquire again
tracing::debug!(lock_key = %lock_key, "Lock disappeared, retrying");
match client
.put_if_not_exists(&lock_key, bytes::Bytes::from(lock_data))
.await
{
Ok(created) => Ok(created),
Err(e) => Err(e),
}
}
}
}
Err(e) => Err(e),
}
})
}
fn create_meta(&self, hash: SequenceHash) -> BoxFuture<'static, Result<()>> {
let client = self.client.clone();
let meta_key = self.meta_key(&hash);
Box::pin(async move {
// Create empty meta file to mark block as offloaded
client.put_object(&meta_key, bytes::Bytes::new()).await?;
tracing::debug!(meta_key = %meta_key, "Created meta file");
Ok(())
})
}
fn release_lock(&self, hash: SequenceHash) -> BoxFuture<'static, Result<()>> {
let client = self.client.clone();
let lock_key = self.lock_key(&hash);
Box::pin(async move {
client.delete_object(&lock_key).await?;
tracing::debug!(lock_key = %lock_key, "Released lock");
Ok(())
})
}
}
#[cfg(all(test, feature = "testing-s3"))]
mod s3_integration {
use super::*;
use crate::object::s3::client::s3_integration::create_test_client;
#[tokio::test]
async fn test_lock_expired_takeover_is_atomic() {
let client = Arc::new(create_test_client("test-lock-atomic").await);
let hash = SequenceHash::new(0xDEAD_BEEF_u64, None, 0);
// Create a lock manager with an already-expired timeout (1ms)
let manager_a = S3LockManager::new(client.clone(), "instance-a".into())
.with_timeout(Duration::from_millis(1));
// Acquire the lock with instance A (it will expire almost immediately)
let acquired = manager_a.try_acquire_lock(hash).await.unwrap();
assert!(acquired, "instance A should acquire lock");
// Wait for the lock to expire
tokio::time::sleep(Duration::from_millis(50)).await;
// Now race two instances trying to take over the expired lock
let client_b = client.clone();
let client_c = client.clone();
let manager_b =
S3LockManager::new(client_b, "instance-b".into()).with_timeout(Duration::from_secs(60));
let manager_c =
S3LockManager::new(client_c, "instance-c".into()).with_timeout(Duration::from_secs(60));
let (result_b, result_c) = tokio::join!(
manager_b.try_acquire_lock(hash),
manager_c.try_acquire_lock(hash),
);
let won_b = result_b.unwrap();
let won_c = result_c.unwrap();
// At most one should win. Both could fail if timing is unlucky (B wins the
// conditional put, then C sees B's non-expired lock). The key invariant is
// that they can't BOTH win.
assert!(
!(won_b && won_c),
"both instances won the lock — race condition!"
);
// Cleanup
if won_b {
manager_b.release_lock(hash).await.unwrap();
} else if won_c {
manager_c.release_lock(hash).await.unwrap();
} else {
// Neither won, clean up the expired lock
client.delete_object(&format!("{}.lock", hash)).await.ok();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! S3-compatible object storage implementations.
//!
//! This module contains all S3-specific implementations:
//! - [`S3ObjectBlockClient`] - Implements [`super::ObjectBlockOps`] for S3/MinIO
//! - [`S3LockManager`] - Implements [`super::ObjectLockManager`] for distributed locking
//!
//! All types in this module are feature-gated behind `s3`. Consumers should use
//! the factory functions in the parent [`object`](super) module to create trait
//! objects without needing to depend on the `s3` feature.
mod client;
mod lock;
pub use client::{S3Config, S3ObjectBlockClient};
pub use lock::S3LockManager;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Batch collection and accumulation for offload transfers.
//!
//! The batch collector accumulates blocks that pass policy evaluation and
//! groups them into batches for efficient transfer execution.
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, watch};
use velo::EventHandle;
use crate::{BlockId, SequenceHash};
use kvbm_logical::blocks::BlockMetadata;
use super::handle::TransferId;
use super::pending::PendingGuard;
use super::queue::CancellableQueue;
use super::source::SourceBlock;
/// Timing trace for tracking block progression through pipeline stages.
///
/// Each block carries a timing trace that records when it passed through
/// each stage. This enables per-container and batch-level timing analysis.
#[derive(Debug, Clone)]
pub struct TimingTrace {
/// When the block was initially enqueued into the pipeline
pub enqueued_at: Instant,
/// When policy evaluation completed for this block
pub policy_complete_at: Option<Instant>,
/// When the precondition (e.g., forward pass) completed
pub precondition_complete_at: Option<Instant>,
/// When the block was added to a transfer batch
pub batched_at: Option<Instant>,
/// When the transfer operation started
pub transfer_start_at: Option<Instant>,
/// When the transfer operation completed
pub transfer_complete_at: Option<Instant>,
}
impl TimingTrace {
/// Create a new timing trace, recording the current time as enqueue time.
pub fn new() -> Self {
Self {
enqueued_at: Instant::now(),
policy_complete_at: None,
precondition_complete_at: None,
batched_at: None,
transfer_start_at: None,
transfer_complete_at: None,
}
}
/// Mark policy evaluation complete.
pub fn mark_policy_complete(&mut self) {
self.policy_complete_at = Some(Instant::now());
}
/// Mark precondition complete.
pub fn mark_precondition_complete(&mut self) {
self.precondition_complete_at = Some(Instant::now());
}
/// Mark block as batched.
pub fn mark_batched(&mut self) {
self.batched_at = Some(Instant::now());
}
/// Mark transfer start.
pub fn mark_transfer_start(&mut self) {
self.transfer_start_at = Some(Instant::now());
}
/// Mark transfer complete.
pub fn mark_transfer_complete(&mut self) {
self.transfer_complete_at = Some(Instant::now());
}
/// Get total time from enqueue to transfer complete (if available).
pub fn total_duration(&self) -> Option<Duration> {
self.transfer_complete_at
.map(|end| end.duration_since(self.enqueued_at))
}
/// Get policy evaluation duration (if available).
pub fn policy_duration(&self) -> Option<Duration> {
self.policy_complete_at
.map(|end| end.duration_since(self.enqueued_at))
}
/// Get precondition wait duration (if available).
pub fn precondition_duration(&self) -> Option<Duration> {
match (self.policy_complete_at, self.precondition_complete_at) {
(Some(start), Some(end)) => Some(end.duration_since(start)),
_ => None,
}
}
/// Get transfer duration (if available).
pub fn transfer_duration(&self) -> Option<Duration> {
match (self.transfer_start_at, self.transfer_complete_at) {
(Some(start), Some(end)) => Some(end.duration_since(start)),
_ => None,
}
}
}
impl Default for TimingTrace {
fn default() -> Self {
Self::new()
}
}
/// Configuration for batch collection.
#[derive(Debug, Clone)]
pub struct BatchConfig {
/// Maximum blocks per batch
pub max_batch_size: usize,
/// Time to wait before flushing a partial batch
pub flush_interval: Duration,
/// Minimum batch size before flush (unless timeout)
pub min_batch_size: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 1024,
flush_interval: Duration::from_millis(10),
min_batch_size: 8,
}
}
}
impl BatchConfig {
/// Create a new batch config with specified max size.
pub fn with_max_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
/// Set the flush interval.
pub fn with_flush_interval(mut self, interval: Duration) -> Self {
self.flush_interval = interval;
self
}
/// Set the minimum batch size.
pub fn with_min_size(mut self, size: usize) -> Self {
self.min_batch_size = size;
self
}
}
/// A block that passed policy evaluation and is queued for transfer.
#[allow(dead_code)]
pub struct QueuedBlock<T: BlockMetadata> {
/// Transfer ID this block belongs to
pub transfer_id: TransferId,
/// Block ID - Some for External/Strong, None for Weak (determined at upgrade)
pub block_id: Option<BlockId>,
/// Sequence hash
pub sequence_hash: SequenceHash,
/// Source block - Strong/External pass through, Weak upgraded just before transfer
pub source: SourceBlock<T>,
/// Transfer state for completion tracking
pub(crate) state: Arc<std::sync::Mutex<TransferState>>,
/// RAII guard that removes this block from pending set on drop.
///
/// This ensures duplicate prevention tracking is automatically cleaned up
/// when the block completes transfer, is cancelled, or errors out.
pub pending_guard: Option<PendingGuard>,
}
impl<T: BlockMetadata> std::fmt::Debug for QueuedBlock<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueuedBlock")
.field("transfer_id", &self.transfer_id)
.field("block_id", &self.block_id)
.field("sequence_hash", &self.sequence_hash)
.finish()
}
}
/// A batch of blocks ready for transfer execution.
pub struct TransferBatch<T: BlockMetadata> {
/// Blocks in this batch
pub blocks: Vec<QueuedBlock<T>>,
/// Optional precondition event that must be satisfied before processing.
/// If Some, the pipeline will await this event before executing the transfer.
pub precondition: Option<EventHandle>,
/// Timing trace for performance monitoring (batch-level, not per-block).
pub timing: TimingTrace,
}
impl<T: BlockMetadata> TransferBatch<T> {
/// Create a new empty batch.
pub fn new() -> Self {
Self {
blocks: Vec::new(),
precondition: None,
timing: TimingTrace::new(),
}
}
/// Create with pre-allocated capacity.
pub fn with_capacity(capacity: usize) -> Self {
Self {
blocks: Vec::with_capacity(capacity),
precondition: None,
timing: TimingTrace::new(),
}
}
/// Set the precondition event for this batch.
#[allow(dead_code)]
pub fn with_precondition(mut self, precondition: EventHandle) -> Self {
self.precondition = Some(precondition);
self
}
/// Add a block to this batch.
pub fn push(&mut self, block: QueuedBlock<T>) {
self.blocks.push(block);
}
/// Get the number of blocks in this batch.
pub fn len(&self) -> usize {
self.blocks.len()
}
/// Check if batch is empty.
pub fn is_empty(&self) -> bool {
self.blocks.is_empty()
}
/// Get block IDs in this batch (only for blocks with known IDs).
///
/// Weak blocks may have `None` for block_id until upgraded.
/// The TransferExecutor resolves actual block_ids at transfer time.
#[allow(dead_code)]
pub fn block_ids(&self) -> Vec<BlockId> {
self.blocks.iter().filter_map(|b| b.block_id).collect()
}
/// Get sequence hashes in this batch.
#[allow(dead_code)]
pub fn sequence_hashes(&self) -> Vec<SequenceHash> {
self.blocks.iter().map(|b| b.sequence_hash).collect()
}
/// Get unique transfer IDs in this batch.
#[allow(dead_code)]
pub fn transfer_ids(&self) -> Vec<TransferId> {
let mut ids: Vec<TransferId> = self.blocks.iter().map(|b| b.transfer_id).collect();
ids.sort_by_key(|id| id.as_uuid());
ids.dedup();
ids
}
/// Take all blocks out of this batch.
#[allow(dead_code)]
pub fn take(&mut self) -> Vec<QueuedBlock<T>> {
std::mem::take(&mut self.blocks)
}
/// Drain blocks for the given transfer ID (for cancellation).
#[allow(dead_code)]
pub fn drain_transfer(&mut self, transfer_id: TransferId) -> Vec<QueuedBlock<T>> {
let mut kept = Vec::new();
let mut drained = Vec::new();
for block in std::mem::take(&mut self.blocks) {
if block.transfer_id == transfer_id {
drained.push(block);
} else {
kept.push(block);
}
}
self.blocks = kept;
drained
}
}
impl<T: BlockMetadata> Default for TransferBatch<T> {
fn default() -> Self {
Self::new()
}
}
use super::handle::TransferState;
/// Result of policy evaluation - blocks ready for batching.
#[allow(dead_code)]
pub struct EvalResult<T: BlockMetadata> {
/// Transfer ID
pub transfer_id: TransferId,
/// Blocks that passed all policies
pub passed_blocks: Vec<QueuedBlock<T>>,
/// Block IDs that were filtered out
pub filtered_ids: Vec<BlockId>,
/// Transfer state for completion tracking
pub(crate) state: Arc<std::sync::Mutex<TransferState>>,
}
/// Output from the batch collector to transfer executor.
pub type BatchOutput<T> = mpsc::Sender<TransferBatch<T>>;
/// Receiver side of batch output channel.
pub type BatchOutputRx<T> = mpsc::Receiver<TransferBatch<T>>;
/// Extract the common precondition from a batch of blocks.
///
/// If all blocks share the same precondition, returns it.
/// Otherwise returns `None`.
fn extract_common_precondition<T: BlockMetadata>(blocks: &[QueuedBlock<T>]) -> Option<EventHandle> {
blocks.first().and_then(|first_block| {
let first_precondition = first_block.state.lock().unwrap().precondition;
let all_same = blocks
.iter()
.all(|block| block.state.lock().unwrap().precondition == first_precondition);
if all_same { first_precondition } else { None }
})
}
/// Batch collector that accumulates blocks and flushes batches.
///
/// The collector accumulates blocks from policy evaluation (via `CancellableQueue`)
/// and groups them into batches based on the configuration. Batches are flushed when:
/// - `max_batch_size` is reached
/// - `flush_interval` expires and `min_batch_size` is met
/// - Shutdown is requested
pub struct BatchCollector<T: BlockMetadata> {
config: BatchConfig,
/// Input queue from policy evaluator
input_queue: Arc<CancellableQueue<EvalResult<T>>>,
/// Output channel to transfer executor
output_tx: BatchOutput<T>,
/// Cancel watch receiver
cancel_rx: watch::Receiver<HashSet<TransferId>>,
/// Current batch being built
current_batch: TransferBatch<T>,
}
impl<T: BlockMetadata> BatchCollector<T> {
/// Create a new batch collector.
pub fn new(
config: BatchConfig,
input_queue: Arc<CancellableQueue<EvalResult<T>>>,
output_tx: BatchOutput<T>,
cancel_rx: watch::Receiver<HashSet<TransferId>>,
) -> Self {
let max_batch_size = config.max_batch_size;
Self {
config,
input_queue,
output_tx,
cancel_rx,
current_batch: TransferBatch::with_capacity(max_batch_size),
}
}
/// Run the batch collector loop.
pub async fn run(mut self) {
let mut flush_timer = tokio::time::interval(self.config.flush_interval);
flush_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut poll_interval = tokio::time::interval(Duration::from_micros(100));
poll_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
// Poll queue for items
_ = poll_interval.tick() => {
while let Some(item) = self.input_queue.pop_valid() {
self.handle_eval_result(item.data).await;
}
}
// Periodic flush timer
_ = flush_timer.tick() => {
self.try_flush().await;
}
// Check for shutdown
result = self.cancel_rx.changed() => {
if result.is_err() {
// Channel closed, flush and exit
self.flush_if_not_empty().await;
break;
}
}
}
}
}
/// Handle an evaluation result.
///
/// Adds passed blocks to the current batch and flushes when:
/// - max_batch_size is reached, OR
/// - all blocks for a transfer have been processed (per-transfer sentinel flush)
async fn handle_eval_result(&mut self, result: EvalResult<T>) {
// Count blocks processed in this eval result (both passed and filtered)
let blocks_in_eval = result.passed_blocks.len() + result.filtered_ids.len();
// Add passed blocks to current batch
for block in result.passed_blocks {
self.current_batch.push(block);
// Flush if we've reached max batch size
if self.current_batch.len() >= self.config.max_batch_size {
self.flush().await;
}
}
// Update transfer state and check if transfer is complete (sentinel flush)
let should_flush = {
let mut state = result.state.lock().unwrap();
state.blocks_processed += blocks_in_eval;
// Flush when all blocks for this transfer have been processed
state.blocks_processed >= state.total_expected_blocks && state.total_expected_blocks > 0
};
// Flush immediately when a transfer completes to avoid waiting for min_batch_size
if should_flush && !self.current_batch.is_empty() {
tracing::debug!(
transfer_id = %result.transfer_id,
batch_size = self.current_batch.len(),
"Per-transfer sentinel flush"
);
self.flush().await;
}
}
/// Try to flush if minimum batch size is reached.
async fn try_flush(&mut self) {
if self.current_batch.len() >= self.config.min_batch_size {
self.flush().await;
}
}
/// Flush current batch if not empty.
async fn flush_if_not_empty(&mut self) {
if !self.current_batch.is_empty() {
self.flush().await;
}
}
/// Flush the current batch to the output channel.
async fn flush(&mut self) {
nvtx_range!("offload::batch");
if self.current_batch.is_empty() {
return;
}
let mut batch = std::mem::replace(
&mut self.current_batch,
TransferBatch::with_capacity(self.config.max_batch_size),
);
// Mark batch as ready (single O(1) call, not per-block)
batch.timing.mark_batched();
batch.precondition = extract_common_precondition(&batch.blocks);
// Send to transfer executor
if self.output_tx.send(batch).await.is_err() {
// Output channel closed, log and continue
tracing::warn!("Batch output channel closed");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.max_batch_size, 1024);
assert_eq!(config.min_batch_size, 8);
}
#[test]
fn test_batch_config_builder() {
let config = BatchConfig::default()
.with_max_size(128)
.with_min_size(16)
.with_flush_interval(Duration::from_millis(50));
assert_eq!(config.max_batch_size, 128);
assert_eq!(config.min_batch_size, 16);
assert_eq!(config.flush_interval, Duration::from_millis(50));
}
#[test]
fn test_transfer_batch() {
let batch: TransferBatch<()> = TransferBatch::new();
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
}
#[tokio::test]
async fn test_batch_collector_empty_input() {
let input_queue = Arc::new(CancellableQueue::<EvalResult<()>>::new());
let (output_tx, mut output_rx) = mpsc::channel::<TransferBatch<()>>(10);
let (cancel_tx, cancel_rx) = watch::channel(HashSet::new());
let collector =
BatchCollector::new(BatchConfig::default(), input_queue, output_tx, cancel_rx);
// Drop cancel sender to close channel (triggers shutdown)
drop(cancel_tx);
// Run collector
tokio::spawn(async move {
collector.run().await;
});
// Should receive nothing (empty input)
let result = tokio::time::timeout(Duration::from_millis(50), output_rx.recv()).await;
assert!(result.is_err() || result.unwrap().is_none());
}
#[test]
fn test_transfer_batch_with_capacity() {
let batch: TransferBatch<()> = TransferBatch::with_capacity(128);
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
}
#[test]
fn test_batch_config_with_methods() {
let config = BatchConfig::default()
.with_max_size(256)
.with_min_size(32)
.with_flush_interval(Duration::from_millis(100));
assert_eq!(config.max_batch_size, 256);
assert_eq!(config.min_batch_size, 32);
assert_eq!(config.flush_interval, Duration::from_millis(100));
}
#[test]
fn test_transfer_batch_methods() {
let mut batch: TransferBatch<()> = TransferBatch::new();
// Note: We can't easily create QueuedBlock without the full pipeline setup,
// so this test just verifies the batch structure methods work on empty batches
assert!(batch.block_ids().is_empty());
assert!(batch.sequence_hashes().is_empty());
assert!(batch.transfer_ids().is_empty());
// Verify take() works
let taken = batch.take();
assert!(taken.is_empty());
assert!(batch.is_empty());
}
#[test]
fn test_batch_precondition() {
let batch: TransferBatch<()> = TransferBatch::new();
assert!(batch.precondition.is_none());
// Note: with_precondition requires an EventHandle which is complex to create
// in a unit test, so we just verify the field exists and is None by default
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Cancellation protocol for offload transfers.
//!
//! The cancellation protocol ensures clean release of all blocks with confirmation
//! that no outstanding operations remain:
//!
//! 1. `cancel()` called → sets `CancelState::Requested`
//! 2. Each stage checks at safe points (between items, not during ops)
//! 3. If in-flight ops: `CancelState::Draining` → wait for completion
//! 4. Drop all `ImmutableBlock` guards → blocks released
//! 5. `CancelState::Confirmed` → `CancelConfirmation` resolves
use std::sync::Arc;
use tokio::sync::watch;
/// State of a cancellation request.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CancelState {
/// Transfer is active, not cancelled
Active,
/// Cancel requested, waiting for checkpoint
Requested,
/// Draining in-flight operations
Draining {
/// Number of in-flight operations remaining
in_flight: usize,
},
/// All operations complete, blocks released, confirmed
Confirmed,
}
impl CancelState {
/// Check if cancellation has been requested (including draining/confirmed states).
pub fn is_cancelled(&self) -> bool {
!matches!(self, CancelState::Active)
}
/// Check if we're in the draining phase.
pub fn is_draining(&self) -> bool {
matches!(self, CancelState::Draining { .. })
}
/// Check if cancellation is fully confirmed.
pub fn is_confirmed(&self) -> bool {
matches!(self, CancelState::Confirmed)
}
}
/// Token for requesting and tracking cancellation.
///
/// The token is shared between the `TransferHandle` (user-facing) and the
/// pipeline stages (internal). When `request()` is called, stages will
/// check at safe points and transition through draining to confirmed.
#[derive(Clone)]
pub struct CancellationToken {
/// Sender for cancellation requests
request_tx: Arc<watch::Sender<bool>>,
/// Receiver for cancel state updates
state_rx: watch::Receiver<CancelState>,
}
impl CancellationToken {
/// Create a new cancellation token pair.
///
/// Returns `(token, state_tx)` where:
/// - `token`: Clone and give to TransferHandle for user access
/// - `state_tx`: Keep in pipeline for updating state
pub fn new() -> (Self, CancelStateUpdater) {
let (request_tx, request_rx) = watch::channel(false);
let (state_tx, state_rx) = watch::channel(CancelState::Active);
let token = CancellationToken {
request_tx: Arc::new(request_tx),
state_rx,
};
let updater = CancelStateUpdater {
request_rx,
state_tx,
};
(token, updater)
}
/// Request cancellation.
///
/// This signals all pipeline stages to stop processing at the next safe point.
/// Returns immediately - use `wait_confirmed()` to await full confirmation.
pub fn request(&self) {
let _ = self.request_tx.send(true);
}
/// Check if cancellation has been requested.
pub fn is_requested(&self) -> bool {
*self.request_tx.borrow()
}
/// Get the current cancellation state.
pub fn state(&self) -> CancelState {
*self.state_rx.borrow()
}
/// Check if cancellation is fully confirmed.
pub fn is_confirmed(&self) -> bool {
self.state().is_confirmed()
}
/// Create a future that resolves when cancellation is confirmed.
///
/// This is the primary way to await clean release of all blocks.
pub fn wait_confirmed(&self) -> CancelConfirmation {
CancelConfirmation {
state_rx: self.state_rx.clone(),
}
}
}
/// Internal updater for cancellation state.
///
/// Held by pipeline stages to update state and check for cancel requests.
pub struct CancelStateUpdater {
/// Receiver for cancellation requests
request_rx: watch::Receiver<bool>,
/// Sender for state updates
state_tx: watch::Sender<CancelState>,
}
impl CancelStateUpdater {
/// Check if cancellation has been requested.
pub fn is_requested(&self) -> bool {
*self.request_rx.borrow()
}
/// Wait for a cancellation request (async).
pub async fn wait_for_request(&mut self) {
while !*self.request_rx.borrow() {
if self.request_rx.changed().await.is_err() {
// Channel closed, treat as cancelled
break;
}
}
}
/// Get the current state.
pub fn state(&self) -> CancelState {
*self.state_tx.borrow()
}
/// Transition to Requested state.
pub fn set_requested(&self) {
let _ = self.state_tx.send(CancelState::Requested);
}
/// Transition to Draining state with count of in-flight operations.
pub fn set_draining(&self, in_flight: usize) {
let _ = self.state_tx.send(CancelState::Draining { in_flight });
}
/// Update the in-flight count during draining.
pub fn update_draining(&self, in_flight: usize) {
if in_flight == 0 {
self.set_confirmed();
} else {
let _ = self.state_tx.send(CancelState::Draining { in_flight });
}
}
/// Transition to Confirmed state (all blocks released).
pub fn set_confirmed(&self) {
let _ = self.state_tx.send(CancelState::Confirmed);
}
/// Subscribe to state changes.
pub fn subscribe(&self) -> watch::Receiver<CancelState> {
self.state_tx.subscribe()
}
}
/// Future that resolves when cancellation is fully confirmed.
///
/// Obtained via `CancellationToken::wait_confirmed()` or `TransferHandle::cancel()`.
pub struct CancelConfirmation {
state_rx: watch::Receiver<CancelState>,
}
impl CancelConfirmation {
/// Wait for confirmation (async).
///
/// This is the recommended way to await cancellation confirmation.
pub async fn wait(mut self) {
loop {
// Check current state
if self.state_rx.borrow().is_confirmed() {
return;
}
// Wait for state change
if self.state_rx.changed().await.is_err() {
// Channel closed, treat as confirmed
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cancel_state_transitions() {
let state = CancelState::Active;
assert!(!state.is_cancelled());
assert!(!state.is_draining());
assert!(!state.is_confirmed());
let state = CancelState::Requested;
assert!(state.is_cancelled());
assert!(!state.is_draining());
assert!(!state.is_confirmed());
let state = CancelState::Draining { in_flight: 5 };
assert!(state.is_cancelled());
assert!(state.is_draining());
assert!(!state.is_confirmed());
let state = CancelState::Confirmed;
assert!(state.is_cancelled());
assert!(!state.is_draining());
assert!(state.is_confirmed());
}
#[test]
fn test_cancellation_token_request() {
let (token, _updater) = CancellationToken::new();
assert!(!token.is_requested());
assert_eq!(token.state(), CancelState::Active);
token.request();
assert!(token.is_requested());
}
#[test]
fn test_cancellation_updater_state() {
let (token, updater) = CancellationToken::new();
assert_eq!(token.state(), CancelState::Active);
updater.set_requested();
assert_eq!(token.state(), CancelState::Requested);
updater.set_draining(3);
assert_eq!(token.state(), CancelState::Draining { in_flight: 3 });
updater.update_draining(1);
assert_eq!(token.state(), CancelState::Draining { in_flight: 1 });
updater.update_draining(0);
assert_eq!(token.state(), CancelState::Confirmed);
}
#[tokio::test]
async fn test_cancel_confirmation_immediate() {
let (token, updater) = CancellationToken::new();
// Set confirmed before waiting
updater.set_confirmed();
// Should resolve immediately
token.wait_confirmed().wait().await;
assert!(token.is_confirmed());
}
#[tokio::test]
async fn test_cancel_confirmation_delayed() {
let (token, updater) = CancellationToken::new();
let confirmation = token.wait_confirmed();
// Spawn task to confirm after short delay
let updater_clone = updater.state_tx.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let _ = updater_clone.send(CancelState::Confirmed);
});
// Wait for confirmation
tokio::time::timeout(tokio::time::Duration::from_millis(100), confirmation.wait())
.await
.expect("Should complete within timeout");
assert!(token.is_confirmed());
}
/// Test that confirmation does NOT resolve while in-flight > 0.
/// This is a critical invariant: cancellation only completes after draining.
#[tokio::test]
async fn test_confirmation_blocked_during_draining() {
let (token, updater) = CancellationToken::new();
token.request();
updater.set_draining(2);
// Confirmation should NOT resolve while draining
let confirmation = token.wait_confirmed();
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(30), confirmation.wait()).await;
assert!(result.is_err(), "Should timeout while in_flight > 0");
// Still draining
assert_eq!(token.state(), CancelState::Draining { in_flight: 2 });
}
/// Test that update_draining(0) transitions directly to Confirmed.
#[test]
fn test_draining_zero_confirms() {
let (token, updater) = CancellationToken::new();
token.request();
updater.set_draining(1);
assert_eq!(token.state(), CancelState::Draining { in_flight: 1 });
// Drain to 0 should confirm
updater.update_draining(0);
assert_eq!(token.state(), CancelState::Confirmed);
}
/// Test the full draining sequence: Requested → Draining(n) → ... → Confirmed.
#[test]
fn test_full_draining_sequence() {
let (token, updater) = CancellationToken::new();
// Start active
assert_eq!(token.state(), CancelState::Active);
// Request
token.request();
assert!(token.is_requested());
// Set draining
updater.set_draining(3);
assert_eq!(token.state(), CancelState::Draining { in_flight: 3 });
// Drain one by one
updater.update_draining(2);
assert_eq!(token.state(), CancelState::Draining { in_flight: 2 });
updater.update_draining(1);
assert_eq!(token.state(), CancelState::Draining { in_flight: 1 });
// Final drain confirms
updater.update_draining(0);
assert!(token.is_confirmed());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Comprehensive cancellation tests for the offload pipeline.
//!
//! These tests verify the cancellation invariants documented in README.md:
//! - P1: Container is the unit of cancellation
//! - P2: Token travels with container
//! - P3: Upgrade is the commitment boundary
//! - P4: Sweep before upgrade
//!
//! Key invariant: Cancellation is only confirmed when:
//! 1. All source block lists are removed from queues
//! 2. All in-flight transfers have completed
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Barrier;
use crate::offload::cancel::{CancelState, CancellationToken};
use crate::offload::handle::TransferId;
use crate::offload::queue::CancellableQueue;
// =========================================================================
// Draining Invariant Tests
// =========================================================================
/// Test that confirmation does NOT resolve while in-flight transfers remain.
#[tokio::test]
async fn test_confirmation_waits_for_in_flight_to_drain() {
let (token, updater) = CancellationToken::new();
// Request cancellation
token.request();
assert!(token.is_requested());
// Set draining with 3 in-flight
updater.set_draining(3);
assert_eq!(token.state(), CancelState::Draining { in_flight: 3 });
// Confirmation should NOT resolve yet
let confirmation = token.wait_confirmed();
let result = tokio::time::timeout(Duration::from_millis(50), confirmation.wait()).await;
assert!(
result.is_err(),
"Confirmation should timeout while in-flight > 0"
);
// Drain to 1
updater.update_draining(1);
assert_eq!(token.state(), CancelState::Draining { in_flight: 1 });
// Still should not resolve
let confirmation = token.wait_confirmed();
let result = tokio::time::timeout(Duration::from_millis(50), confirmation.wait()).await;
assert!(
result.is_err(),
"Confirmation should timeout while in-flight > 0"
);
// Drain to 0 - this should trigger confirmation
updater.update_draining(0);
assert_eq!(token.state(), CancelState::Confirmed);
// Now confirmation should resolve immediately
let confirmation = token.wait_confirmed();
tokio::time::timeout(Duration::from_millis(50), confirmation.wait())
.await
.expect("Confirmation should resolve when in-flight = 0");
}
/// Test that draining countdown correctly transitions to confirmed.
#[tokio::test]
async fn test_draining_countdown_to_confirmation() {
let (token, updater) = CancellationToken::new();
let in_flight = Arc::new(AtomicUsize::new(5));
let in_flight_clone = in_flight.clone();
token.request();
updater.set_draining(in_flight.load(Ordering::SeqCst));
// Spawn task to simulate transfers completing
// We need to move updater into the spawned task
tokio::spawn(async move {
for _ in 0..5 {
tokio::time::sleep(Duration::from_millis(10)).await;
let remaining = in_flight_clone.fetch_sub(1, Ordering::SeqCst) - 1;
updater.update_draining(remaining);
}
});
// Wait for confirmation
let confirmation = token.wait_confirmed();
tokio::time::timeout(Duration::from_millis(200), confirmation.wait())
.await
.expect("Should confirm after all in-flight complete");
assert!(token.is_confirmed());
assert_eq!(in_flight.load(Ordering::SeqCst), 0);
}
/// Test concurrent cancellation requests are idempotent.
#[tokio::test]
async fn test_concurrent_cancel_requests() {
let (token, updater) = CancellationToken::new();
let barrier = Arc::new(Barrier::new(3));
// Spawn multiple tasks requesting cancellation
let token1 = token.clone();
let barrier1 = barrier.clone();
let t1 = tokio::spawn(async move {
barrier1.wait().await;
token1.request();
});
let token2 = token.clone();
let barrier2 = barrier.clone();
let t2 = tokio::spawn(async move {
barrier2.wait().await;
token2.request();
});
barrier.wait().await;
token.request();
t1.await.unwrap();
t2.await.unwrap();
// Should still be requested (idempotent)
assert!(token.is_requested());
// Confirm and verify
updater.set_confirmed();
assert!(token.is_confirmed());
}
// =========================================================================
// Token-Based Cancellation Tests
// =========================================================================
/// Container that carries its own CancellationToken.
struct MockContainer {
id: usize,
cancel_token: CancellationToken,
}
impl MockContainer {
fn new(id: usize, token: CancellationToken) -> Self {
Self {
id,
cancel_token: token,
}
}
fn is_cancelled(&self) -> bool {
self.cancel_token.is_requested()
}
}
/// Test that container carries its own token and can check cancellation.
#[test]
fn test_container_carries_token() {
let (token, _updater) = CancellationToken::new();
let container = MockContainer::new(1, token.clone());
assert!(!container.is_cancelled());
// Cancel via the original token
token.request();
// Container should see cancellation via its cloned token
assert!(container.is_cancelled());
}
/// Test multiple containers sharing same token (from same TransferHandle).
#[test]
fn test_multiple_containers_same_token() {
let (token, _updater) = CancellationToken::new();
let c1 = MockContainer::new(1, token.clone());
let c2 = MockContainer::new(2, token.clone());
let c3 = MockContainer::new(3, token.clone());
assert!(!c1.is_cancelled());
assert!(!c2.is_cancelled());
assert!(!c3.is_cancelled());
// Cancel via handle's token
token.request();
// All containers should see cancellation
assert!(c1.is_cancelled());
assert!(c2.is_cancelled());
assert!(c3.is_cancelled());
}
/// Test containers from different handles have independent cancellation.
#[test]
fn test_independent_container_cancellation() {
let (token1, _updater1) = CancellationToken::new();
let (token2, _updater2) = CancellationToken::new();
let c1 = MockContainer::new(1, token1.clone());
let c2 = MockContainer::new(2, token2.clone());
// Cancel only token1
token1.request();
assert!(c1.is_cancelled());
assert!(!c2.is_cancelled());
}
// =========================================================================
// Queue + Token Integration Tests
// =========================================================================
/// Wrapper that includes a CancellationToken for queue testing.
struct TokenWrapper {
data: i32,
cancel_token: CancellationToken,
}
/// Test queue sweep using token-based cancellation check.
#[test]
fn test_queue_sweep_with_token_check() {
let queue: CancellableQueue<TokenWrapper> = CancellableQueue::new();
let (token1, _) = CancellationToken::new();
let (token2, _) = CancellationToken::new();
let id1 = TransferId::new();
let id2 = TransferId::new();
// Push items with different tokens
queue.push(
id1,
TokenWrapper {
data: 1,
cancel_token: token1.clone(),
},
);
queue.push(
id2,
TokenWrapper {
data: 2,
cancel_token: token2.clone(),
},
);
queue.push(
id1,
TokenWrapper {
data: 3,
cancel_token: token1.clone(),
},
);
assert_eq!(queue.len_approx(), 3);
// Cancel token1 (and mark in queue for sweep)
token1.request();
queue.mark_cancelled(id1);
// Sweep should remove token1's items
let removed = queue.sweep();
assert_eq!(removed, 2);
assert_eq!(queue.len_approx(), 1);
// Remaining item should be from token2
let item = queue.pop().unwrap();
assert_eq!(item.data.data, 2);
assert!(!item.data.cancel_token.is_requested());
}
// =========================================================================
// Batch Partial Cancellation Tests
// =========================================================================
/// Mock batch of containers for testing partial cancellation.
struct MockBatch {
containers: Vec<MockContainer>,
}
impl MockBatch {
fn new(containers: Vec<MockContainer>) -> Self {
Self { containers }
}
/// Remove cancelled containers, return count removed.
fn sweep_cancelled(&mut self) -> usize {
let before = self.containers.len();
self.containers.retain(|c| !c.is_cancelled());
before - self.containers.len()
}
fn len(&self) -> usize {
self.containers.len()
}
fn is_empty(&self) -> bool {
self.containers.is_empty()
}
}
/// Test partial batch cancellation - some containers cancelled, others proceed.
#[test]
fn test_batch_partial_cancellation() {
let (token1, _updater1) = CancellationToken::new();
let (token2, _updater2) = CancellationToken::new();
let (token3, _updater3) = CancellationToken::new();
// Create container with cloned token
let c1 = MockContainer::new(1, token1.clone());
let c2 = MockContainer::new(2, token2.clone());
let c3 = MockContainer::new(3, token3.clone());
let c4 = MockContainer::new(4, token1.clone()); // Same token as c1
// Verify tokens work before batching
assert!(!c1.is_cancelled());
assert!(!c4.is_cancelled());
let mut batch = MockBatch::new(vec![c1, c2, c3, c4]);
assert_eq!(batch.len(), 4);
// Cancel token1 (affects containers 1 and 4)
token1.request();
assert!(token1.is_requested());
// Verify containers in batch see the cancellation
assert!(
batch.containers[0].is_cancelled(),
"Container 1 should be cancelled"
);
assert!(
!batch.containers[1].is_cancelled(),
"Container 2 should NOT be cancelled"
);
assert!(
!batch.containers[2].is_cancelled(),
"Container 3 should NOT be cancelled"
);
assert!(
batch.containers[3].is_cancelled(),
"Container 4 should be cancelled"
);
let removed = batch.sweep_cancelled();
assert_eq!(removed, 2);
assert_eq!(batch.len(), 2);
// Remaining containers should be 2 and 3
assert_eq!(batch.containers[0].id, 2);
assert_eq!(batch.containers[1].id, 3);
}
/// Test batch where all containers are cancelled.
#[test]
fn test_batch_full_cancellation() {
let (token, _updater) = CancellationToken::new();
// Create containers with cloned tokens
let c1 = MockContainer::new(1, token.clone());
let c2 = MockContainer::new(2, token.clone());
let c3 = MockContainer::new(3, token.clone());
// Verify token clone works
assert!(!c1.is_cancelled());
assert!(!c2.is_cancelled());
assert!(!c3.is_cancelled());
let mut batch = MockBatch::new(vec![c1, c2, c3]);
token.request();
assert!(token.is_requested());
// Verify containers see cancellation
assert!(
batch.containers[0].is_cancelled(),
"Container 1 should be cancelled"
);
assert!(
batch.containers[1].is_cancelled(),
"Container 2 should be cancelled"
);
assert!(
batch.containers[2].is_cancelled(),
"Container 3 should be cancelled"
);
let removed = batch.sweep_cancelled();
assert_eq!(removed, 3);
assert!(batch.is_empty());
}
/// Test batch where no containers are cancelled.
#[test]
fn test_batch_no_cancellation() {
let (token1, _updater1) = CancellationToken::new();
let (token2, _updater2) = CancellationToken::new();
let mut batch = MockBatch::new(vec![
MockContainer::new(1, token1.clone()),
MockContainer::new(2, token2.clone()),
]);
// Don't cancel anything
let removed = batch.sweep_cancelled();
assert_eq!(removed, 0);
assert_eq!(batch.len(), 2);
}
// =========================================================================
// Select-Based Cancellation Tests
// =========================================================================
/// Simulate precondition awaiter with select on event OR cancel.
#[tokio::test]
async fn test_select_cancellation_during_wait() {
let (token, _updater) = CancellationToken::new();
let (event_tx, event_rx) = tokio::sync::oneshot::channel::<()>();
// Verify initial state
assert!(!token.is_requested());
let token_clone = token.clone();
let result = tokio::spawn(async move {
// Poll-based cancellation check with timeout
let cancel_check = async {
loop {
if token_clone.is_requested() {
return;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
};
tokio::select! {
biased; // Prefer first branch to complete
_ = event_rx => {
"event"
}
_ = cancel_check => {
"cancelled"
}
}
});
// Give the task time to start and enter select
tokio::time::sleep(Duration::from_millis(20)).await;
// Cancel (don't send event)
token.request();
assert!(token.is_requested());
let outcome = tokio::time::timeout(Duration::from_millis(200), result)
.await
.expect("Should complete within timeout")
.expect("Task should not panic");
assert_eq!(outcome, "cancelled");
// Event sender still exists - wasn't used
drop(event_tx);
}
/// Test that event completes before cancellation.
#[tokio::test]
async fn test_select_event_before_cancel() {
let (token, _) = CancellationToken::new();
let (event_tx, event_rx) = tokio::sync::oneshot::channel::<()>();
let token_clone = token.clone();
let result = tokio::spawn(async move {
tokio::select! {
_ = event_rx => {
"event"
}
_ = async {
loop {
if token_clone.is_requested() {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
} => {
"cancelled"
}
}
});
// Give the task time to start
tokio::time::sleep(Duration::from_millis(10)).await;
// Send event (before cancellation)
event_tx.send(()).unwrap();
let outcome = tokio::time::timeout(Duration::from_millis(100), result)
.await
.expect("Should complete")
.expect("Should not panic");
assert_eq!(outcome, "event");
assert!(!token.is_requested());
}
// =========================================================================
// End-to-End Cancellation Flow Tests
// =========================================================================
/// Test complete cancellation flow: request → sweep → drain → confirm.
#[tokio::test]
async fn test_end_to_end_cancellation_flow() {
let (token, updater) = CancellationToken::new();
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
// Simulate: 3 items in queue, 2 in-flight
queue.push(id, 1);
queue.push(id, 2);
queue.push(id, 3);
let in_flight = Arc::new(AtomicUsize::new(2));
// Request cancellation
token.request();
assert!(token.is_requested());
// Mark cancelled in queue
queue.mark_cancelled(id);
// Sweep queue
let removed = queue.sweep();
assert_eq!(removed, 3);
assert_eq!(queue.len_approx(), 0);
// Set draining for in-flight
updater.set_draining(in_flight.load(Ordering::SeqCst));
assert!(token.state().is_draining());
// Simulate in-flight completing
let in_flight_clone = in_flight.clone();
tokio::spawn(async move {
for _ in 0..2 {
tokio::time::sleep(Duration::from_millis(10)).await;
let remaining = in_flight_clone.fetch_sub(1, Ordering::SeqCst) - 1;
updater.update_draining(remaining);
}
});
// Wait for confirmation
let confirmation = token.wait_confirmed();
tokio::time::timeout(Duration::from_millis(100), confirmation.wait())
.await
.expect("Should confirm after queue swept and in-flight drained");
assert!(token.is_confirmed());
assert_eq!(queue.len_approx(), 0);
assert_eq!(in_flight.load(Ordering::SeqCst), 0);
}
/// Test cancellation with nothing in-flight (immediate confirmation after sweep).
#[tokio::test]
async fn test_cancellation_nothing_in_flight() {
let (token, updater) = CancellationToken::new();
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
// Items only in queue, nothing in-flight
queue.push(id, 1);
queue.push(id, 2);
// Request and sweep
token.request();
queue.mark_cancelled(id);
let removed = queue.sweep();
assert_eq!(removed, 2);
// No in-flight, go directly to confirmed
updater.update_draining(0); // This sets Confirmed when in_flight = 0
assert!(token.is_confirmed());
// Confirmation should resolve immediately
let confirmation = token.wait_confirmed();
tokio::time::timeout(Duration::from_millis(10), confirmation.wait())
.await
.expect("Should confirm immediately with nothing in-flight");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Main offload engine coordinating pipelines.
//!
//! The `OffloadEngine` is a standalone component that manages block offloading
//! between storage tiers (G1→G2, G2→G3, G2→G4).
//!
//! # Example
//! ```ignore
//! let engine = OffloadEngine::builder(leader.clone())
//! .with_registry(registry.clone())
//! .with_g1_to_g2_pipeline(
//! PipelineBuilder::<G1, G2>::new()
//! .policy(Arc::new(PresenceFilter::new(registry.clone())))
//! .batch_size(32)
//! .auto_chain(true)
//! .build()
//! )
//! .with_g2_to_g3_pipeline(
//! PipelineBuilder::<G2, G3>::new()
//! .policy(Arc::new(PresenceAndLFUFilter::with_default_threshold(registry.clone())))
//! .batch_size(64)
//! .build()
//! )
//! .build()?;
//!
//! let handle = engine.enqueue_g2_to_g3(blocks);
//! handle.wait().await?;
//! ```
use std::sync::Arc;
use anyhow::Result;
use dashmap::DashMap;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::leader::InstanceLeader;
use crate::object::ObjectBlockOps;
use crate::worker::RemoteDescriptor;
use crate::{BlockId, G1, G2, G3, SequenceHash};
use kvbm_common::LogicalLayoutHandle;
use kvbm_logical::blocks::{BlockMetadata, BlockRegistry, WeakBlock};
use kvbm_logical::manager::BlockManager;
use kvbm_physical::transfer::{PhysicalLayout, TransferOptions};
use super::handle::{TransferHandle, TransferId, TransferState};
use super::pipeline::{
ChainOutput, ChainOutputRx, ObjectPipeline, ObjectPipelineConfig, Pipeline, PipelineConfig,
PipelineInput,
};
use super::queue::CancellableQueue;
use super::source::SourceBlocks;
/// Central coordinator for offload pipelines.
///
/// The engine manages multiple pipelines (G1→G2, G2→G3, G2→G4) and provides
/// a unified interface for enqueueing blocks for offload.
///
/// # Storage Tier Model
///
/// - G1→G2: `BlockManager<G2>` destination (host memory)
/// - G2→G3: `BlockManager<G3>` destination (disk/NVMe)
/// - G2→G4: ObjectBlockOps destination (object storage like S3)
///
/// # Distributed G2→G4 Offloading
///
/// For distributed setups where the leader doesn't have physical layouts (only workers do),
/// use `with_enable_remote_g4(true)` instead of `with_g2_to_g4_pipeline()`. This enables
/// remote G4 offloading where workers execute object storage uploads via their local
/// `ObjectBlockOps` implementations.
#[allow(dead_code)]
pub struct OffloadEngine {
/// Reference to the instance leader for transfers
leader: Arc<InstanceLeader>,
/// Block registry for policy evaluation
registry: Arc<BlockRegistry>,
/// G1→G2 pipeline (BlockManager destination)
g1_to_g2: Option<Pipeline<G1, G2>>,
/// G2→G3 pipeline (BlockManager destination)
g2_to_g3: Option<Pipeline<G2, G3>>,
/// G2→G4 pipeline (Object storage destination) - for local mode only
g2_to_g4: Option<ObjectPipeline<G2>>,
/// Active transfer tracking
transfers: Arc<DashMap<TransferId, Arc<std::sync::Mutex<TransferState>>>>,
/// Chain router task handle (routes G1→G2 output to downstream pipelines)
_chain_router_handle: Option<JoinHandle<()>>,
/// Remote G4 offload task handle (for distributed mode)
_remote_g4_offload_handle: Option<JoinHandle<()>>,
}
impl OffloadEngine {
/// Create a new builder for the offload engine.
pub fn builder(leader: Arc<InstanceLeader>) -> OffloadEngineBuilder {
OffloadEngineBuilder::new(leader)
}
/// Enqueue blocks for G1→G2 offload.
///
/// Returns a `TransferHandle` for tracking progress and cancellation.
pub fn enqueue_g1_to_g2(&self, blocks: impl Into<SourceBlocks<G1>>) -> Result<TransferHandle> {
let pipeline = self
.g1_to_g2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("G1→G2 pipeline not configured"))?;
self.enqueue_to_pipeline(pipeline, blocks.into())
}
/// Enqueue blocks for G1→G2 offload with a precondition event.
///
/// The precondition event must be satisfied before the batch is processed
/// by the transfer executor. This enables coordination with worker forward passes.
///
/// Returns a `TransferHandle` for tracking progress and cancellation.
pub fn enqueue_g1_to_g2_with_precondition(
&self,
blocks: impl Into<SourceBlocks<G1>>,
precondition: Option<velo::EventHandle>,
) -> Result<TransferHandle> {
let pipeline = self
.g1_to_g2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("G1→G2 pipeline not configured"))?;
self.enqueue_to_pipeline_with_precondition(pipeline, blocks.into(), precondition)
}
/// Enqueue blocks for G2→G3 offload.
///
/// Returns a `TransferHandle` for tracking progress and cancellation.
pub fn enqueue_g2_to_g3(&self, blocks: impl Into<SourceBlocks<G2>>) -> Result<TransferHandle> {
let pipeline = self
.g2_to_g3
.as_ref()
.ok_or_else(|| anyhow::anyhow!("G2→G3 pipeline not configured"))?;
self.enqueue_to_pipeline(pipeline, blocks.into())
}
/// Enqueue blocks for G2→G4 offload (object storage).
///
/// Returns a `TransferHandle` for tracking progress and cancellation.
pub fn enqueue_g2_to_g4(&self, blocks: impl Into<SourceBlocks<G2>>) -> Result<TransferHandle> {
let pipeline = self
.g2_to_g4
.as_ref()
.ok_or_else(|| anyhow::anyhow!("G2→G4 pipeline not configured"))?;
self.enqueue_to_object_pipeline(pipeline, blocks.into())
}
/// Create transfer state, store it, and return the components needed for enqueueing.
fn create_transfer<T: BlockMetadata>(
&self,
source: &SourceBlocks<T>,
) -> (
TransferId,
Arc<std::sync::Mutex<TransferState>>,
TransferHandle,
) {
let input_block_ids = self.extract_block_ids(source);
let transfer_id = TransferId::new();
let (state, handle) = TransferState::new(transfer_id, input_block_ids);
let state = Arc::new(std::sync::Mutex::new(state));
self.transfers.insert(transfer_id, state.clone());
(transfer_id, state, handle)
}
/// Internal: enqueue to a specific pipeline.
fn enqueue_to_pipeline<Src: BlockMetadata, Dst: BlockMetadata>(
&self,
pipeline: &Pipeline<Src, Dst>,
source: SourceBlocks<Src>,
) -> Result<TransferHandle> {
let (transfer_id, state, handle) = self.create_transfer(&source);
if !pipeline.enqueue(transfer_id, source, state) {
tracing::warn!("Transfer {} was cancelled before enqueueing", transfer_id);
}
Ok(handle)
}
/// Internal: enqueue to a specific pipeline with a precondition.
fn enqueue_to_pipeline_with_precondition<Src: BlockMetadata, Dst: BlockMetadata>(
&self,
pipeline: &Pipeline<Src, Dst>,
source: SourceBlocks<Src>,
precondition: Option<velo::EventHandle>,
) -> Result<TransferHandle> {
let (transfer_id, state, handle) = self.create_transfer(&source);
state.lock().unwrap().precondition = precondition;
if !pipeline.enqueue(transfer_id, source, state) {
tracing::warn!("Transfer {} was cancelled before enqueueing", transfer_id);
}
Ok(handle)
}
/// Internal: enqueue to an object pipeline (G2→G4).
fn enqueue_to_object_pipeline(
&self,
pipeline: &ObjectPipeline<G2>,
source: SourceBlocks<G2>,
) -> Result<TransferHandle> {
let (transfer_id, state, handle) = self.create_transfer(&source);
if !pipeline.enqueue(transfer_id, source, state) {
tracing::warn!("Transfer {} was cancelled before enqueueing", transfer_id);
}
Ok(handle)
}
/// Extract block IDs from source blocks.
///
/// For External/Strong blocks, returns the known block IDs.
/// For Weak blocks, returns empty vec (IDs determined at upgrade time).
fn extract_block_ids<T: BlockMetadata>(&self, source: &SourceBlocks<T>) -> Vec<BlockId> {
match source {
SourceBlocks::External(blocks) => blocks.iter().map(|b| b.block_id).collect(),
SourceBlocks::Strong(blocks) => blocks.iter().map(|b| b.block_id()).collect(),
SourceBlocks::Weak(_) => Vec::new(), // IDs not available without upgrade
}
}
/// Release a completed transfer's resources.
///
/// This is optional - transfers are automatically cleaned up,
/// but call this to release resources earlier.
pub fn release_transfer(&self, transfer_id: TransferId) {
self.transfers.remove(&transfer_id);
}
/// Get the number of active transfers.
pub fn active_transfer_count(&self) -> usize {
self.transfers.len()
}
/// Check if G1→G2 pipeline is configured.
pub fn has_g1_to_g2(&self) -> bool {
self.g1_to_g2.is_some()
}
/// Check if G2→G3 pipeline is configured.
pub fn has_g2_to_g3(&self) -> bool {
self.g2_to_g3.is_some()
}
/// Check if G2→G4 pipeline is configured.
pub fn has_g2_to_g4(&self) -> bool {
self.g2_to_g4.is_some()
}
}
/// Builder for OffloadEngine.
pub struct OffloadEngineBuilder {
leader: Arc<InstanceLeader>,
registry: Option<Arc<BlockRegistry>>,
g1_manager: Option<Arc<BlockManager<G1>>>,
g2_manager: Option<Arc<BlockManager<G2>>>,
g3_manager: Option<Arc<BlockManager<G3>>>,
/// Object storage operations for G4 (replaces `BlockManager<G4>`)
object_ops: Option<Arc<dyn ObjectBlockOps>>,
/// G2 physical layout for object transfers (needed by ObjectTransferExecutor)
g2_physical_layout: Option<PhysicalLayout>,
g1_to_g2_config: Option<PipelineConfig<G1, G2>>,
g2_to_g3_config: Option<PipelineConfig<G2, G3>>,
/// G2→G4 uses ObjectPipelineConfig (no destination BlockManager)
g2_to_g4_config: Option<ObjectPipelineConfig<G2>>,
/// Optional runtime handle override (defaults to leader.runtime())
runtime: Option<tokio::runtime::Handle>,
/// Enable remote G4 offloading via workers' ObjectBlockOps (for distributed mode)
enable_remote_g4: bool,
}
impl OffloadEngineBuilder {
/// Create a new builder with the given instance leader.
pub fn new(leader: Arc<InstanceLeader>) -> Self {
Self {
leader,
registry: None,
g1_manager: None,
g2_manager: None,
g3_manager: None,
object_ops: None,
g2_physical_layout: None,
g1_to_g2_config: None,
g2_to_g3_config: None,
g2_to_g4_config: None,
runtime: None,
enable_remote_g4: false,
}
}
/// Set an explicit runtime handle for spawning pipeline tasks.
///
/// If not set, defaults to `leader.runtime()`. Use this when you need
/// pipeline tasks to run on a specific runtime (e.g., in tests).
pub fn with_runtime(mut self, runtime: tokio::runtime::Handle) -> Self {
self.runtime = Some(runtime);
self
}
/// Set the block registry.
pub fn with_registry(mut self, registry: Arc<BlockRegistry>) -> Self {
self.registry = Some(registry);
self
}
/// Set the G1 block manager.
pub fn with_g1_manager(mut self, manager: Arc<BlockManager<G1>>) -> Self {
self.g1_manager = Some(manager);
self
}
/// Set the G2 block manager.
pub fn with_g2_manager(mut self, manager: Arc<BlockManager<G2>>) -> Self {
self.g2_manager = Some(manager);
self
}
/// Set the G3 block manager.
pub fn with_g3_manager(mut self, manager: Arc<BlockManager<G3>>) -> Self {
self.g3_manager = Some(manager);
self
}
/// Set object storage operations for G4.
///
/// G4 is object storage (S3, MinIO, etc.) and uses `ObjectBlockOps`
/// instead of a `BlockManager`. This replaces `with_g4_manager`.
pub fn with_object_ops(mut self, object_ops: Arc<dyn ObjectBlockOps>) -> Self {
self.object_ops = Some(object_ops);
self
}
/// Set the G2 physical layout for object transfers.
///
/// Required when using G2→G4 pipeline. The ObjectTransferExecutor needs
/// the physical layout to read block data for upload to object storage.
pub fn with_g2_physical_layout(mut self, layout: PhysicalLayout) -> Self {
self.g2_physical_layout = Some(layout);
self
}
/// Configure G1→G2 pipeline.
pub fn with_g1_to_g2_pipeline(mut self, config: PipelineConfig<G1, G2>) -> Self {
self.g1_to_g2_config = Some(config);
self
}
/// Configure G2→G3 pipeline.
pub fn with_g2_to_g3_pipeline(mut self, config: PipelineConfig<G2, G3>) -> Self {
self.g2_to_g3_config = Some(config);
self
}
/// Configure G2→G4 pipeline (object storage).
///
/// Uses `ObjectPipelineConfig` instead of `PipelineConfig` since G4
/// is object storage, not a BlockManager destination.
///
/// For distributed setups where the leader doesn't have physical layouts,
/// use `with_enable_remote_g4(true)` instead.
pub fn with_g2_to_g4_pipeline(mut self, config: ObjectPipelineConfig<G2>) -> Self {
self.g2_to_g4_config = Some(config);
self
}
/// Enable remote G4 offloading via workers' ObjectBlockOps.
///
/// In distributed setups, the leader doesn't have physical layouts (only workers do).
/// This enables G2→G4 offloading where:
/// 1. G1→G2 chain output is routed to a remote offload task
/// 2. The task calls workers' ObjectBlockOps::put_blocks() via RPC
/// 3. Workers upload blocks from their local G2 to object storage
/// 4. Per-block results are returned and logged
///
/// This is mutually exclusive with `with_g2_to_g4_pipeline()` - use one or the other.
pub fn with_enable_remote_g4(mut self, enable: bool) -> Self {
self.enable_remote_g4 = enable;
self
}
/// Build the offload engine.
pub fn build(self) -> Result<OffloadEngine> {
let registry = self
.registry
.ok_or_else(|| anyhow::anyhow!("Block registry required"))?;
// Get the runtime handle for spawning background tasks
// Use explicit override if provided, otherwise get from leader
let runtime = self.runtime.unwrap_or_else(|| self.leader.runtime());
// Build G1→G2 pipeline if configured
// Note: G1 is externally owned (vLLM GPU cache), so no G1 manager needed.
// Pipeline works with ExternalBlock<G1> which contains block_id + sequence_hash.
let mut g1_to_g2 = if let Some(config) = self.g1_to_g2_config {
let g2_manager = self
.g2_manager
.clone()
.ok_or_else(|| anyhow::anyhow!("G2 manager required for G1→G2 pipeline"))?;
Some(Pipeline::new(
config,
registry.clone(),
g2_manager,
self.leader.clone(),
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
runtime.clone(),
))
} else {
None
};
// Build G2→G3 pipeline if configured
let g2_to_g3 = if let Some(config) = self.g2_to_g3_config {
let g3_manager = self
.g3_manager
.ok_or_else(|| anyhow::anyhow!("G3 manager required for G2→G3 pipeline"))?;
Some(Pipeline::new(
config,
registry.clone(),
g3_manager,
self.leader.clone(),
LogicalLayoutHandle::G2,
LogicalLayoutHandle::G3,
runtime.clone(),
))
} else {
None
};
// Build G2→G4 pipeline if configured (object storage destination)
// Note: For distributed mode, use enable_remote_g4 instead
let g2_to_g4 = if let Some(config) = self.g2_to_g4_config {
let object_ops = self
.object_ops
.ok_or_else(|| anyhow::anyhow!("ObjectBlockOps required for G2→G4 pipeline"))?;
// ObjectPipeline takes LogicalLayoutHandle - the ObjectBlockOps implementation
// resolves this to a physical layout internally
Some(ObjectPipeline::new(
config,
object_ops,
LogicalLayoutHandle::G2,
self.leader.clone(),
runtime.clone(),
))
} else {
None
};
// Create channel for remote G4 offload if enabled
let (remote_g4_tx, remote_g4_rx) = if self.enable_remote_g4 {
let (tx, rx) = mpsc::channel::<RemoteG4OffloadRequest>(64);
(Some(tx), Some(rx))
} else {
(None, None)
};
// Wire up auto-chaining from G1→G2 to downstream G2→G3/G2→G4 pipelines
let chain_router_handle = if let Some(ref mut g1_to_g2_pipeline) = g1_to_g2 {
if g1_to_g2_pipeline.auto_chain() {
if let Some(chain_rx) = g1_to_g2_pipeline.take_chain_rx() {
// Get references to downstream pipeline queues
let g2_to_g3_queue = g2_to_g3.as_ref().map(|p| p.eval_queue.clone());
let g2_to_g4_queue = g2_to_g4.as_ref().map(|p| p.eval_queue.clone());
// Check if we have any downstream target (local pipelines or remote G4)
let has_g2_to_g4_local = g2_to_g4_queue.is_some();
let has_g2_to_g4_remote = remote_g4_tx.is_some();
// Only spawn if there's at least one downstream target
if g2_to_g3_queue.is_some() || has_g2_to_g4_local || has_g2_to_g4_remote {
tracing::debug!(
has_g2_to_g3 = g2_to_g3_queue.is_some(),
has_g2_to_g4_local,
has_g2_to_g4_remote,
"Spawning chain router for G1→G2 auto-chaining"
);
Some(runtime.spawn(chain_router_task(
chain_rx,
g2_to_g3_queue,
g2_to_g4_queue,
remote_g4_tx,
)))
} else {
tracing::debug!(
"G1→G2 auto_chain enabled but no downstream pipelines configured"
);
None
}
} else {
None
}
} else {
None
}
} else {
None
};
// Spawn remote G4 offload task if enabled
let remote_g4_offload_handle = if let Some(rx) = remote_g4_rx {
tracing::info!("Enabling remote G4 offload via workers' ObjectBlockOps");
Some(runtime.spawn(remote_g4_offload_task(rx, self.leader.clone())))
} else {
None
};
Ok(OffloadEngine {
leader: self.leader,
registry,
g1_to_g2,
g2_to_g3,
g2_to_g4,
transfers: Arc::new(DashMap::new()),
_chain_router_handle: chain_router_handle,
_remote_g4_offload_handle: remote_g4_offload_handle,
})
}
}
/// Request for remote G4 offload (distributed mode).
///
/// Contains the information needed to call workers' ObjectBlockOps::put_blocks().
struct RemoteG4OffloadRequest {
/// Transfer ID for tracking
transfer_id: TransferId,
/// Sequence hashes (keys for object storage)
keys: Vec<SequenceHash>,
/// Block IDs in G2 layout
block_ids: Vec<BlockId>,
}
/// Routes chain output from G1→G2 to downstream G2→G3/G2→G4 pipelines.
///
/// Blocks are converted to WeakBlocks for best-effort offloading - if they're
/// evicted before the downstream pipeline processes them, that's acceptable.
/// This enables graceful degradation under memory pressure.
async fn chain_router_task(
mut chain_rx: ChainOutputRx<G2>,
g2_to_g3_queue: Option<Arc<CancellableQueue<PipelineInput<G2>>>>,
g2_to_g4_queue: Option<Arc<CancellableQueue<PipelineInput<G2>>>>,
remote_g4_tx: Option<mpsc::Sender<RemoteG4OffloadRequest>>,
) {
while let Some(output) = chain_rx.recv().await {
let ChainOutput {
transfer_id,
blocks,
state,
} = output;
if blocks.is_empty() {
continue;
}
// Convert strong blocks to weak blocks for best-effort downstream processing
// This allows blocks to be evicted if memory pressure requires it
let weak_blocks: Vec<WeakBlock<G2>> =
blocks.iter().map(|block| block.downgrade()).collect();
// Extract sequence hashes and block IDs for remote G4 offload before dropping
let remote_g4_data: Option<(Vec<SequenceHash>, Vec<BlockId>)> = if remote_g4_tx.is_some() {
Some((
blocks.iter().map(|b| b.sequence_hash()).collect(),
blocks.iter().map(|b| b.block_id()).collect(),
))
} else {
None
};
// Drop strong references - blocks can now be evicted if needed
drop(blocks);
tracing::debug!(
%transfer_id,
num_blocks = weak_blocks.len(),
"Routing chain output to downstream pipelines as WeakBlocks"
);
// Enqueue to G2→G3 if available
if let Some(ref queue) = g2_to_g3_queue {
let input = PipelineInput {
transfer_id,
source: SourceBlocks::Weak(weak_blocks.clone()),
state: state.clone(),
};
if !queue.push(transfer_id, input) {
tracing::debug!(%transfer_id, "G2→G3 chain enqueue skipped (cancelled)");
}
}
// Enqueue to local G2→G4 pipeline if available
if let Some(ref queue) = g2_to_g4_queue {
let input = PipelineInput {
transfer_id,
source: SourceBlocks::Weak(weak_blocks.clone()),
state: state.clone(),
};
if !queue.push(transfer_id, input) {
tracing::debug!(%transfer_id, "G2→G4 chain enqueue skipped (cancelled)");
}
}
// Send to remote G4 offload if enabled (distributed mode)
if let (Some(tx), Some((keys, block_ids))) = (&remote_g4_tx, remote_g4_data) {
let request = RemoteG4OffloadRequest {
transfer_id,
keys,
block_ids,
};
if tx.send(request).await.is_err() {
tracing::debug!(%transfer_id, "Remote G4 offload channel closed");
}
}
}
tracing::debug!("Chain router task shutting down");
}
/// Task that processes remote G4 offload requests.
///
/// In distributed mode, this task receives requests from the chain router
/// and calls workers' ObjectBlockOps to upload blocks to object storage.
/// Uses execute_remote_offload with RemoteDescriptor::Object to coordinate
/// workers uploading their local G2 data to S3.
async fn remote_g4_offload_task(
mut rx: mpsc::Receiver<RemoteG4OffloadRequest>,
leader: Arc<InstanceLeader>,
) {
tracing::info!("Remote G4 offload task started");
while let Some(request) = rx.recv().await {
let num_blocks = request.keys.len();
tracing::debug!(
%request.transfer_id,
num_blocks,
"Processing remote G4 offload request"
);
// Use the leader's execute_remote_offload with RemoteDescriptor::Object
// This coordinates all workers to upload from their local G2 to object storage
let result = leader.execute_remote_offload(
LogicalLayoutHandle::G2, // Source is G2 (host memory)
RemoteDescriptor::Object {
keys: request.keys.clone(),
},
request.block_ids.clone(),
TransferOptions::default(),
);
match result {
Ok(notification) => {
// Wait for all workers to complete
match notification.await {
Ok(()) => {
tracing::info!(
%request.transfer_id,
num_blocks,
"Remote G4 offload completed successfully"
);
}
Err(e) => {
tracing::warn!(
%request.transfer_id,
num_blocks,
error = %e,
"Remote G4 offload failed"
);
}
}
}
Err(e) => {
tracing::warn!(
%request.transfer_id,
num_blocks,
error = %e,
"Failed to initiate remote G4 offload"
);
}
}
}
tracing::info!("Remote G4 offload task shutting down");
}
#[cfg(test)]
mod tests {
use super::*;
// Note: Full tests require complex infrastructure setup (InstanceLeader, BlockManagers, etc.)
// Basic API tests here.
#[test]
fn test_transfer_id_generation() {
let id1 = TransferId::new();
let id2 = TransferId::new();
assert_ne!(id1, id2);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer handle and status tracking for offload operations.
//!
//! The `TransferHandle` is the user-facing interface for tracking and controlling
//! an offload transfer. It provides:
//! - Status tracking (Evaluating, Queued, Transferring, Complete, Cancelled)
//! - Block visibility (passed, completed, remaining)
//! - Cancellation with confirmation
use std::collections::HashSet;
use anyhow::Result;
use tokio::sync::watch;
use uuid::Uuid;
use crate::BlockId;
use super::cancel::{CancelConfirmation, CancelStateUpdater, CancellationToken};
/// Unique identifier for a transfer operation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TransferId(Uuid);
impl TransferId {
/// Create a new random transfer ID.
pub fn new() -> Self {
TransferId(Uuid::new_v4())
}
/// Get the underlying UUID.
pub fn as_uuid(&self) -> Uuid {
self.0
}
}
impl Default for TransferId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for TransferId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<Uuid> for TransferId {
fn from(uuid: Uuid) -> Self {
TransferId(uuid)
}
}
/// Status of a transfer operation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStatus {
/// Policy/filter evaluation in progress
Evaluating,
/// Passed filters, waiting in batch queue
Queued,
/// Transfer operation in progress
Transferring,
/// Transfer completed successfully
Complete,
/// Transfer was cancelled
Cancelled,
/// Transfer failed with error
Failed,
}
impl TransferStatus {
/// Check if the transfer is in a terminal state.
pub fn is_terminal(&self) -> bool {
matches!(
self,
TransferStatus::Complete | TransferStatus::Cancelled | TransferStatus::Failed
)
}
/// Check if the transfer is still in progress.
pub fn is_active(&self) -> bool {
!self.is_terminal()
}
}
/// Result of a completed transfer.
#[derive(Debug, Clone)]
pub struct TransferResult {
/// Transfer ID
pub id: TransferId,
/// Final status
pub status: TransferStatus,
/// Blocks that passed all filters
pub passed_blocks: Vec<BlockId>,
/// Blocks successfully transferred
pub completed_blocks: Vec<BlockId>,
/// Blocks that failed transfer
pub failed_blocks: Vec<BlockId>,
/// Blocks that were filtered out
pub filtered_blocks: Vec<BlockId>,
/// Error message if failed
pub error: Option<String>,
}
/// Handle for tracking and controlling an offload transfer.
///
/// Obtained from `OffloadEngine::enqueue()`. Use this to:
/// - Monitor transfer progress via `status()`, `passed_blocks()`, etc.
/// - Cancel the transfer via `cancel()` and await confirmation
/// - Wait for completion via `wait()`
#[derive(Clone)]
pub struct TransferHandle {
id: TransferId,
status_rx: watch::Receiver<TransferStatus>,
passed_blocks_rx: watch::Receiver<Vec<BlockId>>,
completed_rx: watch::Receiver<Vec<BlockId>>,
failed_rx: watch::Receiver<Vec<BlockId>>,
remaining_rx: watch::Receiver<Vec<BlockId>>,
cancel_token: CancellationToken,
result_rx: watch::Receiver<Option<TransferResult>>,
}
impl TransferHandle {
/// Get the transfer ID.
pub fn id(&self) -> TransferId {
self.id
}
/// Get the current transfer status.
pub fn status(&self) -> TransferStatus {
*self.status_rx.borrow()
}
/// Get blocks that passed all filter policies.
pub fn passed_blocks(&self) -> Vec<BlockId> {
self.passed_blocks_rx.borrow().clone()
}
/// Get blocks that have been successfully transferred.
pub fn completed_blocks(&self) -> Vec<BlockId> {
self.completed_rx.borrow().clone()
}
/// Get blocks that failed transfer.
pub fn failed_blocks(&self) -> Vec<BlockId> {
self.failed_rx.borrow().clone()
}
/// Get blocks remaining to be transferred.
pub fn remaining_blocks(&self) -> Vec<BlockId> {
self.remaining_rx.borrow().clone()
}
/// Check if the transfer is complete (success, cancelled, or failed).
pub fn is_complete(&self) -> bool {
self.status().is_terminal()
}
/// Cancel the transfer and await confirmation.
///
/// Returns a future that resolves when all blocks are confirmed released
/// with no outstanding operations.
///
/// # Example
/// ```ignore
/// // Request cancellation and wait for confirmation
/// handle.cancel().wait().await;
/// // All blocks are now released
/// ```
pub fn cancel(&self) -> CancelConfirmation {
self.cancel_token.request();
self.cancel_token.wait_confirmed()
}
/// Check if cancellation has been requested.
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_requested()
}
/// Wait for the transfer to complete.
///
/// Returns the final `TransferResult` when the transfer reaches a terminal state.
pub async fn wait(&mut self) -> Result<TransferResult> {
// Wait until we have a result
loop {
{
let result = self.result_rx.borrow();
if let Some(r) = result.as_ref() {
return Ok(r.clone());
}
}
if self.result_rx.changed().await.is_err() {
// Channel closed without result
return Err(anyhow::anyhow!("Transfer channel closed unexpectedly"));
}
}
}
/// Subscribe to status changes.
pub fn subscribe_status(&self) -> watch::Receiver<TransferStatus> {
self.status_rx.clone()
}
}
impl std::fmt::Debug for TransferHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransferHandle")
.field("id", &self.id)
.field("status", &self.status())
.field("passed_count", &self.passed_blocks().len())
.field("completed_count", &self.completed_blocks().len())
.field("failed_count", &self.failed_blocks().len())
.field("remaining_count", &self.remaining_blocks().len())
.finish()
}
}
/// Internal state for tracking a transfer through the pipeline.
#[allow(dead_code)]
pub(crate) struct TransferState {
pub(crate) id: TransferId,
/// Current phase
pub(crate) status: TransferStatus,
/// Original input block IDs
pub(crate) input_blocks: Vec<BlockId>,
/// Blocks that passed policy filters
pub(crate) passed_blocks: Vec<BlockId>,
/// Blocks currently in-flight (being transferred)
pub(crate) in_flight: HashSet<BlockId>,
/// Successfully transferred blocks
pub(crate) completed: Vec<BlockId>,
/// Blocks that failed transfer
pub(crate) failed: Vec<BlockId>,
/// Blocks that failed filters
pub(crate) filtered_out: Vec<BlockId>,
/// Error message if failed
pub(crate) error: Option<String>,
/// Notifier channels
pub(crate) notifiers: TransferNotifiers,
/// Cancel state updater
pub(crate) cancel_updater: CancelStateUpdater,
/// Total blocks expected in this transfer (set by PolicyEvaluator)
pub(crate) total_expected_blocks: usize,
/// Blocks that have been processed through policy evaluation (for sentinel flush)
pub(crate) blocks_processed: usize,
/// Precondition event that must be satisfied before processing this transfer.
/// Set by the caller when enqueuing offload operations. BatchCollector will
/// attach this to the TransferBatch, and PreconditionAwaiter will await it
/// before forwarding to TransferExecutor.
pub(crate) precondition: Option<velo::EventHandle>,
}
#[allow(dead_code)]
impl TransferState {
/// Create transfer state and associated handle.
pub(crate) fn new(id: TransferId, input_blocks: Vec<BlockId>) -> (Self, TransferHandle) {
let (status_tx, status_rx) = watch::channel(TransferStatus::Evaluating);
let (passed_tx, passed_rx) = watch::channel(Vec::new());
let (completed_tx, completed_rx) = watch::channel(Vec::new());
let (failed_tx, failed_rx) = watch::channel(Vec::new());
let (remaining_tx, remaining_rx) = watch::channel(input_blocks.clone());
let (result_tx, result_rx) = watch::channel(None);
let (cancel_token, cancel_updater) = CancellationToken::new();
let notifiers = TransferNotifiers {
status_tx,
passed_tx,
completed_tx,
failed_tx,
remaining_tx,
result_tx,
};
let state = TransferState {
id,
status: TransferStatus::Evaluating,
input_blocks: input_blocks.clone(),
passed_blocks: Vec::new(),
in_flight: HashSet::new(),
completed: Vec::new(),
failed: Vec::new(),
filtered_out: Vec::new(),
error: None,
notifiers,
cancel_updater,
total_expected_blocks: 0, // Set by PolicyEvaluator when transfer starts
blocks_processed: 0,
precondition: None, // Set by caller via enqueue_with_precondition
};
let handle = TransferHandle {
id,
status_rx,
passed_blocks_rx: passed_rx,
completed_rx,
failed_rx,
remaining_rx,
cancel_token,
result_rx,
};
(state, handle)
}
/// Check if cancellation has been requested.
pub(crate) fn is_cancel_requested(&self) -> bool {
self.cancel_updater.is_requested()
}
/// Update status and notify.
pub(crate) fn set_status(&mut self, status: TransferStatus) {
self.status = status;
let _ = self.notifiers.status_tx.send(status);
}
/// Add blocks that passed filters.
pub(crate) fn add_passed(&mut self, block_ids: impl IntoIterator<Item = BlockId>) {
self.passed_blocks.extend(block_ids);
let _ = self.notifiers.passed_tx.send(self.passed_blocks.clone());
self.update_remaining();
}
/// Add blocks that were filtered out.
pub(crate) fn add_filtered(&mut self, block_ids: impl IntoIterator<Item = BlockId>) {
self.filtered_out.extend(block_ids);
self.update_remaining();
}
/// Mark blocks as in-flight (being transferred).
pub(crate) fn mark_in_flight(&mut self, block_ids: impl IntoIterator<Item = BlockId>) {
self.in_flight.extend(block_ids);
}
/// Mark blocks as completed (transferred successfully).
pub(crate) fn mark_completed(&mut self, block_ids: impl IntoIterator<Item = BlockId>) {
for id in block_ids {
self.in_flight.remove(&id);
self.completed.push(id);
}
let _ = self.notifiers.completed_tx.send(self.completed.clone());
self.update_remaining();
}
/// Mark blocks as failed (transfer unsuccessful).
pub(crate) fn mark_failed(&mut self, block_ids: impl IntoIterator<Item = BlockId>) {
for id in block_ids {
self.in_flight.remove(&id);
self.failed.push(id);
}
let _ = self.notifiers.failed_tx.send(self.failed.clone());
self.update_remaining();
}
/// Update remaining blocks notification.
fn update_remaining(&self) {
let remaining: Vec<BlockId> = self
.passed_blocks
.iter()
.filter(|id| !self.completed.contains(id) && !self.failed.contains(id))
.copied()
.collect();
let _ = self.notifiers.remaining_tx.send(remaining);
}
/// Set error and mark as failed.
pub(crate) fn set_error(&mut self, error: String) {
self.error = Some(error);
self.set_status(TransferStatus::Failed);
self.finalize();
}
/// Mark as cancelled.
pub(crate) fn set_cancelled(&mut self) {
self.set_status(TransferStatus::Cancelled);
self.cancel_updater.set_confirmed();
self.finalize();
}
/// Mark as complete (all blocks transferred).
pub(crate) fn set_complete(&mut self) {
self.set_status(TransferStatus::Complete);
self.finalize();
}
/// Finalize and send result.
fn finalize(&mut self) {
let result = TransferResult {
id: self.id,
status: self.status,
passed_blocks: self.passed_blocks.clone(),
completed_blocks: self.completed.clone(),
failed_blocks: self.failed.clone(),
filtered_blocks: self.filtered_out.clone(),
error: self.error.clone(),
};
let _ = self.notifiers.result_tx.send(Some(result));
}
/// Get current in-flight count (for draining).
pub(crate) fn in_flight_count(&self) -> usize {
self.in_flight.len()
}
/// Begin draining (cancellation in progress).
pub(crate) fn begin_draining(&self) {
self.cancel_updater.set_draining(self.in_flight.len());
}
/// Update draining count.
pub(crate) fn update_draining(&self) {
self.cancel_updater.update_draining(self.in_flight.len());
}
}
/// Internal notification channels for transfer state updates.
#[allow(dead_code)]
pub(crate) struct TransferNotifiers {
pub(crate) status_tx: watch::Sender<TransferStatus>,
pub(crate) passed_tx: watch::Sender<Vec<BlockId>>,
pub(crate) completed_tx: watch::Sender<Vec<BlockId>>,
pub(crate) failed_tx: watch::Sender<Vec<BlockId>>,
pub(crate) remaining_tx: watch::Sender<Vec<BlockId>>,
pub(crate) result_tx: watch::Sender<Option<TransferResult>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transfer_id() {
let id1 = TransferId::new();
let id2 = TransferId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_transfer_status() {
assert!(!TransferStatus::Evaluating.is_terminal());
assert!(!TransferStatus::Queued.is_terminal());
assert!(!TransferStatus::Transferring.is_terminal());
assert!(TransferStatus::Complete.is_terminal());
assert!(TransferStatus::Cancelled.is_terminal());
assert!(TransferStatus::Failed.is_terminal());
}
#[test]
fn test_transfer_state_creation() {
let id = TransferId::new();
let blocks = vec![1, 2, 3];
let (state, handle) = TransferState::new(id, blocks.clone());
assert_eq!(state.id, id);
assert_eq!(state.status, TransferStatus::Evaluating);
assert_eq!(state.input_blocks, blocks);
assert!(state.passed_blocks.is_empty());
assert!(state.completed.is_empty());
assert_eq!(handle.id(), id);
assert_eq!(handle.status(), TransferStatus::Evaluating);
assert_eq!(handle.remaining_blocks(), blocks);
}
#[test]
fn test_transfer_state_progress() {
let id = TransferId::new();
let blocks = vec![1, 2, 3, 4, 5];
let (mut state, handle) = TransferState::new(id, blocks);
// Some blocks pass filters
state.add_passed(vec![1, 2, 3]);
state.add_filtered(vec![4, 5]);
assert_eq!(handle.passed_blocks(), vec![1, 2, 3]);
// Start transferring
state.set_status(TransferStatus::Transferring);
state.mark_in_flight(vec![1, 2]);
assert_eq!(handle.status(), TransferStatus::Transferring);
// Complete some
state.mark_completed(vec![1]);
assert_eq!(handle.completed_blocks(), vec![1]);
assert_eq!(state.in_flight_count(), 1);
// Complete rest
state.mark_completed(vec![2, 3]);
state.set_complete();
assert_eq!(handle.status(), TransferStatus::Complete);
assert_eq!(handle.completed_blocks(), vec![1, 2, 3]);
}
#[tokio::test]
async fn test_transfer_handle_wait() {
let id = TransferId::new();
let blocks = vec![1, 2, 3];
let (mut state, mut handle) = TransferState::new(id, blocks);
// Spawn task to complete the transfer
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
state.add_passed(vec![1, 2, 3]);
state.mark_completed(vec![1, 2, 3]);
state.set_complete();
});
// Wait for completion
let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), handle.wait())
.await
.expect("Should complete within timeout")
.expect("Should succeed");
assert_eq!(result.status, TransferStatus::Complete);
assert_eq!(result.completed_blocks, vec![1, 2, 3]);
}
#[test]
fn test_mark_failed_removes_from_in_flight() {
let id = TransferId::new();
let blocks = vec![1, 2, 3];
let (mut state, handle) = TransferState::new(id, blocks);
state.add_passed(vec![1, 2, 3]);
state.mark_in_flight(vec![1, 2, 3]);
assert_eq!(state.in_flight_count(), 3);
state.mark_failed(vec![2]);
assert_eq!(state.in_flight_count(), 2);
assert_eq!(handle.failed_blocks(), vec![2]);
assert!(handle.completed_blocks().is_empty());
}
#[test]
fn test_mark_failed_updates_remaining() {
let id = TransferId::new();
let blocks = vec![1, 2, 3];
let (mut state, handle) = TransferState::new(id, blocks);
state.add_passed(vec![1, 2, 3]);
state.mark_in_flight(vec![1, 2, 3]);
// Fail block 2 — remaining should exclude it
state.mark_failed(vec![2]);
let remaining = handle.remaining_blocks();
assert!(remaining.contains(&1));
assert!(!remaining.contains(&2));
assert!(remaining.contains(&3));
}
#[test]
fn test_partial_failure_result() {
let id = TransferId::new();
let blocks = vec![1, 2, 3, 4, 5];
let (mut state, _handle) = TransferState::new(id, blocks);
state.add_passed(vec![1, 2, 3]);
state.add_filtered(vec![4, 5]);
state.mark_in_flight(vec![1, 2, 3]);
// Block 1 succeeds, block 2 fails, block 3 succeeds
state.mark_completed(vec![1, 3]);
state.mark_failed(vec![2]);
assert_eq!(state.completed, vec![1, 3]);
assert_eq!(state.failed, vec![2]);
assert_eq!(state.in_flight_count(), 0);
// Simulate the pipeline's terminal state logic
let total = state.passed_blocks.len() + state.filtered_out.len();
let done = state.completed.len() + state.failed.len() + state.filtered_out.len();
assert_eq!(done, total);
// With failures, should set_error not set_complete
let failed_count = state.failed.len();
assert!(failed_count > 0);
state.set_error(format!(
"{failed_count} blocks failed to transfer to object storage",
));
assert_eq!(state.status, TransferStatus::Failed);
}
#[tokio::test]
async fn test_partial_failure_wait_result() {
let id = TransferId::new();
let blocks = vec![1, 2, 3];
let (mut state, mut handle) = TransferState::new(id, blocks);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
state.add_passed(vec![1, 2, 3]);
state.mark_in_flight(vec![1, 2, 3]);
state.mark_completed(vec![1, 3]);
state.mark_failed(vec![2]);
state.set_error("1 blocks failed to transfer to object storage".to_string());
});
let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), handle.wait())
.await
.expect("Should complete within timeout")
.expect("Should succeed");
assert_eq!(result.status, TransferStatus::Failed);
assert_eq!(result.completed_blocks, vec![1, 3]);
assert_eq!(result.failed_blocks, vec![2]);
assert!(result.error.is_some());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Offload Engine for asynchronous block transfers between storage tiers.
//!
//! The offload engine provides a policy-based, cancellable pipeline for moving
//! blocks from higher-performance tiers (G1/G2) to lower-cost tiers (G3/G4).
//!
//! # Architecture
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ OffloadEngine │
//! │ │
//! │ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
//! │ │G1→G2 Pipeline │────│ G2→G3 Pipeline│ │ G2→G4 Pipeline│ │
//! │ └───────────────┘ └───────────────┘ └───────────────┘ │
//! │ │ │ │ │
//! │ └─────────auto_chain──┘ │ │
//! │ │
//! └─────────────────────────────────────────────────────────────────┘
//!
//! Pipeline stages:
//! ┌─────────────┐ ┌────────────────┐ ┌──────────────────┐
//! │ Policy │───▶│ Batch │───▶│ Transfer │
//! │ Evaluator │ │ Collector │ │ Executor │
//! └─────────────┘ └────────────────┘ └──────────────────┘
//! │ │ │
//! ▼ ▼ ▼
//! cancel check cancel check wait for in-flight
//! ```
//!
//! # Features
//!
//! - **Policy-based filtering**: Blocks pass through configurable policies
//! (presence checks, LFU thresholds) before transfer
//! - **Batched transfers**: Blocks are accumulated into batches for efficient
//! bulk transfers
//! - **Cancellation**: Clean cancellation with confirmation that all blocks
//! are released and no outstanding operations remain
//! - **Pipeline chaining**: G1→G2 completions can automatically feed G2→G3
//!
//! See also: [Developer Guide](../../docs/offload-developer.md) for implementation
//! details and extension rules.
//!
//! # Example
//!
//! ```ignore
//! use kvbm::v2::distributed::offload::{
//! OffloadEngine, PipelineBuilder, PresenceFilter, PresenceAndLFUFilter,
//! };
//!
//! // Build engine with pipelines
//! let engine = OffloadEngine::builder(leader.clone())
//! .with_registry(registry.clone())
//! .with_g2_manager(g2_manager.clone())
//! .with_g3_manager(g3_manager.clone())
//! .with_g2_to_g3_pipeline(
//! PipelineBuilder::<G2, G3>::new()
//! .policy(Arc::new(PresenceAndLFUFilter::with_default_threshold(registry.clone())))
//! .batch_size(64)
//! .build()
//! )
//! .build()?;
//!
//! // Enqueue blocks for offload
//! let handle = engine.enqueue_g2_to_g3(blocks)?;
//!
//! // Wait for completion or cancel
//! tokio::select! {
//! result = handle.wait() => {
//! println!("Completed: {:?}", result?.completed_blocks);
//! }
//! _ = shutdown_signal => {
//! handle.cancel().wait().await;
//! println!("Cancelled");
//! }
//! }
//! ```
//!
//! See also: [Developer Guide](../../docs/offload-developer.md)
/// Helper macro to create an NVTX range when the nvtx feature is enabled.
/// The range automatically ends when the returned guard is dropped.
macro_rules! nvtx_range {
($name:expr) => {{
#[cfg(feature = "nvtx")]
let _range = nvtx::range!($name);
#[cfg(not(feature = "nvtx"))]
let _range = ();
_range
}};
}
mod batch;
mod cancel;
mod engine;
mod handle;
mod pending;
mod pipeline;
mod policy;
mod queue;
mod source;
#[cfg(test)]
mod cancel_tests;
// Re-export public API
pub use cancel::{CancelConfirmation, CancelState, CancellationToken};
pub use engine::{OffloadEngine, OffloadEngineBuilder};
pub use handle::{TransferHandle, TransferId, TransferResult, TransferStatus};
pub use pending::{PendingGuard, PendingTracker};
pub use pipeline::{
ObjectPipeline, ObjectPipelineBuilder, ObjectPipelineConfig, Pipeline, PipelineBuilder,
PipelineConfig, ResolvedBatch, ResolvedBlock, upgrade_batch,
};
pub use policy::{
AllOfPolicy, AnyOfPolicy, BoxFuture, EvalContext, ObjectLockPresenceFilter,
ObjectPresenceFilter, OffloadPolicy, PassAllPolicy, PolicyBatchFuture, PolicyFuture,
PresenceAndLFUFilter, PresenceChecker, PresenceFilter, S3PresenceChecker, async_batch_result,
async_result, create_policy_from_config, sync_batch_result, sync_result,
};
pub use queue::CancellableQueue;
pub use source::{ExternalBlock, SourceBlock, SourceBlocks};
// Re-export batch config for advanced users
pub use batch::{BatchConfig, TimingTrace};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Pending transfer tracking for duplicate prevention.
//!
//! This module provides `PendingTracker` and `PendingGuard` types that work together
//! to track blocks that are currently in-flight through the transfer pipeline.
//!
//! # Problem
//!
//! When overlapping sequences are enqueued for transfer at roughly the same time,
//! the presence policy may allow duplicate transfers because:
//! - The first sequence's blocks haven't completed registration yet
//! - The second sequence sees the same blocks as "not present"
//!
//! # Solution
//!
//! The `PendingTracker` maintains a set of sequence hashes currently in the pipeline.
//! When blocks pass policy evaluation, a `PendingGuard` is created that:
//! - Adds the sequence hash to the pending set on creation
//! - Automatically removes it on drop (RAII pattern)
//!
//! The `PresenceFilter` can then check both the registry (completed transfers)
//! AND the pending set (in-flight transfers) to avoid duplicates.
//!
//! # Example
//!
//! ```ignore
//! let tracker = Arc::new(PendingTracker::new());
//!
//! // Create guard when block passes policy
//! let guard = tracker.guard(sequence_hash);
//!
//! // Guard travels with block through pipeline stages
//! queued_block.pending_guard = Some(guard);
//!
//! // When block completes or is cancelled, guard is dropped
//! // and hash is automatically removed from pending set
//! ```
use std::sync::Arc;
use dashmap::DashSet;
use crate::SequenceHash;
/// Tracks sequence hashes that are currently pending transfer.
///
/// This is shared between the pipeline and the presence policy via `Arc`.
/// Thread-safe for concurrent access from multiple pipeline stages.
#[derive(Debug, Default)]
pub struct PendingTracker {
pending: DashSet<SequenceHash>,
}
impl PendingTracker {
/// Create a new empty pending tracker.
pub fn new() -> Self {
Self {
pending: DashSet::new(),
}
}
/// Check if a sequence hash is currently pending transfer.
///
/// Used by `PresenceFilter` to skip blocks that are already in-flight.
pub fn is_pending(&self, hash: &SequenceHash) -> bool {
self.pending.contains(hash)
}
/// Get the number of pending transfers.
///
/// Useful for metrics and debugging.
pub fn len(&self) -> usize {
self.pending.len()
}
/// Check if there are no pending transfers.
pub fn is_empty(&self) -> bool {
self.pending.is_empty()
}
/// Create a guard that marks a sequence hash as pending until dropped.
///
/// The guard uses RAII to ensure the hash is removed when:
/// - Transfer completes successfully
/// - Transfer is cancelled
/// - Block is evicted from pipeline
/// - Any error causes the block to be dropped
pub fn guard(self: &Arc<Self>, hash: SequenceHash) -> PendingGuard {
self.pending.insert(hash);
PendingGuard {
hash,
tracker: Arc::clone(self),
}
}
}
/// Extension trait for `Option<Arc<PendingTracker>>` to simplify pending checks.
///
/// Reduces the common pattern `self.pending_tracker.as_ref().is_some_and(|t| t.is_pending(&hash))`
/// to a single method call.
pub(crate) trait PendingCheck {
fn is_hash_pending(&self, hash: &SequenceHash) -> bool;
}
impl PendingCheck for Option<Arc<PendingTracker>> {
fn is_hash_pending(&self, hash: &SequenceHash) -> bool {
self.as_ref().is_some_and(|t| t.is_pending(hash))
}
}
/// RAII guard that removes a sequence hash from the pending set on drop.
///
/// This guard travels with the block through all pipeline stages and ensures
/// cleanup happens automatically regardless of how the transfer completes.
///
/// # Clone Behavior
///
/// Cloning a `PendingGuard` is cheap (Arc clone) but does NOT create a new
/// pending entry. The hash is only inserted once when the first guard is
/// created, and removed when ALL clones are dropped.
///
/// However, the current implementation removes on first drop, so cloning
/// should be avoided unless you understand the implications.
pub struct PendingGuard {
hash: SequenceHash,
tracker: Arc<PendingTracker>,
}
impl PendingGuard {
/// Get the sequence hash this guard is tracking.
pub fn sequence_hash(&self) -> SequenceHash {
self.hash
}
}
impl Drop for PendingGuard {
fn drop(&mut self) {
self.tracker.pending.remove(&self.hash);
}
}
impl std::fmt::Debug for PendingGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingGuard")
.field("sequence_hash", &self.hash)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper to create a test SequenceHash with unique values.
fn test_hash(id: u64) -> SequenceHash {
SequenceHash::new(id, Some(0), id)
}
#[test]
fn test_pending_tracker_new() {
let tracker = PendingTracker::new();
assert!(tracker.is_empty());
assert_eq!(tracker.len(), 0);
}
#[test]
fn test_pending_guard_inserts_and_removes() {
let tracker = Arc::new(PendingTracker::new());
let hash = test_hash(12345);
assert!(!tracker.is_pending(&hash));
{
let _guard = tracker.guard(hash);
assert!(tracker.is_pending(&hash));
assert_eq!(tracker.len(), 1);
}
// Guard dropped, hash should be removed
assert!(!tracker.is_pending(&hash));
assert!(tracker.is_empty());
}
#[test]
fn test_multiple_guards_different_hashes() {
let tracker = Arc::new(PendingTracker::new());
let hash1 = test_hash(111);
let hash2 = test_hash(222);
let hash3 = test_hash(333);
let guard1 = tracker.guard(hash1);
let guard2 = tracker.guard(hash2);
assert!(tracker.is_pending(&hash1));
assert!(tracker.is_pending(&hash2));
assert!(!tracker.is_pending(&hash3));
assert_eq!(tracker.len(), 2);
drop(guard1);
assert!(!tracker.is_pending(&hash1));
assert!(tracker.is_pending(&hash2));
assert_eq!(tracker.len(), 1);
drop(guard2);
assert!(tracker.is_empty());
}
#[test]
fn test_guard_sequence_hash_accessor() {
let tracker = Arc::new(PendingTracker::new());
let hash = test_hash(42);
let guard = tracker.guard(hash);
assert_eq!(guard.sequence_hash(), hash);
}
#[test]
fn test_tracker_debug() {
let tracker = PendingTracker::new();
let debug_str = format!("{:?}", tracker);
assert!(debug_str.contains("PendingTracker"));
}
#[test]
fn test_guard_debug() {
let tracker = Arc::new(PendingTracker::new());
let hash = test_hash(999);
let guard = tracker.guard(hash);
let debug_str = format!("{:?}", guard);
assert!(debug_str.contains("PendingGuard"));
assert!(debug_str.contains("sequence_hash"));
}
#[test]
fn test_concurrent_access_to_same_hash() {
// Test that the same hash being added twice is handled correctly
let tracker = Arc::new(PendingTracker::new());
let hash = test_hash(555);
// First guard marks it as pending
let guard1 = tracker.guard(hash);
assert!(tracker.is_pending(&hash));
assert_eq!(tracker.len(), 1);
// Second guard for same hash - DashSet.insert returns false if already present
// but our guard() always inserts (doesn't check first)
let guard2 = tracker.guard(hash);
assert!(tracker.is_pending(&hash));
// DashSet deduplicates, so len is still 1
assert_eq!(tracker.len(), 1);
// Drop first guard - hash removed from set
drop(guard1);
// DashSet now doesn't have the hash
assert!(!tracker.is_pending(&hash));
// Second guard still exists but hash was already removed
// This is expected behavior - the RAII ensures cleanup on any drop
drop(guard2);
assert!(tracker.is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Pipeline coordination for offload transfers.
//!
//! A pipeline connects these stages:
//! 1. **PolicyEvaluator**: Evaluates blocks against policies, filters out non-passing blocks
//! 2. **BatchCollector**: Accumulates passing blocks into batches
//! 3. **PreconditionAwaiter**: Awaits precondition events before processing
//! 4. **BlockUpgrader**: Upgrades `WeakBlock` → `ImmutableBlock` (via `upgrade_batch`)
//! 5. **Transfer Executor**: Executes the actual data transfer
//! - `BlockTransferExecutor`: For BlockManager destinations (G2, G3)
//! - `ObjectTransferExecutor`: For object storage destinations (G4)
//!
//! # Cancellation Architecture
//!
//! Unlike mpsc-based pipelines where cancellation only happens at dequeue boundaries,
//! this implementation uses [`CancellableQueue`] which enables a dedicated sweeper task
//! to actively remove items from cancelled transfers. This ensures that `ImmutableBlock`
//! guards are dropped promptly when a transfer is cancelled.
//!
//! ```text
//! enqueue() ─┬─► [CancellableQueue A] ──► PolicyEvaluator ──┬─► [CancellableQueue B] ──► ...
//! │ │
//! └──────────────► [CancelSweeper] ◄─────────────┘
//! │
//! (iterates queues,
//! removes by TransferId,
//! drops ImmutableBlock guards)
//! ```
use std::collections::HashSet;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::future::Either;
use tokio::sync::{Semaphore, mpsc, watch};
use tokio::task::JoinHandle;
use crate::leader::InstanceLeader;
use crate::object::ObjectBlockOps;
use crate::{BlockId, SequenceHash};
use kvbm_common::LogicalLayoutHandle;
use kvbm_logical::blocks::{BlockMetadata, BlockRegistry, ImmutableBlock};
use kvbm_logical::manager::BlockManager;
use kvbm_physical::transfer::TransferOptions;
use super::batch::{
BatchCollector, BatchConfig, BatchOutputRx, EvalResult, QueuedBlock, TimingTrace, TransferBatch,
};
use super::handle::{TransferId, TransferState, TransferStatus};
use super::pending::PendingTracker;
use super::policy::{EvalContext, OffloadPolicy};
use super::queue::CancellableQueue;
use super::source::{SourceBlock, SourceBlocks};
use crate::object::ObjectLockManager;
/// Configuration for a pipeline.
#[derive(Clone)]
pub struct PipelineConfig<Src: BlockMetadata, Dst: BlockMetadata> {
/// Policies to evaluate (all must pass)
pub policies: Vec<Arc<dyn OffloadPolicy<Src>>>,
/// Batch configuration
pub batch_config: BatchConfig,
/// Timeout for policy evaluation (fail-fast)
pub policy_timeout: Duration,
/// Whether arrivals from this pipeline auto-feed downstream
pub auto_chain: bool,
/// Channel capacity for evaluation input
pub eval_input_capacity: usize,
/// Channel capacity for batch input
pub batch_input_capacity: usize,
/// Channel capacity for transfer input
pub transfer_input_capacity: usize,
/// Sweep interval for cancellation task
pub sweep_interval: Duration,
/// Skip actual transfers (for testing)
pub skip_transfers: bool,
/// Maximum number of concurrent transfer batches.
///
/// This controls how many batches can be transferred simultaneously.
/// Setting this higher can improve throughput at the cost of memory.
/// Default: 1 (sequential execution)
pub max_concurrent_transfers: usize,
/// Pending tracker for duplicate prevention.
///
/// If provided, this tracker is used. If None, the pipeline creates its own.
/// Share this tracker with presence-based policies to prevent duplicate transfers.
pub pending_tracker: Option<Arc<PendingTracker>>,
/// Maximum number of concurrent precondition awaits.
///
/// This controls how many batches can be awaiting their preconditions simultaneously.
/// Allows multiple iterations to be in-flight without blocking the pipeline.
/// Default: 8 (allows ~8 iterations in-flight concurrently)
pub max_concurrent_precondition_awaits: usize,
/// Marker
_marker: PhantomData<(Src, Dst)>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> Default for PipelineConfig<Src, Dst> {
fn default() -> Self {
Self {
policies: Vec::new(),
batch_config: BatchConfig::default(),
policy_timeout: Duration::from_millis(100),
auto_chain: false,
eval_input_capacity: 128,
batch_input_capacity: 256,
transfer_input_capacity: 8,
sweep_interval: Duration::from_millis(10),
skip_transfers: false,
max_concurrent_transfers: 1,
pending_tracker: None,
max_concurrent_precondition_awaits: 8,
_marker: PhantomData,
}
}
}
/// Builder for pipeline configuration.
pub struct PipelineBuilder<Src: BlockMetadata, Dst: BlockMetadata> {
config: PipelineConfig<Src, Dst>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> PipelineBuilder<Src, Dst> {
/// Create a new pipeline builder with defaults.
pub fn new() -> Self {
Self {
config: PipelineConfig::default(),
}
}
/// Add a policy to the pipeline.
pub fn policy(mut self, policy: Arc<dyn OffloadPolicy<Src>>) -> Self {
self.config.policies.push(policy);
self
}
/// Set batch size.
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_config.max_batch_size = size;
self
}
/// Set minimum batch size for flush.
pub fn min_batch_size(mut self, size: usize) -> Self {
self.config.batch_config.min_batch_size = size;
self
}
/// Set batch flush interval.
pub fn flush_interval(mut self, interval: Duration) -> Self {
self.config.batch_config.flush_interval = interval;
self
}
/// Set policy timeout.
pub fn policy_timeout(mut self, timeout: Duration) -> Self {
self.config.policy_timeout = timeout;
self
}
/// Enable auto-chaining to downstream pipelines.
pub fn auto_chain(mut self, enabled: bool) -> Self {
self.config.auto_chain = enabled;
self
}
/// Set the sweep interval for cancellation.
pub fn sweep_interval(mut self, interval: Duration) -> Self {
self.config.sweep_interval = interval;
self
}
/// Skip actual transfers (for testing).
///
/// When enabled, the transfer executor will mark blocks as completed
/// without executing actual data transfers.
pub fn skip_transfers(mut self, skip: bool) -> Self {
self.config.skip_transfers = skip;
self
}
/// Set maximum concurrent transfers.
///
/// This controls how many batches can be transferred simultaneously.
/// Must be at least 1.
///
/// # Default
/// 1 (sequential execution)
pub fn max_concurrent_transfers(mut self, n: usize) -> Self {
self.config.max_concurrent_transfers = n.max(1);
self
}
/// Set the pending tracker for duplicate prevention.
///
/// Share this tracker with presence-based policies (via `create_policy_from_config`)
/// to prevent duplicate transfers when overlapping sequences are enqueued.
pub fn pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.config.pending_tracker = Some(tracker);
self
}
/// Build the configuration.
pub fn build(self) -> PipelineConfig<Src, Dst> {
self.config
}
}
impl<Src: BlockMetadata, Dst: BlockMetadata> Default for PipelineBuilder<Src, Dst> {
fn default() -> Self {
Self::new()
}
}
/// Input to the pipeline (from enqueue).
pub(crate) struct PipelineInput<T: BlockMetadata> {
pub(crate) transfer_id: TransferId,
/// Source blocks - can be External, Strong, or Weak
pub(crate) source: SourceBlocks<T>,
pub(crate) state: Arc<std::sync::Mutex<TransferState>>,
}
/// Output from the pipeline (completed transfer).
pub struct PipelineOutput {
pub transfer_id: TransferId,
pub completed_hashes: Vec<SequenceHash>,
}
/// Chain output - carries registered blocks for downstream pipelines.
///
/// When `auto_chain` is enabled, the pipeline sends registered blocks
/// through this channel instead of dropping them. The receiving pipeline
/// can then process them through its own policy evaluation and transfer.
pub struct ChainOutput<T: BlockMetadata> {
pub transfer_id: TransferId,
pub blocks: Vec<ImmutableBlock<T>>,
/// State for transfer tracking (used when feeding downstream pipelines)
#[allow(dead_code)]
pub(crate) state: Arc<std::sync::Mutex<TransferState>>,
}
/// Receiver for chain output from a pipeline.
pub type ChainOutputRx<T> = mpsc::Receiver<ChainOutput<T>>;
/// A running pipeline instance.
pub struct Pipeline<Src: BlockMetadata, Dst: BlockMetadata> {
config: PipelineConfig<Src, Dst>,
/// Input queue for new blocks (CancellableQueue for sweep support)
pub(crate) eval_queue: Arc<CancellableQueue<PipelineInput<Src>>>,
/// Output channel for completed blocks (may feed downstream)
output_tx: Option<mpsc::Sender<PipelineOutput>>,
/// Chain output receiver - provides registered blocks for downstream pipelines
chain_rx: Option<ChainOutputRx<Dst>>,
/// Watch channel for cancelled transfer IDs (triggers sweep)
cancel_tx: watch::Sender<HashSet<TransferId>>,
/// Tracker for pending (in-flight) transfers to prevent duplicates
pending_tracker: Arc<PendingTracker>,
/// Task handles for pipeline stages
_task_handles: Vec<JoinHandle<()>>,
/// Marker
_marker: PhantomData<Dst>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> Pipeline<Src, Dst> {
/// Create a new pipeline with the given configuration.
///
/// # Arguments
/// * `config` - Pipeline configuration
/// * `registry` - Block registry for policy evaluation
/// * `dst_manager` - Destination tier block manager
/// * `leader` - Instance leader for transfer execution
/// * `src_layout` - Source logical layout handle
/// * `dst_layout` - Destination logical layout handle
/// * `runtime` - Tokio runtime handle for spawning background tasks
#[allow(clippy::too_many_arguments)]
pub fn new(
config: PipelineConfig<Src, Dst>,
_registry: Arc<BlockRegistry>,
dst_manager: Arc<BlockManager<Dst>>,
leader: Arc<InstanceLeader>,
src_layout: LogicalLayoutHandle,
dst_layout: LogicalLayoutHandle,
runtime: tokio::runtime::Handle,
) -> Self {
// Create cancellable queues
let eval_queue: Arc<CancellableQueue<PipelineInput<Src>>> =
Arc::new(CancellableQueue::new());
let batch_queue: Arc<CancellableQueue<EvalResult<Src>>> = Arc::new(CancellableQueue::new());
// Create output channel (still mpsc for downstream chaining)
let (output_tx, _output_rx) = mpsc::channel(64);
// Create watch channel for cancelled transfer IDs
let (cancel_tx, cancel_rx) = watch::channel(HashSet::new());
// Create batch output channel (BatchCollector → PreconditionAwaiter)
let (batch_tx, batch_rx) = mpsc::channel(config.transfer_input_capacity);
// Create precondition output channel (PreconditionAwaiter → TransferExecutor)
let (precond_tx, precond_rx) = mpsc::channel(config.transfer_input_capacity);
// Create chain output channel if auto_chain is enabled
let (chain_tx, chain_rx) = if config.auto_chain {
let (tx, rx) = mpsc::channel(64);
(Some(tx), Some(rx))
} else {
(None, None)
};
// Use provided pending tracker or create a new one
let pending_tracker = config
.pending_tracker
.clone()
.unwrap_or_else(|| Arc::new(PendingTracker::new()));
// Spawn policy evaluator
let evaluator = PolicyEvaluator {
policies: config.policies.clone(),
timeout: config.policy_timeout,
input_queue: eval_queue.clone(),
output_queue: batch_queue.clone(),
cancel_rx: cancel_rx.clone(),
pending_tracker: pending_tracker.clone(),
};
let eval_handle = runtime.spawn(async move {
evaluator.run().await;
});
// Spawn batch collector (reads from CancellableQueue, outputs to mpsc)
let collector_input_queue = batch_queue.clone();
let batch_config = config.batch_config.clone();
let collector_cancel_rx = cancel_rx.clone();
let batch_handle = runtime.spawn(async move {
let collector = BatchCollector::new(
batch_config,
collector_input_queue,
batch_tx,
collector_cancel_rx,
);
collector.run().await;
});
// Spawn precondition awaiter (reads from batch_rx, outputs to precond_tx)
let awaiter_leader = leader.clone();
let precond_handle = runtime.spawn(async move {
let awaiter = PreconditionAwaiter {
input_rx: batch_rx,
output_tx: precond_tx,
leader: awaiter_leader,
};
awaiter.run().await;
});
// Spawn block transfer executor (reads from precond_rx)
let executor = BlockTransferExecutor {
input_rx: precond_rx,
leader,
dst_manager,
src_layout,
dst_layout,
skip_transfers: config.skip_transfers,
max_concurrent_transfers: config.max_concurrent_transfers,
chain_tx,
_src_marker: PhantomData::<Src>,
};
let transfer_handle = runtime.spawn(async move {
executor.run().await;
});
// Spawn cancel sweeper
let sweeper_queues = vec![eval_queue.clone()];
let sweeper_batch_queue = batch_queue;
let sweeper_interval = config.sweep_interval;
let sweeper_cancel_rx = cancel_rx;
let sweeper_handle = runtime.spawn(async move {
cancel_sweeper(
sweeper_queues,
sweeper_batch_queue,
sweeper_cancel_rx,
sweeper_interval,
)
.await;
});
Self {
config,
eval_queue,
output_tx: Some(output_tx),
chain_rx,
cancel_tx,
pending_tracker,
_task_handles: vec![
eval_handle,
batch_handle,
precond_handle,
transfer_handle,
sweeper_handle,
],
_marker: PhantomData,
}
}
/// Enqueue blocks for offloading through this pipeline.
pub(crate) fn enqueue(
&self,
transfer_id: TransferId,
source: SourceBlocks<Src>,
state: Arc<std::sync::Mutex<TransferState>>,
) -> bool {
tracing::debug!(%transfer_id, num_blocks = source.len(), "Pipeline: enqueueing blocks");
let input = PipelineInput {
transfer_id,
source,
state,
};
self.eval_queue.push(transfer_id, input)
}
/// Request cancellation for a transfer.
///
/// This marks the transfer as cancelled in all queues, triggering the sweeper
/// to remove queued items and the evaluator/collector to skip them.
pub fn request_cancel(&self, transfer_id: TransferId) {
// Mark cancelled in queues
self.eval_queue.mark_cancelled(transfer_id);
// Notify sweeper via watch channel
self.cancel_tx.send_modify(|set| {
set.insert(transfer_id);
});
}
/// Check if this pipeline auto-chains to downstream.
pub fn auto_chain(&self) -> bool {
self.config.auto_chain
}
/// Get a clone of the output channel sender.
pub fn output_tx(&self) -> Option<mpsc::Sender<PipelineOutput>> {
self.output_tx.clone()
}
/// Take the chain output receiver for downstream pipeline feeding.
///
/// This transfers ownership of the receiver - can only be called once.
/// When `auto_chain` is enabled, this receiver will yield `ChainOutput<Dst>`
/// containing registered blocks that can be fed to a downstream pipeline.
///
/// # Returns
/// - `Some(rx)` if `auto_chain` is enabled and receiver hasn't been taken
/// - `None` if `auto_chain` is false or receiver was already taken
pub fn take_chain_rx(&mut self) -> Option<ChainOutputRx<Dst>> {
self.chain_rx.take()
}
/// Get the pending tracker for this pipeline.
///
/// This can be shared with presence policies to enable duplicate prevention
/// for blocks currently in-flight through this pipeline.
pub fn pending_tracker(&self) -> &Arc<PendingTracker> {
&self.pending_tracker
}
}
// ============================================================================
// Object Pipeline (for G4 / object storage destinations)
// ============================================================================
/// Configuration for an object storage pipeline.
///
/// Similar to `PipelineConfig` but designed for object storage destinations
/// that don't use a `BlockManager`. The destination is `ObjectBlockOps`.
#[derive(Clone)]
pub struct ObjectPipelineConfig<Src: BlockMetadata> {
/// Policies to evaluate (all must pass)
pub policies: Vec<Arc<dyn OffloadPolicy<Src>>>,
/// Batch configuration
pub batch_config: BatchConfig,
/// Timeout for policy evaluation (fail-fast)
pub policy_timeout: Duration,
/// Channel capacity for evaluation input
pub eval_input_capacity: usize,
/// Channel capacity for batch input
pub batch_input_capacity: usize,
/// Channel capacity for transfer input
pub transfer_input_capacity: usize,
/// Sweep interval for cancellation task
pub sweep_interval: Duration,
/// Skip actual transfers (for testing)
pub skip_transfers: bool,
/// Maximum number of concurrent transfer batches
pub max_concurrent_transfers: usize,
/// Pending tracker for duplicate prevention
pub pending_tracker: Option<Arc<PendingTracker>>,
/// Maximum concurrent precondition awaits
pub max_concurrent_precondition_awaits: usize,
/// Lock manager for distributed locking (optional)
///
/// When provided, the executor will:
/// - Create `.meta` files after successful transfers
/// - Release `.lock` files after transfer completion
pub lock_manager: Option<Arc<dyn ObjectLockManager>>,
/// Marker
_marker: PhantomData<Src>,
}
impl<Src: BlockMetadata> Default for ObjectPipelineConfig<Src> {
fn default() -> Self {
Self {
policies: Vec::new(),
batch_config: BatchConfig::default(),
policy_timeout: Duration::from_millis(100),
eval_input_capacity: 128,
batch_input_capacity: 256,
transfer_input_capacity: 8,
sweep_interval: Duration::from_millis(10),
skip_transfers: false,
max_concurrent_transfers: 1,
pending_tracker: None,
max_concurrent_precondition_awaits: 8,
lock_manager: None,
_marker: PhantomData,
}
}
}
/// Builder for object pipeline configuration.
pub struct ObjectPipelineBuilder<Src: BlockMetadata> {
config: ObjectPipelineConfig<Src>,
}
impl<Src: BlockMetadata> ObjectPipelineBuilder<Src> {
/// Create a new object pipeline builder with defaults.
pub fn new() -> Self {
Self {
config: ObjectPipelineConfig::default(),
}
}
/// Add a policy to the pipeline.
pub fn policy(mut self, policy: Arc<dyn OffloadPolicy<Src>>) -> Self {
self.config.policies.push(policy);
self
}
/// Set batch size.
pub fn batch_size(mut self, size: usize) -> Self {
self.config.batch_config.max_batch_size = size;
self
}
/// Set minimum batch size for flush.
pub fn min_batch_size(mut self, size: usize) -> Self {
self.config.batch_config.min_batch_size = size;
self
}
/// Set batch flush interval.
pub fn flush_interval(mut self, interval: Duration) -> Self {
self.config.batch_config.flush_interval = interval;
self
}
/// Set policy timeout.
pub fn policy_timeout(mut self, timeout: Duration) -> Self {
self.config.policy_timeout = timeout;
self
}
/// Set the sweep interval for cancellation.
pub fn sweep_interval(mut self, interval: Duration) -> Self {
self.config.sweep_interval = interval;
self
}
/// Skip actual transfers (for testing).
pub fn skip_transfers(mut self, skip: bool) -> Self {
self.config.skip_transfers = skip;
self
}
/// Set maximum concurrent transfers.
pub fn max_concurrent_transfers(mut self, n: usize) -> Self {
self.config.max_concurrent_transfers = n.max(1);
self
}
/// Set the pending tracker for duplicate prevention.
pub fn pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.config.pending_tracker = Some(tracker);
self
}
/// Set the lock manager for distributed locking.
///
/// When provided, the executor will create `.meta` files after successful
/// transfers and release `.lock` files after completion.
pub fn lock_manager(mut self, manager: Arc<dyn ObjectLockManager>) -> Self {
self.config.lock_manager = Some(manager);
self
}
/// Build the configuration.
pub fn build(self) -> ObjectPipelineConfig<Src> {
self.config
}
}
impl<Src: BlockMetadata> Default for ObjectPipelineBuilder<Src> {
fn default() -> Self {
Self::new()
}
}
/// A running pipeline instance for object storage destinations.
///
/// Similar to `Pipeline` but uses `ObjectTransferExecutor` for G4 (object storage)
/// instead of `BlockTransferExecutor`. There is no destination `BlockManager`.
#[allow(dead_code)]
pub struct ObjectPipeline<Src: BlockMetadata> {
config: ObjectPipelineConfig<Src>,
/// Input queue for new blocks (CancellableQueue for sweep support)
pub(crate) eval_queue: Arc<CancellableQueue<PipelineInput<Src>>>,
/// Output channel for completed blocks
output_tx: Option<mpsc::Sender<PipelineOutput>>,
/// Watch channel for cancelled transfer IDs (triggers sweep)
cancel_tx: watch::Sender<HashSet<TransferId>>,
/// Tracker for pending (in-flight) transfers to prevent duplicates
pending_tracker: Arc<PendingTracker>,
/// Task handles for pipeline stages
_task_handles: Vec<JoinHandle<()>>,
}
impl<Src: BlockMetadata> ObjectPipeline<Src> {
/// Create a new object pipeline with the given configuration.
///
/// # Arguments
/// * `config` - Pipeline configuration
/// * `object_ops` - Object storage operations (e.g., S3 client)
/// * `src_layout` - Source physical layout for reading block data
/// * `leader` - Instance leader for precondition events
/// * `runtime` - Tokio runtime handle for spawning background tasks
#[allow(clippy::too_many_arguments)]
pub fn new(
config: ObjectPipelineConfig<Src>,
object_ops: Arc<dyn ObjectBlockOps>,
src_layout: LogicalLayoutHandle,
leader: Arc<InstanceLeader>,
runtime: tokio::runtime::Handle,
) -> Self {
// Create cancellable queues
let eval_queue: Arc<CancellableQueue<PipelineInput<Src>>> =
Arc::new(CancellableQueue::new());
let batch_queue: Arc<CancellableQueue<EvalResult<Src>>> = Arc::new(CancellableQueue::new());
// Create output channel
let (output_tx, _output_rx) = mpsc::channel(64);
// Create watch channel for cancelled transfer IDs
let (cancel_tx, cancel_rx) = watch::channel(HashSet::new());
// Create batch output channel (BatchCollector → PreconditionAwaiter)
let (batch_tx, batch_rx) = mpsc::channel(config.transfer_input_capacity);
// Create precondition output channel (PreconditionAwaiter → ObjectTransferExecutor)
let (precond_tx, precond_rx) = mpsc::channel(config.transfer_input_capacity);
// Use provided pending tracker or create a new one
let pending_tracker = config
.pending_tracker
.clone()
.unwrap_or_else(|| Arc::new(PendingTracker::new()));
// Spawn policy evaluator
let evaluator = PolicyEvaluator {
policies: config.policies.clone(),
timeout: config.policy_timeout,
input_queue: eval_queue.clone(),
output_queue: batch_queue.clone(),
cancel_rx: cancel_rx.clone(),
pending_tracker: pending_tracker.clone(),
};
let eval_handle = runtime.spawn(async move {
evaluator.run().await;
});
// Spawn batch collector
let collector_input_queue = batch_queue.clone();
let batch_config = config.batch_config.clone();
let collector_cancel_rx = cancel_rx.clone();
let batch_handle = runtime.spawn(async move {
let collector = BatchCollector::new(
batch_config,
collector_input_queue,
batch_tx,
collector_cancel_rx,
);
collector.run().await;
});
// Spawn precondition awaiter
let awaiter_leader = leader.clone();
let precond_handle = runtime.spawn(async move {
let awaiter = PreconditionAwaiter {
input_rx: batch_rx,
output_tx: precond_tx,
leader: awaiter_leader,
};
awaiter.run().await;
});
// Spawn object transfer executor
let executor = ObjectTransferExecutor::new(
precond_rx,
object_ops,
src_layout,
config.skip_transfers,
config.max_concurrent_transfers,
config.lock_manager.clone(),
);
let transfer_handle = runtime.spawn(async move {
executor.run().await;
});
// Spawn cancel sweeper
let sweeper_queues = vec![eval_queue.clone()];
let sweeper_batch_queue = batch_queue;
let sweeper_interval = config.sweep_interval;
let sweeper_cancel_rx = cancel_rx;
let sweeper_handle = runtime.spawn(async move {
cancel_sweeper(
sweeper_queues,
sweeper_batch_queue,
sweeper_cancel_rx,
sweeper_interval,
)
.await;
});
Self {
config,
eval_queue,
output_tx: Some(output_tx),
cancel_tx,
pending_tracker,
_task_handles: vec![
eval_handle,
batch_handle,
precond_handle,
transfer_handle,
sweeper_handle,
],
}
}
/// Enqueue blocks for offloading through this pipeline.
pub(crate) fn enqueue(
&self,
transfer_id: TransferId,
source: SourceBlocks<Src>,
state: Arc<std::sync::Mutex<TransferState>>,
) -> bool {
tracing::debug!(%transfer_id, num_blocks = source.len(), "ObjectPipeline: enqueueing blocks");
let input = PipelineInput {
transfer_id,
source,
state,
};
self.eval_queue.push(transfer_id, input)
}
/// Request cancellation for a transfer.
pub fn request_cancel(&self, transfer_id: TransferId) {
self.eval_queue.mark_cancelled(transfer_id);
self.cancel_tx.send_modify(|set| {
set.insert(transfer_id);
});
}
/// Get a clone of the output channel sender.
#[allow(dead_code)]
pub fn output_tx(&self) -> Option<mpsc::Sender<PipelineOutput>> {
self.output_tx.clone()
}
/// Get the pending tracker for this pipeline.
pub fn pending_tracker(&self) -> &Arc<PendingTracker> {
&self.pending_tracker
}
}
/// Sweeper task that removes cancelled items from queues.
async fn cancel_sweeper<Src: BlockMetadata>(
input_queues: Vec<Arc<CancellableQueue<PipelineInput<Src>>>>,
batch_queue: Arc<CancellableQueue<EvalResult<Src>>>,
mut cancel_rx: watch::Receiver<HashSet<TransferId>>,
interval: Duration,
) {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {
// Sweep all queues
for queue in &input_queues {
let removed = queue.sweep();
if removed > 0 {
tracing::debug!("Sweeper removed {} cancelled input items", removed);
}
}
let batch_removed = batch_queue.sweep();
if batch_removed > 0 {
tracing::debug!("Sweeper removed {} cancelled batch items", batch_removed);
}
}
result = cancel_rx.changed() => {
if result.is_err() {
// Channel closed, shutdown
break;
}
// New cancellation added, sweep immediately
for queue in &input_queues {
queue.sweep();
}
batch_queue.sweep();
}
}
}
}
/// Policy evaluator stage.
struct PolicyEvaluator<T: BlockMetadata> {
policies: Vec<Arc<dyn OffloadPolicy<T>>>,
timeout: Duration,
input_queue: Arc<CancellableQueue<PipelineInput<T>>>,
output_queue: Arc<CancellableQueue<EvalResult<T>>>,
cancel_rx: watch::Receiver<HashSet<TransferId>>,
/// Tracker for pending transfers - guards are created when blocks pass policy
pending_tracker: Arc<PendingTracker>,
}
impl<T: BlockMetadata> PolicyEvaluator<T> {
async fn run(self) {
let mut poll_interval = tokio::time::interval(Duration::from_micros(100));
poll_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
// Poll for items
if let Some(item) = self.input_queue.pop_valid() {
self.evaluate(item.data).await;
} else {
// No items available, wait a bit
poll_interval.tick().await;
}
// Check for shutdown (cancel channel closed)
if self.cancel_rx.has_changed().is_err() {
break;
}
}
}
async fn evaluate(&self, input: PipelineInput<T>) {
nvtx_range!("offload::policy");
let transfer_id = input.transfer_id;
// Set total_expected_blocks for per-transfer sentinel flush
let total_blocks = input.source.len();
{
let mut state = input.state.lock().unwrap();
state.total_expected_blocks = total_blocks;
}
// Check if already cancelled (via queue or via handle)
{
let state = input.state.lock().unwrap();
if state.is_cancel_requested() {
drop(state); // Release lock before calling set_cancelled
tracing::debug!(%transfer_id, "Transfer cancelled before evaluation");
let mut state = input.state.lock().unwrap();
state.set_cancelled();
return;
}
}
let mut passed = Vec::new();
let mut filtered = Vec::new();
// Process blocks based on source type
match input.source {
SourceBlocks::External(external_blocks) => {
// External blocks (e.g., G1 from vLLM) still need policy evaluation
// to check presence in destination tier
for ext in external_blocks {
// Check for cancellation between blocks
if self.check_cancelled(&input.state, transfer_id) {
return;
}
// Create context with sequence_hash - block_id is known for External
let ctx = EvalContext::from_external(ext.block_id, ext.sequence_hash);
let pass = self.evaluate_policies(&ctx).await;
if pass {
// Create pending guard for duplicate prevention
let pending_guard = self.pending_tracker.guard(ext.sequence_hash);
passed.push(QueuedBlock {
transfer_id,
block_id: Some(ext.block_id),
sequence_hash: ext.sequence_hash,
source: SourceBlock::External(ext),
state: input.state.clone(),
pending_guard: Some(pending_guard),
});
} else {
filtered.push(ext.block_id);
}
}
tracing::debug!(%transfer_id, passed = passed.len(), filtered = filtered.len(), "External blocks evaluated");
}
SourceBlocks::Strong(strong_blocks) => {
// Strong blocks get full policy evaluation
for block in strong_blocks {
// Check for cancellation between blocks
if self.check_cancelled(&input.state, transfer_id) {
return;
}
let ctx = EvalContext::new(block);
let pass = self.evaluate_policies(&ctx).await;
if pass {
let block = ctx.block.expect("Strong block context always has block");
// Create pending guard for duplicate prevention
let pending_guard = self.pending_tracker.guard(ctx.sequence_hash);
passed.push(QueuedBlock {
transfer_id,
block_id: Some(ctx.block_id),
sequence_hash: ctx.sequence_hash,
source: SourceBlock::Strong(block),
state: input.state.clone(),
pending_guard: Some(pending_guard),
});
} else {
filtered.push(ctx.block_id);
}
}
}
SourceBlocks::Weak(weak_blocks) => {
// Weak blocks get policy evaluation using metadata (deferred upgrade)
// block_id is unknown until upgrade at transfer time
for weak in weak_blocks {
// Check for cancellation between blocks
if self.check_cancelled(&input.state, transfer_id) {
return;
}
let sequence_hash = weak.sequence_hash();
let ctx = EvalContext::from_weak(BlockId::default(), sequence_hash);
let pass = self.evaluate_policies(&ctx).await;
if pass {
// Create pending guard for duplicate prevention
let pending_guard = self.pending_tracker.guard(sequence_hash);
passed.push(QueuedBlock {
transfer_id,
block_id: None, // Determined at upgrade time
sequence_hash,
source: SourceBlock::Weak(weak),
state: input.state.clone(),
pending_guard: Some(pending_guard),
});
} else {
// For weak blocks, we track by sequence_hash since block_id is unknown
// We'll add sequence_hash tracking in TransferState if needed
tracing::debug!(%transfer_id, ?sequence_hash, "Weak block filtered by policy");
}
}
}
}
// Check for cancellation after evaluation
{
let state = input.state.lock().unwrap();
if state.is_cancel_requested() {
drop(state);
tracing::debug!(%transfer_id, "Transfer cancelled after evaluation");
let mut state = input.state.lock().unwrap();
state.set_cancelled();
return;
}
}
tracing::debug!(%transfer_id, passed = passed.len(), filtered = filtered.len(), "Policy evaluation complete");
// Update state with evaluation results
{
let mut state = input.state.lock().unwrap();
// Only track block_ids for blocks that have them (External/Strong)
// Weak blocks don't have block_id until upgrade
state.add_passed(passed.iter().filter_map(|b| b.block_id));
state.add_filtered(filtered.iter().copied());
state.set_status(TransferStatus::Queued);
}
// Check if all blocks were filtered (transfer complete with no transfers)
if passed.is_empty() {
tracing::debug!(%transfer_id, "All blocks filtered, completing transfer");
let mut state = input.state.lock().unwrap();
state.set_complete();
return;
}
// Send to batch collector
let result = EvalResult {
transfer_id,
passed_blocks: passed,
filtered_ids: filtered,
state: input.state,
};
if !self.output_queue.push(transfer_id, result) {
tracing::debug!(%transfer_id, "Push to output queue failed (cancelled)");
}
}
/// Check if transfer is cancelled and handle state update.
fn check_cancelled(
&self,
state: &Arc<std::sync::Mutex<TransferState>>,
transfer_id: TransferId,
) -> bool {
let state_guard = state.lock().unwrap();
if state_guard.is_cancel_requested() {
drop(state_guard);
tracing::debug!(%transfer_id, "Transfer cancelled mid-evaluation");
let mut state_guard = state.lock().unwrap();
state_guard.set_cancelled();
true
} else {
false
}
}
async fn evaluate_policies(&self, ctx: &EvalContext<T>) -> bool {
for policy in &self.policies {
let eval_future = policy.evaluate(ctx);
let timed_result = tokio::time::timeout(self.timeout, async {
match eval_future {
Either::Left(ready) => ready.await,
Either::Right(boxed) => boxed.await,
}
})
.await;
match timed_result {
Ok(Ok(true)) => continue,
Ok(Ok(false)) => return false,
Ok(Err(e)) => {
tracing::warn!("Policy {} error: {}", policy.name(), e);
return false;
}
Err(_) => {
tracing::warn!("Policy {} timed out", policy.name());
return false;
}
}
}
true
}
}
// ============================================================================
// Block Upgrader Types
// ============================================================================
/// A resolved block ready for transfer execution.
///
/// Created during the upgrade stage when `WeakBlock` references are upgraded
/// to `ImmutableBlock` guards. This type is used by both `BlockTransferExecutor`
/// and `ObjectTransferExecutor`.
pub struct ResolvedBlock<T: BlockMetadata> {
/// Transfer ID this block belongs to
pub transfer_id: TransferId,
/// Block ID in the source layout
pub block_id: BlockId,
/// Sequence hash identifying the block content
pub sequence_hash: SequenceHash,
/// Guard holding the block - Some for Strong/Weak, None for External.
/// The guard is held to prevent eviction during transfer.
#[allow(dead_code)]
pub guard: Option<ImmutableBlock<T>>,
/// Transfer state for progress tracking
pub(crate) state: Arc<std::sync::Mutex<TransferState>>,
}
/// A batch of resolved blocks ready for transfer.
///
/// This is the output of the block upgrade stage and input to transfer executors.
pub struct ResolvedBatch<T: BlockMetadata> {
/// Resolved blocks ready for transfer
pub blocks: Vec<ResolvedBlock<T>>,
/// Sequence hashes of blocks that were evicted during upgrade
#[allow(dead_code)]
pub evicted: Vec<SequenceHash>,
/// Timing trace from the original batch (batch-level, not per-block)
pub timing: TimingTrace,
}
impl<T: BlockMetadata> ResolvedBatch<T> {
/// Check if the batch has any resolved blocks.
pub fn is_empty(&self) -> bool {
self.blocks.is_empty()
}
/// Get the number of resolved blocks.
#[allow(dead_code)]
pub fn len(&self) -> usize {
self.blocks.len()
}
}
/// Upgrade a batch of queued blocks by resolving weak references.
///
/// This is the "block upgrader" stage that converts `TransferBatch` (containing
/// mixed `SourceBlock` types) into `ResolvedBatch` (containing only resolved
/// `ImmutableBlock` guards).
///
/// # Block Type Handling
///
/// - `Strong`: Already have a guard, pass through directly
/// - `External`: No guard needed, caller holds the reference
/// - `Weak`: Attempt upgrade; if evicted, record in `evicted` list
///
/// This function is synchronous CPU work that can run in an "on-deck" slot
/// while other transfers are executing.
pub fn upgrade_batch<T: BlockMetadata>(batch: TransferBatch<T>) -> ResolvedBatch<T> {
let mut resolved: Vec<ResolvedBlock<T>> = Vec::with_capacity(batch.len());
let mut evicted: Vec<SequenceHash> = Vec::new();
// Copy timing from batch and mark transfer start (O(1), not per-block)
let mut timing = batch.timing;
timing.mark_transfer_start();
for queued in batch.blocks {
// Note: pending_guard is automatically dropped when QueuedBlock is processed,
// which removes the sequence_hash from the pending set. This happens either
// when the block is resolved and transferred, or when it's evicted/dropped.
match queued.source {
SourceBlock::Strong(block) => {
resolved.push(ResolvedBlock {
transfer_id: queued.transfer_id,
block_id: block.block_id(),
sequence_hash: queued.sequence_hash,
guard: Some(block),
state: queued.state,
});
}
SourceBlock::External(ext) => {
resolved.push(ResolvedBlock {
transfer_id: queued.transfer_id,
block_id: ext.block_id,
sequence_hash: ext.sequence_hash,
guard: None,
state: queued.state,
});
}
SourceBlock::Weak(weak) => match weak.upgrade() {
Some(block) => {
resolved.push(ResolvedBlock {
transfer_id: queued.transfer_id,
block_id: block.block_id(),
sequence_hash: queued.sequence_hash,
guard: Some(block),
state: queued.state,
});
}
None => {
tracing::debug!(
sequence_hash = ?queued.sequence_hash,
"Weak block evicted before transfer"
);
evicted.push(queued.sequence_hash);
}
},
}
}
ResolvedBatch {
blocks: resolved,
evicted,
timing,
}
}
// ============================================================================
// Precondition Awaiter
// ============================================================================
/// Precondition awaiter stage.
///
/// Sits between BatchCollector and the transfer executors, awaiting precondition events
/// before forwarding batches. Spawns unbounded tasks to ensure all preconditions
/// are awaited - event awaiting is cheap (just waiting, no compute), so we never
/// skip awaiting a precondition to prevent deadlock scenarios.
struct PreconditionAwaiter<T: BlockMetadata> {
input_rx: BatchOutputRx<T>,
output_tx: mpsc::Sender<TransferBatch<T>>,
leader: Arc<InstanceLeader>,
}
impl<T: BlockMetadata> PreconditionAwaiter<T> {
async fn run(mut self) {
// NO SEMAPHORE - spawn unbounded tasks
// Event awaiting is cheap, we must never skip awaiting a precondition
while let Some(mut batch) = self.input_rx.recv().await {
let output_tx = self.output_tx.clone();
let nova = self.leader.messenger().clone();
// Spawn task for each batch - unbounded
tokio::spawn(async move {
nvtx_range!("offload::precondition");
if let Some(event_handle) = batch.precondition {
tracing::debug!(?event_handle, "Awaiting precondition for batch");
// Create awaiter (returns Result<LocalEventWaiter, Error>)
let awaiter_result = nova.events().awaiter(event_handle);
match awaiter_result {
Ok(awaiter) => {
// Now await the LocalEventWaiter with timeout
match tokio::time::timeout(Duration::from_secs(300), awaiter).await {
Ok(Ok(())) => {
tracing::debug!(?event_handle, "Precondition satisfied");
}
Ok(Err(poison)) => {
tracing::error!(
?event_handle,
?poison,
"Precondition poisoned, marking all blocks as failed"
);
// Mark all blocks as failed
for queued in batch.blocks {
let mut state = queued.state.lock().unwrap();
state.set_error(format!(
"precondition poisoned: {:?}",
poison
));
}
return;
}
Err(_) => {
tracing::error!(
?event_handle,
"Precondition timeout after 30s"
);
// Mark all blocks as failed
for queued in batch.blocks {
let mut state = queued.state.lock().unwrap();
state.set_error("precondition timeout".to_string());
}
return;
}
}
}
Err(e) => {
tracing::error!(?event_handle, ?e, "Failed to create awaiter");
// Mark all blocks as failed
for queued in batch.blocks {
let mut state = queued.state.lock().unwrap();
state.set_error(format!("failed to create awaiter: {}", e));
}
return;
}
}
}
// Mark precondition complete (batch-level, O(1))
batch.timing.mark_precondition_complete();
// Forward batch to transfer executor
if let Err(e) = output_tx.send(batch).await {
tracing::error!("Failed to forward batch after precondition: {}", e);
}
});
}
}
}
// ============================================================================
// Block Transfer Executor (for G2, G3 destinations)
// ============================================================================
/// Block transfer executor stage for BlockManager-based destinations.
///
/// Executes transfers to destinations with a `BlockManager` (G2, G3).
/// Uses `leader.execute_local_transfer()` to copy block data between layouts.
///
/// For object storage destinations (G4), use `ObjectTransferExecutor` instead.
struct BlockTransferExecutor<Src: BlockMetadata, Dst: BlockMetadata> {
input_rx: BatchOutputRx<Src>,
leader: Arc<InstanceLeader>,
dst_manager: Arc<BlockManager<Dst>>,
src_layout: LogicalLayoutHandle,
dst_layout: LogicalLayoutHandle,
/// Skip actual transfers (for testing)
skip_transfers: bool,
/// Maximum concurrent transfers
max_concurrent_transfers: usize,
/// Channel to send registered blocks for chaining to downstream pipeline
chain_tx: Option<mpsc::Sender<ChainOutput<Dst>>>,
_src_marker: PhantomData<Src>,
}
/// Shared state for BlockTransferExecutor that can be cloned across concurrent tasks.
struct SharedBlockExecutorState<Dst: BlockMetadata> {
leader: Arc<InstanceLeader>,
dst_manager: Arc<BlockManager<Dst>>,
src_layout: LogicalLayoutHandle,
dst_layout: LogicalLayoutHandle,
skip_transfers: bool,
chain_tx: Option<mpsc::Sender<ChainOutput<Dst>>>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> BlockTransferExecutor<Src, Dst> {
async fn run(mut self) {
// N slots for active transfers
let transfer_semaphore = Arc::new(Semaphore::new(self.max_concurrent_transfers));
// 1 slot for preparation (upgrade) work - on-deck
let prepare_semaphore = Arc::new(Semaphore::new(1));
// Extract shared state for concurrent tasks
let shared = Arc::new(SharedBlockExecutorState {
leader: self.leader.clone(),
dst_manager: self.dst_manager.clone(),
src_layout: self.src_layout,
dst_layout: self.dst_layout,
skip_transfers: self.skip_transfers,
chain_tx: self.chain_tx.take(),
});
while let Some(batch) = self.input_rx.recv().await {
if batch.is_empty() {
continue;
}
// Wait for prepare slot (only 1 batch preparing at a time)
// This is the "on-deck" slot for preparing while transfers run
let prepare_permit = prepare_semaphore.clone().acquire_owned().await;
if prepare_permit.is_err() {
break; // Semaphore closed
}
let prepare_permit = prepare_permit.unwrap();
// Prepare stage: resolve/upgrade blocks (weak→strong)
// This happens in the "on-deck" slot while other transfers may be running
let upgraded = upgrade_batch(batch);
// Done preparing, release prepare slot for next batch
drop(prepare_permit);
if upgraded.is_empty() {
tracing::debug!("All blocks in batch evicted, skipping transfer");
continue;
}
// Now wait for transfer slot
let transfer_permit = transfer_semaphore.clone().acquire_owned().await;
if transfer_permit.is_err() {
break; // Semaphore closed
}
let transfer_permit = transfer_permit.unwrap();
// Spawn transfer task
let shared_clone = shared.clone();
tokio::spawn(async move {
let _permit = transfer_permit; // Hold permit until task completes
if let Err(e) = Self::execute_transfer(&shared_clone, upgraded).await {
tracing::error!("BlockTransferExecutor: transfer failed: {}", e);
}
});
}
// Wait for all in-flight transfers to complete by acquiring all permits
let _ = transfer_semaphore
.acquire_many(self.max_concurrent_transfers as u32)
.await;
}
/// Execute the actual transfer for resolved blocks.
///
/// This is async I/O work that runs concurrently with other transfers.
async fn execute_transfer(
shared: &SharedBlockExecutorState<Dst>,
mut batch: ResolvedBatch<Src>,
) -> anyhow::Result<()> {
nvtx_range!("offload::transfer");
if batch.is_empty() {
return Ok(());
}
let resolved = &batch.blocks;
// Collect block_ids and sequence_hashes from resolved blocks
let src_block_ids: Vec<BlockId> = resolved.iter().map(|b| b.block_id).collect();
let sequence_hashes: Vec<SequenceHash> = resolved.iter().map(|b| b.sequence_hash).collect();
// Collect states for completion tracking (group by transfer_id)
let mut transfer_states: std::collections::HashMap<
TransferId,
(Arc<std::sync::Mutex<TransferState>>, Vec<BlockId>),
> = std::collections::HashMap::new();
for block in resolved {
transfer_states
.entry(block.transfer_id)
.or_insert_with(|| (block.state.clone(), Vec::new()))
.1
.push(block.block_id);
}
// Skip actual transfers when in test mode
if !shared.skip_transfers {
// Allocate destination blocks
let dst_blocks = shared
.dst_manager
.allocate_blocks(resolved.len())
.ok_or_else(|| {
anyhow::anyhow!("Failed to allocate {} destination blocks", resolved.len())
})?;
let dst_block_ids: Vec<BlockId> = dst_blocks.iter().map(|b| b.block_id()).collect();
// Execute transfer via leader
let start_xfer = Instant::now();
let notification = shared.leader.execute_local_transfer(
shared.src_layout,
shared.dst_layout,
src_block_ids.clone(),
dst_block_ids.clone(),
TransferOptions::default(),
)?;
// Wait for transfer completion
notification.await?;
let end_xfer = Instant::now();
// Register each transferred block in the destination tier
let registered_blocks: Vec<ImmutableBlock<Dst>> = dst_blocks
.into_iter()
.zip(sequence_hashes.iter())
.map(|(dst_block, seq_hash)| {
let complete = dst_block
.stage(*seq_hash, shared.dst_manager.block_size())
.expect("block size mismatch");
shared.dst_manager.register_block(complete)
})
.collect();
let registration_timepoint = Instant::now();
// Compute timing statistics from batch timing (O(1), not per-block)
let unique_transfer_ids: std::collections::HashSet<_> =
resolved.iter().map(|b| b.transfer_id).collect();
let policy_ms = batch
.timing
.policy_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let precondition_ms = batch
.timing
.precondition_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let total_ms = batch
.timing
.total_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
tracing::info!(
blocks = resolved.len(),
containers = unique_transfer_ids.len(),
policy_ms,
precondition_ms,
xfer_ms = end_xfer.duration_since(start_xfer).as_millis() as u64,
registration_ms =
registration_timepoint.duration_since(end_xfer).as_millis() as u64,
total_ms,
src = std::any::type_name::<Src>(),
dst = std::any::type_name::<Dst>(),
"Batch transfer complete"
);
// Send registered blocks to downstream pipeline if chaining is enabled
if let Some(chain_tx) = &shared.chain_tx {
#[allow(clippy::type_complexity)]
let mut chain_outputs: std::collections::HashMap<
TransferId,
(
Arc<std::sync::Mutex<TransferState>>,
Vec<ImmutableBlock<Dst>>,
),
> = std::collections::HashMap::new();
for (registered, resolved_block) in
registered_blocks.into_iter().zip(resolved.iter())
{
chain_outputs
.entry(resolved_block.transfer_id)
.or_insert_with(|| (resolved_block.state.clone(), Vec::new()))
.1
.push(registered);
}
for (transfer_id, (state, blocks)) in chain_outputs {
let output = ChainOutput {
transfer_id,
blocks,
state,
};
if chain_tx.send(output).await.is_err() {
tracing::warn!(
%transfer_id,
"Chain channel closed, downstream pipeline unavailable"
);
} else {
tracing::debug!(
%transfer_id,
"Sent blocks to chain output for downstream processing"
);
}
}
}
}
// Mark transfer complete (batch-level, O(1))
batch.timing.mark_transfer_complete();
// Mark blocks as completed in each transfer state
for (transfer_id, (state, block_ids)) in transfer_states {
let mut state_guard = state.lock().unwrap();
state_guard.mark_completed(block_ids);
let total = state_guard.passed_blocks.len() + state_guard.filtered_out.len();
let done = state_guard.completed.len() + state_guard.filtered_out.len();
tracing::debug!(
%transfer_id,
total,
done,
passed = state_guard.passed_blocks.len(),
filtered = state_guard.filtered_out.len(),
completed = state_guard.completed.len(),
"Transfer batch progress"
);
if done >= total && total > 0 {
state_guard.set_complete();
}
}
Ok(())
}
}
// ============================================================================
// Object Transfer Executor (for G4 / object storage destinations)
// ============================================================================
/// Object transfer executor stage for object storage destinations.
///
/// Executes transfers to object storage (G4) via `ObjectBlockOps::put_blocks()`.
/// Unlike `BlockTransferExecutor`, this does not require a destination `BlockManager`.
///
/// # Source Requirements
///
/// The source blocks must be `ImmutableBlock<Src>` (post-upgrade). The executor:
/// 1. Receives `ResolvedBlock<Src>` from the upgrade stage
/// 2. Extracts `SequenceHash` as the object key
/// 3. Calls `ObjectBlockOps::put_blocks()` with the source layout
///
/// # Lock Management
///
/// When a `lock_manager` is provided, after successful transfers:
/// 1. Creates `.meta` file to mark block as offloaded
/// 2. Releases `.lock` file to allow other instances to proceed
///
/// # No Destination Registration
///
/// Object storage is external - there's no local `BlockManager<G4>` to register with.
/// The object is simply stored at the key derived from `SequenceHash`.
pub struct ObjectTransferExecutor<Src: BlockMetadata> {
/// Input channel from the batch/precondition stage
input_rx: BatchOutputRx<Src>,
/// Object storage operations
object_ops: Arc<dyn ObjectBlockOps>,
/// Source logical layout handle for reading block data
/// The ObjectBlockOps implementation resolves this to a physical layout
src_layout: LogicalLayoutHandle,
/// Skip actual transfers (for testing)
skip_transfers: bool,
/// Maximum concurrent transfer batches
max_concurrent_transfers: usize,
/// Optional lock manager for creating meta files and releasing locks
lock_manager: Option<Arc<dyn ObjectLockManager>>,
}
/// Shared state for ObjectTransferExecutor that can be cloned across concurrent tasks.
struct SharedObjectExecutorState {
object_ops: Arc<dyn ObjectBlockOps>,
src_layout: LogicalLayoutHandle,
skip_transfers: bool,
lock_manager: Option<Arc<dyn ObjectLockManager>>,
}
impl<Src: BlockMetadata> ObjectTransferExecutor<Src> {
/// Create a new object transfer executor.
#[allow(dead_code)]
pub fn new(
input_rx: BatchOutputRx<Src>,
object_ops: Arc<dyn ObjectBlockOps>,
src_layout: LogicalLayoutHandle,
skip_transfers: bool,
max_concurrent_transfers: usize,
lock_manager: Option<Arc<dyn ObjectLockManager>>,
) -> Self {
Self {
input_rx,
object_ops,
src_layout,
skip_transfers,
max_concurrent_transfers,
lock_manager,
}
}
/// Run the executor loop.
pub async fn run(mut self) {
// N slots for active transfers
let transfer_semaphore = Arc::new(Semaphore::new(self.max_concurrent_transfers));
// 1 slot for preparation (upgrade) work - on-deck
let prepare_semaphore = Arc::new(Semaphore::new(1));
// Extract shared state for concurrent tasks
let shared = Arc::new(SharedObjectExecutorState {
object_ops: self.object_ops.clone(),
src_layout: self.src_layout,
skip_transfers: self.skip_transfers,
lock_manager: self.lock_manager.clone(),
});
while let Some(batch) = self.input_rx.recv().await {
if batch.is_empty() {
continue;
}
// Wait for prepare slot (only 1 batch preparing at a time)
let prepare_permit = prepare_semaphore.clone().acquire_owned().await;
if prepare_permit.is_err() {
break; // Semaphore closed
}
let prepare_permit = prepare_permit.unwrap();
// Prepare stage: resolve/upgrade blocks (weak→strong)
let upgraded = upgrade_batch(batch);
// Done preparing, release prepare slot for next batch
drop(prepare_permit);
if upgraded.is_empty() {
tracing::debug!("All blocks in batch evicted, skipping object transfer");
continue;
}
// Now wait for transfer slot
let transfer_permit = transfer_semaphore.clone().acquire_owned().await;
if transfer_permit.is_err() {
break; // Semaphore closed
}
let transfer_permit = transfer_permit.unwrap();
// Spawn transfer task
let shared_clone = shared.clone();
tokio::spawn(async move {
let _permit = transfer_permit; // Hold permit until task completes
if let Err(e) = Self::execute_transfer(&shared_clone, upgraded).await {
tracing::error!("ObjectTransferExecutor: transfer failed: {}", e);
}
});
}
// Wait for all in-flight transfers to complete by acquiring all permits
let _ = transfer_semaphore
.acquire_many(self.max_concurrent_transfers as u32)
.await;
}
/// Execute the actual transfer for resolved blocks to object storage.
async fn execute_transfer(
shared: &SharedObjectExecutorState,
mut batch: ResolvedBatch<Src>,
) -> anyhow::Result<()> {
nvtx_range!("offload::transfer");
if batch.is_empty() {
return Ok(());
}
let resolved = &batch.blocks;
// Collect keys (sequence hashes) and block_ids from resolved blocks
let keys: Vec<SequenceHash> = resolved.iter().map(|b| b.sequence_hash).collect();
let block_ids: Vec<BlockId> = resolved.iter().map(|b| b.block_id).collect();
// Collect states for completion tracking (group by transfer_id)
let mut transfer_states: std::collections::HashMap<
TransferId,
(Arc<std::sync::Mutex<TransferState>>, Vec<BlockId>),
> = std::collections::HashMap::new();
for block in resolved {
transfer_states
.entry(block.transfer_id)
.or_insert_with(|| (block.state.clone(), Vec::new()))
.1
.push(block.block_id);
}
// Track successfully transferred sequence hashes for lock management
let mut successful_hashes: Vec<SequenceHash> = Vec::new();
// Skip actual transfers when in test mode
if !shared.skip_transfers {
// Execute object put via ObjectBlockOps
let results = shared
.object_ops
.put_blocks(keys.clone(), shared.src_layout, block_ids)
.await;
// Guard: put_blocks must return exactly one result per input block.
// If mismatched, mark all blocks as failed since we can't correlate results.
if results.len() != keys.len() {
tracing::error!(
expected = keys.len(),
actual = results.len(),
"put_blocks returned mismatched result count"
);
for (_transfer_id, (state, block_ids)) in transfer_states {
let mut state_guard = state.lock().unwrap();
state_guard.mark_failed(block_ids);
state_guard
.set_error("put_blocks returned mismatched result count".to_string());
}
return Ok(());
}
// Log results and track successful transfers
let mut success_count = 0;
let mut fail_count = 0;
for result in results {
match result {
Ok(hash) => {
success_count += 1;
successful_hashes.push(hash);
}
Err(hash) => {
fail_count += 1;
tracing::warn!(?hash, "Failed to transfer block to object storage");
}
}
}
if fail_count > 0 {
tracing::warn!(
success = success_count,
failed = fail_count,
"Object transfer partially failed"
);
} else {
tracing::debug!(
num_blocks = success_count,
"Successfully transferred blocks to object storage"
);
}
// todo: merge the else part of this conditional and perhaps add the event tap for the successful transfers
// for block transfers we emit an event as part of registration; however, we don't register g4 blocks in the
// same way; therefore, we need a new convention on how we inform the broader system of the object creation
// Create meta files and release locks for successful transfers
if let Some(lock_manager) = &shared.lock_manager {
for hash in &successful_hashes {
// Create meta file to mark block as offloaded
if let Err(e) = lock_manager.create_meta(*hash).await {
tracing::error!(?hash, error = %e, "Failed to create meta file");
}
// Release lock
if let Err(e) = lock_manager.release_lock(*hash).await {
tracing::error!(?hash, error = %e, "Failed to release lock");
}
}
tracing::debug!(
num_blocks = successful_hashes.len(),
"Created meta files and released locks"
);
}
} else {
// In skip mode, still do lock management if configured
if let Some(lock_manager) = &shared.lock_manager {
for hash in &keys {
if let Err(e) = lock_manager.create_meta(*hash).await {
tracing::error!(?hash, error = %e, "Failed to create meta file");
}
if let Err(e) = lock_manager.release_lock(*hash).await {
tracing::error!(?hash, error = %e, "Failed to release lock");
}
}
}
}
// Mark transfer complete (batch-level, O(1))
batch.timing.mark_transfer_complete();
// Compute timing statistics from batch timing
let unique_transfer_ids: std::collections::HashSet<_> =
resolved.iter().map(|b| b.transfer_id).collect();
let policy_ms = batch
.timing
.policy_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let precondition_ms = batch
.timing
.precondition_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let transfer_ms = batch
.timing
.transfer_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let total_ms = batch
.timing
.total_duration()
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
tracing::info!(
blocks = resolved.len(),
containers = unique_transfer_ids.len(),
policy_ms,
precondition_ms,
transfer_ms,
total_ms,
src = std::any::type_name::<Src>(),
dst = "G4-object",
"Object batch transfer complete"
);
// Build success lookup for filtering completion tracking.
//
// INVARIANT: SequenceHash values within a batch must be unique. This is
// enforced by PendingTracker in PolicyEvaluator — each block's pending guard
// is inserted into a DashSet before the next block is evaluated, so duplicate
// hashes are filtered out. If this invariant is violated, success/failure
// correlation becomes ambiguous because put_blocks() returns Result<SequenceHash, _>
// without block-level identity (and S3 uses buffer_unordered, losing input order).
let block_to_hash: std::collections::HashMap<BlockId, SequenceHash> = resolved
.iter()
.map(|b| (b.block_id, b.sequence_hash))
.collect();
let success_set: std::collections::HashSet<SequenceHash> =
successful_hashes.into_iter().collect();
debug_assert_eq!(
block_to_hash.len(),
resolved.len(),
"duplicate BlockId in batch — block_to_hash would lose entries"
);
debug_assert_eq!(
resolved
.iter()
.map(|b| b.sequence_hash)
.collect::<std::collections::HashSet<_>>()
.len(),
resolved.len(),
"duplicate SequenceHash in batch — hash-based success correlation is ambiguous"
);
// Mark blocks as completed/failed in each transfer state
for (transfer_id, (state, block_ids)) in transfer_states {
let mut state_guard = state.lock().unwrap();
if shared.skip_transfers {
// In test/skip mode, all blocks are considered successful
state_guard.mark_completed(block_ids);
} else {
let (succeeded, failed): (Vec<_>, Vec<_>) = block_ids.into_iter().partition(|id| {
block_to_hash
.get(id)
.is_some_and(|h| success_set.contains(h))
});
state_guard.mark_completed(succeeded);
if !failed.is_empty() {
tracing::warn!(
%transfer_id,
failed_count = failed.len(),
"Marking blocks as failed in transfer state"
);
state_guard.mark_failed(failed);
}
}
let total = state_guard.passed_blocks.len() + state_guard.filtered_out.len();
let done = state_guard.completed.len()
+ state_guard.failed.len()
+ state_guard.filtered_out.len();
tracing::debug!(
%transfer_id,
total,
done,
passed = state_guard.passed_blocks.len(),
filtered = state_guard.filtered_out.len(),
completed = state_guard.completed.len(),
failed = state_guard.failed.len(),
"Object transfer batch progress"
);
if done >= total && total > 0 {
let failed_count = state_guard.failed.len();
if failed_count == 0 {
state_guard.set_complete();
} else {
state_guard.set_error(format!(
"{failed_count} blocks failed to transfer to object storage",
));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_builder() {
let config = PipelineBuilder::<(), ()>::new()
.batch_size(32)
.min_batch_size(8)
.policy_timeout(Duration::from_millis(50))
.auto_chain(true)
.sweep_interval(Duration::from_millis(5))
.build();
assert_eq!(config.batch_config.max_batch_size, 32);
assert_eq!(config.batch_config.min_batch_size, 8);
assert_eq!(config.policy_timeout, Duration::from_millis(50));
assert!(config.auto_chain);
assert_eq!(config.sweep_interval, Duration::from_millis(5));
}
#[test]
fn test_pipeline_config_default() {
let config = PipelineConfig::<(), ()>::default();
assert!(config.policies.is_empty());
assert!(!config.auto_chain);
assert_eq!(config.sweep_interval, Duration::from_millis(10));
}
/// Mock ObjectBlockOps that fails specific hashes.
struct FailableObjectBlockOps {
fail_hashes: std::collections::HashSet<SequenceHash>,
}
impl crate::object::ObjectBlockOps for FailableObjectBlockOps {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> futures::future::BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
Box::pin(async move { keys.into_iter().map(|h| (h, Some(1))).collect() })
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> futures::future::BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
let fail_set = self.fail_hashes.clone();
Box::pin(async move {
keys.into_iter()
.map(|h| if fail_set.contains(&h) { Err(h) } else { Ok(h) })
.collect()
})
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> futures::future::BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Ok).collect() })
}
}
fn test_hash(n: u64) -> SequenceHash {
SequenceHash::new(n, None, 0)
}
#[tokio::test]
async fn test_execute_transfer_partial_failure() {
use crate::offload::handle::{TransferState, TransferStatus};
let hash_ok_1 = test_hash(1);
let hash_fail = test_hash(2);
let hash_ok_2 = test_hash(3);
let fail_hashes = [hash_fail].into_iter().collect();
let object_ops: Arc<dyn crate::object::ObjectBlockOps> =
Arc::new(FailableObjectBlockOps { fail_hashes });
let shared = SharedObjectExecutorState {
object_ops,
src_layout: LogicalLayoutHandle::G2,
skip_transfers: false,
lock_manager: None,
};
let transfer_id = crate::offload::handle::TransferId::new();
let (mut state, handle) = TransferState::new(transfer_id, vec![10, 20, 30]);
state.add_passed(vec![10, 20, 30]);
state.mark_in_flight(vec![10, 20, 30]);
let state_arc = Arc::new(std::sync::Mutex::new(state));
let blocks = vec![
ResolvedBlock::<crate::G2> {
transfer_id,
block_id: 10,
sequence_hash: hash_ok_1,
guard: None,
state: state_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id,
block_id: 20,
sequence_hash: hash_fail,
guard: None,
state: state_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id,
block_id: 30,
sequence_hash: hash_ok_2,
guard: None,
state: state_arc.clone(),
},
];
let mut timing = TimingTrace::new();
timing.mark_policy_complete();
timing.mark_precondition_complete();
let batch = ResolvedBatch {
blocks,
evicted: Vec::new(),
timing,
};
ObjectTransferExecutor::<crate::G2>::execute_transfer(&shared, batch)
.await
.expect("execute_transfer should succeed");
// Verify: block 20 (hash_fail) should be in failed, not completed
let state_guard = state_arc.lock().unwrap();
assert_eq!(state_guard.completed, vec![10, 30]);
assert_eq!(state_guard.failed, vec![20]);
assert_eq!(state_guard.in_flight.len(), 0);
assert_eq!(state_guard.status, TransferStatus::Failed);
assert!(state_guard.error.is_some());
// Handle should reflect the same
drop(state_guard);
assert_eq!(handle.completed_blocks(), vec![10, 30]);
assert_eq!(handle.failed_blocks(), vec![20]);
}
#[tokio::test]
async fn test_execute_transfer_all_success() {
use crate::offload::handle::{TransferState, TransferStatus};
let hash1 = test_hash(1);
let hash2 = test_hash(2);
let object_ops: Arc<dyn crate::object::ObjectBlockOps> = Arc::new(FailableObjectBlockOps {
fail_hashes: std::collections::HashSet::new(),
});
let shared = SharedObjectExecutorState {
object_ops,
src_layout: LogicalLayoutHandle::G2,
skip_transfers: false,
lock_manager: None,
};
let transfer_id = crate::offload::handle::TransferId::new();
let (mut state, handle) = TransferState::new(transfer_id, vec![10, 20]);
state.add_passed(vec![10, 20]);
state.mark_in_flight(vec![10, 20]);
let state_arc = Arc::new(std::sync::Mutex::new(state));
let blocks = vec![
ResolvedBlock::<crate::G2> {
transfer_id,
block_id: 10,
sequence_hash: hash1,
guard: None,
state: state_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id,
block_id: 20,
sequence_hash: hash2,
guard: None,
state: state_arc.clone(),
},
];
let mut timing = TimingTrace::new();
timing.mark_policy_complete();
timing.mark_precondition_complete();
let batch = ResolvedBatch {
blocks,
evicted: Vec::new(),
timing,
};
ObjectTransferExecutor::<crate::G2>::execute_transfer(&shared, batch)
.await
.expect("execute_transfer should succeed");
let state_guard = state_arc.lock().unwrap();
assert_eq!(state_guard.completed, vec![10, 20]);
assert!(state_guard.failed.is_empty());
assert_eq!(state_guard.status, TransferStatus::Complete);
drop(state_guard);
assert_eq!(handle.completed_blocks(), vec![10, 20]);
assert!(handle.failed_blocks().is_empty());
}
/// Mixed batch: two transfer_ids, one partially fails, the other fully succeeds.
#[tokio::test]
async fn test_execute_transfer_mixed_transfers() {
use crate::offload::handle::{TransferState, TransferStatus};
let hash_a1 = test_hash(10);
let hash_a2_fail = test_hash(20); // transfer A, will fail
let hash_b1 = test_hash(30);
let hash_b2 = test_hash(40);
let fail_hashes = [hash_a2_fail].into_iter().collect();
let object_ops: Arc<dyn crate::object::ObjectBlockOps> =
Arc::new(FailableObjectBlockOps { fail_hashes });
let shared = SharedObjectExecutorState {
object_ops,
src_layout: LogicalLayoutHandle::G2,
skip_transfers: false,
lock_manager: None,
};
// Transfer A: blocks 100, 200 (200 will fail)
let tid_a = crate::offload::handle::TransferId::new();
let (mut state_a, handle_a) = TransferState::new(tid_a, vec![100, 200]);
state_a.add_passed(vec![100, 200]);
state_a.mark_in_flight(vec![100, 200]);
let state_a_arc = Arc::new(std::sync::Mutex::new(state_a));
// Transfer B: blocks 300, 400 (both succeed)
let tid_b = crate::offload::handle::TransferId::new();
let (mut state_b, handle_b) = TransferState::new(tid_b, vec![300, 400]);
state_b.add_passed(vec![300, 400]);
state_b.mark_in_flight(vec![300, 400]);
let state_b_arc = Arc::new(std::sync::Mutex::new(state_b));
let blocks = vec![
ResolvedBlock::<crate::G2> {
transfer_id: tid_a,
block_id: 100,
sequence_hash: hash_a1,
guard: None,
state: state_a_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id: tid_a,
block_id: 200,
sequence_hash: hash_a2_fail,
guard: None,
state: state_a_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id: tid_b,
block_id: 300,
sequence_hash: hash_b1,
guard: None,
state: state_b_arc.clone(),
},
ResolvedBlock::<crate::G2> {
transfer_id: tid_b,
block_id: 400,
sequence_hash: hash_b2,
guard: None,
state: state_b_arc.clone(),
},
];
let mut timing = TimingTrace::new();
timing.mark_policy_complete();
timing.mark_precondition_complete();
let batch = ResolvedBatch {
blocks,
evicted: Vec::new(),
timing,
};
ObjectTransferExecutor::<crate::G2>::execute_transfer(&shared, batch)
.await
.expect("execute_transfer should succeed");
// Transfer A: block 100 succeeded, block 200 failed
let sa = state_a_arc.lock().unwrap();
assert_eq!(sa.completed, vec![100]);
assert_eq!(sa.failed, vec![200]);
assert_eq!(sa.status, TransferStatus::Failed);
assert!(sa.error.is_some());
drop(sa);
assert_eq!(handle_a.completed_blocks(), vec![100]);
assert_eq!(handle_a.failed_blocks(), vec![200]);
// Transfer B: both succeeded
let sb = state_b_arc.lock().unwrap();
assert_eq!(sb.completed, vec![300, 400]);
assert!(sb.failed.is_empty());
assert_eq!(sb.status, TransferStatus::Complete);
drop(sb);
assert_eq!(handle_b.completed_blocks(), vec![300, 400]);
assert!(handle_b.failed_blocks().is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Policy trait and built-in implementations for offload filtering.
//!
//! Policies determine which blocks should be offloaded. They are evaluated
//! as filters - blocks that fail any filter are removed from the transfer.
//!
//! # Performance Optimization
//!
//! This module uses `Either<Ready, BoxFuture>` instead of `#[async_trait]` to
//! avoid heap allocations for synchronous policies. Policies that only perform
//! local, synchronous operations (like `PresenceFilter`, `PassAllPolicy`) return
//! `Either::Left(ready(...))` which requires zero heap allocation. Policies that
//! need actual async operations return `Either::Right(Box::pin(...))`.
//!
//! # Built-in Policies
//!
//! - `PresenceFilter<Src, Dst>`: Skip blocks already present in destination tier
//! - `PresenceAndLFUFilter<Src, Dst>`: Presence check + LFU count threshold
//! - `PassAllPolicy`: No filtering (pass all blocks)
//! - `AllOfPolicy`: Composite AND policy
//! - `AnyOfPolicy`: Composite OR policy
use std::future::{Future, Ready, ready};
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::Result;
use futures::future::Either;
use kvbm_config::{PolicyType, TierOffloadConfig};
use crate::{BlockId, SequenceHash};
use kvbm_logical::blocks::{BlockMetadata, BlockRegistry, ImmutableBlock};
use super::pending::{PendingCheck, PendingTracker};
use crate::object::{ObjectBlockOps, ObjectLockManager};
/// Boxed future type for async policy evaluation.
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
/// Future type for single-block policy evaluation.
///
/// - `Left(Ready<...>)`: Synchronous result, zero heap allocation
/// - `Right(BoxFuture<...>)`: Async result, requires heap allocation
pub type PolicyFuture<'a> = Either<Ready<Result<bool>>, BoxFuture<'a, Result<bool>>>;
/// Future type for batch policy evaluation.
///
/// - `Left(Ready<...>)`: Synchronous result, zero heap allocation
/// - `Right(BoxFuture<...>)`: Async result, requires heap allocation
pub type PolicyBatchFuture<'a> = Either<Ready<Result<Vec<bool>>>, BoxFuture<'a, Result<Vec<bool>>>>;
/// Create a synchronous policy result (zero allocation).
#[inline]
pub fn sync_result(result: Result<bool>) -> PolicyFuture<'static> {
Either::Left(ready(result))
}
/// Create a synchronous batch policy result (zero allocation).
#[inline]
pub fn sync_batch_result(result: Result<Vec<bool>>) -> PolicyBatchFuture<'static> {
Either::Left(ready(result))
}
/// Create an async policy result (boxes the future).
#[inline]
pub fn async_result<'a, F>(future: F) -> PolicyFuture<'a>
where
F: Future<Output = Result<bool>> + Send + 'a,
{
Either::Right(Box::pin(future))
}
/// Create an async batch policy result (boxes the future).
#[inline]
pub fn async_batch_result<'a, F>(future: F) -> PolicyBatchFuture<'a>
where
F: Future<Output = Result<Vec<bool>>> + Send + 'a,
{
Either::Right(Box::pin(future))
}
// ============================================================================
// Presence Checker Trait
// ============================================================================
/// Async presence checker for object storage or other external destinations.
///
/// This trait abstracts presence checking for destinations that require async
/// operations (like S3, caching services). Unlike `BlockRegistry::check_presence`
/// which is synchronous, this is designed for remote/external destinations.
///
/// # Implementations
///
/// - `S3PresenceChecker`: Wraps `ObjectBlockOps::has_blocks()` for S3/object storage
/// - Future: `CachedPresenceChecker` - local bloom filter / LRU cache layer
/// - Future: `DistributedCacheChecker` - remote caching service
pub trait PresenceChecker: Send + Sync {
/// Check if blocks exist at the destination.
///
/// Returns a vector of (SequenceHash, exists: bool) pairs.
fn check_presence(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, bool)>>;
}
/// S3/Object storage presence checker.
///
/// Wraps `ObjectBlockOps::has_blocks()` and converts `Option<usize>` → `bool`.
/// This is the default presence checker for G2→G4 (object storage) pipelines.
///
/// # Example
/// ```ignore
/// let object_ops: Arc<dyn ObjectBlockOps> = ...;
/// let checker = S3PresenceChecker::new(object_ops);
/// let results = checker.check_presence(keys).await;
/// ```
pub struct S3PresenceChecker {
object_ops: Arc<dyn ObjectBlockOps>,
}
impl S3PresenceChecker {
/// Create a new S3 presence checker wrapping the given object operations.
pub fn new(object_ops: Arc<dyn ObjectBlockOps>) -> Self {
Self { object_ops }
}
}
impl PresenceChecker for S3PresenceChecker {
fn check_presence(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, bool)>> {
let future = self.object_ops.has_blocks(keys);
Box::pin(async move {
let results = future.await;
// Convert Option<usize> (size) → bool (exists)
results
.into_iter()
.map(|(hash, size_opt)| (hash, size_opt.is_some()))
.collect()
})
}
}
// ============================================================================
// Evaluation Context
// ============================================================================
/// Context provided to policies for block evaluation.
#[derive(Debug)]
pub struct EvalContext<T: BlockMetadata> {
/// Block ID
pub block_id: BlockId,
/// Sequence hash for this block
pub sequence_hash: SequenceHash,
/// Optional strong reference to the block.
/// - Some: Strong blocks (held during evaluation)
/// - None: Weak blocks (deferred upgrade)
pub block: Option<ImmutableBlock<T>>,
}
impl<T: BlockMetadata> EvalContext<T> {
/// Create a new evaluation context from a strong block reference.
pub fn new(block: ImmutableBlock<T>) -> Self {
Self {
block_id: block.block_id(),
sequence_hash: block.sequence_hash(),
block: Some(block),
}
}
/// Create a context for weak block evaluation (deferred upgrade).
///
/// Used when evaluating weak blocks - we have the metadata
/// but defer the actual upgrade until just before transfer.
pub fn from_weak(block_id: BlockId, sequence_hash: SequenceHash) -> Self {
Self {
block_id,
sequence_hash,
block: None,
}
}
/// Create a context for external block evaluation.
///
/// Used when evaluating external blocks (e.g., G1 from vLLM) - we have
/// the block_id and sequence_hash but no ImmutableBlock reference.
pub fn from_external(block_id: BlockId, sequence_hash: SequenceHash) -> Self {
Self {
block_id,
sequence_hash,
block: None,
}
}
}
/// Trait for offload policies that filter blocks.
///
/// Policies are evaluated as a chain - a block must pass ALL policies to proceed.
/// Each policy receives an `EvalContext` with block information and returns
/// `Ok(true)` to pass or `Ok(false)` to filter out.
///
/// # Performance
///
/// This trait uses `Either<Ready, BoxFuture>` instead of `#[async_trait]` to
/// avoid heap allocations for synchronous policies. Implement using:
/// - `sync_result(Ok(true))` for synchronous policies (zero allocation)
/// - `async_result(async { ... })` for async policies (boxes the future)
///
/// # Batch Evaluation
///
/// The `evaluate_batch` method provides a default implementation that calls
/// `evaluate` for each block. Override for efficiency when the policy can
/// benefit from batching (e.g., batch registry lookups).
pub trait OffloadPolicy<T: BlockMetadata>: Send + Sync {
/// Unique name for this policy (for logging/debugging).
fn name(&self) -> &str;
/// Evaluate whether a block should be offloaded.
///
/// Returns:
/// - `Ok(true)`: Block passes this filter, continue to next policy
/// - `Ok(false)`: Block filtered out, remove from transfer
/// - `Err(_)`: Fatal error, fail the entire transfer
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<T>) -> PolicyFuture<'a>;
/// Batch evaluate multiple blocks.
///
/// Default implementation calls `evaluate` for each block.
/// Override for efficiency when batching is beneficial.
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<T>]) -> PolicyBatchFuture<'a> {
// Default: sequential evaluation
let contexts_clone: Vec<_> = contexts.iter().collect();
async_batch_result(async move {
let mut results = Vec::with_capacity(contexts_clone.len());
for ctx in contexts_clone {
// This calls the sync or async evaluate
let result = match self.evaluate(ctx) {
Either::Left(ready) => ready.await,
Either::Right(boxed) => boxed.await,
};
results.push(result?);
}
Ok(results)
})
}
}
/// G1→G2 filter: skip blocks already present in destination tier.
///
/// Uses `BlockRegistry::check_presence` to determine if a block exists
/// in the destination tier without acquiring a full block reference.
/// This is efficient because it only checks the registry metadata.
///
/// # Duplicate Prevention
///
/// When a `PendingTracker` is configured, this filter also checks for blocks
/// that are currently in-flight through the pipeline. This prevents duplicate
/// transfers when overlapping sequences are enqueued at roughly the same time.
///
/// # Performance
///
/// This policy is fully synchronous and returns `Either::Left(Ready)`,
/// avoiding any heap allocation per evaluation.
///
/// # Example
/// ```ignore
/// let tracker = Arc::new(PendingTracker::new());
/// let filter = PresenceFilter::<G1, G2>::new(registry.clone())
/// .with_pending_tracker(tracker);
/// // Blocks already in G2 OR in-flight will be filtered out
/// ```
pub struct PresenceFilter<Src: BlockMetadata, Dst: BlockMetadata> {
registry: Arc<BlockRegistry>,
/// Optional tracker for pending (in-flight) transfers.
/// When set, blocks that are already being transferred will be filtered out.
pending_tracker: Option<Arc<PendingTracker>>,
_marker: PhantomData<(Src, Dst)>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> PresenceFilter<Src, Dst> {
/// Create a new presence filter without pending tracking.
pub fn new(registry: Arc<BlockRegistry>) -> Self {
Self {
registry,
pending_tracker: None,
_marker: PhantomData,
}
}
/// Add a pending tracker for duplicate prevention.
///
/// When set, blocks that are currently in-flight (passed policy but not
/// yet registered in destination) will be filtered out.
pub fn with_pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.pending_tracker = Some(tracker);
self
}
/// Get a reference to the pending tracker if configured.
pub fn pending_tracker(&self) -> Option<&Arc<PendingTracker>> {
self.pending_tracker.as_ref()
}
}
impl<Src: BlockMetadata, Dst: BlockMetadata> OffloadPolicy<Src> for PresenceFilter<Src, Dst> {
fn name(&self) -> &str {
"PresenceFilter"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<Src>) -> PolicyFuture<'a> {
// Purely synchronous - uses Left(Ready), zero heap allocation
// 1. Check if already present in destination registry
let presence = self.registry.check_presence::<Dst>(&[ctx.sequence_hash]);
if presence[0].1 {
return sync_result(Ok(false)); // Already transferred
}
// 2. Check if currently in-flight (pending transfer)
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
return sync_result(Ok(false)); // Already being transferred
}
sync_result(Ok(true)) // Not present, not pending - pass
}
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<Src>]) -> PolicyBatchFuture<'a> {
if contexts.is_empty() {
return sync_batch_result(Ok(Vec::new()));
}
// Batch lookup for efficiency - still synchronous
let hashes: Vec<SequenceHash> = contexts.iter().map(|ctx| ctx.sequence_hash).collect();
let presence = self.registry.check_presence::<Dst>(&hashes);
// Build results checking both registry presence and pending status
let results: Vec<bool> = presence
.into_iter()
.map(|(hash, present)| {
if present {
return false;
}
if self.pending_tracker.is_hash_pending(&hash) {
return false;
}
true
})
.collect();
sync_batch_result(Ok(results))
}
}
/// G2→G3 filter: presence check + LFU count threshold.
///
/// Combines two filter conditions:
/// 1. Skip blocks already present in destination tier
/// 2. Only offload blocks with LFU count above threshold
///
/// The LFU threshold ensures we only offload "hot" blocks that have been
/// accessed frequently, avoiding wasted transfers for rarely-used blocks.
///
/// # Duplicate Prevention
///
/// When a `PendingTracker` is configured, this filter also checks for blocks
/// that are currently in-flight through the pipeline.
///
/// # Performance
///
/// This policy is fully synchronous and returns `Either::Left(Ready)`,
/// avoiding any heap allocation per evaluation.
///
/// # Example
/// ```ignore
/// // Only offload blocks with LFU count > 8 that aren't in G3 or in-flight
/// let tracker = Arc::new(PendingTracker::new());
/// let filter = PresenceAndLFUFilter::<G2, G3>::new(registry.clone(), 8)
/// .with_pending_tracker(tracker);
/// ```
pub struct PresenceAndLFUFilter<Src: BlockMetadata, Dst: BlockMetadata> {
registry: Arc<BlockRegistry>,
min_lfu_count: u32,
/// Optional tracker for pending (in-flight) transfers.
pending_tracker: Option<Arc<PendingTracker>>,
_marker: PhantomData<(Src, Dst)>,
}
impl<Src: BlockMetadata, Dst: BlockMetadata> PresenceAndLFUFilter<Src, Dst> {
/// Create a new presence + LFU filter with specified threshold.
pub fn new(registry: Arc<BlockRegistry>, min_lfu_count: u32) -> Self {
Self {
registry,
min_lfu_count,
pending_tracker: None,
_marker: PhantomData,
}
}
/// Create with default threshold of 8.
pub fn with_default_threshold(registry: Arc<BlockRegistry>) -> Self {
Self::new(registry, 8)
}
/// Add a pending tracker for duplicate prevention.
pub fn with_pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.pending_tracker = Some(tracker);
self
}
}
impl<Src: BlockMetadata, Dst: BlockMetadata> OffloadPolicy<Src> for PresenceAndLFUFilter<Src, Dst> {
fn name(&self) -> &str {
"PresenceAndLFUFilter"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<Src>) -> PolicyFuture<'a> {
// 1. Skip if already in Dst
let presence = self.registry.check_presence::<Dst>(&[ctx.sequence_hash]);
if presence[0].1 {
return sync_result(Ok(false));
}
// 2. Skip if currently pending transfer
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
return sync_result(Ok(false));
}
// 3. Check LFU count > threshold
if let Some(tracker) = self.registry.frequency_tracker() {
// Convert SequenceHash to u128 for the tracker
let count = tracker.count(ctx.sequence_hash.as_u128());
return sync_result(Ok(count > self.min_lfu_count));
}
// No frequency tracker = pass all (conservative default)
sync_result(Ok(true))
}
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<Src>]) -> PolicyBatchFuture<'a> {
if contexts.is_empty() {
return sync_batch_result(Ok(Vec::new()));
}
// Batch presence lookup
let hashes: Vec<SequenceHash> = contexts.iter().map(|ctx| ctx.sequence_hash).collect();
let presence = self.registry.check_presence::<Dst>(&hashes);
// Get trackers once
let freq_tracker = self.registry.frequency_tracker();
let min_lfu = self.min_lfu_count;
let results: Vec<bool> = presence
.into_iter()
.zip(contexts.iter())
.map(|((hash, present), ctx)| {
// Skip if present in Dst
if present {
return false;
}
// Skip if currently pending
if self.pending_tracker.is_hash_pending(&hash) {
return false;
}
// Check LFU count
if let Some(ref t) = freq_tracker {
let count = t.count(ctx.sequence_hash.as_u128());
count > min_lfu
} else {
true // No tracker = pass
}
})
.collect();
sync_batch_result(Ok(results))
}
}
/// G2→G4 filter: async presence check for object storage destinations.
///
/// Unlike `PresenceFilter` which checks local `BlockRegistry` synchronously,
/// this filter queries object storage (S3, etc.) asynchronously via a
/// `PresenceChecker` implementation.
///
/// # Duplicate Prevention
///
/// When a `PendingTracker` is configured, this filter also checks for blocks
/// that are currently in-flight through the pipeline before querying object storage.
///
/// # Performance
///
/// This policy returns `Either::Right(BoxFuture)` since it requires async I/O.
/// The pending tracker check is done synchronously first to avoid unnecessary
/// object storage queries.
///
/// # Example
/// ```ignore
/// let object_ops: Arc<dyn ObjectBlockOps> = ...;
/// let checker = Arc::new(S3PresenceChecker::new(object_ops));
/// let tracker = Arc::new(PendingTracker::new());
/// let filter = ObjectPresenceFilter::<G2>::new(checker)
/// .with_pending_tracker(tracker);
/// // Blocks already in object storage OR in-flight will be filtered out
/// ```
pub struct ObjectPresenceFilter<Src: BlockMetadata> {
presence_checker: Arc<dyn PresenceChecker>,
/// Optional tracker for pending (in-flight) transfers.
pending_tracker: Option<Arc<PendingTracker>>,
_marker: PhantomData<Src>,
}
impl<Src: BlockMetadata> ObjectPresenceFilter<Src> {
/// Create a new object presence filter.
pub fn new(presence_checker: Arc<dyn PresenceChecker>) -> Self {
Self {
presence_checker,
pending_tracker: None,
_marker: PhantomData,
}
}
/// Add a pending tracker for duplicate prevention.
///
/// When set, blocks that are currently in-flight (passed policy but not
/// yet stored in object storage) will be filtered out.
pub fn with_pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.pending_tracker = Some(tracker);
self
}
/// Get a reference to the pending tracker if configured.
pub fn pending_tracker(&self) -> Option<&Arc<PendingTracker>> {
self.pending_tracker.as_ref()
}
}
impl<Src: BlockMetadata> OffloadPolicy<Src> for ObjectPresenceFilter<Src> {
fn name(&self) -> &str {
"ObjectPresenceFilter"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<Src>) -> PolicyFuture<'a> {
// 1. Synchronous check: skip if currently pending
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
return sync_result(Ok(false)); // Already being transferred
}
// 2. Async check: query object storage for presence
let checker = self.presence_checker.clone();
let hash = ctx.sequence_hash;
async_result(async move {
let results = checker.check_presence(vec![hash]).await;
// If present in object storage, filter out
let exists = results
.into_iter()
.next()
.map(|(_, exists)| exists)
.unwrap_or(false);
Ok(!exists) // Pass if NOT present
})
}
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<Src>]) -> PolicyBatchFuture<'a> {
if contexts.is_empty() {
return sync_batch_result(Ok(Vec::new()));
}
// Collect hashes, filtering out pending ones first (sync)
let mut pending_status: Vec<bool> = Vec::with_capacity(contexts.len());
let mut hashes_to_check: Vec<SequenceHash> = Vec::new();
let mut hash_indices: Vec<usize> = Vec::new();
for (i, ctx) in contexts.iter().enumerate() {
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
pending_status.push(true); // Mark as pending (will be filtered)
} else {
pending_status.push(false);
hashes_to_check.push(ctx.sequence_hash);
hash_indices.push(i);
}
}
// If all are pending, return immediately
if hashes_to_check.is_empty() {
return sync_batch_result(Ok(vec![false; contexts.len()]));
}
let checker = self.presence_checker.clone();
let num_contexts = contexts.len();
async_batch_result(async move {
// Query object storage for non-pending hashes
let presence_results = checker.check_presence(hashes_to_check).await;
// Build final results
let mut results = vec![false; num_contexts]; // Default: filtered out
// Map presence results back to original indices
for (check_idx, original_idx) in hash_indices.into_iter().enumerate() {
if let Some((_, exists)) = presence_results.get(check_idx) {
// Pass if NOT present in object storage
results[original_idx] = !*exists;
}
}
Ok(results)
})
}
}
/// G2→G4 filter with distributed locking: check meta, acquire lock, track acquired locks.
///
/// This filter implements the full locking protocol for object storage offloads:
/// 1. Check if `.meta` file exists (block already offloaded) - skip if yes
/// 2. Check if currently pending (in-flight transfer) - skip if yes
/// 3. Try to acquire `.lock` file with conditional PUT
/// - If lock doesn't exist, create it atomically
/// - If lock exists and expired, overwrite it
/// - If lock exists and valid (owned by another instance), skip
/// 4. If we own the lock (either just acquired or already owned), pass the block
///
/// # Lock Management
///
/// Locks acquired during policy evaluation are tracked and must be:
/// - Released after successful transfer (via `ObjectTransferExecutor`)
/// - Released on error/cancellation (via guard or explicit cleanup)
///
/// # Duplicate Prevention
///
/// When a `PendingTracker` is configured, blocks currently in-flight are filtered
/// out before checking object storage, avoiding redundant network calls.
///
/// # Example
/// ```ignore
/// let lock_manager = Arc::new(S3LockManager::new(s3_client, instance_id));
/// let tracker = Arc::new(PendingTracker::new());
/// let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager)
/// .with_pending_tracker(tracker);
/// // Blocks already offloaded, in-flight, or locked by others will be filtered out
/// ```
pub struct ObjectLockPresenceFilter<Src: BlockMetadata> {
lock_manager: Arc<dyn ObjectLockManager>,
/// Optional tracker for pending (in-flight) transfers.
pending_tracker: Option<Arc<PendingTracker>>,
_marker: PhantomData<Src>,
}
impl<Src: BlockMetadata> ObjectLockPresenceFilter<Src> {
/// Create a new object lock presence filter.
pub fn new(lock_manager: Arc<dyn ObjectLockManager>) -> Self {
Self {
lock_manager,
pending_tracker: None,
_marker: PhantomData,
}
}
/// Add a pending tracker for duplicate prevention.
///
/// When set, blocks that are currently in-flight (passed policy but not
/// yet stored in object storage) will be filtered out.
pub fn with_pending_tracker(mut self, tracker: Arc<PendingTracker>) -> Self {
self.pending_tracker = Some(tracker);
self
}
/// Get a reference to the pending tracker if configured.
pub fn pending_tracker(&self) -> Option<&Arc<PendingTracker>> {
self.pending_tracker.as_ref()
}
/// Get a reference to the lock manager.
pub fn lock_manager(&self) -> &Arc<dyn ObjectLockManager> {
&self.lock_manager
}
}
impl<Src: BlockMetadata> OffloadPolicy<Src> for ObjectLockPresenceFilter<Src> {
fn name(&self) -> &str {
"ObjectLockPresenceFilter"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<Src>) -> PolicyFuture<'a> {
// 1. Synchronous check: skip if currently pending
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
return sync_result(Ok(false)); // Already being transferred
}
// 2. Async checks: meta presence, then lock acquisition
let lock_manager = self.lock_manager.clone();
let hash = ctx.sequence_hash;
async_result(async move {
// Check if meta file exists (already offloaded)
match lock_manager.has_meta(hash).await {
Ok(true) => {
tracing::debug!(?hash, "Block already offloaded (meta exists)");
return Ok(false); // Already offloaded, skip
}
Ok(false) => {
// Continue to lock acquisition
}
Err(e) => {
tracing::warn!(?hash, error = %e, "Error checking meta file");
return Ok(false); // Error, skip to be safe
}
}
// Try to acquire lock
match lock_manager.try_acquire_lock(hash).await {
Ok(true) => {
tracing::debug!(?hash, "Lock acquired");
Ok(true) // Pass - we own the lock
}
Ok(false) => {
tracing::debug!(?hash, "Lock held by another instance");
Ok(false) // Skip - another instance owns the lock
}
Err(e) => {
tracing::warn!(?hash, error = %e, "Error acquiring lock");
Ok(false) // Error, skip to be safe
}
}
})
}
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<Src>]) -> PolicyBatchFuture<'a> {
if contexts.is_empty() {
return sync_batch_result(Ok(Vec::new()));
}
// Filter out pending blocks first (sync)
let mut pending_mask: Vec<bool> = Vec::with_capacity(contexts.len());
let mut to_check: Vec<(usize, SequenceHash)> = Vec::new();
for (i, ctx) in contexts.iter().enumerate() {
if self.pending_tracker.is_hash_pending(&ctx.sequence_hash) {
pending_mask.push(true);
} else {
pending_mask.push(false);
to_check.push((i, ctx.sequence_hash));
}
}
// If all are pending, return immediately
if to_check.is_empty() {
return sync_batch_result(Ok(vec![false; contexts.len()]));
}
let lock_manager = self.lock_manager.clone();
let num_contexts = contexts.len();
async_batch_result(async move {
let mut results = vec![false; num_contexts]; // Default: filtered out
// Process each non-pending block
for (original_idx, hash) in to_check {
// Check meta first
let has_meta = match lock_manager.has_meta(hash).await {
Ok(has) => has,
Err(e) => {
tracing::warn!(?hash, error = %e, "Error checking meta file");
continue; // Skip this block
}
};
if has_meta {
tracing::debug!(?hash, "Block already offloaded (meta exists)");
continue; // Already offloaded
}
// Try to acquire lock
match lock_manager.try_acquire_lock(hash).await {
Ok(true) => {
tracing::debug!(?hash, "Lock acquired");
results[original_idx] = true; // Pass
}
Ok(false) => {
tracing::debug!(?hash, "Lock held by another instance");
// Skip - another instance owns the lock
}
Err(e) => {
tracing::warn!(?hash, error = %e, "Error acquiring lock");
// Skip on error
}
}
}
Ok(results)
})
}
}
/// Composite policy that requires ALL sub-policies to pass (AND logic).
pub struct AllOfPolicy<T: BlockMetadata> {
policies: Vec<Arc<dyn OffloadPolicy<T>>>,
}
impl<T: BlockMetadata> AllOfPolicy<T> {
/// Create a new AND composite policy.
pub fn new(policies: Vec<Arc<dyn OffloadPolicy<T>>>) -> Self {
Self { policies }
}
/// Add a policy to the composite.
pub fn with(mut self, policy: Arc<dyn OffloadPolicy<T>>) -> Self {
self.policies.push(policy);
self
}
}
impl<T: BlockMetadata> OffloadPolicy<T> for AllOfPolicy<T> {
fn name(&self) -> &str {
"AllOfPolicy"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<T>) -> PolicyFuture<'a> {
// Must use async because sub-policies might be async
let policies = &self.policies;
async_result(async move {
for policy in policies {
let result = match policy.evaluate(ctx) {
Either::Left(ready) => ready.await,
Either::Right(boxed) => boxed.await,
};
if !result? {
return Ok(false);
}
}
Ok(true)
})
}
}
/// Composite policy that requires ANY sub-policy to pass (OR logic).
pub struct AnyOfPolicy<T: BlockMetadata> {
policies: Vec<Arc<dyn OffloadPolicy<T>>>,
}
impl<T: BlockMetadata> AnyOfPolicy<T> {
/// Create a new OR composite policy.
pub fn new(policies: Vec<Arc<dyn OffloadPolicy<T>>>) -> Self {
Self { policies }
}
/// Add a policy to the composite.
pub fn with(mut self, policy: Arc<dyn OffloadPolicy<T>>) -> Self {
self.policies.push(policy);
self
}
}
impl<T: BlockMetadata> OffloadPolicy<T> for AnyOfPolicy<T> {
fn name(&self) -> &str {
"AnyOfPolicy"
}
fn evaluate<'a>(&'a self, ctx: &'a EvalContext<T>) -> PolicyFuture<'a> {
if self.policies.is_empty() {
return sync_result(Ok(true)); // No policies = pass
}
// Must use async because sub-policies might be async
let policies = &self.policies;
async_result(async move {
for policy in policies {
let result = match policy.evaluate(ctx) {
Either::Left(ready) => ready.await,
Either::Right(boxed) => boxed.await,
};
if result? {
return Ok(true);
}
}
Ok(false)
})
}
}
/// A pass-all policy (no filtering).
///
/// # Performance
///
/// This policy is fully synchronous and returns `Either::Left(Ready)`,
/// avoiding any heap allocation per evaluation.
pub struct PassAllPolicy<T: BlockMetadata> {
_marker: PhantomData<T>,
}
impl<T: BlockMetadata> PassAllPolicy<T> {
/// Create a new pass-all policy.
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T: BlockMetadata> Default for PassAllPolicy<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: BlockMetadata> OffloadPolicy<T> for PassAllPolicy<T> {
fn name(&self) -> &str {
"PassAllPolicy"
}
fn evaluate<'a>(&'a self, _ctx: &'a EvalContext<T>) -> PolicyFuture<'a> {
// Zero allocation - just returns ready(Ok(true))
sync_result(Ok(true))
}
fn evaluate_batch<'a>(&'a self, contexts: &'a [EvalContext<T>]) -> PolicyBatchFuture<'a> {
sync_batch_result(Ok(vec![true; contexts.len()]))
}
}
/// Create a composite policy from tier configuration.
///
/// Policies are applied in order with AND logic - blocks must pass all policies.
/// Returns `PassAllPolicy` if no policies are configured.
///
/// When a `pending_tracker` is provided, it is automatically wired into
/// `Presence` and `PresenceLfu` policies to enable duplicate prevention
/// for blocks currently in-flight through the pipeline.
///
/// # Example
///
/// ```ignore
/// use kvbm_config::offload::TierOffloadConfig;
///
/// let tracker = Arc::new(PendingTracker::new());
/// let config = TierOffloadConfig {
/// policies: vec![PolicyType::Presence, PolicyType::PresenceLfu],
/// presence_lfu: PresenceLfuFilterConfig { min_lfu_count: 8 },
/// ..Default::default()
/// };
///
/// // Pending tracker is automatically wired into presence-based policies
/// let policy = create_policy_from_config::<G2, G3>(&config, registry.clone(), Some(tracker));
/// ```
pub fn create_policy_from_config<Src, Dst>(
config: &TierOffloadConfig,
registry: Arc<BlockRegistry>,
pending_tracker: Option<Arc<PendingTracker>>,
) -> Arc<dyn OffloadPolicy<Src>>
where
Src: BlockMetadata + 'static,
Dst: BlockMetadata + 'static,
{
if config.policies.is_empty() {
return Arc::new(PassAllPolicy::<Src>::new());
}
let policies: Vec<Arc<dyn OffloadPolicy<Src>>> = config
.policies
.iter()
.map(|policy_type| -> Arc<dyn OffloadPolicy<Src>> {
match policy_type {
PolicyType::PassAll => Arc::new(PassAllPolicy::<Src>::new()),
PolicyType::Presence => {
let mut filter = PresenceFilter::<Src, Dst>::new(registry.clone());
if let Some(tracker) = &pending_tracker {
filter = filter.with_pending_tracker(tracker.clone());
}
Arc::new(filter)
}
PolicyType::PresenceLfu => {
let mut filter = PresenceAndLFUFilter::<Src, Dst>::new(
registry.clone(),
config.presence_lfu.min_lfu_count,
);
if let Some(tracker) = &pending_tracker {
filter = filter.with_pending_tracker(tracker.clone());
}
Arc::new(filter)
}
}
})
.collect();
if policies.len() == 1 {
policies.into_iter().next().unwrap()
} else {
Arc::new(AllOfPolicy::new(policies))
}
}
#[cfg(test)]
mod tests {
use super::*;
// Note: Full tests require BlockRegistry infrastructure which needs
// tokio runtime and complex setup. Basic API tests here.
#[test]
fn test_pass_all_policy() {
let _policy: PassAllPolicy<()> = PassAllPolicy::new();
// Would test evaluate with proper setup
}
#[test]
fn test_all_of_policy_creation() {
let policies: Vec<Arc<dyn OffloadPolicy<()>>> = vec![Arc::new(PassAllPolicy::new())];
let composite = AllOfPolicy::new(policies);
assert_eq!(composite.name(), "AllOfPolicy");
}
#[test]
fn test_any_of_policy_creation() {
let policies: Vec<Arc<dyn OffloadPolicy<()>>> = vec![Arc::new(PassAllPolicy::new())];
let composite = AnyOfPolicy::new(policies);
assert_eq!(composite.name(), "AnyOfPolicy");
}
#[tokio::test]
async fn test_sync_result_zero_alloc() {
// Verify sync_result returns Left variant
let future = sync_result(Ok(true));
assert!(matches!(future, Either::Left(_)));
let result = match future {
Either::Left(ready) => ready.await,
Either::Right(_) => unreachable!(),
};
assert!(result.unwrap());
}
#[tokio::test]
async fn test_async_result_boxes() {
// Verify async_result returns Right variant
let future = async_result(async { Ok(false) });
assert!(matches!(future, Either::Right(_)));
let result = match future {
Either::Left(_) => unreachable!(),
Either::Right(boxed) => boxed.await,
};
assert!(!result.unwrap());
}
#[test]
fn test_pending_tracker_wiring() {
use super::PendingTracker;
// Verify pending_tracker can be set on PresenceFilter
let tracker = Arc::new(PendingTracker::new());
let registry = Arc::new(BlockRegistry::new());
let filter: PresenceFilter<(), ()> =
PresenceFilter::new(registry).with_pending_tracker(tracker.clone());
// Verify we can get the tracker back
assert!(filter.pending_tracker().is_some());
assert!(Arc::ptr_eq(filter.pending_tracker().unwrap(), &tracker));
}
#[test]
fn test_pending_tracker_wiring_lfu() {
use super::PendingTracker;
// Verify pending_tracker can be set on PresenceAndLFUFilter
let tracker = Arc::new(PendingTracker::new());
let registry = Arc::new(BlockRegistry::new());
let filter: PresenceAndLFUFilter<(), ()> =
PresenceAndLFUFilter::new(registry, 8).with_pending_tracker(tracker);
// Filter was successfully created with pending tracker
assert_eq!(filter.name(), "PresenceAndLFUFilter");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Cancellable queue implementation using crossbeam SegQueue.
//!
//! Provides a lock-free queue wrapper that supports active cancellation via
//! a sweeper task that can iterate through queued items and remove those
//! belonging to cancelled transfers.
use std::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_queue::SegQueue;
use dashmap::DashSet;
use super::handle::TransferId;
/// A queued item with its associated transfer ID.
pub struct QueueItem<T> {
/// The transfer this item belongs to
pub transfer_id: TransferId,
/// The actual data
pub data: T,
}
impl<T> QueueItem<T> {
/// Create a new queue item.
pub fn new(transfer_id: TransferId, data: T) -> Self {
Self { transfer_id, data }
}
}
/// A lock-free queue that supports active cancellation via sweeping.
///
/// Unlike mpsc channels where cancellation can only be checked at dequeue time,
/// this queue allows a dedicated sweeper task to iterate through queued items
/// and remove those belonging to cancelled transfers. This ensures that
/// `ImmutableBlock` guards are dropped promptly when a transfer is cancelled.
///
/// # Architecture
///
/// ```text
/// Producer ──► [SegQueue] ◄── Consumer
/// ▲
/// │
/// [Sweeper Task]
/// │
/// (removes cancelled items)
/// ```
pub struct CancellableQueue<T> {
/// The underlying lock-free queue
inner: SegQueue<QueueItem<T>>,
/// Set of cancelled transfer IDs
cancelled: DashSet<TransferId>,
/// Approximate length for monitoring (not exact due to concurrent access)
len: AtomicUsize,
}
impl<T> CancellableQueue<T> {
/// Create a new cancellable queue.
pub fn new() -> Self {
Self {
inner: SegQueue::new(),
cancelled: DashSet::new(),
len: AtomicUsize::new(0),
}
}
/// Push an item onto the queue.
///
/// If the transfer has already been cancelled, the item is dropped immediately.
/// Returns `true` if the item was queued, `false` if it was dropped due to cancellation.
pub fn push(&self, transfer_id: TransferId, data: T) -> bool {
// Fast path: check if already cancelled before queuing
if self.cancelled.contains(&transfer_id) {
return false;
}
self.inner.push(QueueItem::new(transfer_id, data));
self.len.fetch_add(1, Ordering::Relaxed);
true
}
/// Pop an item from the queue.
///
/// Returns `None` if the queue is empty.
/// Items from cancelled transfers may still be returned - use `pop_valid()`
/// if you want to skip cancelled items automatically.
pub fn pop(&self) -> Option<QueueItem<T>> {
let item = self.inner.pop();
if item.is_some() {
self.len.fetch_sub(1, Ordering::Relaxed);
}
item
}
/// Pop a valid (non-cancelled) item from the queue.
///
/// Skips and drops items belonging to cancelled transfers.
/// Returns `None` if no valid items are available.
pub fn pop_valid(&self) -> Option<QueueItem<T>> {
loop {
match self.inner.pop() {
Some(item) => {
self.len.fetch_sub(1, Ordering::Relaxed);
if self.cancelled.contains(&item.transfer_id) {
// Drop cancelled item and try again
continue;
}
return Some(item);
}
None => return None,
}
}
}
/// Mark a transfer as cancelled.
///
/// Items belonging to this transfer will be:
/// - Dropped immediately if pushed after this call
/// - Removed by the sweeper task if already in the queue
/// - Skipped by `pop_valid()` if dequeued
pub fn mark_cancelled(&self, transfer_id: TransferId) {
self.cancelled.insert(transfer_id);
}
/// Check if a transfer has been cancelled.
pub fn is_cancelled(&self, transfer_id: TransferId) -> bool {
self.cancelled.contains(&transfer_id)
}
/// Remove cancelled items from the queue.
///
/// This is called by the sweeper task to actively remove items from
/// cancelled transfers, ensuring their resources (like `ImmutableBlock` guards)
/// are released promptly.
///
/// Returns the number of items removed.
///
/// # Implementation Note
///
/// This performs a full drain-and-requeue operation. While not ideal for
/// very large queues, it ensures correctness with the lock-free SegQueue.
/// For typical offload workloads (batches of 64-256 blocks), this is efficient.
pub fn sweep(&self) -> usize {
if self.cancelled.is_empty() {
return 0;
}
// Drain all items and requeue non-cancelled ones
let mut removed = 0;
let mut kept = Vec::new();
while let Some(item) = self.inner.pop() {
if self.cancelled.contains(&item.transfer_id) {
removed += 1;
// Item is dropped here, releasing any held resources
} else {
kept.push(item);
}
}
// Requeue kept items
for item in kept {
self.inner.push(item);
}
// Update length counter
if removed > 0 {
self.len.fetch_sub(removed, Ordering::Relaxed);
}
removed
}
/// Clear the cancelled set for a specific transfer.
///
/// Called when a transfer is fully complete to clean up the cancelled set.
pub fn clear_cancelled(&self, transfer_id: TransferId) {
self.cancelled.remove(&transfer_id);
}
/// Get the approximate queue length.
///
/// This is not exact due to concurrent modifications but useful for monitoring.
pub fn len_approx(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
/// Check if the queue is approximately empty.
pub fn is_empty_approx(&self) -> bool {
self.len_approx() == 0
}
/// Get the number of cancelled transfers being tracked.
pub fn cancelled_count(&self) -> usize {
self.cancelled.len()
}
}
impl<T> Default for CancellableQueue<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_push_pop() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
assert!(queue.push(id, 42));
assert_eq!(queue.len_approx(), 1);
let item = queue.pop().unwrap();
assert_eq!(item.transfer_id, id);
assert_eq!(item.data, 42);
assert_eq!(queue.len_approx(), 0);
}
#[test]
fn test_cancelled_push_rejected() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
queue.mark_cancelled(id);
assert!(!queue.push(id, 42));
assert_eq!(queue.len_approx(), 0);
}
#[test]
fn test_pop_valid_skips_cancelled() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id1 = TransferId::new();
let id2 = TransferId::new();
queue.push(id1, 1);
queue.push(id2, 2);
queue.push(id1, 3);
queue.mark_cancelled(id1);
// pop_valid should skip items from id1
let item = queue.pop_valid().unwrap();
assert_eq!(item.transfer_id, id2);
assert_eq!(item.data, 2);
// No more valid items
assert!(queue.pop_valid().is_none());
}
#[test]
fn test_sweep_removes_cancelled() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id1 = TransferId::new();
let id2 = TransferId::new();
queue.push(id1, 1);
queue.push(id2, 2);
queue.push(id1, 3);
queue.push(id2, 4);
assert_eq!(queue.len_approx(), 4);
queue.mark_cancelled(id1);
let removed = queue.sweep();
assert_eq!(removed, 2);
assert_eq!(queue.len_approx(), 2);
// Remaining items should be from id2
let item1 = queue.pop().unwrap();
let item2 = queue.pop().unwrap();
assert_eq!(item1.transfer_id, id2);
assert_eq!(item2.transfer_id, id2);
}
#[test]
fn test_sweep_empty_cancelled_set() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
queue.push(id, 1);
queue.push(id, 2);
// Sweep with no cancelled transfers should be a no-op
let removed = queue.sweep();
assert_eq!(removed, 0);
assert_eq!(queue.len_approx(), 2);
}
#[test]
fn test_clear_cancelled() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
queue.mark_cancelled(id);
assert!(queue.is_cancelled(id));
assert_eq!(queue.cancelled_count(), 1);
queue.clear_cancelled(id);
assert!(!queue.is_cancelled(id));
assert_eq!(queue.cancelled_count(), 0);
}
/// Test multiple transfer IDs with interleaved cancellation.
#[test]
fn test_multiple_transfers_interleaved() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id1 = TransferId::new();
let id2 = TransferId::new();
let id3 = TransferId::new();
// Push items from different transfers
queue.push(id1, 1);
queue.push(id2, 2);
queue.push(id1, 3);
queue.push(id3, 4);
queue.push(id2, 5);
queue.push(id3, 6);
assert_eq!(queue.len_approx(), 6);
// Cancel id2
queue.mark_cancelled(id2);
let removed = queue.sweep();
assert_eq!(removed, 2); // items 2 and 5
assert_eq!(queue.len_approx(), 4);
// Cancel id1
queue.mark_cancelled(id1);
let removed = queue.sweep();
assert_eq!(removed, 2); // items 1 and 3
assert_eq!(queue.len_approx(), 2);
// Remaining should be from id3
let item1 = queue.pop().unwrap();
let item2 = queue.pop().unwrap();
assert_eq!(item1.transfer_id, id3);
assert_eq!(item2.transfer_id, id3);
}
/// Test sweep with empty queue.
#[test]
fn test_sweep_empty_queue() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
queue.mark_cancelled(id);
let removed = queue.sweep();
assert_eq!(removed, 0);
assert!(queue.is_empty_approx());
}
/// Test pop_valid exhausts queue of only cancelled items.
#[test]
fn test_pop_valid_exhausts_cancelled() {
let queue: CancellableQueue<i32> = CancellableQueue::new();
let id = TransferId::new();
queue.push(id, 1);
queue.push(id, 2);
queue.push(id, 3);
queue.mark_cancelled(id);
// pop_valid should return None after exhausting cancelled items
assert!(queue.pop_valid().is_none());
// Queue should be empty now (items were dropped during pop_valid)
assert_eq!(queue.len_approx(), 0);
}
/// Test that cancelled items are dropped (not leaked) during sweep.
#[test]
fn test_sweep_drops_items() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct DropCounter {
counter: Arc<AtomicUsize>,
}
impl Drop for DropCounter {
fn drop(&mut self) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}
let drop_count = Arc::new(AtomicUsize::new(0));
let queue: CancellableQueue<DropCounter> = CancellableQueue::new();
let id = TransferId::new();
queue.push(
id,
DropCounter {
counter: drop_count.clone(),
},
);
queue.push(
id,
DropCounter {
counter: drop_count.clone(),
},
);
queue.push(
id,
DropCounter {
counter: drop_count.clone(),
},
);
assert_eq!(drop_count.load(Ordering::SeqCst), 0);
queue.mark_cancelled(id);
let removed = queue.sweep();
assert_eq!(removed, 3);
assert_eq!(drop_count.load(Ordering::SeqCst), 3);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Source block types for the offload engine.
//!
//! Blocks can be provided to the offload engine in three forms:
//! - External: BlockId + SequenceHash, block is held elsewhere
//! - Strong: RAII ImmutableBlock reference
//! - Weak: WeakBlock that may have been evicted
use std::marker::PhantomData;
use crate::{BlockId, SequenceHash};
use kvbm_logical::blocks::{BlockMetadata, ImmutableBlock, WeakBlock};
/// External block reference with sequence hash for registration.
///
/// Used when the caller holds the actual block but wants to provide
/// the offload engine with enough information to register blocks
/// in the destination tier after transfer.
#[derive(Debug, Clone, Copy)]
pub struct ExternalBlock<T: BlockMetadata> {
/// The block ID in the source tier
pub block_id: BlockId,
/// The sequence hash for registration in destination tier
pub sequence_hash: SequenceHash,
_marker: PhantomData<T>,
}
impl<T: BlockMetadata> ExternalBlock<T> {
/// Create a new external block reference.
pub fn new(block_id: BlockId, sequence_hash: SequenceHash) -> Self {
Self {
block_id,
sequence_hash,
_marker: PhantomData,
}
}
}
/// Represents a single block source for offloading.
///
/// The source type determines how the block is resolved:
/// - `External`: Caller holds the block, we have ID + SequenceHash for registration
/// - `Strong`: We hold a strong RAII reference
/// - `Weak`: We hold a weak reference that may need upgrading
#[derive(Debug)]
pub enum SourceBlock<T: BlockMetadata> {
/// External block reference with ID and sequence hash
External(ExternalBlock<T>),
/// Strong RAII reference to an immutable block
Strong(ImmutableBlock<T>),
/// Weak reference that may have been evicted
Weak(WeakBlock<T>),
}
impl<T: BlockMetadata> SourceBlock<T> {
/// Get the block ID if available without upgrading.
///
/// For External and Strong variants, returns Some(id).
/// For Weak variant, returns None (would need upgrade to get ID).
pub fn block_id(&self) -> Option<BlockId> {
match self {
SourceBlock::External(ext) => Some(ext.block_id),
SourceBlock::Strong(block) => Some(block.block_id()),
SourceBlock::Weak(_) => None,
}
}
/// Get the sequence hash if available without upgrading.
///
/// All variants can provide sequence_hash without upgrading:
/// - External: stored in ExternalBlock
/// - Strong: from ImmutableBlock
/// - Weak: WeakBlock stores sequence_hash directly
pub fn sequence_hash(&self) -> Option<SequenceHash> {
match self {
SourceBlock::External(ext) => Some(ext.sequence_hash),
SourceBlock::Strong(block) => Some(block.sequence_hash()),
SourceBlock::Weak(weak) => Some(weak.sequence_hash()),
}
}
/// Check if this is an external block reference.
pub fn is_external(&self) -> bool {
matches!(self, SourceBlock::External(_))
}
/// Check if this is a strong reference.
pub fn is_strong(&self) -> bool {
matches!(self, SourceBlock::Strong(_))
}
/// Check if this is a weak reference.
pub fn is_weak(&self) -> bool {
matches!(self, SourceBlock::Weak(_))
}
}
impl<T: BlockMetadata> From<ExternalBlock<T>> for SourceBlock<T> {
fn from(ext: ExternalBlock<T>) -> Self {
SourceBlock::External(ext)
}
}
impl<T: BlockMetadata> From<ImmutableBlock<T>> for SourceBlock<T> {
fn from(block: ImmutableBlock<T>) -> Self {
SourceBlock::Strong(block)
}
}
impl<T: BlockMetadata> From<WeakBlock<T>> for SourceBlock<T> {
fn from(block: WeakBlock<T>) -> Self {
SourceBlock::Weak(block)
}
}
/// Collection of source blocks for batch operations.
///
/// Blocks are grouped by their source type for efficient processing.
/// All blocks in a SourceBlocks must be of the same type.
#[derive(Debug)]
pub enum SourceBlocks<T: BlockMetadata> {
/// External block references with IDs and sequence hashes
External(Vec<ExternalBlock<T>>),
/// Strong RAII references
Strong(Vec<ImmutableBlock<T>>),
/// Weak references that may need upgrading
Weak(Vec<WeakBlock<T>>),
}
impl<T: BlockMetadata> SourceBlocks<T> {
/// Create an empty collection of external blocks.
pub fn empty_external() -> Self {
SourceBlocks::External(Vec::new())
}
/// Create an empty collection of strong blocks.
pub fn empty_strong() -> Self {
SourceBlocks::Strong(Vec::new())
}
/// Create an empty collection of weak blocks.
pub fn empty_weak() -> Self {
SourceBlocks::Weak(Vec::new())
}
/// Get the number of blocks in this collection.
pub fn len(&self) -> usize {
match self {
SourceBlocks::External(blocks) => blocks.len(),
SourceBlocks::Strong(blocks) => blocks.len(),
SourceBlocks::Weak(blocks) => blocks.len(),
}
}
/// Check if the collection is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get external blocks, or None for other types.
pub fn external_blocks(&self) -> Option<&[ExternalBlock<T>]> {
match self {
SourceBlocks::External(blocks) => Some(blocks),
_ => None,
}
}
/// Get strong blocks, or None for other types.
pub fn strong_blocks(&self) -> Option<&[ImmutableBlock<T>]> {
match self {
SourceBlocks::Strong(blocks) => Some(blocks),
_ => None,
}
}
/// Get weak blocks, or None for other types.
pub fn weak_blocks(&self) -> Option<&[WeakBlock<T>]> {
match self {
SourceBlocks::Weak(blocks) => Some(blocks),
_ => None,
}
}
/// Check if this is external blocks.
pub fn is_external(&self) -> bool {
matches!(self, SourceBlocks::External(_))
}
/// Check if this is strong blocks.
pub fn is_strong(&self) -> bool {
matches!(self, SourceBlocks::Strong(_))
}
/// Check if this is weak blocks.
pub fn is_weak(&self) -> bool {
matches!(self, SourceBlocks::Weak(_))
}
}
impl<T: BlockMetadata> From<Vec<ExternalBlock<T>>> for SourceBlocks<T> {
fn from(blocks: Vec<ExternalBlock<T>>) -> Self {
SourceBlocks::External(blocks)
}
}
impl<T: BlockMetadata> From<Vec<ImmutableBlock<T>>> for SourceBlocks<T> {
fn from(blocks: Vec<ImmutableBlock<T>>) -> Self {
SourceBlocks::Strong(blocks)
}
}
impl<T: BlockMetadata> From<Vec<WeakBlock<T>>> for SourceBlocks<T> {
fn from(blocks: Vec<WeakBlock<T>>) -> Self {
SourceBlocks::Weak(blocks)
}
}
// Allow converting a single SourceBlock into SourceBlocks
impl<T: BlockMetadata> From<SourceBlock<T>> for SourceBlocks<T> {
fn from(block: SourceBlock<T>) -> Self {
match block {
SourceBlock::External(ext) => SourceBlocks::External(vec![ext]),
SourceBlock::Strong(b) => SourceBlocks::Strong(vec![b]),
SourceBlock::Weak(b) => SourceBlocks::Weak(vec![b]),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use kvbm_common::tokens::TokenBlockSequence;
use kvbm_logical::KvbmSequenceHashProvider;
/// Create a test sequence hash at a given position.
fn test_seq_hash(position: usize) -> SequenceHash {
let tokens_per_block = 4;
let total_tokens = (position + 1) * tokens_per_block;
let tokens: Vec<u32> = (0..total_tokens as u32).collect();
let seq = TokenBlockSequence::from_slice(&tokens, tokens_per_block as u32, Some(1337));
seq.blocks()[position].kvbm_sequence_hash()
}
#[test]
fn test_external_block_creation() {
let hash = test_seq_hash(0);
let ext: ExternalBlock<()> = ExternalBlock::new(42, hash);
assert_eq!(ext.block_id, 42);
assert_eq!(ext.sequence_hash, hash);
}
#[test]
fn test_source_blocks_from_vec_external() {
let ext1: ExternalBlock<()> = ExternalBlock::new(1, test_seq_hash(0));
let ext2: ExternalBlock<()> = ExternalBlock::new(2, test_seq_hash(1));
let ext3: ExternalBlock<()> = ExternalBlock::new(3, test_seq_hash(2));
let blocks: SourceBlocks<()> = vec![ext1, ext2, ext3].into();
assert!(blocks.is_external());
assert_eq!(blocks.len(), 3);
let external = blocks.external_blocks().unwrap();
assert_eq!(external[0].block_id, 1);
assert_eq!(external[1].block_id, 2);
assert_eq!(external[2].block_id, 3);
}
#[test]
fn test_source_blocks_empty() {
let blocks: SourceBlocks<()> = SourceBlocks::empty_external();
assert!(blocks.is_empty());
assert!(blocks.is_external());
}
#[test]
fn test_source_block_accessors() {
let hash = test_seq_hash(5);
let ext: ExternalBlock<()> = ExternalBlock::new(42, hash);
let block: SourceBlock<()> = ext.into();
assert_eq!(block.block_id(), Some(42));
assert_eq!(block.sequence_hash(), Some(hash));
assert!(block.is_external());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! PubSub abstraction for distributed messaging.
//!
//! This module provides traits for publish/subscribe messaging patterns,
//! with implementations for NATS and an in-memory stub for testing.
use anyhow::Result;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
#[cfg(feature = "nats")]
mod nats;
mod stub;
#[cfg(feature = "nats")]
pub use self::nats::{NatsConfig, NatsPublisher, NatsSubscriber};
pub use stub::{StubBus, StubPublisher, StubSubscriber};
/// Message received from a subscription.
#[derive(Debug, Clone)]
pub struct Message {
/// The subject the message was published to.
pub subject: String,
/// The message payload.
pub payload: Bytes,
}
/// A subscription stream that yields messages.
pub type Subscription = BoxStream<'static, Message>;
pub use kvbm_logical::pubsub::Publisher;
/// Subscriber trait for receiving messages from subjects.
///
/// Subscribers receive messages published to matching subjects.
/// Subject patterns support wildcards:
/// - `*` matches a single token (e.g., `foo.*.bar`)
/// - `>` matches one or more tokens at the tail (e.g., `foo.>`)
pub trait Subscriber: Send + Sync {
/// Subscribe to a subject pattern, returning a message stream.
///
/// The returned stream yields messages as they arrive. The subscription
/// remains active until the stream is dropped.
fn subscribe(&self, subject: &str) -> BoxFuture<'static, Result<Subscription>>;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NATS implementation of the PubSub traits.
use std::sync::Arc;
use anyhow::{Context, Result};
use async_nats::Client;
use bytes::Bytes;
use flume::{Receiver, Sender};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use tokio::sync::oneshot;
use tracing::error;
use super::{Message, Publisher, Subscriber, Subscription};
/// Configuration for NATS publisher/subscriber.
#[derive(Debug, Clone)]
pub struct NatsConfig {
/// NATS server URL (e.g., "nats://localhost:4222").
pub server_url: String,
/// Optional subject prefix prepended to all subjects.
pub subject_prefix: Option<String>,
}
impl NatsConfig {
/// Create a new NATS configuration.
pub fn new(server_url: impl Into<String>) -> Self {
Self {
server_url: server_url.into(),
subject_prefix: None,
}
}
/// Set an optional subject prefix.
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.subject_prefix = Some(prefix.into());
self
}
/// Connect to the NATS server and return a client.
pub async fn connect(&self) -> Result<Client> {
async_nats::connect(&self.server_url)
.await
.context("failed to connect to NATS server")
}
/// Format a subject with the optional prefix.
fn format_subject(&self, subject: &str) -> String {
match &self.subject_prefix {
Some(prefix) => format!("{}.{}", prefix, subject),
None => subject.to_string(),
}
}
}
/// Command sent to the publisher background task.
enum PublishCommand {
/// Publish a message to a subject.
Publish { subject: String, payload: Bytes },
/// Flush pending messages and notify when complete.
Flush { done: oneshot::Sender<Result<()>> },
}
/// NATS implementation of the [`Publisher`] trait.
///
/// Uses a background task with a flume channel to handle async publishes.
pub struct NatsPublisher {
tx: Sender<PublishCommand>,
config: Arc<NatsConfig>,
}
impl NatsPublisher {
/// Create a new NATS publisher from a client and configuration.
///
/// Spawns a background task to handle async publish operations.
pub fn new(client: Client, config: NatsConfig) -> Self {
let (tx, rx) = flume::unbounded();
let config = Arc::new(config);
tokio::spawn(Self::run_publish_loop(client, rx));
Self { tx, config }
}
/// Create a new NATS publisher by connecting to the server.
pub async fn connect(config: NatsConfig) -> Result<Self> {
let client = config.connect().await?;
Ok(Self::new(client, config))
}
/// Background task that processes publish commands.
async fn run_publish_loop(client: Client, rx: Receiver<PublishCommand>) {
while let Ok(cmd) = rx.recv_async().await {
match cmd {
PublishCommand::Publish { subject, payload } => {
if let Err(e) = client.publish(subject, payload).await {
error!("failed to publish message: {e}");
}
}
PublishCommand::Flush { done } => {
let result = client.flush().await.context("failed to flush");
// Ignore send error (receiver may have dropped)
let _ = done.send(result);
}
}
}
}
}
impl Publisher for NatsPublisher {
fn publish(&self, subject: &str, payload: Bytes) -> Result<()> {
let subject = self.config.format_subject(subject);
self.tx
.send(PublishCommand::Publish { subject, payload })
.map_err(|_| anyhow::anyhow!("publisher task has terminated"))
}
fn flush(&self) -> BoxFuture<'static, Result<()>> {
let (done_tx, done_rx) = oneshot::channel();
let tx = self.tx.clone();
async move {
tx.send(PublishCommand::Flush { done: done_tx })
.map_err(|_| anyhow::anyhow!("publisher task has terminated"))?;
done_rx
.await
.map_err(|_| anyhow::anyhow!("publisher task has terminated"))?
}
.boxed()
}
}
/// NATS implementation of the [`Subscriber`] trait.
pub struct NatsSubscriber {
client: Client,
config: NatsConfig,
}
impl NatsSubscriber {
/// Create a new NATS subscriber from a client and configuration.
pub fn new(client: Client, config: NatsConfig) -> Self {
Self { client, config }
}
/// Create a new NATS subscriber by connecting to the server.
pub async fn connect(config: NatsConfig) -> Result<Self> {
let client = config.connect().await?;
Ok(Self::new(client, config))
}
}
impl Subscriber for NatsSubscriber {
fn subscribe(&self, subject: &str) -> BoxFuture<'static, Result<Subscription>> {
let subject = self.config.format_subject(subject);
let client = self.client.clone();
async move {
let subscriber = client
.subscribe(subject)
.await
.context("failed to subscribe")?;
let stream: BoxStream<'static, Message> = subscriber
.map(|msg| Message {
subject: msg.subject.to_string(),
payload: msg.payload,
})
.boxed();
Ok(stream)
}
.boxed()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! In-memory stub implementation of the PubSub traits for testing.
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use parking_lot::RwLock;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use super::{Message, Publisher, Subscriber, Subscription};
/// Shared state for stub publisher/subscriber pairs.
#[derive(Clone)]
pub struct StubBus {
inner: Arc<StubBusInner>,
}
struct StubBusInner {
/// Map of subject patterns to broadcast channels.
channels: RwLock<HashMap<String, broadcast::Sender<Message>>>,
/// Channel capacity for new subscriptions.
capacity: usize,
}
impl Default for StubBus {
fn default() -> Self {
Self::new(256)
}
}
impl StubBus {
/// Create a new stub bus with the specified channel capacity.
pub fn new(capacity: usize) -> Self {
Self {
inner: Arc::new(StubBusInner {
channels: RwLock::new(HashMap::new()),
capacity,
}),
}
}
/// Create a publisher for this bus.
pub fn publisher(&self) -> StubPublisher {
StubPublisher { bus: self.clone() }
}
/// Create a subscriber for this bus.
pub fn subscriber(&self) -> StubSubscriber {
StubSubscriber { bus: self.clone() }
}
fn get_or_create_channel(&self, subject: &str) -> broadcast::Sender<Message> {
let channels = self.inner.channels.read();
if let Some(tx) = channels.get(subject) {
return tx.clone();
}
drop(channels);
let mut channels = self.inner.channels.write();
// Double-check after acquiring write lock
if let Some(tx) = channels.get(subject) {
return tx.clone();
}
let (tx, _) = broadcast::channel(self.inner.capacity);
channels.insert(subject.to_string(), tx.clone());
tx
}
}
/// Stub implementation of the [`Publisher`] trait for testing.
pub struct StubPublisher {
bus: StubBus,
}
impl StubPublisher {
/// Create a new stub publisher with a dedicated bus.
pub fn new() -> (Self, StubSubscriber) {
let bus = StubBus::default();
(bus.publisher(), bus.subscriber())
}
}
impl Publisher for StubPublisher {
fn publish(&self, subject: &str, payload: Bytes) -> Result<()> {
let tx = self.bus.get_or_create_channel(subject);
let msg = Message {
subject: subject.to_string(),
payload,
};
// Ignore send errors (no receivers is ok)
let _ = tx.send(msg);
Ok(())
}
fn flush(&self) -> BoxFuture<'static, Result<()>> {
// In-memory delivery is synchronous, nothing to flush
async { Ok(()) }.boxed()
}
}
/// Stub implementation of the [`Subscriber`] trait for testing.
pub struct StubSubscriber {
bus: StubBus,
}
impl Subscriber for StubSubscriber {
fn subscribe(&self, subject: &str) -> BoxFuture<'static, Result<Subscription>> {
let tx = self.bus.get_or_create_channel(subject);
let rx = tx.subscribe();
let stream: BoxStream<'static, Message> = BroadcastStream::new(rx)
.filter_map(|result| async move { result.ok() })
.boxed();
async move { Ok(stream) }.boxed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_stub_pubsub() {
let bus = StubBus::default();
let publisher = bus.publisher();
let subscriber = bus.subscriber();
// Subscribe first
let mut sub = subscriber.subscribe("test.subject").await.unwrap();
// Publish a message
publisher
.publish("test.subject", Bytes::from("hello"))
.unwrap();
// Receive the message
let msg = sub.next().await.unwrap();
assert_eq!(msg.subject, "test.subject");
assert_eq!(msg.payload.as_ref(), b"hello");
}
#[tokio::test]
async fn test_stub_multiple_subscribers() {
let bus = StubBus::default();
let publisher = bus.publisher();
let mut sub1 = bus.subscriber().subscribe("multi").await.unwrap();
let mut sub2 = bus.subscriber().subscribe("multi").await.unwrap();
publisher
.publish("multi", Bytes::from("broadcast"))
.unwrap();
let msg1 = sub1.next().await.unwrap();
let msg2 = sub2.next().await.unwrap();
assert_eq!(msg1.payload.as_ref(), b"broadcast");
assert_eq!(msg2.payload.as_ref(), b"broadcast");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Builder for KvbmRuntime with optional pre-built components.
use std::sync::Arc;
use anyhow::Result;
use dynamo_memory::nixl::NixlAgent;
use kvbm_config::KvbmConfig;
use tokio::runtime::{Handle, Runtime};
use velo::Messenger;
/// Runtime handle - either owned or borrowed.
pub enum RuntimeHandle {
/// Owned runtime (created by builder).
Owned(Arc<Runtime>),
/// Borrowed handle (external runtime).
Handle(Handle),
}
impl RuntimeHandle {
/// Get a handle to the runtime.
pub fn handle(&self) -> Handle {
match self {
RuntimeHandle::Owned(rt) => rt.handle().clone(),
RuntimeHandle::Handle(h) => h.clone(),
}
}
}
/// Builder for KvbmRuntime with optional pre-built components.
///
/// The builder allows injecting pre-built components or building them from config:
/// - If a component is provided, it's used directly
/// - If not provided, the component is built from the config
pub struct KvbmRuntimeBuilder {
config: KvbmConfig,
runtime: Option<RuntimeHandle>,
messenger: Option<Arc<Messenger>>,
nixl_agent: Option<NixlAgent>,
}
impl KvbmRuntimeBuilder {
/// Create builder from config.
pub fn new(config: KvbmConfig) -> Self {
Self {
config,
runtime: None,
messenger: None,
nixl_agent: None,
}
}
/// Create builder from environment.
pub fn from_env() -> Result<Self, kvbm_config::ConfigError> {
Ok(Self::new(KvbmConfig::from_env()?))
}
/// Create builder from JSON config string (merged with env/files).
///
/// JSON has highest priority - overrides env vars, TOML files, and defaults.
/// This is the primary entrypoint for vLLM's `kv_connector_extra_config` dict.
pub fn from_json(json: &str) -> Result<Self, kvbm_config::ConfigError> {
Ok(Self::new(KvbmConfig::from_figment_with_json(json)?))
}
/// Use an existing tokio Runtime (takes ownership via Arc).
pub fn with_runtime(mut self, runtime: Arc<Runtime>) -> Self {
self.runtime = Some(RuntimeHandle::Owned(runtime));
self
}
/// Use an existing tokio Handle (borrowed).
pub fn with_runtime_handle(mut self, handle: Handle) -> Self {
self.runtime = Some(RuntimeHandle::Handle(handle));
self
}
/// Use an existing Messenger instance.
pub fn with_messenger(mut self, messenger: Arc<Messenger>) -> Self {
self.messenger = Some(messenger);
self
}
/// Use an existing NixlAgent instance.
pub fn with_nixl_agent(mut self, agent: NixlAgent) -> Self {
self.nixl_agent = Some(agent);
self
}
/// Build runtime for leader role.
pub async fn build_leader(self) -> Result<super::KvbmRuntime> {
self.build_internal().await
}
/// Build runtime for worker role.
pub async fn build_worker(self) -> Result<super::KvbmRuntime> {
self.build_internal().await
}
async fn build_internal(self) -> Result<super::KvbmRuntime> {
// 1. Tokio runtime - use provided or build from config
let runtime = match self.runtime {
Some(rt) => rt,
None => RuntimeHandle::Owned(Arc::new(self.config.tokio.build_runtime()?)),
};
// 2. Messenger - use provided or build from config (BEFORE NixL)
let messenger = match self.messenger {
Some(m) => m,
None => self.config.messenger.build_messenger().await?,
};
// 3. NixL - use provided or build from config (AFTER Messenger)
// Only build if config.nixl is Some (NixL enabled)
let nixl_agent = match self.nixl_agent {
Some(agent) => Some(agent),
None => match &self.config.nixl {
Some(nixl_config) => {
let agent_name = format!("nixl-{}", messenger.instance_id());
let backend_config = nixl_config.clone().into();
Some(NixlAgent::from_nixl_backend_config(
&agent_name,
backend_config,
)?)
}
None => None, // NixL disabled
},
};
Ok(super::KvbmRuntime {
config: self.config,
runtime,
messenger,
nixl_agent,
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! KVBM Runtime - composed infrastructure for kvbm operations.
//!
//! The runtime contains the minimal shared components needed to construct
//! all downstream managers and services:
//! - Tokio runtime (for async execution)
//! - NixlAgent (for RDMA/UCX transfers)
//! - Nova (for distributed RPC)
//!
//! # Usage
//!
//! ```rust,ignore
//! // Build from environment (leader role)
//! let runtime = KvbmRuntime::from_env_leader().await?;
//!
//! // Build with custom config and injected components
//! let config = KvbmConfig::extract_from(
//! KvbmConfig::figment()
//! .merge(("nova.backend.tcp_port", 8080u16))
//! )?;
//! let runtime = KvbmRuntime::builder(config)
//! .with_runtime_handle(Handle::current())
//! .build_leader()
//! .await?;
//!
//! // Use runtime components
//! let transfer_mgr = TransferManager::builder()
//! .nixl_agent(runtime.nixl_agent().clone())
//! .event_system(runtime.event_system().clone())
//! .build()?;
//! ```
mod builder;
pub use builder::{KvbmRuntimeBuilder, RuntimeHandle};
use std::sync::Arc;
use dynamo_memory::nixl::NixlAgent;
use kvbm_config::KvbmConfig;
use tokio::runtime::Handle;
use velo::Messenger;
/// KVBM Runtime - composed infrastructure for kvbm operations.
///
/// Contains the minimal shared components needed to construct
/// all downstream managers and services:
/// - Tokio runtime (for async execution)
/// - NixlAgent (for RDMA/UCX transfers)
/// - Nova (for distributed RPC)
///
/// The `LocalEventSystem` is available via `event_system()` which
/// returns the system from Nova.
pub struct KvbmRuntime {
pub(crate) config: KvbmConfig,
pub(crate) runtime: RuntimeHandle,
pub(crate) messenger: Arc<Messenger>,
pub(crate) nixl_agent: Option<NixlAgent>,
}
impl KvbmRuntime {
/// Create a builder for customized construction.
pub fn builder(config: KvbmConfig) -> KvbmRuntimeBuilder {
KvbmRuntimeBuilder::new(config)
}
/// Quick construction from environment (for leader role).
pub async fn from_env_leader() -> anyhow::Result<Self> {
KvbmRuntimeBuilder::from_env()?.build_leader().await
}
/// Quick construction from environment (for worker role).
pub async fn from_env_worker() -> anyhow::Result<Self> {
KvbmRuntimeBuilder::from_env()?.build_worker().await
}
/// Get the configuration.
pub fn config(&self) -> &KvbmConfig {
&self.config
}
/// Get the tokio runtime handle.
pub fn handle(&self) -> Handle {
self.runtime.handle()
}
/// Get the tokio runtime handle (convenience alias for handle()).
pub fn tokio(&self) -> Handle {
self.handle()
}
/// Get Messenger.
pub fn messenger(&self) -> &Arc<Messenger> {
&self.messenger
}
/// Get NixlAgent for RDMA/UCX transfers.
/// Returns None if NixL is disabled in config.
pub fn nixl_agent(&self) -> Option<&NixlAgent> {
self.nixl_agent.as_ref()
}
/// Get the event manager for worker coordination and transfer notifications.
pub fn event_system(&self) -> Arc<velo::EventManager> {
Arc::new(self.messenger.event_manager())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Distributed leader testing utilities.
//!
//! This module provides test infrastructure for:
//! - Single-leader tests with `TestInstanceLeader` and `InstanceLeaderPair`
//! - Multi-worker RDMA tests with `TestWorker` and `TestInstanceLeaderWithWorkers`
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use crate::{
BlockId, G2, G3, InstanceId, SequenceHash,
leader::InstanceLeader,
worker::{DirectWorker, Worker},
};
use kvbm_logical::manager::BlockManager;
use kvbm_physical::manager::{LayoutHandle, TransferManager};
use kvbm_physical::transfer::StorageKind;
use kvbm_physical::{
layout::LayoutConfig,
transfer::{BlockChecksum, FillPattern},
};
use super::{managers, messenger, physical, token_blocks};
/// Number of layers for layerwise transfer tests.
pub const DEFAULT_NUM_LAYERS: usize = 3;
/// Container for a test InstanceLeader with its managers.
pub struct TestInstanceLeader {
pub instance_id: InstanceId,
pub leader: InstanceLeader,
pub g2_manager: Arc<BlockManager<G2>>,
pub g3_manager: Option<Arc<BlockManager<G3>>>,
}
/// Container for a pair of connected InstanceLeaders.
pub struct InstanceLeaderPair {
pub leader_a: TestInstanceLeader,
pub leader_b: TestInstanceLeader,
}
/// Create a pair of InstanceLeaders connected via Messenger for integration testing.
///
/// Setup:
/// - Two Messenger instances with TCP transport
/// - Bidirectional peer registration
/// - G2 BlockManagers for each leader
/// - Handlers registered for distributed communication
///
/// # Arguments
/// * `block_count` - Number of blocks in each G2 manager
/// * `block_size` - Tokens per block
///
/// # Returns
/// InstanceLeaderPair with both leaders ready for testing
///
/// # Example
/// ```ignore
/// let pair = create_instance_leader_pair(100, 16).await?;
///
/// // Populate leader A with blocks
/// let (_, hashes) = populate_leader_with_blocks(&pair.leader_a, 32, 16, 0)?;
///
/// // Leader B can search leader A
/// let result = pair.leader_b.leader.find_matches(&hashes)?;
/// ```
pub async fn create_instance_leader_pair(
block_count: usize,
block_size: usize,
) -> Result<InstanceLeaderPair> {
// Create Messenger pair
let messenger::MessengerPair {
messenger_a,
messenger_b,
} = messenger::create_messenger_pair_tcp().await?;
// Create G2 managers
let registry_a = managers::TestRegistryBuilder::new().build();
let registry_b = managers::TestRegistryBuilder::new().build();
let g2_manager_a = Arc::new(
managers::TestManagerBuilder::<G2>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry_a.clone())
.build(),
);
let g3_manager_a = Arc::new(
managers::TestManagerBuilder::<G3>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry_a.clone())
.build(),
);
let g2_manager_b = Arc::new(
managers::TestManagerBuilder::<G2>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry_b.clone())
.build(),
);
let g3_manager_b = Arc::new(
managers::TestManagerBuilder::<G3>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry_b.clone())
.build(),
);
// Build InstanceLeader A
let leader_a = InstanceLeader::builder()
.messenger(messenger_a.clone())
.registry(registry_a.clone())
.g2_manager(g2_manager_a.clone())
.g3_manager(g3_manager_a.clone())
.workers(vec![]) // No workers for now (no transfers)
.remote_leaders(vec![messenger_b.instance_id()])
.build()?;
// Register handlers for A
leader_a.register_handlers()?;
// Build InstanceLeader B
let leader_b = InstanceLeader::builder()
.messenger(messenger_b.clone())
.registry(registry_b.clone())
.g2_manager(g2_manager_b.clone())
.g3_manager(g3_manager_b.clone())
.workers(vec![]) // No workers for now
.remote_leaders(vec![messenger_a.instance_id()])
.build()?;
// Register handlers for B
leader_b.register_handlers()?;
Ok(InstanceLeaderPair {
leader_a: TestInstanceLeader {
instance_id: messenger_a.instance_id(),
leader: leader_a,
g2_manager: g2_manager_a,
g3_manager: Some(g3_manager_a),
},
leader_b: TestInstanceLeader {
instance_id: messenger_b.instance_id(),
leader: leader_b,
g2_manager: g2_manager_b,
g3_manager: Some(g3_manager_b),
},
})
}
/// Populate a leader's G2 manager with token blocks.
///
/// # Arguments
/// * `leader` - The test leader instance
/// * `num_blocks` - Number of blocks to create
/// * `block_size` - Tokens per block
/// * `start_token` - Starting token value
///
/// # Returns
/// (BlockManager, Vec<SequenceHash>) - Manager and sequence hashes of populated blocks
///
/// # Example
/// ```ignore
/// let pair = create_instance_leader_pair(100, 4).await?;
/// let (manager, hashes) = populate_leader_with_blocks(&pair.leader_a, 32, 4, 0)?;
/// assert_eq!(hashes.len(), 32);
/// ```
pub fn populate_leader_with_blocks(
leader: &TestInstanceLeader,
num_blocks: usize,
block_size: usize,
start_token: u32,
) -> Result<(Arc<BlockManager<G2>>, Vec<SequenceHash>)> {
let token_sequence =
super::token_blocks::create_token_sequence(num_blocks, block_size, start_token);
let seq_hashes =
managers::populate_manager_with_blocks(&leader.g2_manager, token_sequence.blocks())?;
Ok((leader.g2_manager.clone(), seq_hashes))
}
// =============================================================================
// Multi-worker RDMA test infrastructure
// =============================================================================
/// Container for a test worker with its transfer infrastructure.
///
/// This wraps a DirectWorker with access to its TransferManager and registered layouts,
/// enabling fine-grained control over worker-level operations in tests.
pub struct TestWorker {
/// Unique instance identifier (primary identity).
pub instance_id: InstanceId,
/// Unique worker identifier derived from instance_id (used in LayoutHandle encoding).
pub worker_id: u64,
/// The DirectWorker instance (implements Worker trait).
pub worker: Arc<DirectWorker>,
/// TransferManager owned by this worker (for direct transfer operations).
pub manager: Arc<TransferManager>,
/// G2 layout handle registered with this worker.
pub g2_handle: LayoutHandle,
}
impl TestWorker {
/// Fill G2 blocks with test data and return checksums.
///
/// This uses the internal registry accessor to fill blocks in the
/// registered G2 layout. Only works with System or Pinned storage.
pub fn fill_g2_blocks(
&self,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<HashMap<BlockId, BlockChecksum>> {
physical::fill_and_checksum_manager(&self.manager, self.g2_handle, block_ids, pattern)
}
/// Compute checksums for G2 blocks (for verification after transfers).
///
/// This uses the internal registry accessor to compute checksums for
/// blocks in the registered G2 layout.
pub fn compute_g2_checksums(
&self,
block_ids: &[BlockId],
) -> Result<HashMap<BlockId, BlockChecksum>> {
physical::compute_manager_checksums(&self.manager, self.g2_handle, block_ids)
}
}
/// Container for a test InstanceLeader with accessible workers.
///
/// This extends TestInstanceLeader with actual DirectWorker instances,
/// allowing tests to access both the leader-level APIs and the underlying
/// worker infrastructure for RDMA operations.
pub struct TestInstanceLeaderWithWorkers {
/// Instance identifier.
pub instance_id: InstanceId,
/// The InstanceLeader.
pub leader: InstanceLeader,
/// G2 BlockManager for logical block management.
pub g2_manager: Arc<BlockManager<G2>>,
/// G3 BlockManager for disk-backed blocks.
pub g3_manager: Option<Arc<BlockManager<G3>>>,
/// Workers with their transfer infrastructure.
pub workers: Vec<TestWorker>,
}
impl TestInstanceLeaderWithWorkers {
/// Get the G2 layout handle (from first worker).
///
/// This is used for constructing BlockInfo in tests.
/// Returns `None` if there are no workers.
pub fn g2_layout_handle(&self) -> Option<LayoutHandle> {
self.workers.first().map(|w| w.g2_handle)
}
/// Populate G2 with blocks and return their sequence hashes.
///
/// This is a convenience method that combines allocation, filling,
/// and registration into one step.
pub fn populate_g2_blocks(
&self,
num_blocks: usize,
block_size: usize,
start_token: u32,
) -> Result<(Vec<BlockId>, Vec<SequenceHash>)> {
let token_sequence =
token_blocks::create_token_sequence(num_blocks, block_size, start_token);
let seq_hashes =
managers::populate_manager_with_blocks(&self.g2_manager, token_sequence.blocks())?;
// Get the block IDs that were allocated
let matched = self.g2_manager.match_blocks(&seq_hashes);
let block_ids: Vec<BlockId> = matched.into_iter().map(|b| b.block_id()).collect();
Ok((block_ids, seq_hashes))
}
/// Fill blocks on all workers with a layer-specific pattern.
///
/// Each layer gets a different fill byte: layer 0 = 0xA0, layer 1 = 0xA1, etc.
/// This enables verification that the correct layer was transferred.
pub fn fill_blocks_with_layer_pattern(
&self,
block_ids: &[BlockId],
layer: usize,
) -> Result<HashMap<BlockId, BlockChecksum>> {
let pattern = FillPattern::Constant(0xA0 + layer as u8);
let mut all_checksums = HashMap::new();
for worker in &self.workers {
let checksums = worker.fill_g2_blocks(block_ids, pattern)?;
all_checksums.extend(checksums);
}
Ok(all_checksums)
}
/// Verify that blocks have the expected layer pattern.
///
/// Checks that blocks were transferred correctly by verifying
/// the checksum matches the expected layer pattern.
pub fn verify_layer_checksums(
&self,
block_ids: &[BlockId],
expected_checksums: &HashMap<BlockId, BlockChecksum>,
) -> Result<()> {
for worker in &self.workers {
let actual_checksums = worker.compute_g2_checksums(block_ids)?;
for block_id in block_ids {
let expected = expected_checksums.get(block_id).ok_or_else(|| {
anyhow::anyhow!("Missing expected checksum for block {}", block_id)
})?;
let actual = actual_checksums.get(block_id).ok_or_else(|| {
anyhow::anyhow!("Missing actual checksum for block {}", block_id)
})?;
if expected != actual {
anyhow::bail!(
"Checksum mismatch for block {}: expected {:?}, got {:?}",
block_id,
expected,
actual
);
}
}
}
Ok(())
}
}
// =============================================================================
// Test Session Helper
// =============================================================================
use crate::leader::session::{
BlockInfo, EndpointSessionHandle, SessionHandle as UnifiedSessionHandle, SessionId,
SessionPhase, SessionStateSnapshot,
};
use kvbm_physical::transfer::TransferCompleteNotification;
use std::time::Duration;
/// Helper for establishing and managing test sessions with reduced boilerplate.
///
/// Encapsulates the create->attach->wait_for_ready pattern common in tests.
///
/// # Example
///
/// ```ignore
/// // BEFORE: 6 lines repeated in many tests
/// let (session_id, handle) = leader.create_endpoint_session(&hashes)?;
/// let mut remote_handle = remote_leader.attach_session(instance_id, session_id).await?;
/// let state = timeout(Duration::from_secs(5), remote_handle.wait_for_ready())
/// .await.expect("Timeout").expect("Ready");
///
/// // AFTER: 1 line
/// let session = TestSession::establish_default(&leader, &remote_leader, &hashes).await?;
/// ```
pub struct TestSession {
/// The session ID.
pub session_id: SessionId,
/// Handle held by the endpoint (source).
pub endpoint_handle: EndpointSessionHandle,
/// Handle held by the controller (destination).
pub controller_handle: UnifiedSessionHandle,
/// The initial state snapshot after ready.
pub initial_state: SessionStateSnapshot,
}
impl TestSession {
/// Establish a session between two leaders with default timeout (5 seconds).
///
/// # Arguments
/// * `endpoint_leader` - The source leader (endpoint) that creates the session
/// * `controller_leader` - The destination leader (controller) that attaches
/// * `hashes` - Sequence hashes for the blocks to expose in the session
pub async fn establish_default(
endpoint_leader: &InstanceLeader,
controller_leader: &InstanceLeader,
hashes: &[SequenceHash],
) -> Result<Self> {
Self::establish(
endpoint_leader,
controller_leader,
hashes,
Duration::from_secs(5),
)
.await
}
/// Establish a session between two leaders with custom timeout.
///
/// # Arguments
/// * `endpoint_leader` - The source leader (endpoint) that creates the session
/// * `controller_leader` - The destination leader (controller) that attaches
/// * `hashes` - Sequence hashes for the blocks to expose in the session
/// * `timeout_duration` - How long to wait for the session to become ready
pub async fn establish(
endpoint_leader: &InstanceLeader,
controller_leader: &InstanceLeader,
hashes: &[SequenceHash],
timeout_duration: Duration,
) -> Result<Self> {
// Create endpoint session on source
let (session_id, endpoint_handle) = endpoint_leader.create_endpoint_session(hashes)?;
// Controller attaches - get instance ID from Messenger
let endpoint_instance_id = endpoint_leader.messenger().instance_id();
let mut controller_handle = controller_leader
.attach_session(endpoint_instance_id, session_id)
.await?;
// Wait for ready state
let initial_state =
tokio::time::timeout(timeout_duration, controller_handle.wait_for_ready())
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for session to become ready"))?
.map_err(|e| anyhow::anyhow!("Session ready failed: {}", e))?;
Ok(Self {
session_id,
endpoint_handle,
controller_handle,
initial_state,
})
}
/// Returns the G2 blocks available in the session.
pub fn g2_blocks(&self) -> &[BlockInfo] {
&self.initial_state.g2_blocks
}
/// Returns the count of G3 blocks pending staging.
pub fn g3_pending(&self) -> usize {
self.initial_state.g3_pending
}
/// Returns the session phase.
pub fn phase(&self) -> &SessionPhase {
&self.initial_state.phase
}
/// Pull blocks via RDMA using the controller handle.
///
/// # Arguments
/// * `src_blocks` - Source block info (from g2_blocks())
/// * `dst_ids` - Destination block IDs on the controller side
pub async fn pull_blocks_rdma(
&mut self,
src_blocks: &[BlockInfo],
dst_ids: &[BlockId],
) -> Result<TransferCompleteNotification> {
self.controller_handle
.pull_blocks_rdma(src_blocks, dst_ids)
.await
}
/// Notify that layers are ready (called from endpoint side).
///
/// # Arguments
/// * `layer_range` - Range of layers that are ready
pub async fn notify_layers_ready(&self, layer_range: std::ops::Range<usize>) -> Result<()> {
self.endpoint_handle.notify_layers_ready(layer_range).await
}
/// Mark blocks as pulled (called from controller side).
pub async fn mark_blocks_pulled(&mut self, hashes: Vec<SequenceHash>) -> Result<()> {
self.controller_handle.mark_blocks_pulled(hashes).await
}
/// Close the endpoint session.
pub async fn close_endpoint(&self) -> Result<()> {
self.endpoint_handle.close().await
}
/// Clean shutdown of both sides of the session.
///
/// This consumes self because detach() takes ownership of the handle.
pub async fn close(self) -> Result<()> {
self.controller_handle.detach().await.ok();
self.endpoint_handle.close().await.ok();
Ok(())
}
}
// =============================================================================
// Instance Leader Pair with Workers
// =============================================================================
/// Container for a pair of leaders with workers for RDMA testing.
///
/// This is the primary test fixture for prefill-decode RDMA scenarios:
/// - `decode`: The source instance (has data to pull from)
/// - `prefill`: The destination instance (pulls data via RDMA)
pub struct InstanceLeaderPairWithWorkers {
/// Decode leader (source of RDMA transfers).
pub decode: TestInstanceLeaderWithWorkers,
/// Prefill leader (destination of RDMA transfers).
pub prefill: TestInstanceLeaderWithWorkers,
}
/// Create a DirectWorker with UCX backend and registered G2 layout.
///
/// # Arguments
/// * `instance_id` - Unique instance identifier for this worker
/// * `agent_name` - NIXL agent name (must be unique for RDMA addressing)
/// * `layout_config` - Configuration for the G2 physical layout
/// * `storage` - Storage type for the layout (typically Pinned for RDMA)
///
/// # Returns
/// TestWorker with TransferManager and registered G2 layout
///
/// # Worker ID Derivation
/// The worker_id is derived from instance_id using xxh3_64 hash, ensuring
/// unique LayoutHandles (worker_id, layout_id) for each worker.
///
/// # Backend Requirements
/// This function requires UCX backend for RDMA operations. Use
/// `physical::TestAgentBuilder` for more flexible backend handling.
pub fn create_direct_worker(
instance_id: InstanceId,
agent_name: &str,
layout_config: &LayoutConfig,
storage: StorageKind,
) -> Result<TestWorker> {
// Derive worker_id from instance_id (deterministic hash)
let worker_id = instance_id.worker_id().as_u64();
// Create local EventManager (purely local event system for this worker)
let event_system = velo::EventManager::local();
// Create NixlAgent with UCX backend using TestAgentBuilder
// UCX is required for RDMA operations
let test_agent = physical::TestAgentBuilder::new(agent_name)
.require_backend("UCX")
.build()?;
let agent = test_agent.into_nixl_agent();
// Create TransferManager with the event_system
let manager = TransferManager::builder()
.event_system(Arc::new(event_system))
.nixl_agent(agent.clone())
.cuda_device_id(0)
.build()?;
// Create and register G2 physical layout
// This will create LayoutHandle(worker_id, 0) - now unique per worker!
let layout = physical::create_fc_layout_with_config(agent, storage, layout_config.clone());
let g2_handle = manager.register_layout(layout)?;
// Create DirectWorker with G2 handle via builder
let direct_worker = DirectWorker::builder()
.manager(manager.clone())
.g2_handle(g2_handle)
.build()?;
Ok(TestWorker {
instance_id,
worker_id,
worker: Arc::new(direct_worker),
manager: Arc::new(manager),
g2_handle,
})
}
/// Create multiple DirectWorkers for a single leader.
///
/// Each worker gets:
/// - A unique InstanceId (UUID v4)
/// - A unique NixlAgent with UCX backend
/// - Its own TransferManager with unique worker_id
/// - A registered G2 physical layout
///
/// # Arguments
/// * `num_workers` - Number of workers to create
/// * `layout_config` - Configuration for G2 layouts
/// * `storage` - Storage type (typically Pinned for RDMA)
/// * `agent_name_prefix` - Prefix for agent names (e.g., "decode" -> "decode-worker-0")
///
/// # Returns
/// Vector of TestWorkers, one per worker, each with unique InstanceId
pub fn create_direct_workers(
num_workers: usize,
layout_config: &LayoutConfig,
storage: StorageKind,
agent_name_prefix: &str,
) -> Result<Vec<TestWorker>> {
let mut workers = Vec::with_capacity(num_workers);
for i in 0..num_workers {
// Create unique InstanceId for this worker
let instance_id = InstanceId::new_v4();
let agent_name = format!("{}-worker-{}", agent_name_prefix, i);
let worker = create_direct_worker(instance_id, &agent_name, layout_config, storage)?;
workers.push(worker);
}
Ok(workers)
}
/// Create an InstanceLeader with DirectWorkers for RDMA testing.
///
/// # Arguments
/// * `block_count` - Number of blocks in G2 manager
/// * `block_size` - Tokens per block
/// * `num_workers` - Number of DirectWorkers to create
/// * `layout_config` - Configuration for worker G2 layouts
/// * `storage` - Storage type for layouts
/// * `messenger` - Messenger instance for leader communication
/// * `remote_leaders` - Instance IDs of remote leaders
///
/// # Returns
/// TestInstanceLeaderWithWorkers with leader and worker infrastructure
#[allow(clippy::too_many_arguments)]
pub async fn create_instance_leader_with_workers(
block_count: usize,
block_size: usize,
num_workers: usize,
layout_config: &LayoutConfig,
storage: StorageKind,
messenger: Arc<velo::Messenger>,
remote_leaders: Vec<InstanceId>,
agent_name_prefix: &str,
) -> Result<TestInstanceLeaderWithWorkers> {
// Create G2 and G3 managers
let registry = managers::TestRegistryBuilder::new().build();
let g2_manager = Arc::new(
managers::TestManagerBuilder::<G2>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry.clone())
.build(),
);
let g3_manager = Arc::new(
managers::TestManagerBuilder::<G3>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry.clone())
.build(),
);
// Create DirectWorkers
let workers = create_direct_workers(num_workers, layout_config, storage, agent_name_prefix)?;
// Extract worker references for the leader
let worker_refs: Vec<Arc<dyn Worker>> = workers
.iter()
.map(|w| w.worker.clone() as Arc<dyn Worker>)
.collect();
// Build InstanceLeader
let leader = InstanceLeader::builder()
.messenger(messenger.clone())
.registry(registry.clone())
.g2_manager(g2_manager.clone())
.g3_manager(g3_manager.clone())
.workers(worker_refs)
.remote_leaders(remote_leaders)
.build()?;
// Register handlers
leader.register_handlers()?;
Ok(TestInstanceLeaderWithWorkers {
instance_id: messenger.instance_id(),
leader,
g2_manager,
g3_manager: Some(g3_manager),
workers,
})
}
/// Create a pair of InstanceLeaders with workers for RDMA integration testing.
///
/// Setup:
/// - Two Messenger instances with TCP transport
/// - Bidirectional peer registration
/// - N DirectWorkers per leader with UCX-registered layouts
/// - G2 BlockManagers for logical block management
///
/// # Arguments
/// * `block_count` - Number of blocks in each G2 manager
/// * `block_size` - Tokens per block
/// * `num_workers` - Number of workers per leader (must match for RDMA)
/// * `layout_config` - Configuration for worker G2 layouts
/// * `storage` - Storage type (typically Pinned for RDMA)
///
/// # Returns
/// InstanceLeaderPairWithWorkers ready for RDMA testing
///
/// # Example
/// ```ignore
/// let layout_config = custom_config(64, 3, 2, 4, 64, 2);
/// let pair = create_instance_leader_pair_with_workers(
/// 64, 4, 2, &layout_config, StorageKind::Pinned
/// ).await?;
///
/// // Fill decode workers with data
/// for worker in &pair.decode.workers {
/// fill_and_checksum(&layout, &block_ids, FillPattern::Sequential)?;
/// }
/// ```
pub async fn create_instance_leader_pair_with_workers(
block_count: usize,
block_size: usize,
num_workers: usize,
layout_config: &LayoutConfig,
storage: StorageKind,
) -> Result<InstanceLeaderPairWithWorkers> {
// Create Messenger pair
let messenger::MessengerPair {
messenger_a,
messenger_b,
} = messenger::create_messenger_pair_tcp().await?;
// Create Decode leader with workers
let decode = create_instance_leader_with_workers(
block_count,
block_size,
num_workers,
layout_config,
storage,
messenger_a.clone(),
vec![messenger_b.instance_id()],
"decode",
)
.await?;
// Create Prefill leader with workers
let prefill = create_instance_leader_with_workers(
block_count,
block_size,
num_workers,
layout_config,
storage,
messenger_b.clone(),
vec![messenger_a.instance_id()],
"prefill",
)
.await?;
Ok(InstanceLeaderPairWithWorkers { decode, prefill })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::leader::ControllableSessionOptions;
#[tokio::test]
async fn test_create_instance_leader_pair() {
let pair = create_instance_leader_pair(100, 16)
.await
.expect("Should create leader pair");
// Verify different instance IDs
assert_ne!(pair.leader_a.instance_id, pair.leader_b.instance_id);
// Verify managers are configured correctly
assert_eq!(pair.leader_a.g2_manager.total_blocks(), 100);
assert_eq!(pair.leader_a.g2_manager.block_size(), 16);
assert_eq!(pair.leader_b.g2_manager.total_blocks(), 100);
assert_eq!(pair.leader_b.g2_manager.block_size(), 16);
}
#[tokio::test]
async fn test_populate_leader_with_blocks() {
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
let (manager, hashes) =
populate_leader_with_blocks(&pair.leader_a, 10, 4, 0).expect("Should populate");
assert_eq!(hashes.len(), 10);
assert_eq!(manager.available_blocks(), 50); // All blocks available (10 in inactive)
// Verify blocks can be matched
let matched = manager.match_blocks(&hashes);
assert_eq!(matched.len(), 10);
}
// =========================================================================
// scan_with_policy Tests
// =========================================================================
/// Test simple linear scan policy.
///
/// This tests the basic usage of scan_with_policy where the policy
/// iterates through all hashes and yields each found block.
#[tokio::test]
async fn test_scan_with_policy_linear_scan() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
// Populate with blocks
let (_, hashes) =
populate_leader_with_blocks(&pair.leader_a, 10, 4, 0).expect("Should populate");
// Simple linear scan policy: find all blocks
let blocks: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&hashes, true, |hashes, ctx| {
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
// Should find all 10 blocks
assert_eq!(blocks.len(), 10);
// Verify all are G2 blocks (since we populated G2)
for block in &blocks {
assert!(block.is_g2());
}
// Verify positions are sequential (0-9)
for (i, block) in blocks.iter().enumerate() {
assert_eq!(block.position(), i as u64);
}
}
/// Test scan_with_policy with partial matches.
///
/// Tests that the policy correctly handles cases where some hashes
/// are not found in the manager.
#[tokio::test]
async fn test_scan_with_policy_partial_matches() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
// Populate with blocks at positions 0-4
let (_, hashes) =
populate_leader_with_blocks(&pair.leader_a, 5, 4, 0).expect("Should populate");
// Create some hashes that won't be found
let token_seq = token_blocks::create_token_sequence(3, 4, 1000);
let nonexistent_hashes = token_blocks::generate_sequence_hashes(&token_seq);
// Mix found and not-found hashes
let mixed_hashes: Vec<_> = hashes
.iter()
.take(2)
.chain(nonexistent_hashes.iter().take(2))
.chain(hashes.iter().skip(2))
.copied()
.collect();
// Linear scan should only find the existing blocks
let blocks: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&mixed_hashes, true, |hashes, ctx| {
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
// Should find only the 5 blocks that exist
assert_eq!(blocks.len(), 5);
}
/// Test contiguous subsequence discovery policy with a single contiguous sequence.
///
/// This tests that the policy correctly identifies a fully contiguous sequence
/// as a single run.
#[tokio::test]
async fn test_scan_with_policy_contiguous_single_run() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(100, 4)
.await
.expect("Should create pair");
// Create a single contiguous sequence (all positions are consecutive)
let (_, hashes) =
populate_leader_with_blocks(&pair.leader_a, 10, 4, 0).expect("Should populate");
// Contiguous subsequence discovery policy
let runs: Vec<Vec<TieredBlock>> =
pair.leader_a
.leader
.scan_with_policy(&hashes, true, |hashes, ctx| {
let mut sorted_hashes = hashes.to_vec();
sorted_hashes.sort_by_key(|h| h.position());
let mut current_run = Vec::new();
let mut last_pos: Option<u64> = None;
for hash in &sorted_hashes {
if let Some(block) = ctx.accessor().find(*hash) {
let pos = block.position();
let is_contiguous = last_pos.is_none_or(|p| pos == p + 1);
if is_contiguous {
current_run.push(block);
} else {
if !current_run.is_empty() {
ctx.yield_item(std::mem::take(&mut current_run));
}
current_run.push(block);
}
last_pos = Some(pos);
} else if !current_run.is_empty() {
ctx.yield_item(std::mem::take(&mut current_run));
last_pos = None;
}
}
if !current_run.is_empty() {
ctx.yield_item(current_run);
}
});
// Should find exactly 1 contiguous run containing all 10 blocks
assert_eq!(runs.len(), 1, "Expected single contiguous run");
assert_eq!(runs[0].len(), 10, "Run should contain all 10 blocks");
// Verify positions are consecutive 0-9
for (i, block) in runs[0].iter().enumerate() {
assert_eq!(block.position(), i as u64);
}
}
/// Test contiguous subsequence discovery policy with gaps.
///
/// This tests that when some blocks are missing from the search,
/// the policy correctly identifies separate runs.
#[tokio::test]
async fn test_scan_with_policy_contiguous_with_gaps() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(100, 4)
.await
.expect("Should create pair");
// Create a contiguous sequence of 10 blocks (positions 0-9)
let (_, all_hashes) =
populate_leader_with_blocks(&pair.leader_a, 10, 4, 0).expect("Should populate");
// Query only for blocks 0-2, 5-6, 8-9 (skipping 3-4 and 7)
// This should create 3 runs: [0,1,2], [5,6], [8,9]
let query_hashes: Vec<_> = all_hashes
.iter()
.enumerate()
.filter(|(i, _)| matches!(*i, 0..=2 | 5..=6 | 8..=9))
.map(|(_, h)| *h)
.collect();
// Contiguous subsequence discovery policy
let runs: Vec<Vec<TieredBlock>> =
pair.leader_a
.leader
.scan_with_policy(&query_hashes, true, |hashes, ctx| {
let mut sorted_hashes = hashes.to_vec();
sorted_hashes.sort_by_key(|h| h.position());
let mut current_run = Vec::new();
let mut last_pos: Option<u64> = None;
for hash in &sorted_hashes {
if let Some(block) = ctx.accessor().find(*hash) {
let pos = block.position();
let is_contiguous = last_pos.is_none_or(|p| pos == p + 1);
if is_contiguous {
current_run.push(block);
} else {
if !current_run.is_empty() {
ctx.yield_item(std::mem::take(&mut current_run));
}
current_run.push(block);
}
last_pos = Some(pos);
} else if !current_run.is_empty() {
ctx.yield_item(std::mem::take(&mut current_run));
last_pos = None;
}
}
if !current_run.is_empty() {
ctx.yield_item(current_run);
}
});
// Should find 3 runs
assert_eq!(runs.len(), 3, "Expected 3 contiguous runs");
// First run: positions 0, 1, 2
assert_eq!(runs[0].len(), 3);
assert_eq!(runs[0][0].position(), 0);
assert_eq!(runs[0][1].position(), 1);
assert_eq!(runs[0][2].position(), 2);
// Second run: positions 5, 6
assert_eq!(runs[1].len(), 2);
assert_eq!(runs[1][0].position(), 5);
assert_eq!(runs[1][1].position(), 6);
// Third run: positions 8, 9
assert_eq!(runs[2].len(), 2);
assert_eq!(runs[2][0].position(), 8);
assert_eq!(runs[2][1].position(), 9);
}
/// Test scan_with_policy with tiered G2/G3 blocks.
///
/// This tests the precedence behavior where G2 blocks are returned
/// preferentially over G3 blocks when both exist.
///
/// Setup:
/// - 4 blocks total (positions 0, 1, 2, 3)
/// - All 4 blocks are in G3
/// - Even blocks (0, 2) are ALSO in G2
///
/// Expected result:
/// - Blocks 0, 2 should come from G2 (precedence)
/// - Blocks 1, 3 should come from G3
#[tokio::test]
async fn test_scan_with_policy_tiered_g2_g3() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
// Create 4 token blocks
let token_sequence = token_blocks::create_token_sequence(4, 4, 0);
let all_token_blocks = token_sequence.blocks();
// Populate G3 with ALL 4 blocks
let g3_manager = pair
.leader_a
.g3_manager
.as_ref()
.expect("G3 manager should exist");
let g3_hashes =
managers::populate_manager_with_blocks(g3_manager, all_token_blocks).expect("G3 pop");
// Populate G2 with only EVEN blocks (positions 0, 2)
let even_token_blocks: Vec<_> = all_token_blocks
.iter()
.enumerate()
.filter(|(i, _)| i % 2 == 0)
.map(|(_, b)| b.clone())
.collect();
let _g2_hashes =
managers::populate_manager_with_blocks(&pair.leader_a.g2_manager, &even_token_blocks)
.expect("G2 pop");
// The hashes from G3 are for all blocks; G2 hashes are for even blocks only
// We'll query using the G3 hashes (which cover all 4 blocks)
// Note: The actual sequence hashes should match between G2 and G3 for the same token content
// Simple linear scan to get all blocks
let blocks: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&g3_hashes, true, |hashes, ctx| {
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
// Should find all 4 blocks
assert_eq!(blocks.len(), 4, "Should find all 4 blocks");
// Count G2 vs G3 blocks
let g2_count = blocks.iter().filter(|b| b.is_g2()).count();
let g3_count = blocks.iter().filter(|b| b.is_g3()).count();
// Even positions (0, 2) should be G2, odd positions (1, 3) should be G3
assert_eq!(g2_count, 2, "Should have 2 G2 blocks (even positions)");
assert_eq!(g3_count, 2, "Should have 2 G3 blocks (odd positions)");
// Verify the specific tier for each position
// Blocks are returned in the order we queried (g3_hashes order = 0, 1, 2, 3)
assert!(blocks[0].is_g2(), "Block at position 0 should be G2 (even)");
assert!(blocks[1].is_g3(), "Block at position 1 should be G3 (odd)");
assert!(blocks[2].is_g2(), "Block at position 2 should be G2 (even)");
assert!(blocks[3].is_g3(), "Block at position 3 should be G3 (odd)");
// Verify positions are correct
for (i, block) in blocks.iter().enumerate() {
assert_eq!(
block.position(),
i as u64,
"Block {} should be at position {}",
i,
i
);
}
}
/// Test scan_with_policy with empty input.
#[tokio::test]
async fn test_scan_with_policy_empty_hashes() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
let empty_hashes: Vec<SequenceHash> = vec![];
let blocks: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&empty_hashes, true, |hashes, ctx| {
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
assert!(blocks.is_empty());
}
/// Test scan_with_policy with yield_items (batch yield).
#[tokio::test]
async fn test_scan_with_policy_yield_items() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
// Populate with blocks
let (_, hashes) =
populate_leader_with_blocks(&pair.leader_a, 10, 4, 0).expect("Should populate");
// Policy that uses yield_items to batch results
let blocks: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&hashes, true, |hashes, ctx| {
let found: Vec<TieredBlock> = hashes
.iter()
.filter_map(|hash| ctx.accessor().find(*hash))
.collect();
ctx.yield_items(found);
});
assert_eq!(blocks.len(), 10);
}
/// Test scan_with_policy touch parameter.
///
/// Verifies that the touch parameter is correctly passed to the accessor
/// and affects frequency tracking behavior.
#[tokio::test]
async fn test_scan_with_policy_touch_parameter() {
use crate::leader::TieredBlock;
let pair = create_instance_leader_pair(50, 4)
.await
.expect("Should create pair");
let (_, hashes) =
populate_leader_with_blocks(&pair.leader_a, 5, 4, 0).expect("Should populate");
// Scan with touch=false
let blocks_no_touch: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&hashes, false, |hashes, ctx| {
// Verify accessor has correct touch setting
assert!(!ctx.accessor().touch());
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
// Drop blocks so they return to the pool
drop(blocks_no_touch);
// Scan with touch=true
let blocks_with_touch: Vec<TieredBlock> =
pair.leader_a
.leader
.scan_with_policy(&hashes, true, |hashes, ctx| {
// Verify accessor has correct touch setting
assert!(ctx.accessor().touch());
for hash in hashes {
if let Some(block) = ctx.accessor().find(*hash) {
ctx.yield_item(block);
}
}
});
assert_eq!(blocks_with_touch.len(), 5);
}
// =========================================================================
// RDMA Transfer Tests (require UCX and CUDA)
// =========================================================================
// Test constants - scaling up to 2 workers, 4 blocks each
const NUM_WORKERS: usize = 2; // Two workers now
const LAYOUT_BLOCKS: usize = 16; // Blocks per layout
const TEST_BLOCKS: usize = 4; // Test 4 blocks at once
const BLOCK_SIZE: usize = 4; // Tokens per block
const NUM_LAYERS: usize = 2; // Layers
const OUTER_DIM: usize = 1; // Outer dim
const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 64;
const DTYPE_WIDTH: usize = 2; // bf16
const MANAGER_BLOCKS: usize = 16; // Blocks in G2 BlockManager
fn test_layout_config() -> LayoutConfig {
physical::custom_config(
LAYOUT_BLOCKS,
NUM_LAYERS,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE_WIDTH,
)
}
/// Full RDMA transfer test with checksum verification.
///
/// This test (simplified to 1 block for debugging):
/// 1. Creates a pair of leaders with 1 worker each
/// 2. Fills Decode worker's G2 layout with 0xAA pattern
/// 3. Fills Prefill worker's G2 destination with 0xBB pattern
/// 4. Prefill pulls block via RDMA
/// 5. Verifies: Decode unchanged (still 0xAA), Prefill has Decode's data (now 0xAA)
///
/// If the transfer goes the wrong direction (PUT instead of GET):
/// - Decode would have 0xBB (wrong!)
/// - Prefill would have 0xBB (unchanged, wrong!)
#[tokio::test(flavor = "multi_thread")]
async fn test_rdma_transfer_with_checksum_verification() {
use crate::leader::ControllableSessionOptions;
use std::time::Duration;
use tokio::time::timeout;
let layout_config = test_layout_config();
// 1. Create leader pair with workers
let pair = create_instance_leader_pair_with_workers(
MANAGER_BLOCKS,
BLOCK_SIZE,
NUM_WORKERS,
&layout_config,
StorageKind::Pinned,
)
.await
.expect("Should create leader pair with workers");
println!(
"\n=== RDMA Direction Test (1 block) ===\n\
Decode (source): instance={}, {} workers\n\
Prefill (dest): instance={}, {} workers",
pair.decode.instance_id,
pair.decode.workers.len(),
pair.prefill.instance_id,
pair.prefill.workers.len()
);
// 2. Define block IDs - multiple blocks now
let src_block_ids: Vec<BlockId> = (0..TEST_BLOCKS as BlockId).collect();
// Use non-overlapping block IDs for destination
let dst_block_ids: Vec<BlockId> = (TEST_BLOCKS..(TEST_BLOCKS * 2) as BlockId).collect();
println!(
"Testing {} blocks x {} workers: src={:?}, dst={:?}",
TEST_BLOCKS, NUM_WORKERS, src_block_ids, dst_block_ids
);
// 3. Fill ALL DECODE workers' source blocks with 0xAA pattern
let mut decode_checksums_before_by_worker = Vec::new();
for (i, worker) in pair.decode.workers.iter().enumerate() {
let checksums = worker
.fill_g2_blocks(&src_block_ids, FillPattern::Constant(0xAA))
.expect("Should fill Decode G2 blocks");
println!(
"BEFORE transfer - Decode worker {} blocks: {:?}",
i, src_block_ids
);
decode_checksums_before_by_worker.push(checksums);
}
// 4. Fill ALL PREFILL workers' destination blocks with 0xBB pattern (different!)
let mut prefill_checksums_before_by_worker = Vec::new();
for (i, worker) in pair.prefill.workers.iter().enumerate() {
let checksums = worker
.fill_g2_blocks(&dst_block_ids, FillPattern::Constant(0xBB))
.expect("Should fill Prefill G2 blocks");
println!(
"BEFORE transfer - Prefill worker {} blocks: {:?}",
i, dst_block_ids
);
prefill_checksums_before_by_worker.push(checksums);
}
// Sanity check: they should be different
assert_ne!(
decode_checksums_before_by_worker[0][&0],
prefill_checksums_before_by_worker[0][&dst_block_ids[0]],
"Pre-transfer: Decode and Prefill should have different data"
);
// 5. Populate Decode's logical manager with blocks
let test_leader = TestInstanceLeader {
instance_id: pair.decode.instance_id,
leader: pair.decode.leader.clone(),
g2_manager: pair.decode.g2_manager.clone(),
g3_manager: pair.decode.g3_manager.clone(),
};
let (_, sequence_hashes) =
populate_leader_with_blocks(&test_leader, TEST_BLOCKS, BLOCK_SIZE, 0)
.expect("Should populate leader");
// 6. Create controllable session on Decode
let session_result = pair
.decode
.leader
.create_controllable_session_with_options(
&sequence_hashes,
ControllableSessionOptions { auto_stage: false },
)
.expect("Should create controllable session");
println!(
"Decode session created: {} G2 blocks",
session_result.local_g2_count
);
// 7. Prefill attaches
let mut handle = pair
.prefill
.leader
.attach_session(pair.decode.instance_id, session_result.session_id)
.await
.expect("Should attach");
// 8. Wait for initial state
let state = timeout(Duration::from_secs(5), handle.wait_for_ready())
.await
.expect("Timeout")
.expect("Should get state");
println!(
"Prefill sees {} G2 blocks from Decode",
state.g2_blocks.len()
);
// 9. Execute RDMA PULL: Prefill pulls FROM Decode
println!("\n--- Executing RDMA pull: Decode block 0 -> Prefill block 1 ---");
let notification = handle
.pull_blocks_rdma(&state.g2_blocks, &dst_block_ids)
.await
.expect("Should initiate RDMA pull");
notification.await.expect("Transfer should complete");
println!("Transfer complete!\n");
// 10. SPMD replication: ALL workers have ALL blocks
// Each Prefill worker N pulled from Decode worker N
// So verify ALL blocks on ALL workers
println!("\nVerifying SPMD replication - all workers have all blocks:");
println!(
" Each worker: src={:?} -> dst={:?}",
src_block_ids, dst_block_ids
);
// 11. Verify transfer for each worker - all workers have all blocks
for (worker_idx, (decode_worker, prefill_worker)) in pair
.decode
.workers
.iter()
.zip(pair.prefill.workers.iter())
.enumerate()
{
let decode_checksums_after = decode_worker
.compute_g2_checksums(&src_block_ids)
.expect("compute Decode checksums");
let prefill_checksums_after = prefill_worker
.compute_g2_checksums(&dst_block_ids)
.expect("compute Prefill checksums");
println!(
"\nWorker {} verification ({} blocks):",
worker_idx, TEST_BLOCKS
);
let decode_checksums_before = &decode_checksums_before_by_worker[worker_idx];
for i in 0..TEST_BLOCKS {
let src_id = src_block_ids[i];
let dst_id = dst_block_ids[i];
// Decode source should be unchanged (still 0xAA)
assert_eq!(
decode_checksums_before[&src_id], decode_checksums_after[&src_id],
"Decode block {} was modified!",
src_id
);
// Prefill dest should have Decode's data (now 0xAA, not 0xBB)
assert_eq!(
decode_checksums_before[&src_id], prefill_checksums_after[&dst_id],
"Prefill block {} doesn't have Decode block {}'s data",
dst_id, src_id
);
println!(" Worker {} block {} -> {}", worker_idx, src_id, dst_id);
}
}
println!(
"\n=== SUCCESS: {} blocks correctly transferred across {} workers (SPMD) ===",
TEST_BLOCKS, NUM_WORKERS
);
// 12. Cleanup
handle.mark_blocks_pulled(sequence_hashes).await.ok();
handle.detach().await.ok();
}
/// Bidirectional layerwise transfer test.
///
/// This test demonstrates the full prefill-decode workflow:
/// 1. Decode holds cached blocks
/// 2. Prefill pulls cached blocks from Decode (standard flow)
/// 3. Control inverts - Prefill creates session for Decode to attach
/// 4. Prefill sends layerwise events as layers are "computed" (simulated)
/// 5. Decode pulls layer-by-layer via RDMA
///
/// The test validates:
/// - EndpointSession creation and remote attachment
/// - Layerwise notifications via EndpointSessionHandle
/// - Layer-specific RDMA pulls with TransferOptions
/// - Data integrity verification at each layer
#[tokio::test(flavor = "multi_thread")]
async fn test_bidirectional_layerwise_transfer() {
use crate::leader::session::SessionHandle as UnifiedSessionHandle;
use std::time::Duration;
use tokio::time::timeout;
// Test parameters
const CACHED_BLOCKS: usize = 4; // Blocks Prefill pulls from Decode
const NEW_BLOCKS: usize = 2; // Blocks Prefill exposes for Decode to pull
const NUM_TEST_LAYERS: usize = NUM_LAYERS; // Use module constant
let layout_config = test_layout_config();
// 1. Create leader pair with workers
let pair = create_instance_leader_pair_with_workers(
MANAGER_BLOCKS,
BLOCK_SIZE,
NUM_WORKERS,
&layout_config,
StorageKind::Pinned,
)
.await
.expect("Should create leader pair with workers");
println!(
"\n=== Bidirectional Layerwise Transfer Test ===\n\
Decode: instance={}, {} workers\n\
Prefill: instance={}, {} workers\n\
Cached blocks: {}, New blocks: {}, Layers: {}",
pair.decode.instance_id,
pair.decode.workers.len(),
pair.prefill.instance_id,
pair.prefill.workers.len(),
CACHED_BLOCKS,
NEW_BLOCKS,
NUM_TEST_LAYERS
);
// =====================================================================
// Phase 1: Decode Setup - Populate with cached blocks
// =====================================================================
println!("\n--- Phase 1: Decode Setup ---");
// Populate Decode with cached blocks
let (decode_cached_block_ids, cached_hashes) = pair
.decode
.populate_g2_blocks(CACHED_BLOCKS, BLOCK_SIZE, 0)
.expect("Should populate Decode");
// Fill cached blocks with test pattern (0xCA = "cache")
for worker in &pair.decode.workers {
worker
.fill_g2_blocks(&decode_cached_block_ids, FillPattern::Constant(0xCA))
.expect("Should fill cached blocks");
}
println!(
"Decode populated with {} cached blocks: {:?}",
CACHED_BLOCKS, decode_cached_block_ids
);
// Also populate Prefill with "new prefill" blocks (these will be pulled by Decode later)
let (prefill_new_block_ids, new_hashes) = pair
.prefill
.populate_g2_blocks(NEW_BLOCKS, BLOCK_SIZE, 1000)
.expect("Should populate Prefill");
println!(
"Prefill populated with {} new blocks: {:?}",
NEW_BLOCKS, prefill_new_block_ids
);
// =====================================================================
// Phase 2: Prefill Pulls Cached Blocks from Decode (Standard Flow)
// =====================================================================
println!("\n--- Phase 2: Prefill Pulls from Decode ---");
// Create controllable session on Decode
let session_result = pair
.decode
.leader
.create_controllable_session_with_options(
&cached_hashes,
ControllableSessionOptions { auto_stage: false },
)
.expect("Should create controllable session");
println!(
"Decode session created: {} G2 blocks",
session_result.local_g2_count
);
// Prefill attaches using unified session API
let mut prefill_handle = pair
.prefill
.leader
.attach_session(pair.decode.instance_id, session_result.session_id)
.await
.expect("Should attach");
// Wait for initial state
let state = timeout(Duration::from_secs(5), prefill_handle.wait_for_ready())
.await
.expect("Timeout waiting for initial state")
.expect("Should get initial state");
println!(
"Prefill sees {} G2 blocks from Decode",
state.g2_blocks.len()
);
// Prefill allocates destination blocks from its BlockManager
// We hold these MutableBlocks for the duration of the transfer - they are NOT
// owned by the session, the caller maintains ownership.
let prefill_dst_blocks = pair
.prefill
.g2_manager
.allocate_blocks(CACHED_BLOCKS)
.expect("Should allocate destination blocks on Prefill");
let prefill_dst_block_ids: Vec<BlockId> =
prefill_dst_blocks.iter().map(|b| b.block_id()).collect();
println!(
"Prefill allocated destination blocks: {:?}",
prefill_dst_block_ids
);
// Prefill pulls cached blocks (caller holds prefill_dst_blocks for duration)
let notification = prefill_handle
.pull_blocks_rdma(&state.g2_blocks, &prefill_dst_block_ids)
.await
.expect("Should initiate RDMA pull");
notification.await.expect("Transfer should complete");
println!("Prefill pulled {} cached blocks", CACHED_BLOCKS);
// Verify Prefill received Decode's cached data (0xCA pattern)
println!("Verifying Prefill received Decode's cached data...");
for (worker_idx, (decode_worker, prefill_worker)) in pair
.decode
.workers
.iter()
.zip(pair.prefill.workers.iter())
.enumerate()
{
let decode_checksums = decode_worker
.compute_g2_checksums(&decode_cached_block_ids)
.expect("Should compute Decode checksums");
let prefill_checksums = prefill_worker
.compute_g2_checksums(&prefill_dst_block_ids)
.expect("Should compute Prefill checksums");
for i in 0..CACHED_BLOCKS {
let src_id = decode_cached_block_ids[i];
let dst_id = prefill_dst_block_ids[i];
assert_eq!(
decode_checksums[&src_id], prefill_checksums[&dst_id],
"Worker {}: Prefill block {} should match Decode block {}",
worker_idx, dst_id, src_id
);
}
println!(
" Worker {} verified: Prefill has Decode's cached data",
worker_idx
);
}
// Detach without marking blocks pulled (Decode keeps those blocks)
// Detach without marking blocks pulled (Decode keeps those blocks).
prefill_handle.detach().await.ok();
println!("Prefill detached (Decode keeps cached blocks)");
// =====================================================================
// Phase 3: Role Reversal - Prefill Creates Session for Decode
// =====================================================================
println!("\n--- Phase 3: Role Reversal ---");
// Create EndpointSession on Prefill for the new blocks
let (prefill_session_id, prefill_session_handle) = pair
.prefill
.leader
.create_endpoint_session(&new_hashes)
.expect("Should create endpoint session");
println!("Prefill created endpoint session: {}", prefill_session_id);
// Decode attaches to Prefill's session as Controller (role reversal!)
let mut decode_handle: UnifiedSessionHandle = pair
.decode
.leader
.attach_session(pair.prefill.instance_id, prefill_session_id)
.await
.expect("Should attach to Prefill's session");
println!("Decode attached to Prefill's session");
// Wait for session to be ready
let state = timeout(Duration::from_secs(5), decode_handle.wait_for_ready())
.await
.expect("Timeout waiting for ready")
.expect("Should get ready state");
println!(
"Decode sees {} G2 blocks from Prefill, phase: {:?}",
state.g2_blocks.len(),
state.phase
);
// =====================================================================
// Phase 4: Layerwise Transfer (Decode Pulls from Prefill)
// =====================================================================
println!("\n--- Phase 4: Layerwise Transfer ---");
// Decode allocates destination blocks from its BlockManager
// We hold these MutableBlocks for the duration of the transfer - caller owns them.
let decode_dst_blocks = pair
.decode
.g2_manager
.allocate_blocks(NEW_BLOCKS)
.expect("Should allocate destination blocks on Decode");
let decode_dst_block_ids: Vec<BlockId> =
decode_dst_blocks.iter().map(|b| b.block_id()).collect();
println!(
"Decode allocated destination blocks: {:?}",
decode_dst_block_ids
);
// Fill Prefill's blocks with test pattern (0xBB = distinct from Decode's 0xCA)
for worker in &pair.prefill.workers {
worker
.fill_g2_blocks(&prefill_new_block_ids, FillPattern::Constant(0xBB))
.expect("Should fill Prefill blocks");
}
println!("Prefill blocks filled with pattern 0xBB");
// Demonstrate layerwise notification mechanism
// In a real scenario, each layer would be filled as compute completes
for layer in 0..NUM_TEST_LAYERS {
println!("\n Layer {} notification:", layer);
// Prefill signals layer is ready
prefill_session_handle
.notify_layers_ready(0..layer + 1)
.await
.expect("Should notify layer ready");
println!(" Prefill notified layers 0..{} ready", layer + 1);
// Wait briefly for notification to propagate
tokio::time::sleep(Duration::from_millis(10)).await;
}
// After all layers ready, Decode pulls all blocks at once
println!("\n Decode pulling all layers...");
let notification = decode_handle
.pull_blocks_rdma(&state.g2_blocks, &decode_dst_block_ids)
.await
.expect("Should initiate RDMA pull");
notification.await.expect("Transfer should complete");
println!(" Decode pulled all {} blocks", NEW_BLOCKS);
// Verify data integrity - Decode should have Prefill's data
println!("Verifying Decode received Prefill's data...");
for (worker_idx, (prefill_worker, decode_worker)) in pair
.prefill
.workers
.iter()
.zip(pair.decode.workers.iter())
.enumerate()
{
let prefill_checksums = prefill_worker
.compute_g2_checksums(&prefill_new_block_ids)
.expect("Should compute Prefill checksums");
let decode_checksums = decode_worker
.compute_g2_checksums(&decode_dst_block_ids)
.expect("Should compute Decode checksums");
for i in 0..NEW_BLOCKS {
let src_id = prefill_new_block_ids[i];
let dst_id = decode_dst_block_ids[i];
assert_eq!(
prefill_checksums[&src_id], decode_checksums[&dst_id],
"Worker {}: Decode block {} should match Prefill block {}",
worker_idx, dst_id, src_id
);
}
println!(
" Worker {} verified: Decode has Prefill's data (pattern 0xBB)",
worker_idx
);
}
// =====================================================================
// Phase 5: Cleanup
// =====================================================================
println!("\n--- Phase 5: Cleanup ---");
// Decode detaches
decode_handle
.mark_blocks_pulled(new_hashes.clone())
.await
.ok();
decode_handle.detach().await.ok();
println!("Decode detached from Prefill's session");
// Prefill closes its session
prefill_session_handle.close().await.ok();
println!("Prefill closed endpoint session");
println!(
"\n=== SUCCESS: Bidirectional layerwise transfer completed ===\n\
- {} cached blocks transferred Decode -> Prefill\n\
- {} new blocks transferred Prefill -> Decode ({} layers)",
CACHED_BLOCKS, NEW_BLOCKS, NUM_TEST_LAYERS
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! End-to-end testing utilities for the events pipeline.
//!
//! This module provides test infrastructure for verifying the complete event flow
//! from BlockManager registration through EventsManager, batching, and publishing.
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use derive_builder::Builder;
use futures::StreamExt;
use crate::InstanceId;
use crate::pubsub::{StubBus, Subscriber, Subscription};
use kvbm_logical::blocks::BlockMetadata;
use kvbm_logical::events::{
BatchingConfig, EventsManager, KvbmCacheEvents, KvbmCacheEventsPublisher,
};
use kvbm_logical::manager::BlockManager;
use super::managers::TestManagerBuilder;
// =============================================================================
// Events Pipeline Fixture
// =============================================================================
/// Configuration for creating an EventsPipelineFixture.
///
/// Uses `derive_builder` with a custom async build function since fixture
/// construction requires async operations (subscription setup).
///
/// # Example
///
/// ```ignore
/// // BEFORE: 15 lines of setup
/// let events_manager = Arc::new(EventsManager::builder().build());
/// let bus = StubBus::default();
/// let publisher = Arc::new(bus.publisher());
/// let subscriber = bus.subscriber();
/// let mut subscription = subscriber.subscribe("kvbm.events").await?;
/// let _publisher = KvbmCacheEventsPublisher::builder()
/// .instance_id(12345)
/// .event_stream(events_manager.subscribe())
/// .publisher(publisher)
/// .batching_config(BatchingConfig::default().with_window(Duration::from_millis(50)))
/// .subject("kvbm.events")
/// .build()?;
///
/// // AFTER: 3 lines
/// let mut fixture = EventsPipelineFixture::builder()
/// .batching_window(Duration::from_millis(50))
/// .build_async().await?;
/// let manager = fixture.create_manager::<G1>(100, 4);
/// ```
#[derive(Builder)]
#[builder(setter(into, strip_option), build_fn(skip), pattern = "owned")]
#[allow(dead_code)] // Fields are read via builder pattern
pub struct EventsPipelineConfig {
/// Instance ID for events (default: random v4 UUID).
#[builder(default)]
instance_id: Option<InstanceId>,
/// Batching window duration (default: 50ms).
#[builder(default = "Duration::from_millis(50)")]
batching_window: Duration,
/// Subject for publishing events (default: "kvbm.events").
#[builder(default = "\"kvbm.events\".to_string()")]
subject: String,
}
impl EventsPipelineConfigBuilder {
/// Builds the fixture asynchronously.
///
/// This is a custom build function because fixture construction requires
/// async operations (setting up the subscription).
pub async fn build_async(self) -> Result<EventsPipelineFixture> {
let instance_id = self
.instance_id
.flatten()
.unwrap_or_else(InstanceId::new_v4);
let batching_window = self.batching_window.unwrap_or(Duration::from_millis(50));
let subject = self.subject.unwrap_or_else(|| "kvbm.events".to_string());
// Create EventsManager
let events_manager = Arc::new(EventsManager::builder().build());
// Create stub pubsub
let bus = StubBus::default();
let publisher_arc = Arc::new(bus.publisher());
let subscriber = bus.subscriber();
// Subscribe BEFORE publishing (stub doesn't buffer)
let subscription = subscriber.subscribe(&subject).await?;
// Build the publishing pipeline - convert InstanceId to u128
let publisher = KvbmCacheEventsPublisher::builder()
.instance_id(instance_id.as_u128())
.event_stream(events_manager.subscribe())
.publisher(publisher_arc)
.batching_config(BatchingConfig::default().with_window(batching_window))
.subject(&subject)
.build()?;
Ok(EventsPipelineFixture {
events_manager,
subscription,
publisher,
bus,
instance_id,
subject,
})
}
}
/// Test fixture that encapsulates the full events pipeline setup.
///
/// This reduces the ~15 lines of boilerplate for setting up:
/// - EventsManager
/// - StubBus (publisher + subscriber)
/// - Subscription
/// - KvbmCacheEventsPublisher with batching
pub struct EventsPipelineFixture {
/// The EventsManager instance.
pub events_manager: Arc<EventsManager>,
/// Subscription for receiving published events.
pub subscription: Subscription,
/// The publisher (held to keep it alive).
#[allow(dead_code)]
publisher: KvbmCacheEventsPublisher,
/// The stub bus (held for reference).
#[allow(dead_code)]
bus: StubBus,
/// Instance ID used for events.
pub instance_id: InstanceId,
/// Subject for events.
pub subject: String,
}
impl EventsPipelineFixture {
/// Creates a new builder for the fixture.
///
/// Use `.build_async().await?` to construct the fixture.
pub fn builder() -> EventsPipelineConfigBuilder {
EventsPipelineConfigBuilder::default()
}
/// Creates a BlockManager with events integration.
///
/// # Arguments
/// * `block_count` - Number of blocks in the manager
/// * `block_size` - Tokens per block
pub fn create_manager<M: BlockMetadata>(
&self,
block_count: usize,
block_size: usize,
) -> BlockManager<M> {
TestManagerBuilder::<M>::new()
.block_count(block_count)
.block_size(block_size)
.events_manager(self.events_manager.clone())
.build()
}
/// Receive a batch of events with timeout.
///
/// Returns `None` if timeout expires before receiving events.
pub async fn receive_batch(&mut self, timeout: Duration) -> Option<KvbmCacheEvents> {
match tokio::time::timeout(timeout, self.subscription.next()).await {
Ok(Some(msg)) => rmp_serde::from_slice(&msg.payload).ok(),
_ => None,
}
}
/// Receive a batch with default timeout (500ms).
pub async fn receive_batch_default(&mut self) -> Option<KvbmCacheEvents> {
self.receive_batch(Duration::from_millis(500)).await
}
/// Wait for the batching window to flush, then receive events.
///
/// This sleeps for a bit longer than the batching window to ensure
/// events are flushed, then attempts to receive.
pub async fn flush_and_receive(
&mut self,
batching_window: Duration,
) -> Option<KvbmCacheEvents> {
tokio::time::sleep(batching_window + Duration::from_millis(50)).await;
self.receive_batch(Duration::from_millis(500)).await
}
/// Returns a reference to the EventsManager.
pub fn events_manager(&self) -> &Arc<EventsManager> {
&self.events_manager
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use super::super::managers::TestManagerBuilder;
use super::super::token_blocks;
use crate::G1;
use crate::pubsub::{StubBus, Subscriber};
use kvbm_logical::events::{
BatchingConfig, EventsManager, KvCacheEvents, KvbmCacheEvents, KvbmCacheEventsPublisher,
PowerOfTwoPolicy,
};
/// Full end-to-end test: G1 BlockManager -> EventsManager -> Batcher -> Publisher -> Subscriber
///
/// This test verifies the complete event pipeline:
/// 1. Token sequences are created and registered with the BlockManager
/// 2. EventsManager emits Create events via BlockRegistry integration
/// 3. EventBatcher batches and sorts events
/// 4. KvbmCacheEventsPublisher serializes and publishes via stub
/// 5. StubSubscriber receives the batched events
#[tokio::test]
async fn test_full_events_pipeline_with_block_manager() {
// 1. Create EventsManager (default AllEventsPolicy)
let events_manager = Arc::new(EventsManager::builder().build());
// 2. Create G1 BlockManager with EventsManager integrated via builder
let block_count = 100;
let block_size = 4;
let manager = TestManagerBuilder::<G1>::new()
.block_count(block_count)
.block_size(block_size)
.events_manager(events_manager.clone())
.build();
// 4. Create stub pubsub for testing
let bus = StubBus::default();
let publisher = Arc::new(bus.publisher());
let subscriber = bus.subscriber();
// 5. Subscribe BEFORE publishing (stub doesn't buffer)
let mut subscription = subscriber
.subscribe("kvbm.events")
.await
.expect("Should subscribe");
// 6. Build the publishing pipeline
let _events_publisher = KvbmCacheEventsPublisher::builder()
.instance_id(12345)
.event_stream(events_manager.subscribe())
.publisher(publisher)
.batching_config(BatchingConfig::default().with_window(Duration::from_millis(50)))
.subject("kvbm.events")
.build()
.expect("Should build publisher");
// 7. Create token sequence and register blocks via BlockManager
// This automatically triggers events through the registry -> events_manager chain
let num_blocks = 5;
let token_sequence = token_blocks::create_token_sequence(num_blocks, block_size, 0);
// Allocate, complete, and register blocks
let allocated_blocks = manager
.allocate_blocks(num_blocks)
.expect("Should allocate blocks");
// Complete blocks with token data
let complete_blocks: Vec<_> = allocated_blocks
.into_iter()
.zip(token_sequence.blocks())
.map(|(block, token_block)| block.complete(token_block).expect("Should complete"))
.collect();
// Register blocks - this triggers Create events via the integrated EventsManager
let _immutable_blocks = manager.register_blocks(complete_blocks);
// 8. Wait for batch window to flush
tokio::time::sleep(Duration::from_millis(100)).await;
// 9. Receive and verify the batched events
let msg = tokio::time::timeout(Duration::from_millis(500), subscription.next())
.await
.expect("Should receive within timeout")
.expect("Should have message");
// Deserialize the msgpack payload
let batch: KvbmCacheEvents =
rmp_serde::from_slice(&msg.payload).expect("Should deserialize");
assert_eq!(batch.instance_id, 12345);
match &batch.events {
KvCacheEvents::Create(hashes) => {
assert_eq!(hashes.len(), num_blocks);
// Verify sorted ascending by position
for i in 1..hashes.len() {
assert!(
hashes[i - 1].position() <= hashes[i].position(),
"Create events should be sorted ascending"
);
}
}
KvCacheEvents::Remove(_) => panic!("Expected Create events, got Remove"),
KvCacheEvents::Shutdown => panic!("Expected Create events, got Shutdown"),
}
}
/// Test events with PowerOfTwoPolicy filtering.
///
/// Only blocks at power-of-2 positions should emit events.
#[tokio::test]
async fn test_events_with_power_of_two_policy() {
// Use PowerOfTwoPolicy instead of default AllEventsPolicy
let events_manager = Arc::new(
EventsManager::builder()
.policy(Arc::new(PowerOfTwoPolicy::new()))
.build(),
);
let block_size = 4;
let manager = TestManagerBuilder::<G1>::new()
.block_count(100)
.block_size(block_size)
.events_manager(events_manager.clone())
.build();
let bus = StubBus::default();
let publisher = Arc::new(bus.publisher());
let subscriber = bus.subscriber();
let mut subscription = subscriber
.subscribe("kvbm.events")
.await
.expect("Should subscribe");
let _events_publisher = KvbmCacheEventsPublisher::builder()
.instance_id(12345)
.event_stream(events_manager.subscribe())
.publisher(publisher)
.batching_config(BatchingConfig::default().with_window(Duration::from_millis(50)))
.subject("kvbm.events")
.build()
.expect("Should build publisher");
// Create sequence with 32 blocks (positions 0-31)
// Power of 2 positions: 1, 2, 4, 8, 16
let num_blocks = 32;
let token_sequence = token_blocks::create_token_sequence(num_blocks, block_size, 0);
let allocated_blocks = manager
.allocate_blocks(num_blocks)
.expect("Should allocate blocks");
let complete_blocks: Vec<_> = allocated_blocks
.into_iter()
.zip(token_sequence.blocks())
.map(|(block, token_block)| block.complete(token_block).expect("Should complete"))
.collect();
let _immutable_blocks = manager.register_blocks(complete_blocks);
// Wait for batch window
tokio::time::sleep(Duration::from_millis(100)).await;
// Should receive only power-of-2 position events
let msg = tokio::time::timeout(Duration::from_millis(500), subscription.next())
.await
.expect("Should receive within timeout")
.expect("Should have message");
let batch: KvbmCacheEvents =
rmp_serde::from_slice(&msg.payload).expect("Should deserialize");
match &batch.events {
KvCacheEvents::Create(hashes) => {
// Verify all received are at power-of-2 positions
for hash in hashes {
let pos = hash.position();
assert!(
pos.is_power_of_two(),
"Position {} should be power of 2",
pos
);
}
}
KvCacheEvents::Remove(_) => panic!("Expected Create events"),
KvCacheEvents::Shutdown => panic!("Expected Create events"),
}
}
/// Test that Remove events are emitted when blocks are evicted from the pool.
///
/// Note: Dropping ImmutableBlocks returns them to the pool (not dropped).
/// Remove events only fire when blocks are actually evicted from the pool
/// due to capacity limits (LRU eviction).
#[tokio::test]
async fn test_remove_events_on_pool_eviction() {
let events_manager = Arc::new(EventsManager::builder().build());
// Small pool capacity to force eviction
let block_count = 10;
let block_size = 4;
let manager = TestManagerBuilder::<G1>::new()
.block_count(block_count)
.block_size(block_size)
.events_manager(events_manager.clone())
.build();
let bus = StubBus::default();
let publisher = Arc::new(bus.publisher());
let subscriber = bus.subscriber();
let mut subscription = subscriber
.subscribe("kvbm.events")
.await
.expect("Should subscribe");
let _events_publisher = KvbmCacheEventsPublisher::builder()
.instance_id(12345)
.event_stream(events_manager.subscribe())
.publisher(publisher)
.batching_config(BatchingConfig::default().with_window(Duration::from_millis(50)))
.subject("kvbm.events")
.build()
.expect("Should build publisher");
// Fill the pool completely with first batch of blocks
let first_batch_size = block_count;
let token_sequence1 = token_blocks::create_token_sequence(first_batch_size, block_size, 0);
let allocated_blocks = manager
.allocate_blocks(first_batch_size)
.expect("Should allocate blocks");
let complete_blocks: Vec<_> = allocated_blocks
.into_iter()
.zip(token_sequence1.blocks())
.map(|(block, token_block)| block.complete(token_block).expect("Should complete"))
.collect();
// Register and immediately drop handles to return blocks to pool
let _first_batch = manager.register_blocks(complete_blocks);
// Wait for Create batch
tokio::time::sleep(Duration::from_millis(100)).await;
let msg = tokio::time::timeout(Duration::from_millis(500), subscription.next())
.await
.expect("Should receive Create batch")
.expect("Should have message");
let batch: KvbmCacheEvents = rmp_serde::from_slice(&msg.payload).unwrap();
assert!(
matches!(batch.events, KvCacheEvents::Create(ref h) if h.len() == first_batch_size)
);
// Drop first batch to return blocks to pool
drop(_first_batch);
// Now allocate more blocks - this should trigger eviction of old blocks
// since we're reusing the same pool slots with new sequence hashes
let second_batch_size = block_count;
let token_sequence2 =
token_blocks::create_token_sequence(second_batch_size, block_size, 1000);
let allocated_blocks = manager
.allocate_blocks(second_batch_size)
.expect("Should allocate blocks for second batch");
let complete_blocks: Vec<_> = allocated_blocks
.into_iter()
.zip(token_sequence2.blocks())
.map(|(block, token_block)| block.complete(token_block).expect("Should complete"))
.collect();
let _second_batch = manager.register_blocks(complete_blocks);
// Wait for events
tokio::time::sleep(Duration::from_millis(100)).await;
// We should receive Remove events for evicted blocks and Create events for new blocks.
// The order depends on the event batching (type switches cause flush).
// Collect all events received
let mut received_creates = 0;
let mut received_removes = 0;
// Try to receive messages with timeout
while let Ok(Some(msg)) =
tokio::time::timeout(Duration::from_millis(200), subscription.next()).await
{
let batch: KvbmCacheEvents = rmp_serde::from_slice(&msg.payload).unwrap();
match batch.events {
KvCacheEvents::Create(hashes) => received_creates += hashes.len(),
KvCacheEvents::Remove(hashes) => received_removes += hashes.len(),
KvCacheEvents::Shutdown => {} // Ignore shutdown events in counting
}
}
// Should have created second batch
assert_eq!(
received_creates, second_batch_size,
"Should receive Create events for second batch"
);
// Should have removed first batch (evicted due to pool reuse)
assert_eq!(
received_removes, first_batch_size,
"Should receive Remove events for evicted blocks"
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment