Unverified Commit 07cfc3a1 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm + connector (#2258)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
Co-authored-by: default avatarOlga Andreeva <oandreeva@nvidia.com>
Co-authored-by: default avatarZiqi Fan <ziqif@nvidia.com>
Co-authored-by: default avatarJohn Thompson <jothomson@nvidia.com>
Co-authored-by: default avatarRichard Huo <rihuo@nvidia.com>
Co-authored-by: default avatarZicheng Ma <zichengm@nvidia.com>
parent bf5862a1
......@@ -38,21 +38,24 @@
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use nixl_sys::NixlDescriptor;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy},
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, TransferContext,
WritableBlock,
locality::LocalityProvider,
transfer::{TransferContext, WriteTo, WriteToStrategy},
BlockDataProvider, BlockDataProviderMut, BlockError, BlockMetadata, BlockState, ImmutableBlock,
MutableBlock, ReadableBlock, WritableBlock,
};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::metrics::PoolMetrics;
use crate::block_manager::pool::{BlockPool, BlockPoolError};
use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool;
use anyhow::Result;
use async_trait::async_trait;
......@@ -62,26 +65,33 @@ use super::BlockResult;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
const BLOCKS_BW_MIN_PUBLISH_INTERVAL_MS: u64 = 50;
/// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pub struct PendingTransfer<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
> {
/// The block being copied from.
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
/// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>,
targets: Vec<MutableBlock<Target, Locality, Metadata>>,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
/// The target pool that will receive the registered block.
target_pool: Arc<BlockPool<Target, Metadata>>,
target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
PendingTransfer<Source, Target, Metadata>
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
PendingTransfer<Source, Target, Locality, Metadata>
{
pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
target_pool: Arc<BlockPool<Target, Metadata>>,
sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
targets: Vec<MutableBlock<Target, Locality, Metadata>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
) -> Self {
assert_eq!(sources.len(), targets.len());
Self {
......@@ -92,7 +102,7 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
}
}
fn handle_complete(self) -> Result<()> {
async fn handle_complete(self) -> Result<()> {
let Self {
sources,
mut targets,
......@@ -105,7 +115,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
transfer_metadata(source, target)?;
}
let blocks = target_pool.register_blocks_blocking(targets)?;
let blocks = target_pool.register_blocks(targets).await?;
tracing::debug!("Transfer complete. Registered {} blocks.", blocks.len());
if let Some(completion_indicator) = completion_indicator {
completion_indicator
......@@ -117,9 +129,14 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
}
}
fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
source: &Arc<MutableBlock<Source, Metadata>>,
target: &mut MutableBlock<Target, Metadata>,
fn transfer_metadata<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
>(
source: &ImmutableBlock<Source, Locality, Metadata>,
target: &mut MutableBlock<Target, Locality, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle, _) = source.state() {
......@@ -139,136 +156,118 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
}
#[async_trait]
pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata>:
Send + Sync
pub trait TransferManager<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
>: Send + Sync
{
/// Begin a transfer. Blocks if the pending queue is full.
async fn enqueue_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()>;
}
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<(
PendingTransfer<Source, Target, Metadata>,
tokio::sync::oneshot::Receiver<()>,
)>,
transfer_ctx: Arc<TransferContext>,
struct TransferCompletionManager<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
> {
pool_metrics: Arc<PoolMetrics>,
transfer_type: String,
last_publish_time: Option<Instant>,
transfer_start: Instant,
num_blocks_transferred: usize,
_phantom: PhantomData<(Source, Target, Locality, Metadata)>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata>
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
TransferCompletionManager<Source, Target, Locality, Metadata>
{
pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Result<Self> {
let (tx, mut rx) = mpsc::channel::<(
PendingTransfer<Source, Target, Metadata>,
tokio::sync::oneshot::Receiver<()>,
)>(max_concurrent_transfers);
pub fn new(pool_metrics: Arc<PoolMetrics>, transfer_type: String) -> Self {
Self {
pool_metrics,
transfer_type,
last_publish_time: None,
transfer_start: Instant::now(),
num_blocks_transferred: 0,
_phantom: PhantomData,
}
}
CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move {
loop {
tokio::select! {
Some((pending_transfer, notify)) = rx.recv() => {
// Wait for the event.
notify.await.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Only finalize the transfer after the event is signaled.
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}
pub async fn handle_complete(
&mut self,
pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()> {
self.num_blocks_transferred += pending_transfer.sources.len();
_ = cancel_token.cancelled() => {
return Ok(());
}
}
}
},
cancellation_token.clone(),
"Cuda Transfer Manager",
runtime,
)?
.detach();
let should_publish = self.last_publish_time.is_none_or(|last_publish_time| {
last_publish_time.elapsed() > Duration::from_millis(BLOCKS_BW_MIN_PUBLISH_INTERVAL_MS)
});
Ok(Self {
pending_transfer_q: tx,
transfer_ctx,
})
}
}
if should_publish {
self.last_publish_time = Some(Instant::now());
let duration = self.transfer_start.elapsed();
let blocks_per_sec = self.num_blocks_transferred as f64 / duration.as_secs_f64();
#[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata>
for CudaTransferManager<Source, Target, Metadata>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn enqueue_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
let notify = pending_transfer
.sources
.write_to(
&mut pending_transfer.targets,
true,
self.transfer_ctx.clone(),
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q
.send((pending_transfer, notify))
.await?;
self.pool_metrics
.gauge(self.transfer_type.as_str())
.set(blocks_per_sec as i64);
}
match pending_transfer.handle_complete().await {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
Ok(())
}
}
pub struct DiskTransferManager {
futures_tx: mpsc::Sender<Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync>>>,
type TransferFuture<Source, Target, Locality, Metadata> = Pin<
Box<
dyn std::future::Future<Output = PendingTransfer<Source, Target, Locality, Metadata>>
+ Send
+ Sync,
>,
>;
pub struct LocalTransferManager<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
> {
futures_tx: mpsc::Sender<TransferFuture<Source, Target, Locality, Metadata>>,
transfer_ctx: Arc<TransferContext>,
}
impl DiskTransferManager {
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
LocalTransferManager<Source, Target, Locality, Metadata>
{
pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
pool_metrics: Arc<PoolMetrics>,
transfer_type: String,
) -> Result<Self> {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
let mut completion_manager =
TransferCompletionManager::new(pool_metrics.clone(), transfer_type.clone());
CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move {
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new();
let mut pending_transfers: FuturesUnordered<TransferFuture<Source, Target, Locality, Metadata>> = FuturesUnordered::new();
loop {
tokio::select! {
......@@ -279,19 +278,23 @@ impl DiskTransferManager {
Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_concurrent_transfers {
pending_transfers.next().await;
if let Some(pending_transfer) = pending_transfers.next().await {
completion_manager.handle_complete(pending_transfer).await?;
} else {
break;
}
}
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future);
}
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
Some(pending_transfer) = pending_transfers.next(), if !pending_transfers.is_empty() => {
completion_manager.handle_complete(pending_transfer).await?;
}
}
}
},
cancellation_token.clone(),
"Disk Transfer Manager",
"Local Transfer Manager",
runtime,
)?
.detach();
......@@ -304,45 +307,34 @@ impl DiskTransferManager {
}
#[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata> for DiskTransferManager
impl<Source, Target, Locality, Metadata> TransferManager<Source, Target, Locality, Metadata>
for LocalTransferManager<Source, Target, Locality, Metadata>
where
Source: Storage,
Target: Storage,
Source: Storage + NixlDescriptor,
Target: Storage + NixlDescriptor,
Locality: LocalityProvider,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
ImmutableBlock<Source, Locality, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
+ WriteToStrategy<MutableBlock<Target, Locality, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
MutableBlock<Target, Locality, Metadata>: WritableBlock<StorageType = Target>,
// Check that the source and target blocks have the same locality.
ImmutableBlock<Source, Locality, Metadata>: BlockDataProvider<Locality = Locality>,
MutableBlock<Target, Locality, Metadata>: BlockDataProviderMut<Locality = Locality>,
{
async fn enqueue_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
mut pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()> {
let notify = pending_transfer
.sources
.write_to(
&mut pending_transfer.targets,
true,
self.transfer_ctx.clone(),
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
.write_to(&mut pending_transfer.targets, self.transfer_ctx.clone())?;
let completion_future = async move {
let _ = notify.await;
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
pending_transfer
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
......@@ -354,26 +346,29 @@ where
}
/// A transfer manager that enforces a max batch size for transfers.
pub struct TransferBatcher<Source, Target, Metadata, Manager>
pub struct TransferBatcher<Source, Target, Locality, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
Manager: TransferManager<Source, Target, Locality, Metadata>,
{
transfer_manager: Manager,
max_transfer_batch_size: usize,
runtime: Handle,
cancellation_token: CancellationToken,
_phantom: PhantomData<(Source, Target, Metadata)>,
_phantom: PhantomData<(Source, Target, Locality, Metadata)>,
}
impl<Source, Target, Metadata, Manager> TransferBatcher<Source, Target, Metadata, Manager>
impl<Source, Target, Locality, Metadata, Manager>
TransferBatcher<Source, Target, Locality, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
Locality: LocalityProvider + 'static,
Metadata: BlockMetadata + 'static,
Manager: TransferManager<Source, Target, Locality, Metadata> + 'static,
{
pub fn new(
transfer_manager: Manager,
......@@ -392,17 +387,19 @@ where
}
#[async_trait]
impl<Source, Target, Metadata, Manager> TransferManager<Source, Target, Metadata>
for TransferBatcher<Source, Target, Metadata, Manager>
impl<Source, Target, Locality, Metadata, Manager>
TransferManager<Source, Target, Locality, Metadata>
for TransferBatcher<Source, Target, Locality, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Source: Storage + 'static,
Target: Storage + 'static,
Locality: LocalityProvider + 'static,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
Manager: TransferManager<Source, Target, Locality, Metadata>,
{
async fn enqueue_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()> {
// If it's smaller than the max batch size, just enqueue it.
if pending_transfer.sources.len() < self.max_transfer_batch_size {
......@@ -462,7 +459,7 @@ where
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
let _ = completion_indicator.send(Err(e));
return Ok(());
}
};
......@@ -472,7 +469,7 @@ where
}
// Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results))?;
let _ = completion_indicator.send(Ok(results));
Ok(())
},
......
......@@ -15,8 +15,11 @@
use std::cmp::Ordering;
use std::sync::Weak;
use tokio::sync::oneshot;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::block::{
locality::LocalityProvider, BlockMetadata, ImmutableBlock, MutableBlock,
};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
......@@ -46,53 +49,65 @@ impl Ord for OffloadRequestKey {
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
pub struct OffloadRequest<S: Storage, M: BlockMetadata> {
pub struct OffloadRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
pub key: OffloadRequestKey,
pub block: Weak<MutableBlock<S, M>>,
pub block: Weak<MutableBlock<S, L, M>>,
pub sequence_hash: u64,
}
impl<S: Storage, M: BlockMetadata> PartialOrd for OffloadRequest<S, M> {
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PartialOrd for OffloadRequest<S, L, M> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Order offload requests by priority, high to low.
impl<S: Storage, M: BlockMetadata> Ord for OffloadRequest<S, M> {
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Ord for OffloadRequest<S, L, M> {
fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key)
}
}
/// Equality is based on sequence hash, priority, and location.
impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> {
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PartialEq for OffloadRequest<S, L, M> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Eq for OffloadRequest<S, L, M> {}
pub type BlockResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
pub type BlockResult<Target, Locality, Metadata> =
Result<Vec<ImmutableBlock<Target, Locality, Metadata>>, BlockPoolError>;
pub type ResponseSender<Target, Locality, Metadata> =
oneshot::Sender<Result<Vec<ImmutableBlock<Target, Locality, Metadata>>, BlockPoolError>>;
/// Data needed for onboarding.
/// Unlike offloading, we need a means to return the resulting blocks to the caller.
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> {
pub blocks: Vec<ImmutableBlock<Source, M>>,
pub response_tx:
oneshot::Sender<std::result::Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
pub struct OnboardRequest<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
M: BlockMetadata,
> {
pub blocks: Vec<ImmutableBlock<Source, Locality, M>>,
pub response_tx: ResponseSender<Target, Locality, M>,
pub targets: Option<Vec<MutableBlock<Target, Locality, M>>>,
}
impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, Target, M> {
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, M: BlockMetadata>
OnboardRequest<Source, Target, Locality, M>
{
pub fn new(
blocks: Vec<ImmutableBlock<Source, M>>,
response_tx: oneshot::Sender<Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
blocks: Vec<ImmutableBlock<Source, Locality, M>>,
response_tx: ResponseSender<Target, Locality, M>,
targets: Option<Vec<MutableBlock<Target, Locality, M>>>,
) -> Self {
Self {
blocks,
response_tx,
targets,
}
}
}
......
......@@ -13,81 +13,89 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Cache Block Pool Management
//!
//! This module provides the primary [`BlockPool`] structure for managing KV cache blocks.
//! It orchestrates the allocation, registration, and reuse of blocks by coordinating
//! between an [`ActiveBlockPool`] and an [`InactiveBlockPool`].
//!
//! ## Core Components:
//!
//! - **[`BlockPool`]**: The main entry point for interacting with the block management system.
//! It holds the shared state containing both active and inactive pools.
//! - **[`ActiveBlockPool`]**: Manages blocks that are currently associated with active sequences.
//! It primarily uses weak references to track these blocks, allowing them to be potentially
//! reclaimed by the inactive pool if no strong references remain.
//! - **[`InactiveBlockPool`]**: Manages blocks that are not currently in active use. It supports
//! block reuse by matching sequence hashes and employs a priority-based eviction strategy
//! for acquiring free blocks.
//! - **[`BlockRegistry`]**: Manages the registration of blocks that have transitioned from the
//! Complete to Registered state.
//! - **[`MutableBlock`]**: Represents a uniquely owned block, typically obtained from allocation.
//! It allows modification and is returned to the inactive pool upon being dropped.
//! - **[`ImmutableBlock`]**: Represents a shared, immutable reference to a block, usually after
//! it has been registered or matched. Ensures that multiple sequences can reference the
//! same underlying block data.
//!
//! ## Workflow:
//!
//! 1. Blocks are initially added to the [`BlockPool`] via [`BlockPool::add_blocks`], populating the
//! [`InactiveBlockPool`].
//! 2. Sequences request blocks via [`BlockPool::allocate_blocks`], which attempts to acquire them
//! from the [`InactiveBlockPool`]. This returns [`MutableBlock`]s.
//! 3. Once a [`MutableBlock`] is filled and ready, it's registered using [`BlockPool::register_block`].
//! This process checks the both the [`ActiveBlockPool`] and the [`InactiveBlockPool`] for existing blocks
//! with the same content hash. It returns an [`ImmutableBlock`] representing the canonical block
//! (either the one provided or an existing one).
//! 4. Sequences can also try to reuse blocks directly using [`BlockPool::match_sequence_hash`], which
//! checks both the active and inactive pools.
//! 5. When an [`ImmutableBlock`] is no longer needed by any sequence (its `Arc` count drops to zero),
//! the underlying [`MutableBlock`] (if it still exists via the weak reference in the active pool)
//! can eventually be returned to the [`InactiveBlockPool`] when its final strong reference (the `Arc`
//! within `ImmutableBlock`) is dropped.
//! 6. Dropped [`MutableBlock`]s are automatically returned to the [`InactiveBlockPool`].
mod active;
mod inactive;
mod priority_key;
mod state;
use active::ActiveBlockPool;
pub mod managed;
pub use managed::ManagedBlockPool;
use derive_builder::Builder;
use derive_getters::Dissolve;
use inactive::InactiveBlockPool;
use priority_key::PriorityKey;
use serde::{Deserialize, Serialize};
pub use super::block::{ImmutableBlock, MutableBlock};
use super::block::{
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata,
GlobalRegistry,
nixl::short_type_name, private, registry::BlockRegistry, Block, BlockError, BlockMetadata,
GlobalRegistry, MaybeReturnableBlock,
};
use super::events::{EventManager, NullEventManager};
use super::metrics::{BlockManagerMetrics, PoolMetrics};
use super::storage::Storage;
use crate::block_manager::block::locality::LocalityProvider;
use crate::block_manager::CacheLevel;
use crate::tokens::{SequenceHash, TokenBlock};
use async_trait::async_trait;
use prometheus::Registry;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{
collections::{BTreeSet, HashMap, VecDeque},
sync::{Arc, Weak},
};
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result;
// Type aliases to reduce complexity across the module
type BlockPoolResult<T> = Result<T, BlockPoolError>;
type AsyncResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>;
// Collection type aliases
pub type MutableBlocks<S, L, M> = Vec<MutableBlock<S, L, M>>;
pub type ImmutableBlocks<S, L, M> = Vec<ImmutableBlock<S, L, M>>;
/// Enum representing either a mutable or immutable block that can be returned to the pool
#[derive(Debug)]
pub enum OwnedBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
Mutable(MutableBlock<S, L, M>),
Immutable(ImmutableBlock<S, L, M>),
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
for OwnedBlock<S, L, M>
{
fn is_returnable(&self) -> bool {
match self {
OwnedBlock::Mutable(block) => block.is_returnable(),
OwnedBlock::Immutable(block) => block.is_returnable(),
}
}
fn try_take_block(self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
match self {
OwnedBlock::Mutable(block) => block.try_take_block(token),
OwnedBlock::Immutable(block) => block.try_take_block(token),
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> From<MutableBlock<S, L, M>>
for OwnedBlock<S, L, M>
{
fn from(block: MutableBlock<S, L, M>) -> Self {
OwnedBlock::Mutable(block)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> From<ImmutableBlock<S, L, M>>
for OwnedBlock<S, L, M>
{
fn from(block: ImmutableBlock<S, L, M>) -> Self {
OwnedBlock::Immutable(block)
}
}
#[derive(Debug, thiserror::Error)]
pub enum BlockPoolError {
#[error("Block is not complete")]
......@@ -107,74 +115,47 @@ pub enum BlockPoolError {
#[error(transparent)]
BlockError(#[from] BlockError),
}
#[derive(Builder, Dissolve)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default = "NullEventManager::new()")]
event_manager: Arc<dyn EventManager>,
#[builder(default = "CancellationToken::new()")]
cancel_token: CancellationToken,
#[error("Reset error: {0}")]
ResetError(String),
#[builder(default)]
blocks: Vec<Block<S, M>>,
#[error("Block is not returnable")]
NotReturnable,
#[builder(default)]
global_registry: GlobalRegistry,
#[error("Unsupported cache level: {0:?}")]
UnsupportedCacheLevel(CacheLevel),
#[builder(default = "Handle::current()")]
async_runtime: Handle,
#[builder(
default = "BlockManagerMetrics::new(&Arc::new(Registry::new())).unwrap().pool(\"pool\")"
)]
pool_metrics: Arc<PoolMetrics>,
#[error("No blocks to register")]
NoBlocksToRegister,
}
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks, global_registry, async_runtime, metrics) =
args.dissolve();
tracing::info!("building block pool");
let pool = BlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
Ok(pool)
}
}
/// Manages the blocks in a specific storage backenda
pub struct BlockPool<S: Storage, M: BlockMetadata> {
priority_tx: tokio::sync::mpsc::UnboundedSender<PriorityRequest<S, M>>,
ctrl_tx: tokio::sync::mpsc::UnboundedSender<ControlRequest<S, M>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockRegistrationDuplicationSetting {
/// On registration, if duplication is allowed, blocks with duplicate hashes cannot be registered directly,
/// but instead can be held live with a strong arc to the primary block. This maintains the lifetime of
/// the duplicate block.
Allowed,
impl<S: Storage, M: BlockMetadata> Clone for BlockPool<S, M> {
fn clone(&self) -> Self {
Self {
priority_tx: self.priority_tx.clone(),
ctrl_tx: self.ctrl_tx.clone(),
}
}
/// On registration, if duplication is disabled, blocks with duplicate hashes will be returned immediately
/// to the inactive pool and the primary block, the one first registered, will be returned to the caller,
/// replacing the duplicate block.
///
/// Note: If block duplication is disabled, then the implementation must always respect the fact that the
/// mutable block that was registered, may not be the same block returned by the registration function, and
/// thus be able to update any references that wish to use the block after registration.
Disabled,
}
/// Generic request-response pattern for background task communication
#[derive(Dissolve)]
struct Unary<Req, Resp> {
request: Req,
response_tx: oneshot::Sender<Resp>,
pub struct RequestResponse<Req, Resp> {
pub request: Req,
pub response_tx: oneshot::Sender<Resp>,
}
impl<Req, Resp> Unary<Req, Resp> {
fn make_request(request: Req) -> (Self, oneshot::Receiver<Resp>) {
impl<Req, Resp> RequestResponse<Req, Resp> {
/// Create a new request-response pair
pub fn new(request: Req) -> (Self, oneshot::Receiver<Resp>) {
let (response_tx, response_rx) = oneshot::channel();
(
Self {
......@@ -186,119 +167,11 @@ impl<Req, Resp> Unary<Req, Resp> {
}
}
type UnaryResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>;
type ImmutableBlocksResult<S, M> = Result<Vec<ImmutableBlock<S, M>>, BlockPoolError>;
pub type MutableBlocks<S, M> = Vec<MutableBlock<S, M>>;
pub type ImmutableBlocks<S, M> = Vec<ImmutableBlock<S, M>>;
enum PriorityRequest<S: Storage, M: BlockMetadata> {
AllocateBlocks(Unary<usize, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>),
RegisterBlocks(Unary<MutableBlocks<S, M>, Result<ImmutableBlocks<S, M>, BlockPoolError>>),
MatchSequenceHashes(Unary<Vec<SequenceHash>, Vec<ImmutableBlock<S, M>>>),
}
enum ControlRequest<S: Storage, M: BlockMetadata> {
AddBlocks(Unary<Vec<Block<S, M>>, ()>),
}
impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
pub fn builder() -> BlockPoolArgsBuilder<S, M> {
BlockPoolArgsBuilder::default()
}
/// Creates a new [`BlockPool`] with the given [`EventManager`].
///
/// The pool starts empty and requires blocks to be added via [`add_blocks`].
///
/// # Arguments
///
/// * `event_manager` - An [`Arc<dyn EventManager>`] used for publishing block registration/removal events.
///
/// # Returns
///
/// A new [`BlockPool`] instance.
fn new(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (pool, progress_engine) = Self::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// tracing::debug!("starting progress engine");
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// });
let thread_name = format!("block-pool-{}", short_type_name::<S>());
std::thread::Builder::new()
.name(thread_name)
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for block pool progress engine");
runtime.block_on(async move {
let mut progress_engine = progress_engine;
tracing::debug!("starting progress engine");
while progress_engine.step().await {
tracing::trace!("progress engine step");
}
});
})
.expect("Failed to spawn block pool progress engine thread");
pool
}
fn with_progress_engine(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> (Self, ProgressEngine<S, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine = ProgressEngine::<S, M>::new(
event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
(
Self {
priority_tx,
ctrl_tx,
},
progress_engine,
)
}
/// Adds a vector of [`Block`]s to the [`InactiveBlockPool`].
#[async_trait]
pub trait BlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata>:
BlockPoolController + AsyncBlockPoolController + Send + Sync
{
/// Add a vector of [`Block`]s to the pool.
///
/// These blocks are typically created from a [`super::block::Blocks`]
/// and represent the initial set of available cache blocks.
......@@ -307,38 +180,12 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
/// # Arguments
///
/// * `blocks` - A [`Vec<Block<S, M>>`] to add to the inactive pool.
#[expect(dead_code)]
pub(crate) async fn add_blocks(&self, blocks: Vec<Block<S, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
async fn add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> BlockPoolResult<()>;
/// Blocking version of [`BlockPool::add_blocks`].
pub(crate) fn add_blocks_blocking(
&self,
blocks: Vec<Block<S, M>>,
) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn _add_blocks(&self, blocks: Vec<Block<S, M>>) -> UnaryResponse<()> {
let (req, resp_rx) = Unary::<_, ()>::make_request(blocks);
self.ctrl_tx
.send(ControlRequest::AddBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn add_blocks_blocking(&self, blocks: Vec<Block<S, L, M>>) -> BlockPoolResult<()>;
/// Attempts to allocate a specified number of free blocks from the [`InactiveBlockPool`].
///
/// Blocks acquired this way are returned as [`MutableBlock`]s, granting unique ownership
/// and allowing modification. Dropping a [`MutableBlock`] automatically returns it
/// to the [`InactiveBlockPool`].
/// Allocate a specified number of free blocks from the pool.
///
/// # Arguments
///
......@@ -349,633 +196,122 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
/// A [`Result`] containing:
/// - `Ok(Vec<MutableBlock<S, M>>)`: If successful, a vector of allocated mutable blocks.
/// - `Err(BlockPoolError)`: If not enough blocks are available in the inactive pool.
pub async fn allocate_blocks(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn allocate_blocks(&self, count: usize) -> BlockPoolResult<MutableBlocks<S, L, M>>;
/// Blocking version of [`BlockPool::allocate_blocks`].
pub fn allocate_blocks_blocking(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn allocate_blocks_blocking(&self, count: usize) -> BlockPoolResult<MutableBlocks<S, L, M>>;
fn _allocate_blocks(
/// Register a vector of [`MutableBlock`]s with the pool.
async fn register_blocks(
&self,
count: usize,
) -> UnaryResponse<Result<Vec<MutableBlock<S, M>>, BlockPoolError>> {
// Create the request
let (req, resp_rx) =
Unary::<_, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>::make_request(count);
// Issue the request
self.priority_tx
.send(PriorityRequest::AllocateBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Registers a vector of [`MutableBlock`]s (presumably after filling them) with the pool,
/// making them available for sharing via the [`ActiveBlockPool`].
///
/// This function checks if any of the blocks have the same sequence hash as an existing block
/// in the active pool. If so, it returns an [`ImmutableBlock`] pointing to the existing block,
/// and the provided `block` is implicitly dropped (returned to the [`InactiveBlockPool`]).
pub async fn register_blocks(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> ImmutableBlocksResult<S, M> {
self._register_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
/// Blocking version of [`BlockPool::register_blocks`].
pub fn register_blocks_blocking(
fn register_blocks_blocking(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> ImmutableBlocksResult<S, M> {
self._register_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
fn _register_blocks(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> UnaryResponse<ImmutableBlocksResult<S, M>> {
// Make the request
let (req, resp_rx) = Unary::<_, ImmutableBlocksResult<S, M>>::make_request(blocks);
// Issue the request
self.priority_tx
.send(PriorityRequest::RegisterBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Attempts to match the given [`SequenceHash`] to an existing block, checking
/// both the active and inactive pools.
///
/// Checks the [`ActiveBlockPool`] first. If a valid strong reference exists, it returns
/// an [`ImmutableBlock`] cloned from it. If the weak reference exists but is stale,
/// it's removed.
///
/// If not found in the active pool, it checks the [`InactiveBlockPool`]. If found there,
/// the block is moved to the active pool (tracked by a weak reference) and returned
/// as a new [`ImmutableBlock`].
/// Match a set of [`SequenceHash`]s to existing blocks in the pool.
///
/// # Arguments
///
/// * `sequence_hash` - The [`SequenceHash`] to look for.
/// * `sequence_hashes` - A [`Vec<SequenceHash>`] to match.
///
/// # Returns
///
/// An [`Option<ImmutableBlock<S, M>>`] containing the shared block if found, otherwise `None`.
pub async fn match_sequence_hashes(
async fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> {
self._match_sequence_hashes(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
/// Blocking version of [`BlockPool::match_sequence_hashes`].
pub fn match_sequence_hashes_blocking(
fn match_sequence_hashes_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> {
self._match_sequence_hashes(sequence_hashes)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
fn _match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> UnaryResponse<Vec<ImmutableBlock<S, M>>> {
// Create the request
let (req, resp_rx) =
Unary::<_, Vec<ImmutableBlock<S, M>>>::make_request(sequence_hashes.into());
// Issue the request
self.priority_tx
.send(PriorityRequest::MatchSequenceHashes(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
}
/// Touch a set of blocks. Equivalent to registering and then immediately dropping.
async fn touch_blocks(&self, sequence_hashes: &[SequenceHash]) -> BlockPoolResult<()>;
struct State<S: Storage, M: BlockMetadata> {
active: ActiveBlockPool<S, M>,
inactive: InactiveBlockPool<S, M>,
registry: BlockRegistry,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
event_manager: Arc<dyn EventManager>,
metrics: Arc<PoolMetrics>,
}
/// Blocking version of [`BlockPool::touch_blocks`].
fn touch_blocks_blocking(&self, sequence_hashes: &[SequenceHash]) -> BlockPoolResult<()>;
struct ProgressEngine<S: Storage, M: BlockMetadata> {
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken,
state: State<S, M>,
return_rx: tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
metrics: Arc<PoolMetrics>,
}
/// Attempt to return a block to the pool. Blocks will naturally be returned to the pool when they are dropped
/// and their reference count drops to 0; however, for testing and to synchronize the block returning to the
/// pool, this function can be used.
async fn try_return_block(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()>;
#[cfg(test)]
mod tests {
use super::super::block::{BasicMetadata, Blocks};
use super::super::layout::{tests::setup_layout, FullyContiguous, LayoutConfig};
use super::*;
use crate::block_manager::block::BlockExt;
use crate::block_manager::DType;
use crate::tokens::{TokenBlockSequence, Tokens};
use crate::block_manager::storage::tests::{NullDeviceAllocator, NullDeviceStorage};
/// Helper method to build a [`BlockPool`] with a [`ProgressEngine`] for unit testing
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
fn build_with_progress_engine(
self,
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks, global_registry, async_runtime, metrics) =
args.dissolve();
let (pool, progress_engine) = BlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
Ok((pool, progress_engine))
}
}
/// Blocking version of [`BlockPool::try_return_block`].
fn try_return_block_blocking(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()>;
#[tokio::test]
async fn test_block_pool_state() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
fn total_blocks(&self) -> u64;
let (_pool, mut progress) = BlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
drop(blocks);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
let mut blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
let mut block = blocks.pop().unwrap();
block.init_sequence(1337).unwrap();
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
assert!(block.add_token(5).is_err());
}
#[tokio::test]
async fn test_block_pool() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (pool, mut progress) = BlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let pool_clone = pool.clone();
let allocate_1_block =
tokio::spawn(async move { pool_clone.allocate_blocks(1).await.unwrap() });
progress.step().await;
let blocks = allocate_1_block.await.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
// drop the single block
drop(blocks);
// check before and after the progress engine step
assert_eq!(progress.state.inactive.available_blocks(), 6);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
}
#[test]
fn test_block_pool_blocking() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the BlockPool and add the blocks
let pool = BlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash().unwrap(), sequence_hash);
// Dropping the immutable block should return the block to the pool
// However, the block should remain in the BlockPool as an inactive block until it is reused
// or promoted back to an immutable block by being matched with a sequence hash
drop(block);
// Get the list of ImmutableBlocks that match the sequence hash
let matched = pool
.match_sequence_hashes_blocking(&[sequence_hash])
.unwrap();
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].sequence_hash().unwrap(), sequence_hash);
}
async fn create_blocks<S: Storage, M: BlockMetadata>(
pool: &BlockPool<S, M>,
num_blocks: usize,
) -> anyhow::Result<(Vec<ImmutableBlock<S, M>>, Vec<SequenceHash>)> {
let tokens = vec![0; num_blocks * 4];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), num_blocks);
let mut sequence_hashes = Vec::new();
let mut mutable_blocks = Vec::new();
for token_block in token_blocks.blocks().iter() {
let mut block = pool.allocate_blocks(1).await?.pop().unwrap();
block.apply_token_block(token_block.clone())?;
sequence_hashes.push(block.sequence_hash().unwrap());
mutable_blocks.push(block);
}
let immutable_blocks = pool.register_blocks(mutable_blocks).await?;
Ok((immutable_blocks, sequence_hashes))
}
async fn make_simple_pool(
num_blocks: usize,
) -> anyhow::Result<BlockPool<NullDeviceStorage, BasicMetadata>> {
let config = LayoutConfig {
num_blocks,
num_layers: 1,
outer_dim: 1,
page_size: 4,
inner_dim: 1024,
alignment: 1,
dtype: DType::FP16,
};
let layout = FullyContiguous::<NullDeviceStorage>::allocate(config, &NullDeviceAllocator)?;
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)?.into_blocks()?;
let pool = BlockPool::builder().blocks(blocks).build()?;
Ok(pool)
}
/// A test that ensures that we only ever evict leaves from the inactive pool.
#[tokio::test]
async fn test_block_pool_evict_leaves() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 block. This should evict the leaf of our allocated sequence.
pool.allocate_blocks(1).await?;
// The leaf should be evicted, so we should have 3 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 3);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should get the previously allocated block, as well as one more leaf.
pool.allocate_blocks(2).await.unwrap();
// The next leaf should be evicted, so we should have 2 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 2);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// If we allocate all the blocks, the entire remaining sequence should be evicted.
let blocks = pool.allocate_blocks(4).await?;
assert_eq!(blocks.len(), 4);
Ok(())
}
/// When a block has two children, we need to ensure that we evict both children before
/// adding the parent to the leaf set.
#[tokio::test]
async fn test_block_pool_parent_child() -> anyhow::Result<()> {
let pool = make_simple_pool(3).await?;
let tokens = vec![1, 2, 3, 4, 5];
let sequence = TokenBlockSequence::new(Tokens::from(tokens.clone()), 4, None);
// Create a root block, with two child blocks.
let mut root_block = pool.allocate_blocks(1).await?.pop().unwrap();
root_block.apply_token_block(sequence.blocks().first().unwrap().clone())?;
let root_block_hash = root_block.sequence_hash().unwrap();
let mut child_blocks = Vec::new();
let mut child_block_hashes = Vec::new();
for i in 0..2 {
// Create a new token sequence using the common prefix.
let mut tokens = tokens.clone();
for _ in 0..4 {
tokens.push(i);
}
let seq = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
// Allocate and apply the suffix to the child block.
let mut child_block = pool.allocate_blocks(1).await?.pop().unwrap();
child_block.apply_token_block(seq.blocks()[1].clone())?;
child_block_hashes.push(child_block.sequence_hash().unwrap());
child_blocks.push(child_block);
}
// Register the children first. This can happen with offloading.
let child_blocks = pool.register_blocks(child_blocks).await?;
// After the children are registered, we can register the root block.
let root_block = pool.register_blocks(vec![root_block]).await?;
fn available_blocks(&self) -> u64;
}
// Drop both of them.
drop(root_block);
drop(child_blocks);
/// State of the pool when queried.
///
/// Provides a snapshot of the pool's current state including:
/// - Active blocks currently in use
/// - Inactive blocks ordered by reuse priority
/// - Number of empty blocks
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct BlockPoolStatus {
/// Active blocks currently in use
pub active_blocks: usize,
/// Inactive blocks ordered by reuse priority
/// Blocks at the front of the list are more likely to be reused
pub inactive_blocks: usize,
/// Number of empty blocks
pub empty_blocks: usize,
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct ResetBlocksResponse {
/// Blocks that were reset
pub reset_blocks: Vec<SequenceHash>,
// Allocate two new blocks, which should evict both children.
pool.allocate_blocks(2).await?;
/// Blocks that were not found in the pool
pub not_found: Vec<SequenceHash>,
// Now, the root block should be the only block left.
for child_block_hash in child_block_hashes {
let matched = pool.match_sequence_hashes(&[child_block_hash]).await?;
assert_eq!(matched.len(), 0);
}
// Check that the root block remains.
let matched = pool.match_sequence_hashes(&[root_block_hash]).await?;
assert_eq!(matched.len(), 1);
/// Blocks that were not reset
pub not_reset: Vec<SequenceHash>,
}
Ok(())
}
pub trait BlockPoolController: Send + Sync {
/// Returns the [`BlockPoolStatus`] of the pool.
fn status_blocking(&self) -> Result<BlockPoolStatus, BlockPoolError>;
/// When offloading, it's possible that the tail of a sequence in a pool is evicted before
/// the entire sequence can be offloaded. This can happen in the following case:
/// Resets the pool to its initial state.
///
/// Assume a sequence of 4 blocks: [0, 1, 2, 3]
/// 1. Blocks 0, 1, and 2 are offloaded to host memory.
/// 2. Block 2 is evicted from the host.
/// 3. Block 3 is offloaded to host memory.
/// Now, the contents of the cache are [0, 1] and [3].
/// We need to treat these as two separate sequences.
#[tokio::test]
async fn test_block_pool_fragmentation() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let tokens = vec![0; 16];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), 4);
let mut sequence_hashes = Vec::new();
// Allocate and register the first 3 blocks.
for block in token_blocks.blocks()[..3].iter() {
let mut mutable_block = pool.allocate_blocks(1).await?.pop().unwrap();
mutable_block.apply_token_block(block.clone())?;
sequence_hashes.push(mutable_block.sequence_hash()?);
let _ = pool.register_blocks(vec![mutable_block]).await?;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should take the remaining uninitialized block as well as the
// tail of the currently registered sequence.
let _ = pool.allocate_blocks(2).await?;
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
2
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 more block for the leaf of the sequence.
let mut mutable_block = pool.allocate_blocks(1).await?.into_iter().next().unwrap();
mutable_block.apply_token_block(token_blocks.blocks()[3].clone())?;
/// This function will error unless all blocks have returned to the inactive pool.
fn reset_blocking(&self) -> Result<(), BlockPoolError>;
let _ = pool.register_blocks(vec![mutable_block]).await?;
// We should still only match the first 2 blocks, since the 3rd block has been evicted.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
2
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Now, we should be able to allocate all 4 blocks.
let _ = pool.allocate_blocks(4).await?;
Ok(())
}
/// Matching an entire sequence (moving it to the active pool), and returning it
/// should not affect the parent-child relationships of the blocks.
#[tokio::test]
async fn test_block_pool_match_return() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// We match the root of the sequence (moving it to the active pool), then
// immediately return it.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
1
);
let _alloc_blocks1 = pool.allocate_blocks(3).await?;
// Allocating 3 blocks should evict all but the root of the sequence.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
1
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _alloc_blocks2 = pool.allocate_blocks(1).await?;
// Now, allocating one more block should evict the root.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
0
);
Ok(())
}
/// When we move a suffix of a sequence to the active pool (like what happens when onboarding),
/// then return it to the inactive pool, we need to ensure that the parent-child relationships
/// are still correct, and that the temporary leaf in the inactive pool can't be evicted.
#[tokio::test]
async fn test_block_pool_match_partial() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// Assert that all 4 blocks are in the pool.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
4
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Now, we match only the last 2 blocks
let matched_suffix = pool.match_sequence_hashes(&sequence_hashes[2..]).await?;
assert_eq!(matched_suffix.len(), 2);
// This allocation should fail. Although there are 2 inactive blocks, the leaf is in the active pool.
let new_alloc_block = pool.allocate_blocks(1).await?;
assert_eq!(new_alloc_block.len(), 0);
// Now, drop the leaf, and return it to the inactive pool.
drop(matched_suffix);
/// Attempt to reset a set of blocks.
fn reset_blocks_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError>;
}
// All 4 blocks should still be in the pool.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
4
);
#[async_trait::async_trait]
pub trait AsyncBlockPoolController: Send + Sync {
/// Returns the [`BlockPoolStatus`] of the pool.
async fn status(&self) -> Result<BlockPoolStatus, BlockPoolError>;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
/// Resets the pool to its initial state.
///
/// This function will error unless all blocks have returned to the inactive pool.
async fn reset(&self) -> Result<(), BlockPoolError>;
Ok(())
}
/// Attempt to reset a set of blocks.
async fn reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError>;
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! # KV Cache Block Pool Management
//!
//! This module provides the primary [`BlockPool`] structure for managing KV cache blocks.
//! It orchestrates the allocation, registration, and reuse of blocks by coordinating
//! between an [`ActiveBlockPool`] and an [`InactiveBlockPool`].
//!
//! ## Core Components:
//!
//! - **[`BlockPool`]**: The main entry point for interacting with the block management system.
//! It holds the shared state containing both active and inactive pools.
//! - **[`ActiveBlockPool`]**: Manages blocks that are currently associated with active sequences.
//! It primarily uses weak references to track these blocks, allowing them to be potentially
//! reclaimed by the inactive pool if no strong references remain.
//! - **[`InactiveBlockPool`]**: Manages blocks that are not currently in active use. It supports
//! block reuse by matching sequence hashes and employs a priority-based eviction strategy
//! for acquiring free blocks.
//! - **[`BlockRegistry`]**: Manages the registration of blocks that have transitioned from the
//! Complete to Registered state.
//! - **[`MutableBlock`]**: Represents a uniquely owned block, typically obtained from allocation.
//! It allows modification and is returned to the inactive pool upon being dropped.
//! - **[`ImmutableBlock`]**: Represents a shared, immutable reference to a block, usually after
//! it has been registered or matched. Ensures that multiple sequences can reference the
//! same underlying block data.
//!
//! ## Workflow:
//!
//! 1. Blocks are initially added to the [`BlockPool`] via [`BlockPool::add_blocks`], populating the
//! [`InactiveBlockPool`].
//! 2. Sequences request blocks via [`BlockPool::allocate_blocks`], which attempts to acquire them
//! from the [`InactiveBlockPool`]. This returns [`MutableBlock`]s.
//! 3. Once a [`MutableBlock`] is filled and ready, it's registered using [`BlockPool::register_block`].
//! This process checks the both the [`ActiveBlockPool`] and the [`InactiveBlockPool`] for existing blocks
//! with the same content hash. It returns an [`ImmutableBlock`] representing the canonical block
//! (either the one provided or an existing one).
//! 4. Sequences can also try to reuse blocks directly using [`BlockPool::match_sequence_hash`], which
//! checks both the active and inactive pools.
//! 5. When an [`ImmutableBlock`] is no longer needed by any sequence (its `Arc` count drops to zero),
//! the underlying [`MutableBlock`] (if it still exists via the weak reference in the active pool)
//! can eventually be returned to the [`InactiveBlockPool`] when its final strong reference (the `Arc`
//! within `ImmutableBlock`) is dropped.
//! 6. Dropped [`MutableBlock`]s are automatically returned to the [`InactiveBlockPool`].
use super::*;
pub mod active;
pub mod controller;
pub mod inactive;
pub mod priority_key;
pub mod state;
use active::ActiveBlockPool;
use inactive::InactiveBlockPool;
#[derive(Builder, Dissolve)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct ManagedBlockPoolArgs<S: Storage, L: LocalityProvider, M: BlockMetadata> {
#[builder(default = "NullEventManager::new()")]
event_manager: Arc<dyn EventManager>,
#[builder(default = "CancellationToken::new()")]
cancel_token: CancellationToken,
#[builder(default)]
blocks: Vec<Block<S, L, M>>,
#[builder(default)]
global_registry: GlobalRegistry,
#[builder(default = "Handle::current()")]
async_runtime: Handle,
#[builder(
default = "BlockManagerMetrics::new(&Arc::new(Registry::new())).unwrap().pool(\"pool\")"
)]
pool_metrics: Arc<PoolMetrics>,
#[builder(default = "BlockRegistrationDuplicationSetting::Disabled")]
default_duplication_setting: BlockRegistrationDuplicationSetting,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuilder<S, L, M> {
pub fn build(self) -> anyhow::Result<ManagedBlockPool<S, L, M>> {
let args = self.build_internal()?;
let (
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
) = args.dissolve();
tracing::info!("building block pool");
let pool = ManagedBlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
Ok(pool)
}
}
// Specific request type aliases for our use cases
type AllocateBlocksReq<S, L, M> = RequestResponse<usize, BlockPoolResult<MutableBlocks<S, L, M>>>;
type RegisterBlocksReq<S, L, M> = RequestResponse<
(MutableBlocks<S, L, M>, BlockRegistrationDuplicationSetting),
BlockPoolResult<ImmutableBlocks<S, L, M>>,
>;
type MatchHashesReq<S, L, M> =
RequestResponse<Vec<SequenceHash>, BlockPoolResult<ImmutableBlocks<S, L, M>>>;
type TouchBlocksReq = RequestResponse<Vec<SequenceHash>, BlockPoolResult<()>>;
type AddBlocksReq<S, L, M> = RequestResponse<Vec<Block<S, L, M>>, ()>;
type ResetReq = RequestResponse<(), BlockPoolResult<()>>;
type ReturnBlockReq<S, L, M> = RequestResponse<Vec<Block<S, L, M>>, BlockPoolResult<()>>;
type StatusReq = RequestResponse<(), BlockPoolResult<BlockPoolStatus>>;
type ResetBlocksReq = RequestResponse<Vec<SequenceHash>, BlockPoolResult<ResetBlocksResponse>>;
// Update the request enums to use the cleaner types
pub enum PriorityRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
AllocateBlocks(AllocateBlocksReq<S, L, M>),
RegisterBlocks(RegisterBlocksReq<S, L, M>),
MatchSequenceHashes(MatchHashesReq<S, L, M>),
TouchBlocks(TouchBlocksReq),
Reset(ResetReq),
ReturnBlock(ReturnBlockReq<S, L, M>),
}
pub enum ControlRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
AddBlocks(AddBlocksReq<S, L, M>),
Status(StatusReq),
ResetBlocks(ResetBlocksReq),
}
/// Manages the blocks in a specific storage backenda
pub struct ManagedBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
priority_tx: tokio::sync::mpsc::UnboundedSender<PriorityRequest<S, L, M>>,
ctrl_tx: tokio::sync::mpsc::UnboundedSender<ControlRequest<S, L, M>>,
available_blocks_counter: Arc<AtomicU64>,
total_blocks_counter: Arc<AtomicU64>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Clone for ManagedBlockPool<S, L, M> {
fn clone(&self) -> Self {
Self {
priority_tx: self.priority_tx.clone(),
ctrl_tx: self.ctrl_tx.clone(),
available_blocks_counter: self.available_blocks_counter.clone(),
total_blocks_counter: self.total_blocks_counter.clone(),
default_duplication_setting: self.default_duplication_setting,
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M> {
pub fn builder() -> ManagedBlockPoolArgsBuilder<S, L, M> {
ManagedBlockPoolArgsBuilder::default()
}
/// Creates a new [`ManagedBlockPool`] with the given [`EventManager`].
///
/// The pool starts empty and requires blocks to be added via [`add_blocks`].
///
/// # Arguments
///
/// * `event_manager` - An [`Arc<dyn EventManager>`] used for publishing block registration/removal events.
///
/// # Returns
///
/// A new [`ManagedBlockPool`] instance.
pub fn new(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> Self {
let (pool, progress_engine) = Self::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// tracing::debug!("starting progress engine");
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// });
let thread_name = format!(
"block-pool-{}-{}",
short_type_name::<S>(),
short_type_name::<L>()
);
std::thread::Builder::new()
.name(thread_name)
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for block pool progress engine");
runtime.block_on(async move {
let mut progress_engine = progress_engine;
tracing::debug!("starting progress engine");
while progress_engine.step().await {
tracing::trace!("progress engine step");
}
});
})
.expect("Failed to spawn block pool progress engine thread");
pool
}
fn with_progress_engine(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> (Self, ProgressEngine<S, L, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine = ProgressEngine::<S, L, M>::new(
event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
let available_blocks_counter = progress_engine.available_blocks_counter.clone();
let total_blocks_counter = progress_engine.total_blocks_counter.clone();
(
Self {
priority_tx,
ctrl_tx,
available_blocks_counter,
total_blocks_counter,
default_duplication_setting,
},
progress_engine,
)
}
pub fn default_duplication_setting(&self) -> BlockRegistrationDuplicationSetting {
self.default_duplication_setting
}
fn _add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> AsyncResponse<()> {
let (req, resp_rx) = AddBlocksReq::new(blocks);
self.ctrl_tx
.send(ControlRequest::AddBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _allocate_blocks(
&self,
count: usize,
) -> AsyncResponse<BlockPoolResult<Vec<MutableBlock<S, L, M>>>> {
let (req, resp_rx) = AllocateBlocksReq::new(count);
self.priority_tx
.send(PriorityRequest::AllocateBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _register_blocks(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
duplication_setting: BlockRegistrationDuplicationSetting,
) -> AsyncResponse<BlockPoolResult<ImmutableBlocks<S, L, M>>> {
if blocks.is_empty() {
return Err(BlockPoolError::NoBlocksToRegister);
}
let (req, resp_rx) = RegisterBlocksReq::new((blocks, duplication_setting));
self.priority_tx
.send(PriorityRequest::RegisterBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<ImmutableBlocks<S, L, M>>> {
let (req, resp_rx) = MatchHashesReq::new(sequence_hashes.into());
self.priority_tx
.send(PriorityRequest::MatchSequenceHashes(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _touch_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<()>> {
let (req, resp_rx) = TouchBlocksReq::new(sequence_hashes.into());
self.priority_tx
.send(PriorityRequest::TouchBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _reset(&self) -> AsyncResponse<BlockPoolResult<()>> {
let (req, resp_rx) = ResetReq::new(());
self.priority_tx
.send(PriorityRequest::Reset(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _try_return_block(&self, block: OwnedBlock<S, L, M>) -> AsyncResponse<BlockPoolResult<()>> {
let raw_blocks = block
.try_take_block(private::PrivateToken)
.ok_or(BlockPoolError::NotReturnable)?;
let (req, resp_rx) = ReturnBlockReq::new(raw_blocks);
self.priority_tx
.send(PriorityRequest::ReturnBlock(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
}
#[async_trait]
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPool<S, L, M>
for ManagedBlockPool<S, L, M>
{
/// Adds a vector of [`Block`]s to the [`InactiveBlockPool`].
async fn add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn add_blocks_blocking(&self, blocks: Vec<Block<S, L, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Attempts to allocate a specified number of free blocks from the [`InactiveBlockPool`].
///
/// Blocks acquired this way are returned as [`MutableBlock`]s, granting unique ownership
/// and allowing modification. Dropping a [`MutableBlock`] automatically returns it
/// to the [`InactiveBlockPool`].
async fn allocate_blocks(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn allocate_blocks_blocking(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Registers a vector of [`MutableBlock`]s (presumably after filling them) with the pool,
/// making them available for sharing via the [`ActiveBlockPool`].
///
/// This function checks if any of the blocks have the same sequence hash as an existing block
/// in the active pool. If so, it returns an [`ImmutableBlock`].
///
/// Note: Depending on the [`BlockRegistrationDuplicationSetting`], the returned [`ImmutableBlock`] may
/// not be the same block that was provided -- that is, it should hold the same content, but was the
/// first block registered. If duplication is allowed, we will keep alive both the primary block and
/// the duplicate block.
async fn register_blocks(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._register_blocks(blocks, self.default_duplication_setting)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn register_blocks_blocking(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._register_blocks(blocks, self.default_duplication_setting)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Attempts to match the given [`SequenceHash`] to an existing block, checking
/// both the active and inactive pools.
///
/// Checks the [`ActiveBlockPool`] first. If a valid strong reference exists, it returns
/// an [`ImmutableBlock`] cloned from it. If the weak reference exists but is stale,
/// it's removed.
///
/// If not found in the active pool, it checks the [`InactiveBlockPool`]. If found there,
/// the block is moved to the active pool (tracked by a weak reference) and returned
/// as a new [`ImmutableBlock`].
async fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._match_sequence_hashes(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn match_sequence_hashes_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._match_sequence_hashes(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn touch_blocks(&self, sequence_hashes: &[SequenceHash]) -> Result<(), BlockPoolError> {
self._touch_blocks(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn touch_blocks_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<(), BlockPoolError> {
self._touch_blocks(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn try_return_block(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()> {
self._try_return_block(block)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn try_return_block_blocking(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()> {
self._try_return_block(block)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn total_blocks(&self) -> u64 {
self.total_blocks_counter.load(Ordering::Relaxed)
}
fn available_blocks(&self) -> u64 {
self.available_blocks_counter.load(Ordering::Relaxed)
}
}
struct ProgressEngine<S: Storage, L: LocalityProvider, M: BlockMetadata> {
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, L, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, L, M>>,
cancel_token: CancellationToken,
state: State<S, L, M>,
return_rx: tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
metrics: Arc<PoolMetrics>,
available_blocks_counter: Arc<AtomicU64>,
total_blocks_counter: Arc<AtomicU64>,
}
pub struct State<S: Storage, L: LocalityProvider, M: BlockMetadata> {
active: ActiveBlockPool<S, L, M>,
inactive: InactiveBlockPool<S, L, M>,
registry: BlockRegistry,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
event_manager: Arc<dyn EventManager>,
metrics: Arc<PoolMetrics>,
}
impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> ProgressEngine<S, L, M> {
#[allow(clippy::too_many_arguments)]
pub fn new(
event_manager: Arc<dyn EventManager>,
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, L, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, L, M>>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, L, M>::new(
event_manager,
return_tx,
global_registry,
async_runtime,
metrics.clone(),
);
let count = blocks.len();
tracing::debug!(count, "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
let available_blocks_counter = state.inactive.available_blocks_counter();
let total_blocks_counter = state.inactive.total_blocks_counter();
Self {
priority_rx,
ctrl_rx,
cancel_token,
state,
return_rx,
metrics,
available_blocks_counter,
total_blocks_counter,
}
}
pub async fn step(&mut self) -> bool {
tokio::select! {
biased;
Some(priority_req) = self.priority_rx.recv(), if !self.priority_rx.is_closed() => {
self.metrics.gauge("priority_request_queue_size").set(self.priority_rx.len() as i64);
self.state.handle_priority_request(priority_req, &mut self.return_rx).await;
}
Some(req) = self.ctrl_rx.recv(), if !self.ctrl_rx.is_closed() => {
self.metrics.gauge("control_request_queue_size").set(self.ctrl_rx.len() as i64);
self.state.handle_control_request(req);
}
Some(block) = self.return_rx.recv() => {
self.metrics.gauge("return_block_queue_size").set(self.return_rx.len() as i64);
self.state.handle_return_block(block);
}
_ = self.cancel_token.cancelled() => {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use crate::block_manager::block::{BasicMetadata, Blocks};
use crate::block_manager::layout::{tests::setup_layout, FullyContiguous, LayoutConfig};
use crate::block_manager::locality::Local;
use crate::tokens::{TokenBlockSequence, Tokens};
use crate::block_manager::storage::tests::{NullDeviceAllocator, NullDeviceStorage};
use super::*;
/// Helper method to build a [`ManagedBlockPool`] with a [`ProgressEngine`] for unit testing
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuilder<S, L, M> {
#[allow(clippy::type_complexity)]
fn build_with_progress_engine(
self,
) -> anyhow::Result<(ManagedBlockPool<S, L, M>, ProgressEngine<S, L, M>)> {
let args = self.build_internal()?;
let (
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
) = args.dissolve();
let (pool, progress_engine) = ManagedBlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
Ok((pool, progress_engine))
}
}
#[tokio::test]
async fn test_block_pool_state() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (_pool, mut progress) = ManagedBlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
drop(blocks);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
let mut blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
let mut block = blocks.pop().unwrap();
block.init_sequence(1337).unwrap();
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
assert!(block.add_token(5).is_err());
}
#[tokio::test]
async fn test_block_pool() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (pool, mut progress) = ManagedBlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let pool_clone = pool.clone();
let allocate_1_block =
tokio::spawn(async move { pool_clone.allocate_blocks(1).await.unwrap() });
progress.step().await;
let blocks = allocate_1_block.await.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
// drop the single block
drop(blocks);
// check before and after the progress engine step
assert_eq!(progress.state.inactive.available_blocks(), 6);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
}
#[test]
fn test_block_pool_blocking() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash(), sequence_hash);
// Dropping the immutable block should return the block to the pool
// However, the block should remain in the ManagedBlockPool as an inactive block until it is reused
// or promoted back to an immutable block by being matched with a sequence hash
drop(block);
// Get the list of ImmutableBlocks that match the sequence hash
let matched = pool
.match_sequence_hashes_blocking(&[sequence_hash])
.unwrap();
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].sequence_hash(), sequence_hash);
}
async fn create_blocks<S: Storage, L: LocalityProvider, M: BlockMetadata>(
pool: &ManagedBlockPool<S, L, M>,
num_blocks: usize,
) -> anyhow::Result<(Vec<ImmutableBlock<S, L, M>>, Vec<SequenceHash>)> {
let tokens = vec![0; num_blocks * 4];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), num_blocks);
let mut sequence_hashes = Vec::new();
let mut mutable_blocks = Vec::new();
for token_block in token_blocks.blocks().iter() {
let mut block = pool.allocate_blocks(1).await?.pop().unwrap();
block.apply_token_block(token_block.clone())?;
sequence_hashes.push(block.sequence_hash().unwrap());
mutable_blocks.push(block);
}
let immutable_blocks = pool.register_blocks(mutable_blocks).await?;
Ok((immutable_blocks, sequence_hashes))
}
async fn make_simple_pool(
num_blocks: usize,
) -> anyhow::Result<
ManagedBlockPool<NullDeviceStorage, crate::block_manager::locality::Local, BasicMetadata>,
> {
let config = LayoutConfig {
num_blocks,
num_layers: 1,
outer_dim: 1,
page_size: 4,
inner_dim: 1024,
alignment: 1,
dtype_width_bytes: 2,
};
let layout = FullyContiguous::<NullDeviceStorage>::allocate(config, &NullDeviceAllocator)?;
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)?.into_blocks()?;
let pool = ManagedBlockPool::builder().blocks(blocks).build()?;
Ok(pool)
}
/// A test that ensures that we only ever evict leaves from the inactive pool.
#[tokio::test]
async fn test_block_pool_evict_leaves() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 block. This should evict the leaf of our allocated sequence.
pool.allocate_blocks(1).await?;
// The leaf should be evicted, so we should have 3 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 3);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should get the previously allocated block, as well as one more leaf.
pool.allocate_blocks(2).await.unwrap();
// The next leaf should be evicted, so we should have 2 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 2);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// If we allocate all the blocks, the entire remaining sequence should be evicted.
let blocks = pool.allocate_blocks(4).await?;
assert_eq!(blocks.len(), 4);
Ok(())
}
/// When a block has two children, we need to ensure that we evict both children before
/// adding the parent to the leaf set.
#[tokio::test]
async fn test_block_pool_parent_child() -> anyhow::Result<()> {
let pool = make_simple_pool(3).await?;
let tokens = vec![1, 2, 3, 4, 5];
let sequence = TokenBlockSequence::new(Tokens::from(tokens.clone()), 4, None);
// Create a root block, with two child blocks.
let mut root_block = pool.allocate_blocks(1).await?.pop().unwrap();
root_block.apply_token_block(sequence.blocks().first().unwrap().clone())?;
let root_block_hash = root_block.sequence_hash().unwrap();
let mut child_blocks = Vec::new();
let mut child_block_hashes = Vec::new();
for i in 0..2 {
// Create a new token sequence using the common prefix.
let mut tokens = tokens.clone();
for _ in 0..4 {
tokens.push(i);
}
let seq = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
// Allocate and apply the suffix to the child block.
let mut child_block = pool.allocate_blocks(1).await?.pop().unwrap();
child_block.apply_token_block(seq.blocks()[1].clone())?;
child_block_hashes.push(child_block.sequence_hash().unwrap());
child_blocks.push(child_block);
}
// Register the root block
let root_block = pool.register_blocks(vec![root_block]).await?;
// Register the children
let child_blocks = pool.register_blocks(child_blocks).await?;
// Drop both of them.
drop(root_block);
drop(child_blocks);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate two new blocks, which should evict both children.
pool.allocate_blocks(2).await?;
// Now, the root block should be the only block left.
for child_block_hash in child_block_hashes {
let matched = pool.match_sequence_hashes(&[child_block_hash]).await?;
assert_eq!(matched.len(), 0);
}
// Check that the root block remains.
let matched = pool.match_sequence_hashes(&[root_block_hash]).await?;
assert_eq!(matched.len(), 1);
Ok(())
}
/// Matching an entire sequence (moving it to the active pool), and returning it
/// should not affect the parent-child relationships of the blocks.
#[tokio::test]
async fn test_block_pool_match_return() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// We match the root of the sequence (moving it to the active pool), then
// immediately return it.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
1
);
let _alloc_blocks1 = pool.allocate_blocks(3).await?;
// Allocating 3 blocks should evict all but the root of the sequence.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
1
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _alloc_blocks2 = pool.allocate_blocks(1).await?;
// Now, allocating one more block should evict the root.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
0
);
Ok(())
}
#[tokio::test]
async fn test_block_pool_touch() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _block0 = pool.allocate_blocks(1).await?;
// The leaf should be evicted.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[3]].as_slice())
.await?
.len(),
0
);
// Now, touch the new leaf.
pool.touch_blocks(vec![sequence_hashes[2]].as_slice())
.await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _block1 = pool.allocate_blocks(1).await?;
// Since we touched block 2, block 1 should have been evicted.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[1]].as_slice())
.await?
.len(),
0
);
pool.touch_blocks(vec![sequence_hashes[3]].as_slice())
.await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
pool.allocate_blocks(1).await?;
// Now block 0 was evicted, since it was the last to be touched.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
0
);
Ok(())
}
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
fn create_block(
pool: &ManagedBlockPool<NullDeviceStorage, Local, BasicMetadata>,
) -> ImmutableBlock<NullDeviceStorage, Local, BasicMetadata> {
let count = pool.available_blocks();
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
assert_eq!(pool.available_blocks(), count - 1);
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash(), sequence_hash);
block
}
#[test]
fn test_block_registration_allow_duplicates() {
// const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let count = blocks.len() as u64;
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.default_duplication_setting(BlockRegistrationDuplicationSetting::Allowed)
.build()
.unwrap();
assert_eq!(pool.total_blocks(), count);
assert_eq!(pool.available_blocks(), count);
assert_eq!(
pool.default_duplication_setting(),
BlockRegistrationDuplicationSetting::Allowed
);
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
let primary = create_block(&pool);
let primary_id = primary.block_id();
assert_eq!(pool.available_blocks(), count - 1);
// Now allocate another and register it with the same sequence
let duplicate = create_block(&pool);
assert!(duplicate.is_duplicate());
assert_ne!(duplicate.block_id(), primary_id);
assert_eq!(pool.available_blocks(), count - 2);
// Reset only succeeds if all the blocks have been returned to the pool
let reset_result = pool.reset_blocking();
assert!(reset_result.is_err());
// we hold both the primary and the duplicate in the duplicate
// since we hold the primary in the duplicate, we expect this to fail
assert!(pool.try_return_block_blocking(primary.into()).is_err());
assert_eq!(pool.available_blocks(), count - 2);
assert!(pool.try_return_block_blocking(duplicate.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// we can still match the primary block because we have not reset the pool
let mut matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
let primary = matched_blocks.pop().unwrap();
assert!(pool.try_return_block_blocking(primary.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// we can still create a duplicate even if the block is inactive
let duplicate = create_block(&pool);
assert!(duplicate.is_duplicate());
assert_ne!(duplicate.block_id(), primary_id);
assert_eq!(pool.available_blocks(), count - 2);
assert!(pool.try_return_block_blocking(duplicate.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// Reset the pool
let reset_result = pool.reset_blocking();
assert!(reset_result.is_ok());
// Now we should not be able to match the primary block
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
}
#[test]
fn test_block_registration_disable_duplicates() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let count = blocks.len() as u64;
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPoolArgsBuilder::default()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.default_duplication_setting(BlockRegistrationDuplicationSetting::Disabled)
.build()
.unwrap();
assert_eq!(pool.total_blocks(), count);
assert_eq!(pool.available_blocks(), count);
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// allocate and register the primary block
let primary = create_block(&pool);
let primary_id = primary.block_id();
assert_eq!(pool.available_blocks(), count - 1);
// Now allocate another and register it with the same sequence
let duplicate = create_block(&pool);
assert_eq!(pool.available_blocks(), count - 1);
assert_eq!(duplicate.block_id(), primary_id);
}
}
......@@ -13,14 +13,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::block_manager::block::locality::LocalityProvider;
use super::*;
/// Manages active blocks being used by sequences
pub struct ActiveBlockPool<S: Storage, M: BlockMetadata> {
pub(super) map: HashMap<SequenceHash, Weak<MutableBlock<S, M>>>,
pub struct ActiveBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
pub(super) map: HashMap<SequenceHash, Weak<MutableBlock<S, L, M>>>,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Default for ActiveBlockPool<S, L, M> {
fn default() -> Self {
Self::new()
}
}
impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ActiveBlockPool<S, L, M> {
pub fn new() -> Self {
Self {
map: HashMap::new(),
......@@ -29,8 +37,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
pub fn register(
&mut self,
mut block: MutableBlock<S, M>,
) -> Result<ImmutableBlock<S, M>, BlockPoolError> {
mut block: MutableBlock<S, L, M>,
) -> Result<ImmutableBlock<S, L, M>, BlockPoolError> {
if !block.state().is_registered() {
return Err(BlockPoolError::InvalidMutableBlock(
"block is not registered".to_string(),
......@@ -69,7 +77,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
}
}
pub fn remove(&mut self, block: &mut Block<S, M>) {
pub fn remove(&mut self, block: &mut Block<S, L, M>) {
if let Ok(sequence_hash) = block.sequence_hash() {
if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(_arc) = weak.upgrade() {
......@@ -84,7 +92,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
pub fn match_sequence_hash(
&mut self,
sequence_hash: SequenceHash,
) -> Option<ImmutableBlock<S, M>> {
) -> Option<ImmutableBlock<S, L, M>> {
if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(arc) = weak.upgrade() {
Some(ImmutableBlock::new(arc))
......@@ -97,4 +105,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
None
}
}
pub fn status(&self) -> usize {
self.map.keys().len()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M> {
fn _status(&self) -> AsyncResponse<BlockPoolResult<BlockPoolStatus>> {
let (req, resp_rx) = StatusReq::new(());
self.ctrl_tx
.send(ControlRequest::Status(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<ResetBlocksResponse>> {
let (req, resp_rx) = ResetBlocksReq::new(sequence_hashes.into());
self.ctrl_tx
.send(ControlRequest::ResetBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPoolController
for ManagedBlockPool<S, L, M>
{
fn status_blocking(&self) -> Result<BlockPoolStatus, BlockPoolError> {
self._status()?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn reset_blocking(&self) -> Result<(), BlockPoolError> {
self._reset()?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn reset_blocks_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError> {
self._reset_blocks(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
}
#[async_trait::async_trait]
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AsyncBlockPoolController
for ManagedBlockPool<S, L, M>
{
async fn status(&self) -> Result<BlockPoolStatus, BlockPoolError> {
self._status()?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn reset(&self) -> Result<(), BlockPoolError> {
self._reset()?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError> {
self._reset_blocks(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
}
......@@ -13,35 +13,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::block_manager::block::BlockState;
use std::sync::atomic::AtomicU64;
use crate::block_manager::block::{locality::LocalityProvider, BlockState};
use super::*;
use std::collections::HashSet;
use priority_key::PriorityKey;
use tracing::instrument;
#[derive(Default)]
pub struct InactiveBlockPool<S: Storage, M: BlockMetadata> {
pub struct InactiveBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
// Direct lookup by sequence_hash.
lookup_map: HashMap<SequenceHash, Block<S, M>>,
lookup_map: HashMap<SequenceHash, Block<S, L, M>>,
// A priority ordering for the leaf nodes.
// Leaf nodes are defined as blocks that have no children in the inactive pool.
leaf_set: BTreeSet<PriorityKey<M>>,
// Mapping from parents to their children.
parent_children: HashMap<SequenceHash, HashSet<SequenceHash>>,
// Ordered by timestamp (oldest first)
priority_set: BTreeSet<PriorityKey<M>>,
// Fully Uninitialized
uninitialized_set: VecDeque<Block<S, M>>,
uninitialized_set: VecDeque<Block<S, L, M>>,
// Return Tick
return_tick: u64,
// Total blocks
total_blocks: u64,
// Total blocks counter
total_blocks: Arc<AtomicU64>,
// Inactive blocks
available_blocks: Arc<AtomicU64>,
}
impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> InactiveBlockPool<S, L, M> {
/// Creates a new, empty [`InactiveBlockPool`].
///
/// # Returns
......@@ -50,21 +52,39 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub(crate) fn new() -> Self {
Self {
lookup_map: HashMap::new(),
leaf_set: BTreeSet::new(),
parent_children: HashMap::new(),
priority_set: BTreeSet::new(),
uninitialized_set: VecDeque::new(),
return_tick: 0,
total_blocks: 0,
total_blocks: Arc::new(AtomicU64::new(0)),
available_blocks: Arc::new(AtomicU64::new(0)),
}
}
/// Returns a counter for the number of available blocks.
///
/// # Returns
///
/// A counter for the number of available blocks as an [`Arc<AtomicU64>`].
pub fn available_blocks_counter(&self) -> Arc<AtomicU64> {
self.available_blocks.clone()
}
/// Returns a counter for the total number of blocks.
///
/// # Returns
///
/// A counter for the total number of blocks as an [`Arc<AtomicU64>`].
pub fn total_blocks_counter(&self) -> Arc<AtomicU64> {
self.total_blocks.clone()
}
/// Returns the total number of blocks managed by this pool (both available and acquired).
///
/// # Returns
///
/// The total block count as a [`u64`].
pub fn total_blocks(&self) -> u64 {
self.total_blocks
self.total_blocks.load(Ordering::Relaxed)
}
/// Returns the number of blocks currently available in the pool.
......@@ -84,17 +104,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// If an entry with the same sequence hash already exists in the [`lookup_map`]
/// the block is reset and moved to the [`uninitialized_set`].
/// Otherwise, the block is added to the [`lookup_map`].
/// If there are no children of the block, it is added to the [`leaf_set`].
/// If the parent of the block is in the [`leaf_set`], it is removed from the [`leaf_set`].
///
/// # Arguments
///
/// * `block` - The block to insert ([`Block<T, M>`]).
/// * `sequence_hash` - The sequence hash associated with the block's content ([`SequenceHash`]).
#[instrument(level = "trace", skip(self, block), fields(sequence_hash = ?sequence_hash))]
fn insert_with_sequence_hash(&mut self, block: Block<S, M>, sequence_hash: SequenceHash) {
fn insert_with_sequence_hash(&mut self, block: Block<S, L, M>, sequence_hash: SequenceHash) {
let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
if self.lookup_map.contains_key(&sequence_hash) {
if self.priority_set.contains(&priority_key) {
tracing::trace!("multiple entries with the same sequence hash, resetting block and inserting into uninitialized set");
let mut block = block;
block.reset();
......@@ -102,27 +120,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
} else {
tracing::trace!("inserting block to map and priority set");
if let Ok(Some(parent)) = block.parent_sequence_hash() {
// Add the entry for the parent->child link.
self.parent_children
.entry(parent)
.or_default()
.insert(sequence_hash);
// If the parent is currently in the inactive pool, remove it from the leaf set.
if let Some(parent_block) = self.lookup_map.get_mut(&parent) {
self.leaf_set
.remove(&PriorityKey::new(parent_block.metadata().clone(), parent));
}
}
// Create the entry for the block in the lookup map.
self.priority_set.insert(priority_key);
self.lookup_map.insert(sequence_hash, block);
// If the block has no children, it is a leaf.
if !self.parent_children.contains_key(&sequence_hash) {
self.leaf_set.insert(priority_key);
}
}
}
......@@ -137,7 +136,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// * `block` - The block to insert ([`Block<S, M>`]).
#[instrument(level = "trace", skip(self, block), fields(block_state = ?block.state()))]
fn insert(&mut self, block: Block<S, M>) {
fn insert(&mut self, block: Block<S, L, M>) {
tracing::trace!("Inserting block into available pool");
// If we already have an entry for this sequence hash or the block is reset,
......@@ -161,6 +160,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
self.insert_with_sequence_hash(block, sequence_hash);
}
}
self.available_blocks.fetch_add(1, Ordering::Relaxed);
}
/// Adds multiple blocks to the pool.
......@@ -171,7 +172,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks(&mut self, blocks: Vec<Block<S, M>>) {
pub fn add_blocks(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool");
......@@ -181,7 +182,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
self.insert(block);
}
self.total_blocks += count as u64;
self.total_blocks.fetch_add(count as u64, Ordering::Relaxed);
}
/// Adds multiple blocks to the pool.
......@@ -192,10 +193,10 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks_with_state(&mut self, blocks: Vec<Block<S, M>>) {
pub fn add_blocks_with_state(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool");
self.total_blocks += count as u64;
self.total_blocks.fetch_add(count as u64, Ordering::Relaxed);
// self.available_blocks += count as u64;
self.return_blocks(blocks);
}
......@@ -209,7 +210,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// * `block` - The block ([`Block<S, M>`]) to return.
#[instrument(level = "debug", skip(self, block))]
pub fn return_block(&mut self, mut block: Block<S, M>) {
pub fn return_block(&mut self, mut block: Block<S, L, M>) {
// increment the return tick
self.return_tick += 1;
......@@ -231,7 +232,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to return.
#[instrument(level = "debug", skip(self, blocks))]
pub fn return_blocks(&mut self, blocks: Vec<Block<S, M>>) {
pub fn return_blocks(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len();
tracing::debug!(count, "Returning blocks to pool");
// return the block to the pool from tail to head
......@@ -243,7 +244,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
}
/// Attempts to remove and return a block associated with the given sequence hash
/// from the [`lookup_map`] and [`leaf_set`].
/// from the [`lookup_map`] and [`priority_set`].
///
/// # Arguments
///
......@@ -253,13 +254,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "trace", skip(self), fields(sequence_hash = ?sequence_hash))]
fn take_with_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> {
fn take_with_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, L, M>> {
match self.lookup_map.remove(&sequence_hash) {
Some(block) => {
// Remove from leaf set, if it exists.
self.leaf_set
.remove(&PriorityKey::new(block.metadata().clone(), sequence_hash));
// Remove from priority set.
let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
// Remove from priority set, if it exists.
self.priority_set.remove(&priority_key);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
Some(block)
}
None => None,
......@@ -278,7 +281,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "debug", skip(self), fields(sequence_hash = ?sequence_hash))]
pub fn match_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> {
pub fn match_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, L, M>> {
self.take_with_sequence_hash(sequence_hash)
}
......@@ -299,7 +302,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub fn match_sequence_hashes(
&mut self,
sequence_hashes: Vec<SequenceHash>,
) -> Vec<Block<S, M>> {
) -> Vec<Block<S, L, M>> {
let total_hashes = sequence_hashes.len();
let mut matched_blocks = Vec::with_capacity(total_hashes);
......@@ -332,7 +335,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// A vector containing the blocks ([`Block<T, M>`]) that were successfully matched and taken.
/// The vector may be shorter than `token_blocks` if not all corresponding hashes were found.
#[instrument(level = "debug", skip(self, token_blocks), fields(num_token_blocks = token_blocks.len()))]
pub fn match_token_blocks(&mut self, token_blocks: &[TokenBlock]) -> Vec<Block<S, M>> {
pub fn match_token_blocks(&mut self, token_blocks: &[TokenBlock]) -> Vec<Block<S, L, M>> {
let total_blocks = token_blocks.len();
let mut matched_blocks = Vec::with_capacity(total_blocks);
......@@ -375,52 +378,27 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// and [`lookup_map`] (i.e., a key exists in the set but not the map). This indicates
/// a bug in the pool's internal logic.
#[instrument(level = "debug", skip(self))]
pub fn acquire_free_block(&mut self) -> Option<Block<S, M>> {
pub fn acquire_free_block(&mut self) -> Option<Block<S, L, M>> {
// First try uninitialized blocks - these are often part of sequences
// that have been arranged in the correct order
if let Some(mut block) = self.uninitialized_set.pop_front() {
tracing::trace!("Acquired uninitialized block");
self.return_tick += 1;
block.metadata_on_acquired(self.return_tick);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
return Some(block);
}
// if we have blocks in the leaf set, pop the first (it's sorted by priority)
// if we have blocks in the priority set, pop the first (it's sorted by priority)
// a fatal error will occur if the block is not found in the lookup map
if let Some(key) = self.leaf_set.pop_first() {
if let Some(key) = self.priority_set.pop_first() {
tracing::trace!("Acquired priority/registered block map; resetting block");
match self.lookup_map.remove(&key.sequence_hash()) {
Some(mut block) => {
if let Some(children) = self.parent_children.get(&key.sequence_hash()) {
panic!(
"Block has {} inactive children, but should have none.",
children.len()
);
}
if let Ok(Some(parent)) = block.parent_sequence_hash() {
let is_leaf = match self.parent_children.get_mut(&parent) {
Some(children) => {
children.remove(&key.sequence_hash());
children.is_empty()
}
None => true,
};
if is_leaf {
self.parent_children.remove(&parent);
if let Some(parent_block) = self.lookup_map.get(&parent) {
self.leaf_set.insert(PriorityKey::new(
parent_block.metadata().clone(),
parent,
));
}
}
}
block.reset();
self.return_tick += 1;
block.metadata_on_acquired(self.return_tick);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
Some(block)
}
None => {
......@@ -457,7 +435,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub fn acquire_free_blocks(
&mut self,
count: usize,
) -> Result<Vec<Block<S, M>>, BlockPoolError> {
) -> Result<Vec<Block<S, L, M>>, BlockPoolError> {
if count == 0 {
return Ok(Vec::new());
}
......@@ -529,13 +507,48 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
Ok(blocks)
}
/// Resets the pool to its initial state.
///
/// This function will acquire all blocks, which will reset their state, then return them.
///
/// A [`Result`] containing `Ok(())` if the reset was successful, otherwise an error.
pub fn reset(&mut self) -> Result<(), BlockPoolError> {
let total_blocks = self.total_blocks.load(Ordering::Relaxed);
let available_blocks = self.available_blocks.load(Ordering::Relaxed);
if total_blocks != available_blocks {
return Err(BlockPoolError::ResetError(format!(
"total blocks: {}, available blocks: {}",
total_blocks, available_blocks
)));
}
let blocks = self.acquire_free_blocks(total_blocks as usize)?;
for block in blocks.into_iter() {
self.return_block(block);
}
Ok(())
}
/// Returns the [`PoolStatus`] of the pool.
pub fn status(&self) -> (usize, usize) {
let inactive_blocks = self.priority_set.len();
let empty_blocks = self.uninitialized_set.len();
(inactive_blocks, empty_blocks)
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::{
block_manager::{
block::{registry::BlockRegistry, state::CompleteState, Blocks, PrivateBlockExt},
block::{
locality::Local, registry::BlockRegistry, state::CompleteState, Blocks,
PrivateBlockExt,
},
events::NullEventManager,
layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder},
storage::tests::{NullDeviceAllocator, NullDeviceStorage},
......@@ -650,7 +663,7 @@ pub(crate) mod tests {
tokens: Tokens,
block_size: u32,
async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
) -> Vec<Block<NullDeviceStorage, Local, TestMetadata>> {
let (token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
let num_blocks = token_blocks.len();
......@@ -681,7 +694,7 @@ pub(crate) mod tests {
pub fn create_block_pool(
num_blocks: usize,
) -> InactiveBlockPool<NullDeviceStorage, TestMetadata> {
) -> InactiveBlockPool<NullDeviceStorage, Local, TestMetadata> {
let mut pool = InactiveBlockPool::new();
let blocks = create_block_collection(num_blocks).into_blocks().unwrap();
pool.add_blocks(blocks);
......@@ -692,9 +705,9 @@ pub(crate) mod tests {
pub fn acquire_blocks(
tokens: Tokens,
block_size: u32,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
pool: &mut InactiveBlockPool<NullDeviceStorage, Local, TestMetadata>,
async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
) -> (Vec<Block<NullDeviceStorage, Local, TestMetadata>>, usize) {
let (mut token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
......@@ -764,6 +777,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let tokens = create_token_sequence(&[1, 2, 3, 4]);
......@@ -776,11 +793,19 @@ pub(crate) mod tests {
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 0);
assert_eq!(pool.available_blocks(), 8);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(),
......@@ -791,11 +816,19 @@ pub(crate) mod tests {
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 2);
assert_eq!(pool.available_blocks(), 8);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let blocks = pool.acquire_free_blocks(10).unwrap();
for block in &blocks {
......@@ -828,6 +861,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
// Match the blocks in sequence
let matched = pool.match_sequence_hashes(hashes.clone());
......@@ -835,6 +872,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 0);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
// Validate the blocks are in the correct order and match the sequence hashes
assert_eq!(matched[0].sequence_hash().unwrap(), hashes[0]);
......@@ -845,5 +886,9 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
}
}
......@@ -20,10 +20,13 @@ use crate::block_manager::{
use super::*;
impl<S: Storage, M: BlockMetadata> State<S, M> {
fn new(
use active::ActiveBlockPool;
use inactive::InactiveBlockPool;
impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> State<S, L, M> {
pub fn new(
event_manager: Arc<dyn EventManager>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
......@@ -38,10 +41,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
}
}
async fn handle_priority_request(
pub async fn handle_priority_request(
&mut self,
req: PriorityRequest<S, M>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
req: PriorityRequest<S, L, M>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) {
match req {
PriorityRequest::AllocateBlocks(req) => {
......@@ -52,8 +55,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
}
}
PriorityRequest::RegisterBlocks(req) => {
let (blocks, resp_tx) = req.dissolve();
let immutable_blocks = self.register_blocks(blocks, return_rx).await;
let ((blocks, duplication_setting), resp_tx) = req.dissolve();
let immutable_blocks = self
.register_blocks(blocks, duplication_setting, return_rx)
.await;
if resp_tx.send(immutable_blocks).is_err() {
tracing::error!("failed to send response to register blocks");
}
......@@ -61,14 +66,37 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
PriorityRequest::MatchSequenceHashes(req) => {
let (sequence_hashes, resp_tx) = req.dissolve();
let immutable_blocks = self.match_sequence_hashes(sequence_hashes, return_rx).await;
if resp_tx.send(immutable_blocks).is_err() {
if resp_tx.send(Ok(immutable_blocks)).is_err() {
tracing::error!("failed to send response to match sequence hashes");
}
}
PriorityRequest::TouchBlocks(req) => {
let (sequence_hashes, resp_tx) = req.dissolve();
self.touch_blocks(&sequence_hashes, return_rx).await;
if resp_tx.send(Ok(())).is_err() {
tracing::error!("failed to send response to touch blocks");
}
}
PriorityRequest::Reset(req) => {
let (_req, resp_tx) = req.dissolve();
let result = self.inactive.reset();
if resp_tx.send(result).is_err() {
tracing::error!("failed to send response to reset");
}
}
PriorityRequest::ReturnBlock(req) => {
let (returnable_blocks, resp_tx) = req.dissolve();
for block in returnable_blocks {
self.return_block(block);
}
if resp_tx.send(Ok(())).is_err() {
tracing::error!("failed to send response to return block");
}
}
}
}
fn handle_control_request(&mut self, req: ControlRequest<S, M>) {
pub fn handle_control_request(&mut self, req: ControlRequest<S, L, M>) {
match req {
ControlRequest::AddBlocks(blocks) => {
let (blocks, resp_rx) = blocks.dissolve();
......@@ -77,10 +105,25 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
tracing::error!("failed to send response to add blocks");
}
}
ControlRequest::Status(req) => {
let (_, resp_rx) = req.dissolve();
if resp_rx.send(Ok(self.status())).is_err() {
tracing::error!("failed to send response to status");
}
}
ControlRequest::ResetBlocks(req) => {
let (sequence_hashes, resp_rx) = req.dissolve();
if resp_rx
.send(Ok(self.try_reset_blocks(&sequence_hashes)))
.is_err()
{
tracing::error!("failed to send response to reset blocks");
}
}
}
}
fn handle_return_block(&mut self, block: Block<S, M>) {
pub fn handle_return_block(&mut self, block: Block<S, L, M>) {
self.return_block(block);
}
......@@ -89,8 +132,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
async fn wait_for_returned_block(
&mut self,
sequence_hash: SequenceHash,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Block<S, M> {
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Block<S, L, M> {
while let Some(block) = return_rx.recv().await {
if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash)
{
......@@ -105,7 +148,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
pub fn allocate_blocks(
&mut self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
let available_blocks = self.inactive.available_blocks() as usize;
if available_blocks < count {
......@@ -135,11 +178,15 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
Ok(blocks)
}
#[tracing::instrument(level = "debug", skip_all, fields(blocks = ?blocks))]
pub async fn register_blocks(
&mut self,
blocks: Vec<MutableBlock<S, M>>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Result<Vec<ImmutableBlock<S, M>>, BlockPoolError> {
blocks: Vec<MutableBlock<S, L, M>>,
duplication_setting: BlockRegistrationDuplicationSetting,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Result<Vec<ImmutableBlock<S, L, M>>, BlockPoolError> {
assert!(!blocks.is_empty(), "no blocks to register");
let expected_len = blocks.len();
let mut immutable_blocks = Vec::new();
......@@ -151,44 +198,81 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
// If the block is already registered, acquire a clone of the immutable block
if let Some(immutable) = self.active.match_sequence_hash(sequence_hash) {
let immutable = if duplication_setting
== BlockRegistrationDuplicationSetting::Allowed
{
immutable.with_duplicate(block.into()).expect("incompatible immutable block; only primary should be returned from match_sequence_hash")
} else {
// immediate return the block to the pool if duplicates are disabled
if let Some(blocks) = block.try_take_block(private::PrivateToken) {
self.inactive.return_blocks(blocks);
}
immutable
};
immutable_blocks.push(immutable);
continue;
}
let mut offload = true;
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{
assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
MutableBlock::new(raw_block, self.return_tx.clone())
} else {
// Attempt to register the block
// On the very rare chance that the block is registered, but in the process of being returned,
// we will wait for it to be returned and then register it.
let result = block.register(&mut self.registry);
match result {
Ok(handle) => {
// Only create our publish handle if this block is new, and not transfered.
if let Some(handle) = handle {
publish_handles.take_handle(handle);
let (mutable, duplicate) =
if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) {
// We already have a match, so our block is a duplicate.
assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
(
MutableBlock::new(raw_block, self.return_tx.clone()),
Some(block),
)
} else {
// Attempt to register the block
// On the very rare chance that the block is registered, but in the process of being returned,
// we will wait for it to be returned and then register it.
let result = block.register(&mut self.registry);
match result {
Ok(handle) => {
// Only create our publish handle if this block is new, and not transfered.
if let Some(handle) = handle {
publish_handles.take_handle(handle);
}
(block, None)
}
Err(BlockRegistrationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned
// Return the original block as the primary, and the block we passed in as the duplicate.
offload = false;
let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await;
(
MutableBlock::new(raw_block, self.return_tx.clone()),
Some(block),
)
}
Err(e) => {
return Err(BlockPoolError::FailedToRegisterBlock(e.to_string()));
}
block
}
Err(BlockRegistrationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned
offload = false;
let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await;
MutableBlock::new(raw_block, self.return_tx.clone())
};
let mut immutable = self.active.register(mutable)?;
match duplication_setting {
BlockRegistrationDuplicationSetting::Allowed => {
if let Some(duplicate) = duplicate {
immutable = immutable
.with_duplicate(duplicate.into())
.expect("incompatible immutable block; only primary should be returned from ActiveBlockPool::register");
}
Err(e) => {
return Err(BlockPoolError::FailedToRegisterBlock(e.to_string()));
}
BlockRegistrationDuplicationSetting::Disabled => {
if let Some(block) = duplicate {
if let Some(raw_blocks) = block.try_take_block(private::PrivateToken) {
self.inactive.return_blocks(raw_blocks);
}
}
}
};
let immutable = self.active.register(mutable)?;
}
if offload {
if let Some(priority) = immutable.metadata().offload_priority() {
......@@ -211,8 +295,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
async fn match_sequence_hashes(
&mut self,
sequence_hashes: Vec<SequenceHash>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Vec<ImmutableBlock<S, M>> {
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Vec<ImmutableBlock<S, L, M>> {
let mut immutable_blocks = Vec::new();
for sequence_hash in &sequence_hashes {
if !self.registry.is_registered(*sequence_hash) {
......@@ -245,7 +329,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self
.active
.register(mutable)
.expect("unable to register block; should ever happen");
.expect("unable to register block; should never happen");
immutable_blocks.push(immutable);
}
......@@ -260,8 +344,31 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
immutable_blocks
}
async fn touch_blocks(
&mut self,
sequence_hashes: &[SequenceHash],
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) {
for sequence_hash in sequence_hashes {
if !self.registry.is_registered(*sequence_hash) {
break;
}
let block = if let Some(block) = self.inactive.match_sequence_hash(*sequence_hash) {
block
} else if self.active.match_sequence_hash(*sequence_hash).is_none() {
self.wait_for_returned_block(*sequence_hash, return_rx)
.await
} else {
continue;
};
self.inactive.return_block(block);
}
}
/// Returns a block to the inactive pool
pub fn return_block(&mut self, mut block: Block<S, M>) {
pub fn return_block(&mut self, mut block: Block<S, L, M>) {
self.active.remove(&mut block);
self.inactive.return_block(block);
}
......@@ -269,111 +376,41 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
fn publisher(&self) -> Publisher {
Publisher::new(self.event_manager.clone())
}
}
impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
#[allow(clippy::too_many_arguments)]
pub fn new(
event_manager: Arc<dyn EventManager>,
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, M>::new(
event_manager,
return_tx,
global_registry,
async_runtime,
metrics.clone(),
);
tracing::debug!(count = blocks.len(), "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
Self {
priority_rx,
ctrl_rx,
cancel_token,
state,
return_rx,
metrics,
fn status(&self) -> BlockPoolStatus {
let active = self.active.status();
let (inactive, empty) = self.inactive.status();
BlockPoolStatus {
active_blocks: active,
inactive_blocks: inactive,
empty_blocks: empty,
}
}
pub async fn step(&mut self) -> bool {
tokio::select! {
biased;
fn try_reset_blocks(&mut self, sequence_hashes: &[SequenceHash]) -> ResetBlocksResponse {
let mut reset_blocks = Vec::new();
let mut not_found = Vec::new();
let mut not_reset = Vec::new();
Some(priority_req) = self.priority_rx.recv(), if !self.priority_rx.is_closed() => {
self.metrics.gauge("priority_request_queue_size").set(self.priority_rx.len() as i64);
self.state.handle_priority_request(priority_req, &mut self.return_rx).await;
}
Some(req) = self.ctrl_rx.recv(), if !self.ctrl_rx.is_closed() => {
self.metrics.gauge("control_request_queue_size").set(self.ctrl_rx.len() as i64);
self.state.handle_control_request(req);
}
Some(block) = self.return_rx.recv() => {
self.metrics.gauge("return_block_queue_size").set(self.return_rx.len() as i64);
self.state.handle_return_block(block);
for sequence_hash in sequence_hashes {
if !self.registry.is_registered(*sequence_hash) {
not_found.push(*sequence_hash);
continue;
}
_ = self.cancel_token.cancelled() => {
return false;
if let Some(mut block) = self.inactive.match_sequence_hash(*sequence_hash) {
reset_blocks.push(*sequence_hash);
block.reset();
self.inactive.return_block(block);
} else {
not_reset.push(*sequence_hash);
}
}
true
ResetBlocksResponse {
reset_blocks,
not_found,
not_reset,
}
}
}
// pub(crate) async fn progress_engine<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// mut priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// mut ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let (return_tx, mut return_rx) = tokio::sync::mpsc::unbounded_channel();
// let mut state = State::<S, M>::new(event_manager, return_tx);
// loop {
// tokio::select! {
// biased;
// Some(priority_req) = priority_rx.recv(), if !priority_rx.is_closed() => {
// state.handle_priority_request(priority_req, &mut return_rx).await;
// }
// Some(req) = ctrl_rx.recv(), if !ctrl_rx.is_closed() => {
// state.handle_control_request(req);
// }
// Some(block) = return_rx.recv() => {
// state.handle_return_block(block);
// }
// _ = cancel_token.cancelled() => {
// break;
// }
// }
// }
// }
// pub(crate) async fn progress_engine_v2<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let mut progress_engine =
// ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token);
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// }
......@@ -13,190 +13,255 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod local;
mod logical;
mod resources;
use crate::block_manager::block::{factory::IntoBlocks, MutableBlock};
use crate::block_manager::locality::LogicalResources;
use crate::block_manager::offload::request::BlockResult;
use super::*;
use super::offload::OffloadManager;
// use super::offload::OffloadManager;
use super::{
block::{Block, GlobalRegistry, ImmutableBlock},
block::{
factory::LocalBlockDataFactory, locality::LocalityProvider, Block, GlobalRegistry,
ImmutableBlock,
},
config::NixlOptions,
events::{EventManager, NullEventManager},
metrics::{BlockManagerMetrics, PoolMetrics},
metrics::BlockManagerMetrics,
offload::OffloadManager,
};
use derive_getters::Dissolve;
use std::sync::Arc;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
#[allow(dead_code)]
pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID,
cancellation_token: CancellationToken,
pub(crate) struct Resources {
pub worker_id: WorkerID,
pub cancellation_token: CancellationToken,
pub async_rt_handle: Handle,
nixl_agent: Arc<Option<NixlAgent>>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
// nixl agent/backends for the block manager
pub nixl_agent: Arc<Option<NixlAgent>>,
#[expect(dead_code)]
pub nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host_pool: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
// registry for blocks across all storage types
pub global_registry: GlobalRegistry,
local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
// event manager for block manager events
pub event_manager: Arc<dyn EventManager>,
offload_manager: Arc<OffloadManager<Metadata>>,
// metrics for the block manager
pub metrics: Arc<BlockManagerMetrics>,
// config for the block manager
pub config: KvBlockManagerConfig,
}
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
pub fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
config
.runtime
.validate()
.context("Validating runtime config")?;
#[allow(dead_code)]
pub struct KvBlockManagerState<Locality: LocalityProvider, Metadata: BlockMetadata> {
resources: Arc<Resources>,
config.model.validate().context("Validating model config")?;
disk_pool: Option<Arc<dyn BlockPool<DiskStorage, Locality, Metadata>>>,
host_pool: Option<Arc<dyn BlockPool<PinnedStorage, Locality, Metadata>>>,
device_pool: Option<Arc<dyn BlockPool<DeviceStorage, Locality, Metadata>>>,
let worker_id = config.runtime.worker_id;
let cancellation_token = config.runtime.cancellation_token;
local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
offload_manager: Arc<OffloadManager<Locality, Metadata>>,
}
// Create a map of NIXL backends
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
impl<Locality: LocalityProvider, Metadata: BlockMetadata> KvBlockManagerState<Locality, Metadata> {
pub fn disk(&self) -> Option<&dyn BlockPool<DiskStorage, Locality, Metadata>> {
self.disk_pool.as_ref().map(|pool| pool.as_ref())
}
let global_registry = GlobalRegistry::default();
pub fn host(&self) -> Option<&dyn BlockPool<PinnedStorage, Locality, Metadata>> {
self.host_pool.as_ref().map(|pool| pool.as_ref())
}
let metrics = BlockManagerMetrics::new(&config.runtime.metrics_registry)?;
pub fn device(&self) -> Option<&dyn BlockPool<DeviceStorage, Locality, Metadata>> {
self.device_pool.as_ref().map(|pool| pool.as_ref())
}
let event_manager = config
.event_manager
.clone()
.unwrap_or_else(|| NullEventManager::new());
pub fn worker_id(&self) -> WorkerID {
self.resources.worker_id
}
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = Arc::new(match config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
&self,
block: &ImmutableBlock<S, Locality, Metadata>,
priority: u64,
) -> Result<()> {
self.offload_manager.offload(block, priority).await?;
tracing::debug!("Creating NIXL backends");
Ok(())
}
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") {
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
} else {
tracing::warn!("No UCX plugin found; will not create UCX backend");
}
pub fn onboard_blocks<S: Storage + 'static>(
&self,
blocks: Vec<ImmutableBlock<S, Locality, Metadata>>,
targets: Option<Vec<MutableBlock<DeviceStorage, Locality, Metadata>>>,
) -> oneshot::Receiver<BlockResult<DeviceStorage, Locality, Metadata>> {
self.offload_manager.onboard(blocks, targets)
}
}
if config.disk_layout.is_some() {
if let Ok((_, gds_params)) = agent.get_plugin_params("GDS") {
let backend = agent.create_backend("GDS", &gds_params)?;
nixl_backends.insert("GDS".to_string(), Arc::new(backend));
} else {
tracing::warn!("No GDS plugin found; will not create GDS backend");
}
}
impl<R: LogicalResources, Metadata: BlockMetadata>
KvBlockManagerState<locality::Logical<R>, Metadata>
{
pub async fn new(config: KvBlockManagerConfig, logical_resources: R) -> Result<Arc<Self>> {
let mut resources = Resources::new(config)?;
let block_data_factories =
logical::LogicalBlockFactories::new(&mut resources, logical_resources)?;
let (disk_factory, host_factory, device_factory) = block_data_factories.dissolve();
let (disk_pool, disk_blocks) = match disk_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
}
};
Some(agent)
let (host_pool, host_blocks) = match host_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "host")?;
(Some(pool), Some(blocks))
}
NixlOptions::EnabledWithAgent(agent) => Some(agent),
NixlOptions::Disabled => None,
});
None => {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None)
}
};
// Initialize model-specific layout config. The layout_builder is incomplete at this point.
// We will clone this builder and apply the storage-specific configs to each clone in the
// following steps.
let model = &config.model;
let mut layout_builder = LayoutConfig::builder();
layout_builder
.num_layers(model.num_layers)
.outer_dim(model.outer_dim)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype(model.dtype);
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
let (device_pool, device_blocks) = match device_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "device")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None)
}
};
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
let offload_manager = OffloadManager::new(
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
resources.nixl_agent.clone(),
resources.async_rt_handle.clone(),
resources.metrics.clone(),
resources.cancellation_token.clone(),
)?;
let resources = Arc::new(resources);
let state = Arc::new(Self {
resources: resources.clone(),
disk_pool,
host_pool,
device_pool,
local_block_set: NixlBlockSet::new(resources.worker_id),
remote_block_sets: RwLock::new(HashMap::new()),
offload_manager,
});
if let Some(mut blocks) = disk_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state.disk_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state.host_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = device_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.device_pool
.as_ref()
.unwrap()
.add_blocks(blocks)
.await?;
}
Ok(state)
}
}
// move into mod local
// move local block data factory into mod super::block
// create a method on locality to construct a block data factory from a layout builder and resources
// - this will allow us to use the locality abstraction to build our factories and block pools
impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> {
pub async fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
let mut resources = Resources::new(config)?;
let block_data_factories = local::LocalBlockDataFactories::new(&mut resources)?;
let (mut local_block_set, disk_factory, host_factory, device_factory) =
block_data_factories.dissolve();
let (disk_pool, disk_blocks) = match disk_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("disk"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
}
} else {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
};
// Create the host block pool if a host layout is provided
let (host_pool, host_blocks) = if let Some(config) = config.host_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("host"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None)
let (host_pool, host_blocks) = match host_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "host")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
}
};
// Create the device block pool if a device layout is provided
let (device_pool, device_blocks) = if let Some(config) = config.device_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("device"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None)
let (device_pool, device_blocks) = match device_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
}
};
// Finalize the local block set by adding NIXL metadata
if let Some(nixl_agent) = nixl_agent.as_ref() {
if let Some(nixl_agent) = resources.nixl_agent.as_ref() {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
}
......@@ -205,17 +270,16 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
nixl_agent.clone(),
async_rt_handle,
metrics.clone(),
cancellation_token.clone(),
resources.nixl_agent.clone(),
resources.async_rt_handle.clone(),
resources.metrics.clone(),
resources.cancellation_token.clone(),
)?;
let resources = Arc::new(resources);
let state = Arc::new(Self {
worker_id,
cancellation_token,
nixl_agent,
nixl_backends,
resources: resources.clone(),
disk_pool,
host_pool,
device_pool,
......@@ -229,12 +293,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
block.set_manager(state.clone());
});
state
.disk_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
state.disk_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = host_blocks {
......@@ -242,12 +301,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
block.set_manager(state.clone());
});
state
.host_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
state.host_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = device_blocks {
......@@ -258,9 +312,9 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state
.device_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
.add_blocks(blocks)
.await?;
}
Ok(state)
......@@ -296,11 +350,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
tracing::debug!("Importing remote blockset from worker {}", worker_id);
assert_ne!(
worker_id, self.worker_id,
worker_id, self.resources.worker_id,
"Cannot import blockset from self"
);
let agent = self
.resources
.nixl_agent
.as_ref()
.as_ref()
......@@ -417,91 +472,51 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(blocks)
}
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn worker_id(&self) -> WorkerID {
self.worker_id
}
pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
&self,
block: &ImmutableBlock<S, Metadata>,
priority: u64,
) -> Result<()> {
self.offload_manager.offload(block, priority).await?;
Ok(())
}
pub async fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
self.offload_manager.onboard(blocks).await
}
}
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> {
impl<Locality: LocalityProvider, Metadata: BlockMetadata> std::fmt::Debug
for KvBlockManagerState<Locality, Metadata>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "KvBlockManagerState")
}
}
fn create_layout<S: Storage + NixlRegisterableStorage>(
mut builder: LayoutConfigBuilder,
config: KvManagerLayoutConfig<S>,
nixl_agent: Option<&NixlAgent>,
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> {
let layout = builder.num_blocks(config.num_blocks).build()?;
if let Some(storage) = config.storage {
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
if let Some(allocator) = config.allocator {
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
// if let Some(storage) = config.storage {
// let mut layout = layout.create_layout(config.layout_type, storage, false)?;
// if let Some(nixl_agent) = nixl_agent {
// layout.nixl_register(nixl_agent, None)?;
// }
// return Ok(layout.into());
// }
// if let Some(allocator) = config.allocator {
// let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
// if let Some(nixl_agent) = nixl_agent {
// layout.nixl_register(nixl_agent, None)?;
// }
// return Ok(layout.into());
// }
// anyhow::bail!("failed to create layout");
// }
#[expect(clippy::type_complexity)]
pub(crate) fn create_block_pool<S: Storage, L: LocalityProvider, M: BlockMetadata>(
factory: impl IntoBlocks<S, L>,
resources: &Resources,
pool_name: &str,
) -> Result<(Arc<dyn BlockPool<S, L, M>>, Vec<Block<S, L, M>>)> {
let pool = ManagedBlockPool::<S, L, M>::builder()
.cancel_token(resources.cancellation_token.clone())
.global_registry(resources.global_registry.clone())
.async_runtime(resources.async_rt_handle.clone())
.event_manager(resources.event_manager.clone())
.pool_metrics(resources.metrics.pool(pool_name))
.build()?;
anyhow::bail!("failed to create layout");
let blocks = factory.into_blocks()?;
Ok((Arc::new(pool), blocks))
}
#[expect(clippy::type_complexity, clippy::too_many_arguments)]
fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
layout: Arc<dyn NixlLayout<StorageType = S>>,
block_set_idx: usize,
cancellation_token: CancellationToken,
worker_id: WorkerID,
global_registry: GlobalRegistry,
async_runtime: Handle,
pool_metrics: Arc<PoolMetrics>,
event_manager: Option<Arc<dyn EventManager>>,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let event_manager = event_manager.unwrap_or_else(|| NullEventManager::new());
let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token)
.global_registry(global_registry)
.async_runtime(async_runtime)
.pool_metrics(pool_metrics)
.event_manager(event_manager)
.build()?;
Ok((pool, blocks))
}
// Block state operations moved to block.rs for better organization and private field access
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
/// The local block factories for the block manager
///
/// This struct will construct the factories in a consistent order and can be
/// used as an intermediate step before creating the block pools.
///
/// This is useful for debugging and for testing.
#[derive(Dissolve)]
pub struct LocalBlockDataFactories {
block_set: NixlBlockSet,
disk_factory: Option<LocalBlockDataFactory<DiskStorage>>,
host_factory: Option<LocalBlockDataFactory<PinnedStorage>>,
device_factory: Option<LocalBlockDataFactory<DeviceStorage>>,
}
impl LocalBlockDataFactories {
/// Construct the local block factories
pub fn new(resources: &mut Resources) -> Result<Self> {
let mut block_set = NixlBlockSet::new(resources.worker_id);
let mut next_block_set_idx = 0;
let layout_builder = resources.layout_builder();
let device_factory = if let Some(config) = resources.config.device_layout.take() {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
} else {
None
};
let host_factory = if let Some(config) = resources.config.host_layout.take() {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
} else {
None
};
let disk_factory = if let Some(config) = resources.config.disk_layout.take() {
if resources.nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
None
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
}
} else {
None
};
Ok(Self {
block_set,
disk_factory,
host_factory,
device_factory,
})
}
}
fn create_layout<S: Storage + NixlRegisterableStorage>(
mut builder: LayoutConfigBuilder,
config: KvManagerLayoutConfig<S>,
nixl_agent: Option<&NixlAgent>,
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> {
let layout = builder.num_blocks(config.num_blocks).build()?;
if let Some(_logical) = config.logical {
return Err(anyhow::anyhow!(
"Logical layouts are not supported by the local builder"
));
}
if let Some(storage) = config.storage {
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(layout.into());
}
if let Some(allocator) = config.allocator {
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(layout.into());
}
anyhow::bail!("failed to create layout");
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::block_manager::{block::factory::logical::LogicalBlockFactory, storage::StorageType};
/// The local block factories for the block manager
///
/// This struct will construct the factories in a consistent order and can be
/// used as an intermediate step before creating the block pools.
///
/// This is useful for debugging and for testing.
#[derive(Dissolve)]
pub struct LogicalBlockFactories<R: LogicalResources> {
disk_factory: Option<LogicalBlockFactory<DiskStorage, R>>,
host_factory: Option<LogicalBlockFactory<PinnedStorage, R>>,
device_factory: Option<LogicalBlockFactory<DeviceStorage, R>>,
}
impl<R: LogicalResources> LogicalBlockFactories<R> {
/// Construct the local block factories
pub fn new(resources: &mut Resources, logical_resources: R) -> Result<Self> {
let mut next_block_set_idx = 0;
let layout_builder = resources.layout_builder();
let logical_resources = Arc::new(logical_resources);
let device_factory = if let Some(config) = resources.config.device_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Device(0),
);
Some(factory)
} else {
None
};
let host_factory = if let Some(config) = resources.config.host_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Pinned,
);
Some(factory)
} else {
None
};
let disk_factory = if let Some(config) = resources.config.disk_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Disk(0),
);
Some(factory)
} else {
None
};
Ok(Self {
disk_factory,
host_factory,
device_factory,
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
impl Resources {
/// Create a new [`Resources`] instance
pub fn new(config: KvBlockManagerConfig) -> Result<Self> {
config
.runtime
.validate()
.context("Validating runtime config")?;
config.model.validate().context("Validating model config")?;
let worker_id = config.runtime.worker_id;
let cancellation_token = config.runtime.cancellation_token.clone();
let global_registry = GlobalRegistry::default();
let metrics = BlockManagerMetrics::new(&config.runtime.metrics_registry)?;
let event_manager = config
.event_manager
.clone()
.unwrap_or_else(|| NullEventManager::new());
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
let nixl_agent = Arc::new(match &config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends");
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") {
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
} else {
tracing::warn!("No UCX plugin found; will not create UCX backend");
}
if config.disk_layout.is_some() {
if let Ok((_, gds_mt_params)) = agent.get_plugin_params("GDS_MT") {
let backend = agent.create_backend("GDS_MT", &gds_mt_params)?;
nixl_backends.insert("GDS_MT".to_string(), Arc::new(backend));
} else {
tracing::warn!("No GDS_MT plugin found; will not create GDS_MT backend");
}
}
Some(agent)
}
NixlOptions::EnabledWithAgent(agent) => Some(agent.clone()),
NixlOptions::Disabled => None,
});
let async_rt_handle = match &config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
Ok(Self {
worker_id,
cancellation_token,
async_rt_handle,
nixl_agent,
nixl_backends,
global_registry,
event_manager,
metrics,
config,
})
}
/// Create a new [`LayoutConfigBuilder`] with the model configuration
pub fn layout_builder(&self) -> LayoutConfigBuilder {
let mut layout_builder = LayoutConfig::builder();
let model = &self.config.model;
layout_builder
.num_layers(model.num_layers)
.outer_dim(model.outer_dim)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype_width_bytes(model.dtype_width_bytes);
layout_builder
}
}
......@@ -77,14 +77,15 @@
//! - [`StorageMemset`] - Memory initialization operations
//! - [`StorageAllocator`] - Factory for creating storage instances
pub mod arena;
pub mod cuda;
pub mod disk;
pub mod nixl;
pub mod arena;
pub mod torch;
pub use cuda::*;
pub use disk::*;
use torch::*;
use std::{
alloc::{alloc_zeroed, dealloc, Layout},
......@@ -100,7 +101,7 @@ use thiserror::Error;
pub type StorageResult<T> = std::result::Result<T, StorageError>;
/// Represents the type of storage used for a block
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum StorageType {
/// System memory
System,
......@@ -112,7 +113,7 @@ pub enum StorageType {
Pinned,
/// Disk memory
Disk,
Disk(u64),
/// Remote memory accessible through NIXL
Nixl,
......@@ -193,6 +194,14 @@ pub trait Storage: Debug + Send + Sync + 'static {
unsafe fn as_mut_ptr(&mut self) -> *mut u8;
}
pub trait StorageTypeProvider {
type StorageType: Storage;
fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<Self::StorageType>()
}
}
/// Extension trait for storage types that support memory setting operations
pub trait StorageMemset: Storage {
/// Sets a region of memory to a specific value
......@@ -524,3 +533,41 @@ pub mod tests {
}
}
}
// Comment out Nixl-related code for now
/*
pub trait NixlDescriptor: Storage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable>;
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable>;
}
impl NixlDescriptor for SystemStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
impl NixlDescriptor for PinnedStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
impl NixlDescriptor for DeviceStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
*/
......@@ -303,6 +303,17 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator {
}
}
/// An enum indicating the type of device storage.
/// This is needed to ensure ownership of memory is correctly handled.
/// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that
/// the torch tensor is not GCed until the [`DeviceStorage`] is dropped.
/// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`]
#[derive(Debug)]
enum DeviceStorageType {
Owned, // Memory that we allocated ourselves.
Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor.
}
/// CUDA device memory storage
#[derive(Debug)]
pub struct DeviceStorage {
......@@ -310,6 +321,7 @@ pub struct DeviceStorage {
size: usize,
ctx: Arc<CudaContext>,
handles: RegistrationHandles,
_storage_type: DeviceStorageType,
}
impl Local for DeviceStorage {}
......@@ -326,6 +338,35 @@ impl DeviceStorage {
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Owned,
})
}
pub fn new_from_torch(
ctx: &Arc<CudaContext>,
tensor: Arc<dyn TorchTensor>,
) -> Result<Self, StorageError> {
let device = tensor.device();
let TorchDevice::Cuda(device_id) = device else {
return Err(StorageError::InvalidConfig("Tensor is not CUDA!".into()));
};
if device_id != ctx.cu_device() as usize {
return Err(StorageError::InvalidConfig(
"Tensor is not on the same device as the context!".into(),
));
}
let data_ptr = tensor.data_ptr();
let size = tensor.size_bytes();
Ok(Self {
ptr: data_ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Torch { _tensor: tensor },
})
}
......@@ -366,7 +407,14 @@ impl CudaContextProivder for DeviceStorage {
impl Drop for DeviceStorage {
fn drop(&mut self) {
self.handles.release();
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap();
match &self._storage_type {
DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
}
DeviceStorageType::Torch { _tensor } => {
// Do nothing. The torch storage is resposible for cleaning up itself.
}
}
}
}
......@@ -419,3 +467,100 @@ impl StorageAllocator<DeviceStorage> for DeviceAllocator {
DeviceStorage::new(&self.ctx, size)
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockTensor {
device: TorchDevice,
data_ptr: u64,
size_bytes: usize,
}
impl MockTensor {
pub fn new(device: TorchDevice, data_ptr: u64, size_bytes: usize) -> Self {
Self {
device,
data_ptr,
size_bytes,
}
}
}
impl TorchTensor for MockTensor {
fn device(&self) -> TorchDevice {
self.device.clone()
}
fn data_ptr(&self) -> u64 {
self.data_ptr
}
fn size_bytes(&self) -> usize {
self.size_bytes
}
fn shape(&self) -> Vec<usize> {
vec![self.size_bytes]
}
fn stride(&self) -> Vec<usize> {
vec![1]
}
}
#[test]
fn test_device_storage_from_torch_valid_tensor() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage =
std::mem::ManuallyDrop::new(DeviceStorage::new(&ctx, size_bytes).unwrap());
let tensor = MockTensor::new(TorchDevice::Cuda(0), actual_storage.addr(), size_bytes);
let storage = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor)).unwrap();
assert_eq!(storage.size(), size_bytes);
assert_eq!(storage.storage_type(), StorageType::Device(0));
assert_eq!(storage.addr(), actual_storage.addr());
}
#[test]
fn test_device_storage_from_torch_cpu_tensor_fails() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(
TorchDevice::Other("cpu".to_string()),
actual_storage.addr(),
size_bytes,
);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
if let Err(StorageError::InvalidConfig(msg)) = result {
assert!(msg.contains("Tensor is not CUDA"));
} else {
panic!("Expected InvalidConfig error for CPU tensor");
}
}
#[test]
fn test_device_storage_wrong_device() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(TorchDevice::Cuda(1), actual_storage.addr(), size_bytes);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
}
}
......@@ -17,16 +17,21 @@ use super::*;
use core::ffi::c_char;
use nix::fcntl::{fallocate, FallocateFlags};
use nix::unistd::unlink;
use std::ffi::CStr;
use std::ffi::CString;
use std::fs::File;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::path::Path;
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
#[derive(Debug)]
pub struct DiskStorage {
file: File,
fd: u64,
file_name: String,
size: usize,
handles: RegistrationHandles,
unlinked: bool,
}
impl Local for DiskStorage {}
......@@ -37,7 +42,17 @@ impl DiskStorage {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let template = CString::new("/tmp/dynamo-kvbm-disk-cache-XXXXXX").unwrap();
let specified_dir =
std::env::var(DISK_CACHE_KEY).unwrap_or_else(|_| DEFAULT_DISK_CACHE_DIR.to_string());
let file_path = Path::new(&specified_dir).join("dynamo-kvbm-disk-cache-XXXXXX");
if !file_path.exists() {
std::fs::create_dir_all(file_path.parent().unwrap()).unwrap();
}
tracing::debug!("Allocating disk cache file at {}", file_path.display());
let template = CString::new(file_path.to_str().unwrap()).unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let raw_fd = unsafe {
......@@ -50,45 +65,63 @@ impl DiskStorage {
)
};
let file = unsafe { File::from_raw_fd(raw_fd) };
let file_name = String::from_utf8_lossy(&template_bytes)
.trim_end_matches("\0")
let file_name = CStr::from_bytes_with_nul(template_bytes.as_slice())
.unwrap()
.to_str()
.map_err(|e| {
StorageError::AllocationFailed(format!("Failed to read temp file name: {}", e))
})?
.to_string();
file.set_len(size as u64).map_err(|_| {
StorageError::AllocationFailed("Failed to set temp file size".to_string())
})?;
// File::set_len() only updates the metadata of the file, it does not allocate the underlying storage.
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(file.as_raw_fd(), FallocateFlags::empty(), 0, size as i64).map_err(|_| {
StorageError::AllocationFailed("Failed to allocate temp file".to_string())
fallocate(raw_fd, FallocateFlags::empty(), 0, size as i64).map_err(|e| {
StorageError::AllocationFailed(format!("Failed to allocate temp file: {}", e))
})?;
Ok(Self {
file,
fd: raw_fd as u64,
file_name,
size,
handles: RegistrationHandles::new(),
unlinked: false,
})
}
pub fn fd(&self) -> u64 {
self.file.as_raw_fd() as u64
self.fd
}
/// Unlink our temp file.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
pub fn unlink(&mut self) -> Result<(), StorageError> {
if self.unlinked {
return Ok(());
}
self.unlinked = true;
unlink(self.file_name.as_str()).map_err(|e| {
StorageError::AllocationFailed(format!("Failed to unlink temp file: {}", e))
})
}
pub fn unlinked(&self) -> bool {
self.unlinked
}
}
impl Drop for DiskStorage {
// TODO: How robust is this actually?
fn drop(&mut self) {
self.handles.release();
std::fs::remove_file(self.file_name.clone()).unwrap();
let _ = self.unlink();
}
}
impl Storage for DiskStorage {
fn storage_type(&self) -> StorageType {
StorageType::Disk
StorageType::Disk(self.fd())
}
fn addr(&self) -> u64 {
......
......@@ -156,7 +156,7 @@ impl StorageType {
StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown,
StorageType::Disk => MemType::File,
StorageType::Disk(_) => MemType::File,
}
}
}
......@@ -169,6 +169,15 @@ impl RegistationHandle for NixlRegistrationHandle {
}
}
fn handle_nixl_register<S: NixlRegisterableStorage>(
storage: &mut S,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(storage, opt_args)?);
storage.register("nixl", handle)
}
/// Extension to the [`RegisterableStorage`] trait for NIXL-compatible storage.
pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized {
/// Register the storage with the NIXL agent.
......@@ -177,9 +186,7 @@ pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(self, opt_args)?);
// Assuming PinnedStorage has `handles: RegistrationHandles`
self.register("nixl", handle)
handle_nixl_register(self, agent, opt_args)
}
/// Check if the storage is registered with the NIXL agent.
......@@ -379,7 +386,23 @@ impl NixlDescriptor for DeviceStorage {
}
impl NixlAccessible for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
if self.unlinked() {
return Err(StorageError::AllocationFailed(
"Disk storage has already been unlinked. GDS registration will fail.".to_string(),
));
}
handle_nixl_register(self, agent, opt_args)?;
self.unlink()?;
Ok(())
}
}
impl MemoryRegion for DiskStorage {
unsafe fn as_ptr(&self) -> *const u8 {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
......@@ -22,6 +22,7 @@ where
}
/// A generic recorder for events that streams directly to a JSONL file
#[derive(Debug)]
pub struct Recorder<T> {
/// A sender for events that can be cloned and shared with producers
event_tx: mpsc::Sender<T>,
......
......@@ -87,6 +87,12 @@ impl From<&[Token]> for Tokens {
}
}
impl From<Vec<usize>> for Tokens {
fn from(tokens: Vec<usize>) -> Self {
Tokens(tokens.into_iter().map(|t| t as u32).collect())
}
}
impl From<Vec<i32>> for Tokens {
/// Converts `Vec<i32>` to `Tokens`, casting each `i32` to `u32`.
fn from(tokens: Vec<i32>) -> Self {
......@@ -460,6 +466,11 @@ impl TokenBlock {
pub fn parent_sequence_hash(&self) -> Option<SequenceHash> {
self.parent_sequence_hash
}
/// Returns the number of tokens in the block.
pub fn block_size(&self) -> usize {
self.tokens.0.len()
}
}
/// Represents a sequence of tokens, segmented into fixed-size, hashed blocks.
......@@ -481,6 +492,7 @@ pub struct TokenBlockSequence {
blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock,
salt_hash: SaltHash,
block_size: usize,
}
impl TokenBlockSequence {
......@@ -507,6 +519,7 @@ impl TokenBlockSequence {
blocks,
current_block,
salt_hash,
block_size: block_size as usize,
}
}
......@@ -545,14 +558,12 @@ impl TokenBlockSequence {
tokens_to_append = self.current_block.push_tokens(available_tokens);
// Check if the current block *became* full after pushing tokens
if self.current_block.remaining() == 0 && !tokens_to_append.is_empty() {
if self.current_block.remaining() == 0 {
// If it became full AND there are still more tokens to append,
// commit it now so the next loop iteration starts with a fresh block.
let new_block = self.current_block.commit()?;
self.blocks.push(new_block);
}
// If it became full and there are NO more tokens, the loop will exit,
// and the block remains partial but full, ready for the next append/commit.
}
let end_block_index = self.blocks.len();
......@@ -708,6 +719,13 @@ impl TokenBlockSequence {
self.truncate(len)
}
/// Resets the sequence to the initial state.
pub fn reset(&mut self) {
self.blocks.clear();
self.current_block =
PartialTokenBlock::create_sequence_root(self.block_size as u32, self.salt_hash);
}
/// Removes the last token from the sequence and returns it, or [`None`] if it is empty.
///
/// This operation is analogous to `Vec::pop`.
......@@ -779,6 +797,11 @@ impl TokenBlockSequence {
(self.blocks, self.current_block)
}
/// Returns the block size used for this sequence.
pub fn block_size(&self) -> usize {
self.block_size
}
/// Returns the [`SaltHash`] used for this sequence.
pub fn salt_hash(&self) -> SaltHash {
self.salt_hash
......@@ -791,6 +814,38 @@ impl TokenBlockSequence {
(self.blocks.len() * block_size) + self.current_block.len()
}
/// Extract the token with the range
pub fn tokens_at(&self, range: Range<usize>) -> Tokens {
let total = self.total_tokens();
// Validate range - return empty tokens for invalid ranges
if range.start > range.end || range.end > total {
return Tokens::default();
}
// Handle empty range
if range.is_empty() {
return Tokens::default();
}
let mut result = Vec::with_capacity(range.len());
for i in range {
if i < self.blocks.len() * self.block_size {
// Token is in a completed block
let block_index = i / self.block_size;
let token_index = i % self.block_size;
result.push(self.blocks[block_index].tokens()[token_index]);
} else {
// Token is in the current partial block
let current_block_index = i - (self.blocks.len() * self.block_size);
result.push(self.current_block.tokens()[current_block_index]);
}
}
Tokens::from(result)
}
/// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block.
///
/// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally.
......@@ -857,6 +912,7 @@ impl TokenBlockSequence {
blocks,
current_block,
salt_hash,
block_size: block_size as usize,
}
}
}
......@@ -1109,6 +1165,15 @@ mod tests {
Some(SEQ_HASH_5_8)
);
// Test tokens_at across blocks and partial block
assert_eq!(seq_multi.tokens_at(0..4).as_ref(), &[1, 2, 3, 4]); // First complete block
assert_eq!(seq_multi.tokens_at(4..8).as_ref(), &[5, 6, 7, 8]); // Second complete block
assert_eq!(seq_multi.tokens_at(8..9).as_ref(), &[9]); // Current partial block
assert_eq!(seq_multi.tokens_at(2..6).as_ref(), &[3, 4, 5, 6]); // Spanning blocks
assert_eq!(seq_multi.tokens_at(6..9).as_ref(), &[7, 8, 9]); // Spanning to partial
assert_eq!(seq_multi.tokens_at(5..5).as_ref(), &[0u32; 0]); // Empty range
assert_eq!(seq_multi.tokens_at(10..15).as_ref(), &[0u32; 0]); // Out of bounds
// No salt hash
let seq_no_salt = create_test_sequence(&[1, 2, 3, 4, 5], 4, None);
assert_eq!(seq_no_salt.salt_hash(), 0);
......@@ -1142,22 +1207,22 @@ mod tests {
assert_eq!(sequence.current_block().tokens.as_ref(), &[9, 10, 11]);
// Append token 12 - should complete block 2 (index 2)
// This will also commit block 2
let completed_idx = sequence.append(12).unwrap();
assert_eq!(completed_idx, None); // Lazy commit: extend returns None
assert_eq!(sequence.blocks().len(), 2); // Block 2 not added yet
assert_eq!(sequence.current_block.tokens.as_ref(), &[9, 10, 11, 12]); // Current block is now full
assert_eq!(sequence.current_block.remaining(), 0);
assert_eq!(completed_idx, Some(2));
assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.current_block.tokens.as_ref(), &[0u32; 0]);
assert_eq!(sequence.current_block.remaining(), 4);
assert_eq!(
sequence.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
Some(SEQ_HASH_9_12)
); // Still linked to block 1
// Append token 13 - should not complete a block
// NOW appending 13 should first commit block 2, then add 13 to the new current
let completed_idx_13 = sequence.append(13).unwrap();
assert_eq!(completed_idx_13, Some(2)); // Block 2 (index 2) was completed by this append
assert_eq!(sequence.blocks.len(), 3); // Now 3 blocks committed
assert_eq!(sequence.blocks[2].tokens().as_ref(), &[9, 10, 11, 12]); // Verify committed block 2
assert_eq!(completed_idx_13, None);
assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.blocks[2].tokens().as_ref(), &[9, 10, 11, 12]);
assert_eq!(sequence.blocks[2].sequence_hash(), SEQ_HASH_9_12);
assert_eq!(sequence.current_block.tokens.as_ref(), &[13]); // New current block has 13
assert_eq!(sequence.current_block.remaining(), 3);
......@@ -1180,16 +1245,17 @@ mod tests {
assert_eq!(seq1.blocks.len(), 0);
assert_eq!(seq1.current_block.tokens.as_ref(), &[1, 2]);
assert_eq!(seq1.current_block.remaining(), 2);
assert_eq!(seq1.current_block.parent_sequence_hash, None); // Still the root block
// Case 2: Extend exactly block size
let mut seq2 = create_test_sequence(&[], block_size, salt_hash);
let tokens2 = Tokens::from(vec![1, 2, 3, 4]);
let completed2 = seq2.extend(tokens2).unwrap();
assert_eq!(completed2, None); // Block is full but not committed yet
assert_eq!(seq2.blocks.len(), 0); // No blocks committed
assert_eq!(seq2.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full
assert_eq!(seq2.current_block.remaining(), 0);
assert_eq!(seq2.current_block.parent_sequence_hash, None); // Still the root block
assert_eq!(completed2, Some(0..1));
assert_eq!(seq2.blocks.len(), 1);
assert_eq!(seq2.current_block.tokens.as_ref(), &[0u32; 0]); // Current block is empty
assert_eq!(seq2.current_block.remaining(), 4);
assert_eq!(seq2.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Still the root block
// Case 3: Extend more than block size, less than two blocks
let mut seq3 = create_test_sequence(&[], block_size, salt_hash);
......@@ -1206,13 +1272,13 @@ mod tests {
let mut seq4 = create_test_sequence(&[], block_size, salt_hash);
let tokens4 = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
let completed4 = seq4.extend(tokens4).unwrap();
assert_eq!(completed4, Some(0..1)); // Only block 0 is committed
assert_eq!(seq4.blocks.len(), 1); // Only 1 block committed
assert_eq!(seq4.current_block.tokens.as_ref(), &[5, 6, 7, 8]); // Current block holds the second block's tokens
assert_eq!(seq4.current_block.remaining(), 0); // Current block is full
assert_eq!(completed4, Some(0..2)); // Only block 0 is committed
assert_eq!(seq4.blocks.len(), 2); // Only 1 block committed
assert_eq!(seq4.current_block.tokens.as_ref(), &[0u32; 0]);
assert_eq!(seq4.current_block.remaining(), 4);
assert_eq!(seq4.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq4.blocks[0].sequence_hash(), SEQ_HASH_1_4);
assert_eq!(seq4.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Parent is the first block
assert_eq!(seq4.current_block.parent_sequence_hash, Some(SEQ_HASH_5_8)); // Parent is the first block
// Case 5: Extend multiple times, completing blocks across calls
let mut seq5 = create_test_sequence(&[], block_size, salt_hash);
......@@ -1252,12 +1318,18 @@ mod tests {
let mut seq7 = create_test_sequence(&[1, 2], block_size, salt_hash);
let tokens7 = Tokens::from(vec![3, 4]);
let completed7 = seq7.extend(tokens7).unwrap();
assert_eq!(completed7, None); // Block is full but not committed yet
assert_eq!(seq7.blocks.len(), 0);
assert_eq!(seq7.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full
assert_eq!(seq7.current_block.remaining(), 0);
assert_eq!(completed7, Some(0..1)); // Block is full but not committed yet
assert_eq!(seq7.blocks.len(), 1);
assert_eq!(seq7.current_block.tokens.as_ref(), &[0u32; 0]); // Current block is full
assert_eq!(seq7.current_block.remaining(), 4);
assert_eq!(seq7.total_tokens(), 4);
assert_eq!(seq7.current_block.parent_sequence_hash, None); // Still the root block
assert_eq!(seq7.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Still the root block
// Test tokens_at extraction
assert_eq!(seq7.tokens_at(0..2).as_ref(), &[1, 2]);
assert_eq!(seq7.tokens_at(1..3).as_ref(), &[2, 3]);
assert_eq!(seq7.tokens_at(0..4).as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq7.tokens_at(2..2).as_ref(), &[0u32; 0]); // Empty range
}
#[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment