// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Batch collection and accumulation for offload transfers. //! //! The batch collector accumulates blocks that pass policy evaluation and //! groups them into batches for efficient transfer execution. use std::collections::HashSet; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, watch}; use velo::EventHandle; use crate::{BlockId, SequenceHash}; use kvbm_logical::blocks::BlockMetadata; use super::handle::TransferId; use super::pending::PendingGuard; use super::queue::CancellableQueue; use super::source::SourceBlock; /// Timing trace for tracking block progression through pipeline stages. /// /// Each block carries a timing trace that records when it passed through /// each stage. This enables per-container and batch-level timing analysis. #[derive(Debug, Clone)] pub struct TimingTrace { /// When the block was initially enqueued into the pipeline pub enqueued_at: Instant, /// When policy evaluation completed for this block pub policy_complete_at: Option, /// When the precondition (e.g., forward pass) completed pub precondition_complete_at: Option, /// When the block was added to a transfer batch pub batched_at: Option, /// When the transfer operation started pub transfer_start_at: Option, /// When the transfer operation completed pub transfer_complete_at: Option, } impl TimingTrace { /// Create a new timing trace, recording the current time as enqueue time. pub fn new() -> Self { Self { enqueued_at: Instant::now(), policy_complete_at: None, precondition_complete_at: None, batched_at: None, transfer_start_at: None, transfer_complete_at: None, } } /// Mark policy evaluation complete. pub fn mark_policy_complete(&mut self) { self.policy_complete_at = Some(Instant::now()); } /// Mark precondition complete. pub fn mark_precondition_complete(&mut self) { self.precondition_complete_at = Some(Instant::now()); } /// Mark block as batched. pub fn mark_batched(&mut self) { self.batched_at = Some(Instant::now()); } /// Mark transfer start. pub fn mark_transfer_start(&mut self) { self.transfer_start_at = Some(Instant::now()); } /// Mark transfer complete. pub fn mark_transfer_complete(&mut self) { self.transfer_complete_at = Some(Instant::now()); } /// Get total time from enqueue to transfer complete (if available). pub fn total_duration(&self) -> Option { self.transfer_complete_at .map(|end| end.duration_since(self.enqueued_at)) } /// Get policy evaluation duration (if available). pub fn policy_duration(&self) -> Option { self.policy_complete_at .map(|end| end.duration_since(self.enqueued_at)) } /// Get precondition wait duration (if available). pub fn precondition_duration(&self) -> Option { match (self.policy_complete_at, self.precondition_complete_at) { (Some(start), Some(end)) => Some(end.duration_since(start)), _ => None, } } /// Get transfer duration (if available). pub fn transfer_duration(&self) -> Option { match (self.transfer_start_at, self.transfer_complete_at) { (Some(start), Some(end)) => Some(end.duration_since(start)), _ => None, } } } impl Default for TimingTrace { fn default() -> Self { Self::new() } } /// Configuration for batch collection. #[derive(Debug, Clone)] pub struct BatchConfig { /// Maximum blocks per batch pub max_batch_size: usize, /// Time to wait before flushing a partial batch pub flush_interval: Duration, /// Minimum batch size before flush (unless timeout) pub min_batch_size: usize, } impl Default for BatchConfig { fn default() -> Self { Self { max_batch_size: 1024, flush_interval: Duration::from_millis(10), min_batch_size: 8, } } } impl BatchConfig { /// Create a new batch config with specified max size. pub fn with_max_size(mut self, size: usize) -> Self { self.max_batch_size = size; self } /// Set the flush interval. pub fn with_flush_interval(mut self, interval: Duration) -> Self { self.flush_interval = interval; self } /// Set the minimum batch size. pub fn with_min_size(mut self, size: usize) -> Self { self.min_batch_size = size; self } } /// A block that passed policy evaluation and is queued for transfer. #[allow(dead_code)] pub struct QueuedBlock { /// Transfer ID this block belongs to pub transfer_id: TransferId, /// Block ID - Some for External/Strong, None for Weak (determined at upgrade) pub block_id: Option, /// Sequence hash pub sequence_hash: SequenceHash, /// Source block - Strong/External pass through, Weak upgraded just before transfer pub source: SourceBlock, /// Transfer state for completion tracking pub(crate) state: Arc>, /// RAII guard that removes this block from pending set on drop. /// /// This ensures duplicate prevention tracking is automatically cleaned up /// when the block completes transfer, is cancelled, or errors out. pub pending_guard: Option, } impl std::fmt::Debug for QueuedBlock { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("QueuedBlock") .field("transfer_id", &self.transfer_id) .field("block_id", &self.block_id) .field("sequence_hash", &self.sequence_hash) .finish() } } /// A batch of blocks ready for transfer execution. pub struct TransferBatch { /// Blocks in this batch pub blocks: Vec>, /// Optional precondition event that must be satisfied before processing. /// If Some, the pipeline will await this event before executing the transfer. pub precondition: Option, /// Timing trace for performance monitoring (batch-level, not per-block). pub timing: TimingTrace, } impl TransferBatch { /// Create a new empty batch. pub fn new() -> Self { Self { blocks: Vec::new(), precondition: None, timing: TimingTrace::new(), } } /// Create with pre-allocated capacity. pub fn with_capacity(capacity: usize) -> Self { Self { blocks: Vec::with_capacity(capacity), precondition: None, timing: TimingTrace::new(), } } /// Set the precondition event for this batch. #[allow(dead_code)] pub fn with_precondition(mut self, precondition: EventHandle) -> Self { self.precondition = Some(precondition); self } /// Add a block to this batch. pub fn push(&mut self, block: QueuedBlock) { self.blocks.push(block); } /// Get the number of blocks in this batch. pub fn len(&self) -> usize { self.blocks.len() } /// Check if batch is empty. pub fn is_empty(&self) -> bool { self.blocks.is_empty() } /// Get block IDs in this batch (only for blocks with known IDs). /// /// Weak blocks may have `None` for block_id until upgraded. /// The TransferExecutor resolves actual block_ids at transfer time. #[allow(dead_code)] pub fn block_ids(&self) -> Vec { self.blocks.iter().filter_map(|b| b.block_id).collect() } /// Get sequence hashes in this batch. #[allow(dead_code)] pub fn sequence_hashes(&self) -> Vec { self.blocks.iter().map(|b| b.sequence_hash).collect() } /// Get unique transfer IDs in this batch. #[allow(dead_code)] pub fn transfer_ids(&self) -> Vec { let mut ids: Vec = self.blocks.iter().map(|b| b.transfer_id).collect(); ids.sort_by_key(|id| id.as_uuid()); ids.dedup(); ids } /// Take all blocks out of this batch. #[allow(dead_code)] pub fn take(&mut self) -> Vec> { std::mem::take(&mut self.blocks) } /// Drain blocks for the given transfer ID (for cancellation). #[allow(dead_code)] pub fn drain_transfer(&mut self, transfer_id: TransferId) -> Vec> { let mut kept = Vec::new(); let mut drained = Vec::new(); for block in std::mem::take(&mut self.blocks) { if block.transfer_id == transfer_id { drained.push(block); } else { kept.push(block); } } self.blocks = kept; drained } } impl Default for TransferBatch { fn default() -> Self { Self::new() } } use super::handle::TransferState; /// Result of policy evaluation - blocks ready for batching. #[allow(dead_code)] pub struct EvalResult { /// Transfer ID pub transfer_id: TransferId, /// Blocks that passed all policies pub passed_blocks: Vec>, /// Block IDs that were filtered out pub filtered_ids: Vec, /// Transfer state for completion tracking pub(crate) state: Arc>, } /// Output from the batch collector to transfer executor. pub type BatchOutput = mpsc::Sender>; /// Receiver side of batch output channel. pub type BatchOutputRx = mpsc::Receiver>; /// Extract the common precondition from a batch of blocks. /// /// If all blocks share the same precondition, returns it. /// Otherwise returns `None`. fn extract_common_precondition(blocks: &[QueuedBlock]) -> Option { blocks.first().and_then(|first_block| { let first_precondition = first_block.state.lock().unwrap().precondition; let all_same = blocks .iter() .all(|block| block.state.lock().unwrap().precondition == first_precondition); if all_same { first_precondition } else { None } }) } /// Batch collector that accumulates blocks and flushes batches. /// /// The collector accumulates blocks from policy evaluation (via `CancellableQueue`) /// and groups them into batches based on the configuration. Batches are flushed when: /// - `max_batch_size` is reached /// - `flush_interval` expires and `min_batch_size` is met /// - Shutdown is requested pub struct BatchCollector { config: BatchConfig, /// Input queue from policy evaluator input_queue: Arc>>, /// Output channel to transfer executor output_tx: BatchOutput, /// Cancel watch receiver cancel_rx: watch::Receiver>, /// Current batch being built current_batch: TransferBatch, } impl BatchCollector { /// Create a new batch collector. pub fn new( config: BatchConfig, input_queue: Arc>>, output_tx: BatchOutput, cancel_rx: watch::Receiver>, ) -> Self { let max_batch_size = config.max_batch_size; Self { config, input_queue, output_tx, cancel_rx, current_batch: TransferBatch::with_capacity(max_batch_size), } } /// Run the batch collector loop. pub async fn run(mut self) { let mut flush_timer = tokio::time::interval(self.config.flush_interval); flush_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut poll_interval = tokio::time::interval(Duration::from_micros(100)); poll_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { tokio::select! { // Poll queue for items _ = poll_interval.tick() => { while let Some(item) = self.input_queue.pop_valid() { self.handle_eval_result(item.data).await; } } // Periodic flush timer _ = flush_timer.tick() => { self.try_flush().await; } // Check for shutdown result = self.cancel_rx.changed() => { if result.is_err() { // Channel closed, flush and exit self.flush_if_not_empty().await; break; } } } } } /// Handle an evaluation result. /// /// Adds passed blocks to the current batch and flushes when: /// - max_batch_size is reached, OR /// - all blocks for a transfer have been processed (per-transfer sentinel flush) async fn handle_eval_result(&mut self, result: EvalResult) { // Count blocks processed in this eval result (both passed and filtered) let blocks_in_eval = result.passed_blocks.len() + result.filtered_ids.len(); // Add passed blocks to current batch for block in result.passed_blocks { self.current_batch.push(block); // Flush if we've reached max batch size if self.current_batch.len() >= self.config.max_batch_size { self.flush().await; } } // Update transfer state and check if transfer is complete (sentinel flush) let should_flush = { let mut state = result.state.lock().unwrap(); state.blocks_processed += blocks_in_eval; // Flush when all blocks for this transfer have been processed state.blocks_processed >= state.total_expected_blocks && state.total_expected_blocks > 0 }; // Flush immediately when a transfer completes to avoid waiting for min_batch_size if should_flush && !self.current_batch.is_empty() { tracing::debug!( transfer_id = %result.transfer_id, batch_size = self.current_batch.len(), "Per-transfer sentinel flush" ); self.flush().await; } } /// Try to flush if minimum batch size is reached. async fn try_flush(&mut self) { if self.current_batch.len() >= self.config.min_batch_size { self.flush().await; } } /// Flush current batch if not empty. async fn flush_if_not_empty(&mut self) { if !self.current_batch.is_empty() { self.flush().await; } } /// Flush the current batch to the output channel. async fn flush(&mut self) { nvtx_range!("offload::batch"); if self.current_batch.is_empty() { return; } let mut batch = std::mem::replace( &mut self.current_batch, TransferBatch::with_capacity(self.config.max_batch_size), ); // Mark batch as ready (single O(1) call, not per-block) batch.timing.mark_batched(); batch.precondition = extract_common_precondition(&batch.blocks); // Send to transfer executor if self.output_tx.send(batch).await.is_err() { // Output channel closed, log and continue tracing::warn!("Batch output channel closed"); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_batch_config_default() { let config = BatchConfig::default(); assert_eq!(config.max_batch_size, 1024); assert_eq!(config.min_batch_size, 8); } #[test] fn test_batch_config_builder() { let config = BatchConfig::default() .with_max_size(128) .with_min_size(16) .with_flush_interval(Duration::from_millis(50)); assert_eq!(config.max_batch_size, 128); assert_eq!(config.min_batch_size, 16); assert_eq!(config.flush_interval, Duration::from_millis(50)); } #[test] fn test_transfer_batch() { let batch: TransferBatch<()> = TransferBatch::new(); assert!(batch.is_empty()); assert_eq!(batch.len(), 0); } #[tokio::test] async fn test_batch_collector_empty_input() { let input_queue = Arc::new(CancellableQueue::>::new()); let (output_tx, mut output_rx) = mpsc::channel::>(10); let (cancel_tx, cancel_rx) = watch::channel(HashSet::new()); let collector = BatchCollector::new(BatchConfig::default(), input_queue, output_tx, cancel_rx); // Drop cancel sender to close channel (triggers shutdown) drop(cancel_tx); // Run collector tokio::spawn(async move { collector.run().await; }); // Should receive nothing (empty input) let result = tokio::time::timeout(Duration::from_millis(50), output_rx.recv()).await; assert!(result.is_err() || result.unwrap().is_none()); } #[test] fn test_transfer_batch_with_capacity() { let batch: TransferBatch<()> = TransferBatch::with_capacity(128); assert!(batch.is_empty()); assert_eq!(batch.len(), 0); } #[test] fn test_batch_config_with_methods() { let config = BatchConfig::default() .with_max_size(256) .with_min_size(32) .with_flush_interval(Duration::from_millis(100)); assert_eq!(config.max_batch_size, 256); assert_eq!(config.min_batch_size, 32); assert_eq!(config.flush_interval, Duration::from_millis(100)); } #[test] fn test_transfer_batch_methods() { let mut batch: TransferBatch<()> = TransferBatch::new(); // Note: We can't easily create QueuedBlock without the full pipeline setup, // so this test just verifies the batch structure methods work on empty batches assert!(batch.block_ids().is_empty()); assert!(batch.sequence_hashes().is_empty()); assert!(batch.transfer_ids().is_empty()); // Verify take() works let taken = batch.take(); assert!(taken.is_empty()); assert!(batch.is_empty()); } #[test] fn test_batch_precondition() { let batch: TransferBatch<()> = TransferBatch::new(); assert!(batch.precondition.is_none()); // Note: with_precondition requires an EventHandle which is complex to create // in a unit test, so we just verify the field exists and is None by default } }