"ml/vscode:/vscode.git/clone" did not exist on "73d6a82cce18f84ff5c67148783224cf25b30b32"
Unverified Commit 07cfc3a1 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm + connector (#2258)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
Co-authored-by: default avatarOlga Andreeva <oandreeva@nvidia.com>
Co-authored-by: default avatarZiqi Fan <ziqif@nvidia.com>
Co-authored-by: default avatarJohn Thompson <jothomson@nvidia.com>
Co-authored-by: default avatarRichard Huo <rihuo@nvidia.com>
Co-authored-by: default avatarZicheng Ma <zichengm@nvidia.com>
parent bf5862a1
...@@ -38,21 +38,24 @@ ...@@ -38,21 +38,24 @@
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers. //! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller. //! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use nixl_sys::NixlDescriptor;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc; use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{ use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy}, locality::LocalityProvider,
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, TransferContext, transfer::{TransferContext, WriteTo, WriteToStrategy},
WritableBlock, BlockDataProvider, BlockDataProviderMut, BlockError, BlockMetadata, BlockState, ImmutableBlock,
MutableBlock, ReadableBlock, WritableBlock,
}; };
use crate::block_manager::pool::BlockPoolError; use crate::block_manager::metrics::PoolMetrics;
use crate::block_manager::pool::{BlockPool, BlockPoolError};
use crate::block_manager::storage::{Local, Storage}; use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -62,26 +65,33 @@ use super::BlockResult; ...@@ -62,26 +65,33 @@ use super::BlockResult;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle; use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
const BLOCKS_BW_MIN_PUBLISH_INTERVAL_MS: u64 = 50;
/// Manage a set of pending transfers. /// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> { pub struct PendingTransfer<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
> {
/// The block being copied from. /// The block being copied from.
sources: Vec<Arc<MutableBlock<Source, Metadata>>>, sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
/// The block being copied to. /// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>, targets: Vec<MutableBlock<Target, Locality, Metadata>>,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete. /// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
/// The target pool that will receive the registered block. /// The target pool that will receive the registered block.
target_pool: Arc<BlockPool<Target, Metadata>>, target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
} }
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
PendingTransfer<Source, Target, Metadata> PendingTransfer<Source, Target, Locality, Metadata>
{ {
pub fn new( pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>, sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
targets: Vec<MutableBlock<Target, Metadata>>, targets: Vec<MutableBlock<Target, Locality, Metadata>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
target_pool: Arc<BlockPool<Target, Metadata>>, target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
) -> Self { ) -> Self {
assert_eq!(sources.len(), targets.len()); assert_eq!(sources.len(), targets.len());
Self { Self {
...@@ -92,7 +102,7 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -92,7 +102,7 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
} }
} }
fn handle_complete(self) -> Result<()> { async fn handle_complete(self) -> Result<()> {
let Self { let Self {
sources, sources,
mut targets, mut targets,
...@@ -105,7 +115,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -105,7 +115,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
transfer_metadata(source, target)?; transfer_metadata(source, target)?;
} }
let blocks = target_pool.register_blocks_blocking(targets)?; let blocks = target_pool.register_blocks(targets).await?;
tracing::debug!("Transfer complete. Registered {} blocks.", blocks.len());
if let Some(completion_indicator) = completion_indicator { if let Some(completion_indicator) = completion_indicator {
completion_indicator completion_indicator
...@@ -117,9 +129,14 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -117,9 +129,14 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
} }
} }
fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>( fn transfer_metadata<
source: &Arc<MutableBlock<Source, Metadata>>, Source: Storage,
target: &mut MutableBlock<Target, Metadata>, Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
>(
source: &ImmutableBlock<Source, Locality, Metadata>,
target: &mut MutableBlock<Target, Locality, Metadata>,
) -> Result<()> { ) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail. // Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle, _) = source.state() { if let BlockState::Registered(reg_handle, _) = source.state() {
...@@ -139,136 +156,118 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>( ...@@ -139,136 +156,118 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
} }
#[async_trait] #[async_trait]
pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata>: pub trait TransferManager<
Send + Sync Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
>: Send + Sync
{ {
/// Begin a transfer. Blocks if the pending queue is full. /// Begin a transfer. Blocks if the pending queue is full.
async fn enqueue_transfer( async fn enqueue_transfer(
&self, &self,
pending_transfer: PendingTransfer<Source, Target, Metadata>, pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()>; ) -> Result<()>;
} }
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> { struct TransferCompletionManager<
pending_transfer_q: mpsc::Sender<( Source: Storage,
PendingTransfer<Source, Target, Metadata>, Target: Storage,
tokio::sync::oneshot::Receiver<()>, Locality: LocalityProvider,
)>, Metadata: BlockMetadata,
transfer_ctx: Arc<TransferContext>, > {
pool_metrics: Arc<PoolMetrics>,
transfer_type: String,
last_publish_time: Option<Instant>,
transfer_start: Instant,
num_blocks_transferred: usize,
_phantom: PhantomData<(Source, Target, Locality, Metadata)>,
} }
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata> TransferCompletionManager<Source, Target, Locality, Metadata>
{ {
pub fn new( pub fn new(pool_metrics: Arc<PoolMetrics>, transfer_type: String) -> Self {
transfer_ctx: Arc<TransferContext>, Self {
max_concurrent_transfers: usize, pool_metrics,
runtime: &Handle, transfer_type,
cancellation_token: CancellationToken, last_publish_time: None,
) -> Result<Self> { transfer_start: Instant::now(),
let (tx, mut rx) = mpsc::channel::<( num_blocks_transferred: 0,
PendingTransfer<Source, Target, Metadata>, _phantom: PhantomData,
tokio::sync::oneshot::Receiver<()>, }
)>(max_concurrent_transfers); }
CriticalTaskExecutionHandle::new_with_runtime( pub async fn handle_complete(
move |cancel_token| async move { &mut self,
loop { pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
tokio::select! { ) -> Result<()> {
Some((pending_transfer, notify)) = rx.recv() => { self.num_blocks_transferred += pending_transfer.sources.len();
// Wait for the event.
notify.await.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Only finalize the transfer after the event is signaled.
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}
_ = cancel_token.cancelled() => { let should_publish = self.last_publish_time.is_none_or(|last_publish_time| {
return Ok(()); last_publish_time.elapsed() > Duration::from_millis(BLOCKS_BW_MIN_PUBLISH_INTERVAL_MS)
} });
}
}
},
cancellation_token.clone(),
"Cuda Transfer Manager",
runtime,
)?
.detach();
Ok(Self { if should_publish {
pending_transfer_q: tx, self.last_publish_time = Some(Instant::now());
transfer_ctx, let duration = self.transfer_start.elapsed();
}) let blocks_per_sec = self.num_blocks_transferred as f64 / duration.as_secs_f64();
}
}
#[async_trait] self.pool_metrics
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata> .gauge(self.transfer_type.as_str())
for CudaTransferManager<Source, Target, Metadata> .set(blocks_per_sec as i64);
where }
Source: Storage,
Target: Storage, match pending_transfer.handle_complete().await {
Metadata: BlockMetadata, Ok(_) => {}
// Check that the source block is readable, local, and writable to the target block. Err(e) => {
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source> // The only case where this can fail is if the progress engine is being shutdown.
+ Local // This is not a problem, so we can just ignore it.
+ WriteToStrategy<MutableBlock<Target, Metadata>>, tracing::warn!("Error handling transfer completion: {:?}", e);
// Check that the target block is writable. }
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>, }
{
async fn enqueue_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
let notify = pending_transfer
.sources
.write_to(
&mut pending_transfer.targets,
true,
self.transfer_ctx.clone(),
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q
.send((pending_transfer, notify))
.await?;
Ok(()) Ok(())
} }
} }
pub struct DiskTransferManager { type TransferFuture<Source, Target, Locality, Metadata> = Pin<
futures_tx: mpsc::Sender<Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync>>>, Box<
dyn std::future::Future<Output = PendingTransfer<Source, Target, Locality, Metadata>>
+ Send
+ Sync,
>,
>;
pub struct LocalTransferManager<
Source: Storage,
Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata,
> {
futures_tx: mpsc::Sender<TransferFuture<Source, Target, Locality, Metadata>>,
transfer_ctx: Arc<TransferContext>, transfer_ctx: Arc<TransferContext>,
} }
impl DiskTransferManager { impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
LocalTransferManager<Source, Target, Locality, Metadata>
{
pub fn new( pub fn new(
transfer_ctx: Arc<TransferContext>, transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize, max_concurrent_transfers: usize,
runtime: &Handle, runtime: &Handle,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
pool_metrics: Arc<PoolMetrics>,
transfer_type: String,
) -> Result<Self> { ) -> Result<Self> {
let (futures_tx, mut futures_rx) = mpsc::channel(1); let (futures_tx, mut futures_rx) = mpsc::channel(1);
let mut completion_manager =
TransferCompletionManager::new(pool_metrics.clone(), transfer_type.clone());
CriticalTaskExecutionHandle::new_with_runtime( CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move { move |cancel_token| async move {
// Keep track of our pending transfers. let mut pending_transfers: FuturesUnordered<TransferFuture<Source, Target, Locality, Metadata>> = FuturesUnordered::new();
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new();
loop { loop {
tokio::select! { tokio::select! {
...@@ -279,19 +278,23 @@ impl DiskTransferManager { ...@@ -279,19 +278,23 @@ impl DiskTransferManager {
Some(future) = futures_rx.recv() => { Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity. // If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_concurrent_transfers { while pending_transfers.len() >= max_concurrent_transfers {
pending_transfers.next().await; if let Some(pending_transfer) = pending_transfers.next().await {
completion_manager.handle_complete(pending_transfer).await?;
} else {
break;
}
} }
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future); pending_transfers.push(future);
} }
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => { Some(pending_transfer) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more completion_manager.handle_complete(pending_transfer).await?;
} }
} }
} }
}, },
cancellation_token.clone(), cancellation_token.clone(),
"Disk Transfer Manager", "Local Transfer Manager",
runtime, runtime,
)? )?
.detach(); .detach();
...@@ -304,45 +307,34 @@ impl DiskTransferManager { ...@@ -304,45 +307,34 @@ impl DiskTransferManager {
} }
#[async_trait] #[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata> for DiskTransferManager impl<Source, Target, Locality, Metadata> TransferManager<Source, Target, Locality, Metadata>
for LocalTransferManager<Source, Target, Locality, Metadata>
where where
Source: Storage, Source: Storage + NixlDescriptor,
Target: Storage, Target: Storage + NixlDescriptor,
Locality: LocalityProvider,
Metadata: BlockMetadata, Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block. // Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source> ImmutableBlock<Source, Locality, Metadata>: ReadableBlock<StorageType = Source>
+ Local + Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>, + WriteToStrategy<MutableBlock<Target, Locality, Metadata>>,
// Check that the target block is writable. // Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>, MutableBlock<Target, Locality, Metadata>: WritableBlock<StorageType = Target>,
// Check that the source and target blocks have the same locality.
ImmutableBlock<Source, Locality, Metadata>: BlockDataProvider<Locality = Locality>,
MutableBlock<Target, Locality, Metadata>: BlockDataProviderMut<Locality = Locality>,
{ {
async fn enqueue_transfer( async fn enqueue_transfer(
&self, &self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()> { ) -> Result<()> {
let notify = pending_transfer let notify = pending_transfer
.sources .sources
.write_to( .write_to(&mut pending_transfer.targets, self.transfer_ctx.clone())?;
&mut pending_transfer.targets,
true,
self.transfer_ctx.clone(),
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
let completion_future = async move { let completion_future = async move {
let _ = notify.await; let _ = notify.await;
match pending_transfer.handle_complete() { pending_transfer
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}; };
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`, // Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
...@@ -354,26 +346,29 @@ where ...@@ -354,26 +346,29 @@ where
} }
/// A transfer manager that enforces a max batch size for transfers. /// A transfer manager that enforces a max batch size for transfers.
pub struct TransferBatcher<Source, Target, Metadata, Manager> pub struct TransferBatcher<Source, Target, Locality, Metadata, Manager>
where where
Source: Storage, Source: Storage,
Target: Storage, Target: Storage,
Locality: LocalityProvider,
Metadata: BlockMetadata, Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>, Manager: TransferManager<Source, Target, Locality, Metadata>,
{ {
transfer_manager: Manager, transfer_manager: Manager,
max_transfer_batch_size: usize, max_transfer_batch_size: usize,
runtime: Handle, runtime: Handle,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
_phantom: PhantomData<(Source, Target, Metadata)>, _phantom: PhantomData<(Source, Target, Locality, Metadata)>,
} }
impl<Source, Target, Metadata, Manager> TransferBatcher<Source, Target, Metadata, Manager> impl<Source, Target, Locality, Metadata, Manager>
TransferBatcher<Source, Target, Locality, Metadata, Manager>
where where
Source: Storage, Source: Storage,
Target: Storage, Target: Storage,
Metadata: BlockMetadata, Locality: LocalityProvider + 'static,
Manager: TransferManager<Source, Target, Metadata>, Metadata: BlockMetadata + 'static,
Manager: TransferManager<Source, Target, Locality, Metadata> + 'static,
{ {
pub fn new( pub fn new(
transfer_manager: Manager, transfer_manager: Manager,
...@@ -392,17 +387,19 @@ where ...@@ -392,17 +387,19 @@ where
} }
#[async_trait] #[async_trait]
impl<Source, Target, Metadata, Manager> TransferManager<Source, Target, Metadata> impl<Source, Target, Locality, Metadata, Manager>
for TransferBatcher<Source, Target, Metadata, Manager> TransferManager<Source, Target, Locality, Metadata>
for TransferBatcher<Source, Target, Locality, Metadata, Manager>
where where
Source: Storage, Source: Storage + 'static,
Target: Storage, Target: Storage + 'static,
Locality: LocalityProvider + 'static,
Metadata: BlockMetadata, Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>, Manager: TransferManager<Source, Target, Locality, Metadata>,
{ {
async fn enqueue_transfer( async fn enqueue_transfer(
&self, &self,
pending_transfer: PendingTransfer<Source, Target, Metadata>, pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
) -> Result<()> { ) -> Result<()> {
// If it's smaller than the max batch size, just enqueue it. // If it's smaller than the max batch size, just enqueue it.
if pending_transfer.sources.len() < self.max_transfer_batch_size { if pending_transfer.sources.len() < self.max_transfer_batch_size {
...@@ -462,7 +459,7 @@ where ...@@ -462,7 +459,7 @@ where
Ok(result) => result, Ok(result) => result,
Err(e) => { Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e); tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap(); let _ = completion_indicator.send(Err(e));
return Ok(()); return Ok(());
} }
}; };
...@@ -472,7 +469,7 @@ where ...@@ -472,7 +469,7 @@ where
} }
// Send the final results to the top-level completion indicator. // Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results))?; let _ = completion_indicator.send(Ok(results));
Ok(()) Ok(())
}, },
......
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
use std::cmp::Ordering; use std::cmp::Ordering;
use std::sync::Weak; use std::sync::Weak;
use tokio::sync::oneshot;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock}; use crate::block_manager::block::{
locality::LocalityProvider, BlockMetadata, ImmutableBlock, MutableBlock,
};
use crate::block_manager::pool::BlockPoolError; use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage; use crate::block_manager::storage::Storage;
...@@ -46,53 +49,65 @@ impl Ord for OffloadRequestKey { ...@@ -46,53 +49,65 @@ impl Ord for OffloadRequestKey {
/// Data needed to offload a block. /// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it. /// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed. /// This way, we don't prevent the block from being reused if needed.
pub struct OffloadRequest<S: Storage, M: BlockMetadata> { pub struct OffloadRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
pub key: OffloadRequestKey, pub key: OffloadRequestKey,
pub block: Weak<MutableBlock<S, M>>, pub block: Weak<MutableBlock<S, L, M>>,
pub sequence_hash: u64, pub sequence_hash: u64,
} }
impl<S: Storage, M: BlockMetadata> PartialOrd for OffloadRequest<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PartialOrd for OffloadRequest<S, L, M> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other)) Some(self.cmp(other))
} }
} }
/// Order offload requests by priority, high to low. /// Order offload requests by priority, high to low.
impl<S: Storage, M: BlockMetadata> Ord for OffloadRequest<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Ord for OffloadRequest<S, L, M> {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key) self.key.cmp(&other.key)
} }
} }
/// Equality is based on sequence hash, priority, and location. /// Equality is based on sequence hash, priority, and location.
impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PartialEq for OffloadRequest<S, L, M> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.key == other.key self.key == other.key
} }
} }
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {} impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Eq for OffloadRequest<S, L, M> {}
pub type BlockResult<Target, Metadata> = pub type BlockResult<Target, Locality, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>; Result<Vec<ImmutableBlock<Target, Locality, Metadata>>, BlockPoolError>;
pub type ResponseSender<Target, Locality, Metadata> =
oneshot::Sender<Result<Vec<ImmutableBlock<Target, Locality, Metadata>>, BlockPoolError>>;
/// Data needed for onboarding. /// Data needed for onboarding.
/// Unlike offloading, we need a means to return the resulting blocks to the caller. /// Unlike offloading, we need a means to return the resulting blocks to the caller.
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> { pub struct OnboardRequest<
pub blocks: Vec<ImmutableBlock<Source, M>>, Source: Storage,
pub response_tx: Target: Storage,
oneshot::Sender<std::result::Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>, Locality: LocalityProvider,
M: BlockMetadata,
> {
pub blocks: Vec<ImmutableBlock<Source, Locality, M>>,
pub response_tx: ResponseSender<Target, Locality, M>,
pub targets: Option<Vec<MutableBlock<Target, Locality, M>>>,
} }
impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, Target, M> { impl<Source: Storage, Target: Storage, Locality: LocalityProvider, M: BlockMetadata>
OnboardRequest<Source, Target, Locality, M>
{
pub fn new( pub fn new(
blocks: Vec<ImmutableBlock<Source, M>>, blocks: Vec<ImmutableBlock<Source, Locality, M>>,
response_tx: oneshot::Sender<Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>, response_tx: ResponseSender<Target, Locality, M>,
targets: Option<Vec<MutableBlock<Target, Locality, M>>>,
) -> Self { ) -> Self {
Self { Self {
blocks, blocks,
response_tx, response_tx,
targets,
} }
} }
} }
......
...@@ -13,81 +13,89 @@ ...@@ -13,81 +13,89 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! # KV Cache Block Pool Management pub mod managed;
//! pub use managed::ManagedBlockPool;
//! This module provides the primary [`BlockPool`] structure for managing KV cache blocks.
//! It orchestrates the allocation, registration, and reuse of blocks by coordinating
//! between an [`ActiveBlockPool`] and an [`InactiveBlockPool`].
//!
//! ## Core Components:
//!
//! - **[`BlockPool`]**: The main entry point for interacting with the block management system.
//! It holds the shared state containing both active and inactive pools.
//! - **[`ActiveBlockPool`]**: Manages blocks that are currently associated with active sequences.
//! It primarily uses weak references to track these blocks, allowing them to be potentially
//! reclaimed by the inactive pool if no strong references remain.
//! - **[`InactiveBlockPool`]**: Manages blocks that are not currently in active use. It supports
//! block reuse by matching sequence hashes and employs a priority-based eviction strategy
//! for acquiring free blocks.
//! - **[`BlockRegistry`]**: Manages the registration of blocks that have transitioned from the
//! Complete to Registered state.
//! - **[`MutableBlock`]**: Represents a uniquely owned block, typically obtained from allocation.
//! It allows modification and is returned to the inactive pool upon being dropped.
//! - **[`ImmutableBlock`]**: Represents a shared, immutable reference to a block, usually after
//! it has been registered or matched. Ensures that multiple sequences can reference the
//! same underlying block data.
//!
//! ## Workflow:
//!
//! 1. Blocks are initially added to the [`BlockPool`] via [`BlockPool::add_blocks`], populating the
//! [`InactiveBlockPool`].
//! 2. Sequences request blocks via [`BlockPool::allocate_blocks`], which attempts to acquire them
//! from the [`InactiveBlockPool`]. This returns [`MutableBlock`]s.
//! 3. Once a [`MutableBlock`] is filled and ready, it's registered using [`BlockPool::register_block`].
//! This process checks the both the [`ActiveBlockPool`] and the [`InactiveBlockPool`] for existing blocks
//! with the same content hash. It returns an [`ImmutableBlock`] representing the canonical block
//! (either the one provided or an existing one).
//! 4. Sequences can also try to reuse blocks directly using [`BlockPool::match_sequence_hash`], which
//! checks both the active and inactive pools.
//! 5. When an [`ImmutableBlock`] is no longer needed by any sequence (its `Arc` count drops to zero),
//! the underlying [`MutableBlock`] (if it still exists via the weak reference in the active pool)
//! can eventually be returned to the [`InactiveBlockPool`] when its final strong reference (the `Arc`
//! within `ImmutableBlock`) is dropped.
//! 6. Dropped [`MutableBlock`]s are automatically returned to the [`InactiveBlockPool`].
mod active;
mod inactive;
mod priority_key;
mod state;
use active::ActiveBlockPool;
use derive_builder::Builder; use derive_builder::Builder;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use inactive::InactiveBlockPool; use serde::{Deserialize, Serialize};
use priority_key::PriorityKey;
pub use super::block::{ImmutableBlock, MutableBlock}; pub use super::block::{ImmutableBlock, MutableBlock};
use super::block::{ use super::block::{
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata, nixl::short_type_name, private, registry::BlockRegistry, Block, BlockError, BlockMetadata,
GlobalRegistry, GlobalRegistry, MaybeReturnableBlock,
}; };
use super::events::{EventManager, NullEventManager}; use super::events::{EventManager, NullEventManager};
use super::metrics::{BlockManagerMetrics, PoolMetrics}; use super::metrics::{BlockManagerMetrics, PoolMetrics};
use super::storage::Storage; use super::storage::Storage;
use crate::block_manager::block::locality::LocalityProvider;
use crate::block_manager::CacheLevel;
use crate::tokens::{SequenceHash, TokenBlock}; use crate::tokens::{SequenceHash, TokenBlock};
use async_trait::async_trait;
use prometheus::Registry; use prometheus::Registry;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{ use std::{
collections::{BTreeSet, HashMap, VecDeque}, collections::{BTreeSet, HashMap, VecDeque},
sync::{Arc, Weak}, sync::{Arc, Weak},
}; };
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result; use dynamo_runtime::Result;
// Type aliases to reduce complexity across the module
type BlockPoolResult<T> = Result<T, BlockPoolError>;
type AsyncResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>;
// Collection type aliases
pub type MutableBlocks<S, L, M> = Vec<MutableBlock<S, L, M>>;
pub type ImmutableBlocks<S, L, M> = Vec<ImmutableBlock<S, L, M>>;
/// Enum representing either a mutable or immutable block that can be returned to the pool
#[derive(Debug)]
pub enum OwnedBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
Mutable(MutableBlock<S, L, M>),
Immutable(ImmutableBlock<S, L, M>),
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
for OwnedBlock<S, L, M>
{
fn is_returnable(&self) -> bool {
match self {
OwnedBlock::Mutable(block) => block.is_returnable(),
OwnedBlock::Immutable(block) => block.is_returnable(),
}
}
fn try_take_block(self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
match self {
OwnedBlock::Mutable(block) => block.try_take_block(token),
OwnedBlock::Immutable(block) => block.try_take_block(token),
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> From<MutableBlock<S, L, M>>
for OwnedBlock<S, L, M>
{
fn from(block: MutableBlock<S, L, M>) -> Self {
OwnedBlock::Mutable(block)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> From<ImmutableBlock<S, L, M>>
for OwnedBlock<S, L, M>
{
fn from(block: ImmutableBlock<S, L, M>) -> Self {
OwnedBlock::Immutable(block)
}
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum BlockPoolError { pub enum BlockPoolError {
#[error("Block is not complete")] #[error("Block is not complete")]
...@@ -107,74 +115,47 @@ pub enum BlockPoolError { ...@@ -107,74 +115,47 @@ pub enum BlockPoolError {
#[error(transparent)] #[error(transparent)]
BlockError(#[from] BlockError), BlockError(#[from] BlockError),
}
#[derive(Builder, Dissolve)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default = "NullEventManager::new()")]
event_manager: Arc<dyn EventManager>,
#[builder(default = "CancellationToken::new()")] #[error("Reset error: {0}")]
cancel_token: CancellationToken, ResetError(String),
#[builder(default)] #[error("Block is not returnable")]
blocks: Vec<Block<S, M>>, NotReturnable,
#[builder(default)] #[error("Unsupported cache level: {0:?}")]
global_registry: GlobalRegistry, UnsupportedCacheLevel(CacheLevel),
#[builder(default = "Handle::current()")] #[error("No blocks to register")]
async_runtime: Handle, NoBlocksToRegister,
#[builder(
default = "BlockManagerMetrics::new(&Arc::new(Registry::new())).unwrap().pool(\"pool\")"
)]
pool_metrics: Arc<PoolMetrics>,
} }
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> { #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> { pub enum BlockRegistrationDuplicationSetting {
let args = self.build_internal()?; /// On registration, if duplication is allowed, blocks with duplicate hashes cannot be registered directly,
let (event_manager, cancel_token, blocks, global_registry, async_runtime, metrics) = /// but instead can be held live with a strong arc to the primary block. This maintains the lifetime of
args.dissolve(); /// the duplicate block.
Allowed,
tracing::info!("building block pool");
let pool = BlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
Ok(pool)
}
}
/// Manages the blocks in a specific storage backenda
pub struct BlockPool<S: Storage, M: BlockMetadata> {
priority_tx: tokio::sync::mpsc::UnboundedSender<PriorityRequest<S, M>>,
ctrl_tx: tokio::sync::mpsc::UnboundedSender<ControlRequest<S, M>>,
}
impl<S: Storage, M: BlockMetadata> Clone for BlockPool<S, M> { /// On registration, if duplication is disabled, blocks with duplicate hashes will be returned immediately
fn clone(&self) -> Self { /// to the inactive pool and the primary block, the one first registered, will be returned to the caller,
Self { /// replacing the duplicate block.
priority_tx: self.priority_tx.clone(), ///
ctrl_tx: self.ctrl_tx.clone(), /// Note: If block duplication is disabled, then the implementation must always respect the fact that the
} /// mutable block that was registered, may not be the same block returned by the registration function, and
} /// thus be able to update any references that wish to use the block after registration.
Disabled,
} }
/// Generic request-response pattern for background task communication
#[derive(Dissolve)] #[derive(Dissolve)]
struct Unary<Req, Resp> { pub struct RequestResponse<Req, Resp> {
request: Req, pub request: Req,
response_tx: oneshot::Sender<Resp>, pub response_tx: oneshot::Sender<Resp>,
} }
impl<Req, Resp> Unary<Req, Resp> { impl<Req, Resp> RequestResponse<Req, Resp> {
fn make_request(request: Req) -> (Self, oneshot::Receiver<Resp>) { /// Create a new request-response pair
pub fn new(request: Req) -> (Self, oneshot::Receiver<Resp>) {
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
( (
Self { Self {
...@@ -186,119 +167,11 @@ impl<Req, Resp> Unary<Req, Resp> { ...@@ -186,119 +167,11 @@ impl<Req, Resp> Unary<Req, Resp> {
} }
} }
type UnaryResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>; #[async_trait]
pub trait BlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata>:
type ImmutableBlocksResult<S, M> = Result<Vec<ImmutableBlock<S, M>>, BlockPoolError>; BlockPoolController + AsyncBlockPoolController + Send + Sync
{
pub type MutableBlocks<S, M> = Vec<MutableBlock<S, M>>; /// Add a vector of [`Block`]s to the pool.
pub type ImmutableBlocks<S, M> = Vec<ImmutableBlock<S, M>>;
enum PriorityRequest<S: Storage, M: BlockMetadata> {
AllocateBlocks(Unary<usize, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>),
RegisterBlocks(Unary<MutableBlocks<S, M>, Result<ImmutableBlocks<S, M>, BlockPoolError>>),
MatchSequenceHashes(Unary<Vec<SequenceHash>, Vec<ImmutableBlock<S, M>>>),
}
enum ControlRequest<S: Storage, M: BlockMetadata> {
AddBlocks(Unary<Vec<Block<S, M>>, ()>),
}
impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
pub fn builder() -> BlockPoolArgsBuilder<S, M> {
BlockPoolArgsBuilder::default()
}
/// Creates a new [`BlockPool`] with the given [`EventManager`].
///
/// The pool starts empty and requires blocks to be added via [`add_blocks`].
///
/// # Arguments
///
/// * `event_manager` - An [`Arc<dyn EventManager>`] used for publishing block registration/removal events.
///
/// # Returns
///
/// A new [`BlockPool`] instance.
fn new(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (pool, progress_engine) = Self::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// tracing::debug!("starting progress engine");
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// });
let thread_name = format!("block-pool-{}", short_type_name::<S>());
std::thread::Builder::new()
.name(thread_name)
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for block pool progress engine");
runtime.block_on(async move {
let mut progress_engine = progress_engine;
tracing::debug!("starting progress engine");
while progress_engine.step().await {
tracing::trace!("progress engine step");
}
});
})
.expect("Failed to spawn block pool progress engine thread");
pool
}
fn with_progress_engine(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> (Self, ProgressEngine<S, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine = ProgressEngine::<S, M>::new(
event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
(
Self {
priority_tx,
ctrl_tx,
},
progress_engine,
)
}
/// Adds a vector of [`Block`]s to the [`InactiveBlockPool`].
/// ///
/// These blocks are typically created from a [`super::block::Blocks`] /// These blocks are typically created from a [`super::block::Blocks`]
/// and represent the initial set of available cache blocks. /// and represent the initial set of available cache blocks.
...@@ -307,38 +180,12 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> { ...@@ -307,38 +180,12 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
/// # Arguments /// # Arguments
/// ///
/// * `blocks` - A [`Vec<Block<S, M>>`] to add to the inactive pool. /// * `blocks` - A [`Vec<Block<S, M>>`] to add to the inactive pool.
#[expect(dead_code)] async fn add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> BlockPoolResult<()>;
pub(crate) async fn add_blocks(&self, blocks: Vec<Block<S, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Blocking version of [`BlockPool::add_blocks`]. /// Blocking version of [`BlockPool::add_blocks`].
pub(crate) fn add_blocks_blocking( fn add_blocks_blocking(&self, blocks: Vec<Block<S, L, M>>) -> BlockPoolResult<()>;
&self,
blocks: Vec<Block<S, M>>,
) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn _add_blocks(&self, blocks: Vec<Block<S, M>>) -> UnaryResponse<()> {
let (req, resp_rx) = Unary::<_, ()>::make_request(blocks);
self.ctrl_tx
.send(ControlRequest::AddBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
/// Attempts to allocate a specified number of free blocks from the [`InactiveBlockPool`]. /// Allocate a specified number of free blocks from the pool.
///
/// Blocks acquired this way are returned as [`MutableBlock`]s, granting unique ownership
/// and allowing modification. Dropping a [`MutableBlock`] automatically returns it
/// to the [`InactiveBlockPool`].
/// ///
/// # Arguments /// # Arguments
/// ///
...@@ -349,633 +196,122 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> { ...@@ -349,633 +196,122 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
/// A [`Result`] containing: /// A [`Result`] containing:
/// - `Ok(Vec<MutableBlock<S, M>>)`: If successful, a vector of allocated mutable blocks. /// - `Ok(Vec<MutableBlock<S, M>>)`: If successful, a vector of allocated mutable blocks.
/// - `Err(BlockPoolError)`: If not enough blocks are available in the inactive pool. /// - `Err(BlockPoolError)`: If not enough blocks are available in the inactive pool.
pub async fn allocate_blocks( async fn allocate_blocks(&self, count: usize) -> BlockPoolResult<MutableBlocks<S, L, M>>;
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Blocking version of [`BlockPool::allocate_blocks`]. /// Blocking version of [`BlockPool::allocate_blocks`].
pub fn allocate_blocks_blocking( fn allocate_blocks_blocking(&self, count: usize) -> BlockPoolResult<MutableBlocks<S, L, M>>;
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn _allocate_blocks( /// Register a vector of [`MutableBlock`]s with the pool.
async fn register_blocks(
&self, &self,
count: usize, blocks: Vec<MutableBlock<S, L, M>>,
) -> UnaryResponse<Result<Vec<MutableBlock<S, M>>, BlockPoolError>> { ) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
// Create the request
let (req, resp_rx) =
Unary::<_, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>::make_request(count);
// Issue the request
self.priority_tx
.send(PriorityRequest::AllocateBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Registers a vector of [`MutableBlock`]s (presumably after filling them) with the pool,
/// making them available for sharing via the [`ActiveBlockPool`].
///
/// This function checks if any of the blocks have the same sequence hash as an existing block
/// in the active pool. If so, it returns an [`ImmutableBlock`] pointing to the existing block,
/// and the provided `block` is implicitly dropped (returned to the [`InactiveBlockPool`]).
pub async fn register_blocks(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> ImmutableBlocksResult<S, M> {
self._register_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Blocking version of [`BlockPool::register_blocks`]. /// Blocking version of [`BlockPool::register_blocks`].
pub fn register_blocks_blocking( fn register_blocks_blocking(
&self, &self,
blocks: Vec<MutableBlock<S, M>>, blocks: Vec<MutableBlock<S, L, M>>,
) -> ImmutableBlocksResult<S, M> { ) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
self._register_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn _register_blocks( /// Match a set of [`SequenceHash`]s to existing blocks in the pool.
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> UnaryResponse<ImmutableBlocksResult<S, M>> {
// Make the request
let (req, resp_rx) = Unary::<_, ImmutableBlocksResult<S, M>>::make_request(blocks);
// Issue the request
self.priority_tx
.send(PriorityRequest::RegisterBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Attempts to match the given [`SequenceHash`] to an existing block, checking
/// both the active and inactive pools.
///
/// Checks the [`ActiveBlockPool`] first. If a valid strong reference exists, it returns
/// an [`ImmutableBlock`] cloned from it. If the weak reference exists but is stale,
/// it's removed.
///
/// If not found in the active pool, it checks the [`InactiveBlockPool`]. If found there,
/// the block is moved to the active pool (tracked by a weak reference) and returned
/// as a new [`ImmutableBlock`].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `sequence_hash` - The [`SequenceHash`] to look for. /// * `sequence_hashes` - A [`Vec<SequenceHash>`] to match.
/// ///
/// # Returns /// # Returns
/// ///
/// An [`Option<ImmutableBlock<S, M>>`] containing the shared block if found, otherwise `None`. /// An [`Option<ImmutableBlock<S, M>>`] containing the shared block if found, otherwise `None`.
pub async fn match_sequence_hashes( async fn match_sequence_hashes(
&self, &self,
sequence_hashes: &[SequenceHash], sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> { ) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
self._match_sequence_hashes(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Blocking version of [`BlockPool::match_sequence_hashes`]. /// Blocking version of [`BlockPool::match_sequence_hashes`].
pub fn match_sequence_hashes_blocking( fn match_sequence_hashes_blocking(
&self, &self,
sequence_hashes: &[SequenceHash], sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> { ) -> BlockPoolResult<ImmutableBlocks<S, L, M>>;
self._match_sequence_hashes(sequence_hashes)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn _match_sequence_hashes( /// Touch a set of blocks. Equivalent to registering and then immediately dropping.
&self, async fn touch_blocks(&self, sequence_hashes: &[SequenceHash]) -> BlockPoolResult<()>;
sequence_hashes: &[SequenceHash],
) -> UnaryResponse<Vec<ImmutableBlock<S, M>>> {
// Create the request
let (req, resp_rx) =
Unary::<_, Vec<ImmutableBlock<S, M>>>::make_request(sequence_hashes.into());
// Issue the request
self.priority_tx
.send(PriorityRequest::MatchSequenceHashes(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
}
struct State<S: Storage, M: BlockMetadata> { /// Blocking version of [`BlockPool::touch_blocks`].
active: ActiveBlockPool<S, M>, fn touch_blocks_blocking(&self, sequence_hashes: &[SequenceHash]) -> BlockPoolResult<()>;
inactive: InactiveBlockPool<S, M>,
registry: BlockRegistry,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
event_manager: Arc<dyn EventManager>,
metrics: Arc<PoolMetrics>,
}
struct ProgressEngine<S: Storage, M: BlockMetadata> { /// Attempt to return a block to the pool. Blocks will naturally be returned to the pool when they are dropped
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>, /// and their reference count drops to 0; however, for testing and to synchronize the block returning to the
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>, /// pool, this function can be used.
cancel_token: CancellationToken, async fn try_return_block(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()>;
state: State<S, M>,
return_rx: tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
metrics: Arc<PoolMetrics>,
}
#[cfg(test)] /// Blocking version of [`BlockPool::try_return_block`].
mod tests { fn try_return_block_blocking(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()>;
use super::super::block::{BasicMetadata, Blocks};
use super::super::layout::{tests::setup_layout, FullyContiguous, LayoutConfig};
use super::*;
use crate::block_manager::block::BlockExt;
use crate::block_manager::DType;
use crate::tokens::{TokenBlockSequence, Tokens};
use crate::block_manager::storage::tests::{NullDeviceAllocator, NullDeviceStorage};
/// Helper method to build a [`BlockPool`] with a [`ProgressEngine`] for unit testing
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
fn build_with_progress_engine(
self,
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks, global_registry, async_runtime, metrics) =
args.dissolve();
let (pool, progress_engine) = BlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
Ok((pool, progress_engine))
}
}
#[tokio::test] fn total_blocks(&self) -> u64;
async fn test_block_pool_state() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (_pool, mut progress) = BlockPool::builder() fn available_blocks(&self) -> u64;
.blocks(blocks) }
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
drop(blocks);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
let mut blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
let mut block = blocks.pop().unwrap();
block.init_sequence(1337).unwrap();
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
assert!(block.add_token(5).is_err());
}
#[tokio::test]
async fn test_block_pool() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (pool, mut progress) = BlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let pool_clone = pool.clone();
let allocate_1_block =
tokio::spawn(async move { pool_clone.allocate_blocks(1).await.unwrap() });
progress.step().await;
let blocks = allocate_1_block.await.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
// drop the single block
drop(blocks);
// check before and after the progress engine step
assert_eq!(progress.state.inactive.available_blocks(), 6);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
}
#[test]
fn test_block_pool_blocking() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the BlockPool and add the blocks
let pool = BlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash().unwrap(), sequence_hash);
// Dropping the immutable block should return the block to the pool
// However, the block should remain in the BlockPool as an inactive block until it is reused
// or promoted back to an immutable block by being matched with a sequence hash
drop(block);
// Get the list of ImmutableBlocks that match the sequence hash
let matched = pool
.match_sequence_hashes_blocking(&[sequence_hash])
.unwrap();
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].sequence_hash().unwrap(), sequence_hash);
}
async fn create_blocks<S: Storage, M: BlockMetadata>(
pool: &BlockPool<S, M>,
num_blocks: usize,
) -> anyhow::Result<(Vec<ImmutableBlock<S, M>>, Vec<SequenceHash>)> {
let tokens = vec![0; num_blocks * 4];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), num_blocks);
let mut sequence_hashes = Vec::new();
let mut mutable_blocks = Vec::new();
for token_block in token_blocks.blocks().iter() {
let mut block = pool.allocate_blocks(1).await?.pop().unwrap();
block.apply_token_block(token_block.clone())?;
sequence_hashes.push(block.sequence_hash().unwrap());
mutable_blocks.push(block);
}
let immutable_blocks = pool.register_blocks(mutable_blocks).await?;
Ok((immutable_blocks, sequence_hashes))
}
async fn make_simple_pool(
num_blocks: usize,
) -> anyhow::Result<BlockPool<NullDeviceStorage, BasicMetadata>> {
let config = LayoutConfig {
num_blocks,
num_layers: 1,
outer_dim: 1,
page_size: 4,
inner_dim: 1024,
alignment: 1,
dtype: DType::FP16,
};
let layout = FullyContiguous::<NullDeviceStorage>::allocate(config, &NullDeviceAllocator)?;
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)?.into_blocks()?;
let pool = BlockPool::builder().blocks(blocks).build()?;
Ok(pool)
}
/// A test that ensures that we only ever evict leaves from the inactive pool.
#[tokio::test]
async fn test_block_pool_evict_leaves() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 block. This should evict the leaf of our allocated sequence.
pool.allocate_blocks(1).await?;
// The leaf should be evicted, so we should have 3 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 3);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should get the previously allocated block, as well as one more leaf.
pool.allocate_blocks(2).await.unwrap();
// The next leaf should be evicted, so we should have 2 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 2);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// If we allocate all the blocks, the entire remaining sequence should be evicted.
let blocks = pool.allocate_blocks(4).await?;
assert_eq!(blocks.len(), 4);
Ok(())
}
/// When a block has two children, we need to ensure that we evict both children before
/// adding the parent to the leaf set.
#[tokio::test]
async fn test_block_pool_parent_child() -> anyhow::Result<()> {
let pool = make_simple_pool(3).await?;
let tokens = vec![1, 2, 3, 4, 5];
let sequence = TokenBlockSequence::new(Tokens::from(tokens.clone()), 4, None);
// Create a root block, with two child blocks.
let mut root_block = pool.allocate_blocks(1).await?.pop().unwrap();
root_block.apply_token_block(sequence.blocks().first().unwrap().clone())?;
let root_block_hash = root_block.sequence_hash().unwrap();
let mut child_blocks = Vec::new();
let mut child_block_hashes = Vec::new();
for i in 0..2 {
// Create a new token sequence using the common prefix.
let mut tokens = tokens.clone();
for _ in 0..4 {
tokens.push(i);
}
let seq = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
// Allocate and apply the suffix to the child block.
let mut child_block = pool.allocate_blocks(1).await?.pop().unwrap();
child_block.apply_token_block(seq.blocks()[1].clone())?;
child_block_hashes.push(child_block.sequence_hash().unwrap());
child_blocks.push(child_block);
}
// Register the children first. This can happen with offloading.
let child_blocks = pool.register_blocks(child_blocks).await?;
// After the children are registered, we can register the root block.
let root_block = pool.register_blocks(vec![root_block]).await?;
// Drop both of them. /// State of the pool when queried.
drop(root_block); ///
drop(child_blocks); /// Provides a snapshot of the pool's current state including:
/// - Active blocks currently in use
/// - Inactive blocks ordered by reuse priority
/// - Number of empty blocks
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct BlockPoolStatus {
/// Active blocks currently in use
pub active_blocks: usize,
/// Inactive blocks ordered by reuse priority
/// Blocks at the front of the list are more likely to be reused
pub inactive_blocks: usize,
/// Number of empty blocks
pub empty_blocks: usize,
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await; #[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct ResetBlocksResponse {
/// Blocks that were reset
pub reset_blocks: Vec<SequenceHash>,
// Allocate two new blocks, which should evict both children. /// Blocks that were not found in the pool
pool.allocate_blocks(2).await?; pub not_found: Vec<SequenceHash>,
// Now, the root block should be the only block left. /// Blocks that were not reset
for child_block_hash in child_block_hashes { pub not_reset: Vec<SequenceHash>,
let matched = pool.match_sequence_hashes(&[child_block_hash]).await?; }
assert_eq!(matched.len(), 0);
}
// Check that the root block remains.
let matched = pool.match_sequence_hashes(&[root_block_hash]).await?;
assert_eq!(matched.len(), 1);
Ok(()) pub trait BlockPoolController: Send + Sync {
} /// Returns the [`BlockPoolStatus`] of the pool.
fn status_blocking(&self) -> Result<BlockPoolStatus, BlockPoolError>;
/// When offloading, it's possible that the tail of a sequence in a pool is evicted before /// Resets the pool to its initial state.
/// the entire sequence can be offloaded. This can happen in the following case:
/// ///
/// Assume a sequence of 4 blocks: [0, 1, 2, 3] /// This function will error unless all blocks have returned to the inactive pool.
/// 1. Blocks 0, 1, and 2 are offloaded to host memory. fn reset_blocking(&self) -> Result<(), BlockPoolError>;
/// 2. Block 2 is evicted from the host.
/// 3. Block 3 is offloaded to host memory.
/// Now, the contents of the cache are [0, 1] and [3].
/// We need to treat these as two separate sequences.
#[tokio::test]
async fn test_block_pool_fragmentation() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let tokens = vec![0; 16];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), 4);
let mut sequence_hashes = Vec::new();
// Allocate and register the first 3 blocks.
for block in token_blocks.blocks()[..3].iter() {
let mut mutable_block = pool.allocate_blocks(1).await?.pop().unwrap();
mutable_block.apply_token_block(block.clone())?;
sequence_hashes.push(mutable_block.sequence_hash()?);
let _ = pool.register_blocks(vec![mutable_block]).await?;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should take the remaining uninitialized block as well as the
// tail of the currently registered sequence.
let _ = pool.allocate_blocks(2).await?;
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
2
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 more block for the leaf of the sequence.
let mut mutable_block = pool.allocate_blocks(1).await?.into_iter().next().unwrap();
mutable_block.apply_token_block(token_blocks.blocks()[3].clone())?;
let _ = pool.register_blocks(vec![mutable_block]).await?; /// Attempt to reset a set of blocks.
fn reset_blocks_blocking(
// We should still only match the first 2 blocks, since the 3rd block has been evicted. &self,
assert_eq!( sequence_hashes: &[SequenceHash],
pool.match_sequence_hashes(sequence_hashes.as_slice()) ) -> Result<ResetBlocksResponse, BlockPoolError>;
.await? }
.len(),
2
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Now, we should be able to allocate all 4 blocks.
let _ = pool.allocate_blocks(4).await?;
Ok(())
}
/// Matching an entire sequence (moving it to the active pool), and returning it
/// should not affect the parent-child relationships of the blocks.
#[tokio::test]
async fn test_block_pool_match_return() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// We match the root of the sequence (moving it to the active pool), then
// immediately return it.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
1
);
let _alloc_blocks1 = pool.allocate_blocks(3).await?;
// Allocating 3 blocks should evict all but the root of the sequence.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
1
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _alloc_blocks2 = pool.allocate_blocks(1).await?;
// Now, allocating one more block should evict the root.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
0
);
Ok(())
}
/// When we move a suffix of a sequence to the active pool (like what happens when onboarding),
/// then return it to the inactive pool, we need to ensure that the parent-child relationships
/// are still correct, and that the temporary leaf in the inactive pool can't be evicted.
#[tokio::test]
async fn test_block_pool_match_partial() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// Assert that all 4 blocks are in the pool.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
4
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Now, we match only the last 2 blocks
let matched_suffix = pool.match_sequence_hashes(&sequence_hashes[2..]).await?;
assert_eq!(matched_suffix.len(), 2);
// This allocation should fail. Although there are 2 inactive blocks, the leaf is in the active pool.
let new_alloc_block = pool.allocate_blocks(1).await?;
assert_eq!(new_alloc_block.len(), 0);
// Now, drop the leaf, and return it to the inactive pool.
drop(matched_suffix);
// All 4 blocks should still be in the pool. #[async_trait::async_trait]
assert_eq!( pub trait AsyncBlockPoolController: Send + Sync {
pool.match_sequence_hashes(sequence_hashes.as_slice()) /// Returns the [`BlockPoolStatus`] of the pool.
.await? async fn status(&self) -> Result<BlockPoolStatus, BlockPoolError>;
.len(),
4
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await; /// Resets the pool to its initial state.
///
/// This function will error unless all blocks have returned to the inactive pool.
async fn reset(&self) -> Result<(), BlockPoolError>;
Ok(()) /// Attempt to reset a set of blocks.
} async fn reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError>;
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! # KV Cache Block Pool Management
//!
//! This module provides the primary [`BlockPool`] structure for managing KV cache blocks.
//! It orchestrates the allocation, registration, and reuse of blocks by coordinating
//! between an [`ActiveBlockPool`] and an [`InactiveBlockPool`].
//!
//! ## Core Components:
//!
//! - **[`BlockPool`]**: The main entry point for interacting with the block management system.
//! It holds the shared state containing both active and inactive pools.
//! - **[`ActiveBlockPool`]**: Manages blocks that are currently associated with active sequences.
//! It primarily uses weak references to track these blocks, allowing them to be potentially
//! reclaimed by the inactive pool if no strong references remain.
//! - **[`InactiveBlockPool`]**: Manages blocks that are not currently in active use. It supports
//! block reuse by matching sequence hashes and employs a priority-based eviction strategy
//! for acquiring free blocks.
//! - **[`BlockRegistry`]**: Manages the registration of blocks that have transitioned from the
//! Complete to Registered state.
//! - **[`MutableBlock`]**: Represents a uniquely owned block, typically obtained from allocation.
//! It allows modification and is returned to the inactive pool upon being dropped.
//! - **[`ImmutableBlock`]**: Represents a shared, immutable reference to a block, usually after
//! it has been registered or matched. Ensures that multiple sequences can reference the
//! same underlying block data.
//!
//! ## Workflow:
//!
//! 1. Blocks are initially added to the [`BlockPool`] via [`BlockPool::add_blocks`], populating the
//! [`InactiveBlockPool`].
//! 2. Sequences request blocks via [`BlockPool::allocate_blocks`], which attempts to acquire them
//! from the [`InactiveBlockPool`]. This returns [`MutableBlock`]s.
//! 3. Once a [`MutableBlock`] is filled and ready, it's registered using [`BlockPool::register_block`].
//! This process checks the both the [`ActiveBlockPool`] and the [`InactiveBlockPool`] for existing blocks
//! with the same content hash. It returns an [`ImmutableBlock`] representing the canonical block
//! (either the one provided or an existing one).
//! 4. Sequences can also try to reuse blocks directly using [`BlockPool::match_sequence_hash`], which
//! checks both the active and inactive pools.
//! 5. When an [`ImmutableBlock`] is no longer needed by any sequence (its `Arc` count drops to zero),
//! the underlying [`MutableBlock`] (if it still exists via the weak reference in the active pool)
//! can eventually be returned to the [`InactiveBlockPool`] when its final strong reference (the `Arc`
//! within `ImmutableBlock`) is dropped.
//! 6. Dropped [`MutableBlock`]s are automatically returned to the [`InactiveBlockPool`].
use super::*;
pub mod active;
pub mod controller;
pub mod inactive;
pub mod priority_key;
pub mod state;
use active::ActiveBlockPool;
use inactive::InactiveBlockPool;
#[derive(Builder, Dissolve)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct ManagedBlockPoolArgs<S: Storage, L: LocalityProvider, M: BlockMetadata> {
#[builder(default = "NullEventManager::new()")]
event_manager: Arc<dyn EventManager>,
#[builder(default = "CancellationToken::new()")]
cancel_token: CancellationToken,
#[builder(default)]
blocks: Vec<Block<S, L, M>>,
#[builder(default)]
global_registry: GlobalRegistry,
#[builder(default = "Handle::current()")]
async_runtime: Handle,
#[builder(
default = "BlockManagerMetrics::new(&Arc::new(Registry::new())).unwrap().pool(\"pool\")"
)]
pool_metrics: Arc<PoolMetrics>,
#[builder(default = "BlockRegistrationDuplicationSetting::Disabled")]
default_duplication_setting: BlockRegistrationDuplicationSetting,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuilder<S, L, M> {
pub fn build(self) -> anyhow::Result<ManagedBlockPool<S, L, M>> {
let args = self.build_internal()?;
let (
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
) = args.dissolve();
tracing::info!("building block pool");
let pool = ManagedBlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
Ok(pool)
}
}
// Specific request type aliases for our use cases
type AllocateBlocksReq<S, L, M> = RequestResponse<usize, BlockPoolResult<MutableBlocks<S, L, M>>>;
type RegisterBlocksReq<S, L, M> = RequestResponse<
(MutableBlocks<S, L, M>, BlockRegistrationDuplicationSetting),
BlockPoolResult<ImmutableBlocks<S, L, M>>,
>;
type MatchHashesReq<S, L, M> =
RequestResponse<Vec<SequenceHash>, BlockPoolResult<ImmutableBlocks<S, L, M>>>;
type TouchBlocksReq = RequestResponse<Vec<SequenceHash>, BlockPoolResult<()>>;
type AddBlocksReq<S, L, M> = RequestResponse<Vec<Block<S, L, M>>, ()>;
type ResetReq = RequestResponse<(), BlockPoolResult<()>>;
type ReturnBlockReq<S, L, M> = RequestResponse<Vec<Block<S, L, M>>, BlockPoolResult<()>>;
type StatusReq = RequestResponse<(), BlockPoolResult<BlockPoolStatus>>;
type ResetBlocksReq = RequestResponse<Vec<SequenceHash>, BlockPoolResult<ResetBlocksResponse>>;
// Update the request enums to use the cleaner types
pub enum PriorityRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
AllocateBlocks(AllocateBlocksReq<S, L, M>),
RegisterBlocks(RegisterBlocksReq<S, L, M>),
MatchSequenceHashes(MatchHashesReq<S, L, M>),
TouchBlocks(TouchBlocksReq),
Reset(ResetReq),
ReturnBlock(ReturnBlockReq<S, L, M>),
}
pub enum ControlRequest<S: Storage, L: LocalityProvider, M: BlockMetadata> {
AddBlocks(AddBlocksReq<S, L, M>),
Status(StatusReq),
ResetBlocks(ResetBlocksReq),
}
/// Manages the blocks in a specific storage backenda
pub struct ManagedBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
priority_tx: tokio::sync::mpsc::UnboundedSender<PriorityRequest<S, L, M>>,
ctrl_tx: tokio::sync::mpsc::UnboundedSender<ControlRequest<S, L, M>>,
available_blocks_counter: Arc<AtomicU64>,
total_blocks_counter: Arc<AtomicU64>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Clone for ManagedBlockPool<S, L, M> {
fn clone(&self) -> Self {
Self {
priority_tx: self.priority_tx.clone(),
ctrl_tx: self.ctrl_tx.clone(),
available_blocks_counter: self.available_blocks_counter.clone(),
total_blocks_counter: self.total_blocks_counter.clone(),
default_duplication_setting: self.default_duplication_setting,
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M> {
pub fn builder() -> ManagedBlockPoolArgsBuilder<S, L, M> {
ManagedBlockPoolArgsBuilder::default()
}
/// Creates a new [`ManagedBlockPool`] with the given [`EventManager`].
///
/// The pool starts empty and requires blocks to be added via [`add_blocks`].
///
/// # Arguments
///
/// * `event_manager` - An [`Arc<dyn EventManager>`] used for publishing block registration/removal events.
///
/// # Returns
///
/// A new [`ManagedBlockPool`] instance.
pub fn new(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> Self {
let (pool, progress_engine) = Self::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// tracing::debug!("starting progress engine");
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// });
let thread_name = format!(
"block-pool-{}-{}",
short_type_name::<S>(),
short_type_name::<L>()
);
std::thread::Builder::new()
.name(thread_name)
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for block pool progress engine");
runtime.block_on(async move {
let mut progress_engine = progress_engine;
tracing::debug!("starting progress engine");
while progress_engine.step().await {
tracing::trace!("progress engine step");
}
});
})
.expect("Failed to spawn block pool progress engine thread");
pool
}
fn with_progress_engine(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
default_duplication_setting: BlockRegistrationDuplicationSetting,
) -> (Self, ProgressEngine<S, L, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine = ProgressEngine::<S, L, M>::new(
event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
);
let available_blocks_counter = progress_engine.available_blocks_counter.clone();
let total_blocks_counter = progress_engine.total_blocks_counter.clone();
(
Self {
priority_tx,
ctrl_tx,
available_blocks_counter,
total_blocks_counter,
default_duplication_setting,
},
progress_engine,
)
}
pub fn default_duplication_setting(&self) -> BlockRegistrationDuplicationSetting {
self.default_duplication_setting
}
fn _add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> AsyncResponse<()> {
let (req, resp_rx) = AddBlocksReq::new(blocks);
self.ctrl_tx
.send(ControlRequest::AddBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _allocate_blocks(
&self,
count: usize,
) -> AsyncResponse<BlockPoolResult<Vec<MutableBlock<S, L, M>>>> {
let (req, resp_rx) = AllocateBlocksReq::new(count);
self.priority_tx
.send(PriorityRequest::AllocateBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _register_blocks(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
duplication_setting: BlockRegistrationDuplicationSetting,
) -> AsyncResponse<BlockPoolResult<ImmutableBlocks<S, L, M>>> {
if blocks.is_empty() {
return Err(BlockPoolError::NoBlocksToRegister);
}
let (req, resp_rx) = RegisterBlocksReq::new((blocks, duplication_setting));
self.priority_tx
.send(PriorityRequest::RegisterBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<ImmutableBlocks<S, L, M>>> {
let (req, resp_rx) = MatchHashesReq::new(sequence_hashes.into());
self.priority_tx
.send(PriorityRequest::MatchSequenceHashes(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _touch_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<()>> {
let (req, resp_rx) = TouchBlocksReq::new(sequence_hashes.into());
self.priority_tx
.send(PriorityRequest::TouchBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _reset(&self) -> AsyncResponse<BlockPoolResult<()>> {
let (req, resp_rx) = ResetReq::new(());
self.priority_tx
.send(PriorityRequest::Reset(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _try_return_block(&self, block: OwnedBlock<S, L, M>) -> AsyncResponse<BlockPoolResult<()>> {
let raw_blocks = block
.try_take_block(private::PrivateToken)
.ok_or(BlockPoolError::NotReturnable)?;
let (req, resp_rx) = ReturnBlockReq::new(raw_blocks);
self.priority_tx
.send(PriorityRequest::ReturnBlock(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
}
#[async_trait]
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPool<S, L, M>
for ManagedBlockPool<S, L, M>
{
/// Adds a vector of [`Block`]s to the [`InactiveBlockPool`].
async fn add_blocks(&self, blocks: Vec<Block<S, L, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn add_blocks_blocking(&self, blocks: Vec<Block<S, L, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Attempts to allocate a specified number of free blocks from the [`InactiveBlockPool`].
///
/// Blocks acquired this way are returned as [`MutableBlock`]s, granting unique ownership
/// and allowing modification. Dropping a [`MutableBlock`] automatically returns it
/// to the [`InactiveBlockPool`].
async fn allocate_blocks(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn allocate_blocks_blocking(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Registers a vector of [`MutableBlock`]s (presumably after filling them) with the pool,
/// making them available for sharing via the [`ActiveBlockPool`].
///
/// This function checks if any of the blocks have the same sequence hash as an existing block
/// in the active pool. If so, it returns an [`ImmutableBlock`].
///
/// Note: Depending on the [`BlockRegistrationDuplicationSetting`], the returned [`ImmutableBlock`] may
/// not be the same block that was provided -- that is, it should hold the same content, but was the
/// first block registered. If duplication is allowed, we will keep alive both the primary block and
/// the duplicate block.
async fn register_blocks(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._register_blocks(blocks, self.default_duplication_setting)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn register_blocks_blocking(
&self,
blocks: Vec<MutableBlock<S, L, M>>,
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._register_blocks(blocks, self.default_duplication_setting)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Attempts to match the given [`SequenceHash`] to an existing block, checking
/// both the active and inactive pools.
///
/// Checks the [`ActiveBlockPool`] first. If a valid strong reference exists, it returns
/// an [`ImmutableBlock`] cloned from it. If the weak reference exists but is stale,
/// it's removed.
///
/// If not found in the active pool, it checks the [`InactiveBlockPool`]. If found there,
/// the block is moved to the active pool (tracked by a weak reference) and returned
/// as a new [`ImmutableBlock`].
async fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._match_sequence_hashes(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn match_sequence_hashes_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
self._match_sequence_hashes(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn touch_blocks(&self, sequence_hashes: &[SequenceHash]) -> Result<(), BlockPoolError> {
self._touch_blocks(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn touch_blocks_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<(), BlockPoolError> {
self._touch_blocks(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn try_return_block(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()> {
self._try_return_block(block)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn try_return_block_blocking(&self, block: OwnedBlock<S, L, M>) -> BlockPoolResult<()> {
self._try_return_block(block)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn total_blocks(&self) -> u64 {
self.total_blocks_counter.load(Ordering::Relaxed)
}
fn available_blocks(&self) -> u64 {
self.available_blocks_counter.load(Ordering::Relaxed)
}
}
struct ProgressEngine<S: Storage, L: LocalityProvider, M: BlockMetadata> {
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, L, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, L, M>>,
cancel_token: CancellationToken,
state: State<S, L, M>,
return_rx: tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
metrics: Arc<PoolMetrics>,
available_blocks_counter: Arc<AtomicU64>,
total_blocks_counter: Arc<AtomicU64>,
}
pub struct State<S: Storage, L: LocalityProvider, M: BlockMetadata> {
active: ActiveBlockPool<S, L, M>,
inactive: InactiveBlockPool<S, L, M>,
registry: BlockRegistry,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
event_manager: Arc<dyn EventManager>,
metrics: Arc<PoolMetrics>,
}
impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> ProgressEngine<S, L, M> {
#[allow(clippy::too_many_arguments)]
pub fn new(
event_manager: Arc<dyn EventManager>,
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, L, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, L, M>>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, L, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, L, M>::new(
event_manager,
return_tx,
global_registry,
async_runtime,
metrics.clone(),
);
let count = blocks.len();
tracing::debug!(count, "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
let available_blocks_counter = state.inactive.available_blocks_counter();
let total_blocks_counter = state.inactive.total_blocks_counter();
Self {
priority_rx,
ctrl_rx,
cancel_token,
state,
return_rx,
metrics,
available_blocks_counter,
total_blocks_counter,
}
}
pub async fn step(&mut self) -> bool {
tokio::select! {
biased;
Some(priority_req) = self.priority_rx.recv(), if !self.priority_rx.is_closed() => {
self.metrics.gauge("priority_request_queue_size").set(self.priority_rx.len() as i64);
self.state.handle_priority_request(priority_req, &mut self.return_rx).await;
}
Some(req) = self.ctrl_rx.recv(), if !self.ctrl_rx.is_closed() => {
self.metrics.gauge("control_request_queue_size").set(self.ctrl_rx.len() as i64);
self.state.handle_control_request(req);
}
Some(block) = self.return_rx.recv() => {
self.metrics.gauge("return_block_queue_size").set(self.return_rx.len() as i64);
self.state.handle_return_block(block);
}
_ = self.cancel_token.cancelled() => {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use crate::block_manager::block::{BasicMetadata, Blocks};
use crate::block_manager::layout::{tests::setup_layout, FullyContiguous, LayoutConfig};
use crate::block_manager::locality::Local;
use crate::tokens::{TokenBlockSequence, Tokens};
use crate::block_manager::storage::tests::{NullDeviceAllocator, NullDeviceStorage};
use super::*;
/// Helper method to build a [`ManagedBlockPool`] with a [`ProgressEngine`] for unit testing
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPoolArgsBuilder<S, L, M> {
#[allow(clippy::type_complexity)]
fn build_with_progress_engine(
self,
) -> anyhow::Result<(ManagedBlockPool<S, L, M>, ProgressEngine<S, L, M>)> {
let args = self.build_internal()?;
let (
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
) = args.dissolve();
let (pool, progress_engine) = ManagedBlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
metrics,
default_duplication_setting,
);
Ok((pool, progress_engine))
}
}
#[tokio::test]
async fn test_block_pool_state() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (_pool, mut progress) = ManagedBlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
drop(blocks);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
let mut blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
let mut block = blocks.pop().unwrap();
block.init_sequence(1337).unwrap();
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
assert!(block.add_token(5).is_err());
}
#[tokio::test]
async fn test_block_pool() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (pool, mut progress) = ManagedBlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let pool_clone = pool.clone();
let allocate_1_block =
tokio::spawn(async move { pool_clone.allocate_blocks(1).await.unwrap() });
progress.step().await;
let blocks = allocate_1_block.await.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
// drop the single block
drop(blocks);
// check before and after the progress engine step
assert_eq!(progress.state.inactive.available_blocks(), 6);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
}
#[test]
fn test_block_pool_blocking() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash(), sequence_hash);
// Dropping the immutable block should return the block to the pool
// However, the block should remain in the ManagedBlockPool as an inactive block until it is reused
// or promoted back to an immutable block by being matched with a sequence hash
drop(block);
// Get the list of ImmutableBlocks that match the sequence hash
let matched = pool
.match_sequence_hashes_blocking(&[sequence_hash])
.unwrap();
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].sequence_hash(), sequence_hash);
}
async fn create_blocks<S: Storage, L: LocalityProvider, M: BlockMetadata>(
pool: &ManagedBlockPool<S, L, M>,
num_blocks: usize,
) -> anyhow::Result<(Vec<ImmutableBlock<S, L, M>>, Vec<SequenceHash>)> {
let tokens = vec![0; num_blocks * 4];
let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
assert_eq!(token_blocks.blocks().len(), num_blocks);
let mut sequence_hashes = Vec::new();
let mut mutable_blocks = Vec::new();
for token_block in token_blocks.blocks().iter() {
let mut block = pool.allocate_blocks(1).await?.pop().unwrap();
block.apply_token_block(token_block.clone())?;
sequence_hashes.push(block.sequence_hash().unwrap());
mutable_blocks.push(block);
}
let immutable_blocks = pool.register_blocks(mutable_blocks).await?;
Ok((immutable_blocks, sequence_hashes))
}
async fn make_simple_pool(
num_blocks: usize,
) -> anyhow::Result<
ManagedBlockPool<NullDeviceStorage, crate::block_manager::locality::Local, BasicMetadata>,
> {
let config = LayoutConfig {
num_blocks,
num_layers: 1,
outer_dim: 1,
page_size: 4,
inner_dim: 1024,
alignment: 1,
dtype_width_bytes: 2,
};
let layout = FullyContiguous::<NullDeviceStorage>::allocate(config, &NullDeviceAllocator)?;
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)?.into_blocks()?;
let pool = ManagedBlockPool::builder().blocks(blocks).build()?;
Ok(pool)
}
/// A test that ensures that we only ever evict leaves from the inactive pool.
#[tokio::test]
async fn test_block_pool_evict_leaves() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 1 block. This should evict the leaf of our allocated sequence.
pool.allocate_blocks(1).await?;
// The leaf should be evicted, so we should have 3 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 3);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate 2 blocks. This should get the previously allocated block, as well as one more leaf.
pool.allocate_blocks(2).await.unwrap();
// The next leaf should be evicted, so we should have 2 matches.
let matched = pool
.match_sequence_hashes(sequence_hashes.as_slice())
.await?;
assert_eq!(matched.len(), 2);
drop(matched);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// If we allocate all the blocks, the entire remaining sequence should be evicted.
let blocks = pool.allocate_blocks(4).await?;
assert_eq!(blocks.len(), 4);
Ok(())
}
/// When a block has two children, we need to ensure that we evict both children before
/// adding the parent to the leaf set.
#[tokio::test]
async fn test_block_pool_parent_child() -> anyhow::Result<()> {
let pool = make_simple_pool(3).await?;
let tokens = vec![1, 2, 3, 4, 5];
let sequence = TokenBlockSequence::new(Tokens::from(tokens.clone()), 4, None);
// Create a root block, with two child blocks.
let mut root_block = pool.allocate_blocks(1).await?.pop().unwrap();
root_block.apply_token_block(sequence.blocks().first().unwrap().clone())?;
let root_block_hash = root_block.sequence_hash().unwrap();
let mut child_blocks = Vec::new();
let mut child_block_hashes = Vec::new();
for i in 0..2 {
// Create a new token sequence using the common prefix.
let mut tokens = tokens.clone();
for _ in 0..4 {
tokens.push(i);
}
let seq = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
// Allocate and apply the suffix to the child block.
let mut child_block = pool.allocate_blocks(1).await?.pop().unwrap();
child_block.apply_token_block(seq.blocks()[1].clone())?;
child_block_hashes.push(child_block.sequence_hash().unwrap());
child_blocks.push(child_block);
}
// Register the root block
let root_block = pool.register_blocks(vec![root_block]).await?;
// Register the children
let child_blocks = pool.register_blocks(child_blocks).await?;
// Drop both of them.
drop(root_block);
drop(child_blocks);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Allocate two new blocks, which should evict both children.
pool.allocate_blocks(2).await?;
// Now, the root block should be the only block left.
for child_block_hash in child_block_hashes {
let matched = pool.match_sequence_hashes(&[child_block_hash]).await?;
assert_eq!(matched.len(), 0);
}
// Check that the root block remains.
let matched = pool.match_sequence_hashes(&[root_block_hash]).await?;
assert_eq!(matched.len(), 1);
Ok(())
}
/// Matching an entire sequence (moving it to the active pool), and returning it
/// should not affect the parent-child relationships of the blocks.
#[tokio::test]
async fn test_block_pool_match_return() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
// We match the root of the sequence (moving it to the active pool), then
// immediately return it.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
1
);
let _alloc_blocks1 = pool.allocate_blocks(3).await?;
// Allocating 3 blocks should evict all but the root of the sequence.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
1
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _alloc_blocks2 = pool.allocate_blocks(1).await?;
// Now, allocating one more block should evict the root.
assert_eq!(
pool.match_sequence_hashes(sequence_hashes.as_slice())
.await?
.len(),
0
);
Ok(())
}
#[tokio::test]
async fn test_block_pool_touch() -> anyhow::Result<()> {
let pool = make_simple_pool(4).await?;
let (_, sequence_hashes) = create_blocks(&pool, 4).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _block0 = pool.allocate_blocks(1).await?;
// The leaf should be evicted.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[3]].as_slice())
.await?
.len(),
0
);
// Now, touch the new leaf.
pool.touch_blocks(vec![sequence_hashes[2]].as_slice())
.await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let _block1 = pool.allocate_blocks(1).await?;
// Since we touched block 2, block 1 should have been evicted.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[1]].as_slice())
.await?
.len(),
0
);
pool.touch_blocks(vec![sequence_hashes[3]].as_slice())
.await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
pool.allocate_blocks(1).await?;
// Now block 0 was evicted, since it was the last to be touched.
assert_eq!(
pool.match_sequence_hashes(vec![sequence_hashes[0]].as_slice())
.await?
.len(),
0
);
Ok(())
}
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
fn create_block(
pool: &ManagedBlockPool<NullDeviceStorage, Local, BasicMetadata>,
) -> ImmutableBlock<NullDeviceStorage, Local, BasicMetadata> {
let count = pool.available_blocks();
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
assert_eq!(pool.available_blocks(), count - 1);
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash(), sequence_hash);
block
}
#[test]
fn test_block_registration_allow_duplicates() {
// const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let count = blocks.len() as u64;
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.default_duplication_setting(BlockRegistrationDuplicationSetting::Allowed)
.build()
.unwrap();
assert_eq!(pool.total_blocks(), count);
assert_eq!(pool.available_blocks(), count);
assert_eq!(
pool.default_duplication_setting(),
BlockRegistrationDuplicationSetting::Allowed
);
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
let primary = create_block(&pool);
let primary_id = primary.block_id();
assert_eq!(pool.available_blocks(), count - 1);
// Now allocate another and register it with the same sequence
let duplicate = create_block(&pool);
assert!(duplicate.is_duplicate());
assert_ne!(duplicate.block_id(), primary_id);
assert_eq!(pool.available_blocks(), count - 2);
// Reset only succeeds if all the blocks have been returned to the pool
let reset_result = pool.reset_blocking();
assert!(reset_result.is_err());
// we hold both the primary and the duplicate in the duplicate
// since we hold the primary in the duplicate, we expect this to fail
assert!(pool.try_return_block_blocking(primary.into()).is_err());
assert_eq!(pool.available_blocks(), count - 2);
assert!(pool.try_return_block_blocking(duplicate.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// we can still match the primary block because we have not reset the pool
let mut matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
let primary = matched_blocks.pop().unwrap();
assert!(pool.try_return_block_blocking(primary.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// we can still create a duplicate even if the block is inactive
let duplicate = create_block(&pool);
assert!(duplicate.is_duplicate());
assert_ne!(duplicate.block_id(), primary_id);
assert_eq!(pool.available_blocks(), count - 2);
assert!(pool.try_return_block_blocking(duplicate.into()).is_ok());
assert_eq!(pool.available_blocks(), count);
// Reset the pool
let reset_result = pool.reset_blocking();
assert!(reset_result.is_ok());
// Now we should not be able to match the primary block
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
}
#[test]
fn test_block_registration_disable_duplicates() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let count = blocks.len() as u64;
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the ManagedBlockPool and add the blocks
let pool = ManagedBlockPoolArgsBuilder::default()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.default_duplication_setting(BlockRegistrationDuplicationSetting::Disabled)
.build()
.unwrap();
assert_eq!(pool.total_blocks(), count);
assert_eq!(pool.available_blocks(), count);
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// allocate and register the primary block
let primary = create_block(&pool);
let primary_id = primary.block_id();
assert_eq!(pool.available_blocks(), count - 1);
// Now allocate another and register it with the same sequence
let duplicate = create_block(&pool);
assert_eq!(pool.available_blocks(), count - 1);
assert_eq!(duplicate.block_id(), primary_id);
}
}
...@@ -13,14 +13,22 @@ ...@@ -13,14 +13,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::block_manager::block::locality::LocalityProvider;
use super::*; use super::*;
/// Manages active blocks being used by sequences /// Manages active blocks being used by sequences
pub struct ActiveBlockPool<S: Storage, M: BlockMetadata> { pub struct ActiveBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
pub(super) map: HashMap<SequenceHash, Weak<MutableBlock<S, M>>>, pub(super) map: HashMap<SequenceHash, Weak<MutableBlock<S, L, M>>>,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Default for ActiveBlockPool<S, L, M> {
fn default() -> Self {
Self::new()
}
} }
impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ActiveBlockPool<S, L, M> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
map: HashMap::new(), map: HashMap::new(),
...@@ -29,8 +37,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> { ...@@ -29,8 +37,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
pub fn register( pub fn register(
&mut self, &mut self,
mut block: MutableBlock<S, M>, mut block: MutableBlock<S, L, M>,
) -> Result<ImmutableBlock<S, M>, BlockPoolError> { ) -> Result<ImmutableBlock<S, L, M>, BlockPoolError> {
if !block.state().is_registered() { if !block.state().is_registered() {
return Err(BlockPoolError::InvalidMutableBlock( return Err(BlockPoolError::InvalidMutableBlock(
"block is not registered".to_string(), "block is not registered".to_string(),
...@@ -69,7 +77,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> { ...@@ -69,7 +77,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
} }
} }
pub fn remove(&mut self, block: &mut Block<S, M>) { pub fn remove(&mut self, block: &mut Block<S, L, M>) {
if let Ok(sequence_hash) = block.sequence_hash() { if let Ok(sequence_hash) = block.sequence_hash() {
if let Some(weak) = self.map.get(&sequence_hash) { if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(_arc) = weak.upgrade() { if let Some(_arc) = weak.upgrade() {
...@@ -84,7 +92,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> { ...@@ -84,7 +92,7 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
pub fn match_sequence_hash( pub fn match_sequence_hash(
&mut self, &mut self,
sequence_hash: SequenceHash, sequence_hash: SequenceHash,
) -> Option<ImmutableBlock<S, M>> { ) -> Option<ImmutableBlock<S, L, M>> {
if let Some(weak) = self.map.get(&sequence_hash) { if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(arc) = weak.upgrade() { if let Some(arc) = weak.upgrade() {
Some(ImmutableBlock::new(arc)) Some(ImmutableBlock::new(arc))
...@@ -97,4 +105,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> { ...@@ -97,4 +105,8 @@ impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
None None
} }
} }
pub fn status(&self) -> usize {
self.map.keys().len()
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ManagedBlockPool<S, L, M> {
fn _status(&self) -> AsyncResponse<BlockPoolResult<BlockPoolStatus>> {
let (req, resp_rx) = StatusReq::new(());
self.ctrl_tx
.send(ControlRequest::Status(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
fn _reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> AsyncResponse<BlockPoolResult<ResetBlocksResponse>> {
let (req, resp_rx) = ResetBlocksReq::new(sequence_hashes.into());
self.ctrl_tx
.send(ControlRequest::ResetBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPoolController
for ManagedBlockPool<S, L, M>
{
fn status_blocking(&self) -> Result<BlockPoolStatus, BlockPoolError> {
self._status()?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn reset_blocking(&self) -> Result<(), BlockPoolError> {
self._reset()?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn reset_blocks_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError> {
self._reset_blocks(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
}
#[async_trait::async_trait]
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> AsyncBlockPoolController
for ManagedBlockPool<S, L, M>
{
async fn status(&self) -> Result<BlockPoolStatus, BlockPoolError> {
self._status()?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn reset(&self) -> Result<(), BlockPoolError> {
self._reset()?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
async fn reset_blocks(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<ResetBlocksResponse, BlockPoolError> {
self._reset_blocks(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
}
...@@ -13,35 +13,37 @@ ...@@ -13,35 +13,37 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::block_manager::block::BlockState; use std::sync::atomic::AtomicU64;
use crate::block_manager::block::{locality::LocalityProvider, BlockState};
use super::*; use super::*;
use std::collections::HashSet; use priority_key::PriorityKey;
use tracing::instrument; use tracing::instrument;
#[derive(Default)] #[derive(Default)]
pub struct InactiveBlockPool<S: Storage, M: BlockMetadata> { pub struct InactiveBlockPool<S: Storage, L: LocalityProvider, M: BlockMetadata> {
// Direct lookup by sequence_hash. // Direct lookup by sequence_hash.
lookup_map: HashMap<SequenceHash, Block<S, M>>, lookup_map: HashMap<SequenceHash, Block<S, L, M>>,
// A priority ordering for the leaf nodes. // Ordered by timestamp (oldest first)
// Leaf nodes are defined as blocks that have no children in the inactive pool. priority_set: BTreeSet<PriorityKey<M>>,
leaf_set: BTreeSet<PriorityKey<M>>,
// Mapping from parents to their children.
parent_children: HashMap<SequenceHash, HashSet<SequenceHash>>,
// Fully Uninitialized // Fully Uninitialized
uninitialized_set: VecDeque<Block<S, M>>, uninitialized_set: VecDeque<Block<S, L, M>>,
// Return Tick // Return Tick
return_tick: u64, return_tick: u64,
// Total blocks // Total blocks counter
total_blocks: u64, total_blocks: Arc<AtomicU64>,
// Inactive blocks
available_blocks: Arc<AtomicU64>,
} }
impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { impl<S: Storage, L: LocalityProvider, M: BlockMetadata> InactiveBlockPool<S, L, M> {
/// Creates a new, empty [`InactiveBlockPool`]. /// Creates a new, empty [`InactiveBlockPool`].
/// ///
/// # Returns /// # Returns
...@@ -50,21 +52,39 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -50,21 +52,39 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
Self { Self {
lookup_map: HashMap::new(), lookup_map: HashMap::new(),
leaf_set: BTreeSet::new(), priority_set: BTreeSet::new(),
parent_children: HashMap::new(),
uninitialized_set: VecDeque::new(), uninitialized_set: VecDeque::new(),
return_tick: 0, return_tick: 0,
total_blocks: 0, total_blocks: Arc::new(AtomicU64::new(0)),
available_blocks: Arc::new(AtomicU64::new(0)),
} }
} }
/// Returns a counter for the number of available blocks.
///
/// # Returns
///
/// A counter for the number of available blocks as an [`Arc<AtomicU64>`].
pub fn available_blocks_counter(&self) -> Arc<AtomicU64> {
self.available_blocks.clone()
}
/// Returns a counter for the total number of blocks.
///
/// # Returns
///
/// A counter for the total number of blocks as an [`Arc<AtomicU64>`].
pub fn total_blocks_counter(&self) -> Arc<AtomicU64> {
self.total_blocks.clone()
}
/// Returns the total number of blocks managed by this pool (both available and acquired). /// Returns the total number of blocks managed by this pool (both available and acquired).
/// ///
/// # Returns /// # Returns
/// ///
/// The total block count as a [`u64`]. /// The total block count as a [`u64`].
pub fn total_blocks(&self) -> u64 { pub fn total_blocks(&self) -> u64 {
self.total_blocks self.total_blocks.load(Ordering::Relaxed)
} }
/// Returns the number of blocks currently available in the pool. /// Returns the number of blocks currently available in the pool.
...@@ -84,17 +104,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -84,17 +104,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// If an entry with the same sequence hash already exists in the [`lookup_map`] /// If an entry with the same sequence hash already exists in the [`lookup_map`]
/// the block is reset and moved to the [`uninitialized_set`]. /// the block is reset and moved to the [`uninitialized_set`].
/// Otherwise, the block is added to the [`lookup_map`]. /// Otherwise, the block is added to the [`lookup_map`].
/// If there are no children of the block, it is added to the [`leaf_set`].
/// If the parent of the block is in the [`leaf_set`], it is removed from the [`leaf_set`].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `block` - The block to insert ([`Block<T, M>`]). /// * `block` - The block to insert ([`Block<T, M>`]).
/// * `sequence_hash` - The sequence hash associated with the block's content ([`SequenceHash`]). /// * `sequence_hash` - The sequence hash associated with the block's content ([`SequenceHash`]).
#[instrument(level = "trace", skip(self, block), fields(sequence_hash = ?sequence_hash))] #[instrument(level = "trace", skip(self, block), fields(sequence_hash = ?sequence_hash))]
fn insert_with_sequence_hash(&mut self, block: Block<S, M>, sequence_hash: SequenceHash) { fn insert_with_sequence_hash(&mut self, block: Block<S, L, M>, sequence_hash: SequenceHash) {
let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash); let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
if self.lookup_map.contains_key(&sequence_hash) { if self.priority_set.contains(&priority_key) {
tracing::trace!("multiple entries with the same sequence hash, resetting block and inserting into uninitialized set"); tracing::trace!("multiple entries with the same sequence hash, resetting block and inserting into uninitialized set");
let mut block = block; let mut block = block;
block.reset(); block.reset();
...@@ -102,27 +120,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -102,27 +120,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
} else { } else {
tracing::trace!("inserting block to map and priority set"); tracing::trace!("inserting block to map and priority set");
if let Ok(Some(parent)) = block.parent_sequence_hash() { self.priority_set.insert(priority_key);
// Add the entry for the parent->child link.
self.parent_children
.entry(parent)
.or_default()
.insert(sequence_hash);
// If the parent is currently in the inactive pool, remove it from the leaf set.
if let Some(parent_block) = self.lookup_map.get_mut(&parent) {
self.leaf_set
.remove(&PriorityKey::new(parent_block.metadata().clone(), parent));
}
}
// Create the entry for the block in the lookup map.
self.lookup_map.insert(sequence_hash, block); self.lookup_map.insert(sequence_hash, block);
// If the block has no children, it is a leaf.
if !self.parent_children.contains_key(&sequence_hash) {
self.leaf_set.insert(priority_key);
}
} }
} }
...@@ -137,7 +136,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -137,7 +136,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// * `block` - The block to insert ([`Block<S, M>`]). /// * `block` - The block to insert ([`Block<S, M>`]).
#[instrument(level = "trace", skip(self, block), fields(block_state = ?block.state()))] #[instrument(level = "trace", skip(self, block), fields(block_state = ?block.state()))]
fn insert(&mut self, block: Block<S, M>) { fn insert(&mut self, block: Block<S, L, M>) {
tracing::trace!("Inserting block into available pool"); tracing::trace!("Inserting block into available pool");
// If we already have an entry for this sequence hash or the block is reset, // If we already have an entry for this sequence hash or the block is reset,
...@@ -161,6 +160,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -161,6 +160,8 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
self.insert_with_sequence_hash(block, sequence_hash); self.insert_with_sequence_hash(block, sequence_hash);
} }
} }
self.available_blocks.fetch_add(1, Ordering::Relaxed);
} }
/// Adds multiple blocks to the pool. /// Adds multiple blocks to the pool.
...@@ -171,7 +172,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -171,7 +172,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add. /// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))] #[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks(&mut self, blocks: Vec<Block<S, M>>) { pub fn add_blocks(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len(); let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool"); tracing::debug!(count, "Adding blocks to pool");
...@@ -181,7 +182,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -181,7 +182,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
self.insert(block); self.insert(block);
} }
self.total_blocks += count as u64; self.total_blocks.fetch_add(count as u64, Ordering::Relaxed);
} }
/// Adds multiple blocks to the pool. /// Adds multiple blocks to the pool.
...@@ -192,10 +193,10 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -192,10 +193,10 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add. /// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))] #[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks_with_state(&mut self, blocks: Vec<Block<S, M>>) { pub fn add_blocks_with_state(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len(); let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool"); tracing::debug!(count, "Adding blocks to pool");
self.total_blocks += count as u64; self.total_blocks.fetch_add(count as u64, Ordering::Relaxed);
// self.available_blocks += count as u64; // self.available_blocks += count as u64;
self.return_blocks(blocks); self.return_blocks(blocks);
} }
...@@ -209,7 +210,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -209,7 +210,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// * `block` - The block ([`Block<S, M>`]) to return. /// * `block` - The block ([`Block<S, M>`]) to return.
#[instrument(level = "debug", skip(self, block))] #[instrument(level = "debug", skip(self, block))]
pub fn return_block(&mut self, mut block: Block<S, M>) { pub fn return_block(&mut self, mut block: Block<S, L, M>) {
// increment the return tick // increment the return tick
self.return_tick += 1; self.return_tick += 1;
...@@ -231,7 +232,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -231,7 +232,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to return. /// * `blocks` - A vector of blocks ([`Block<T, M>`]) to return.
#[instrument(level = "debug", skip(self, blocks))] #[instrument(level = "debug", skip(self, blocks))]
pub fn return_blocks(&mut self, blocks: Vec<Block<S, M>>) { pub fn return_blocks(&mut self, blocks: Vec<Block<S, L, M>>) {
let count = blocks.len(); let count = blocks.len();
tracing::debug!(count, "Returning blocks to pool"); tracing::debug!(count, "Returning blocks to pool");
// return the block to the pool from tail to head // return the block to the pool from tail to head
...@@ -243,7 +244,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -243,7 +244,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
} }
/// Attempts to remove and return a block associated with the given sequence hash /// Attempts to remove and return a block associated with the given sequence hash
/// from the [`lookup_map`] and [`leaf_set`]. /// from the [`lookup_map`] and [`priority_set`].
/// ///
/// # Arguments /// # Arguments
/// ///
...@@ -253,13 +254,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -253,13 +254,15 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`. /// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "trace", skip(self), fields(sequence_hash = ?sequence_hash))] #[instrument(level = "trace", skip(self), fields(sequence_hash = ?sequence_hash))]
fn take_with_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> { fn take_with_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, L, M>> {
match self.lookup_map.remove(&sequence_hash) { match self.lookup_map.remove(&sequence_hash) {
Some(block) => { Some(block) => {
// Remove from leaf set, if it exists. // Remove from priority set.
self.leaf_set let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
.remove(&PriorityKey::new(block.metadata().clone(), sequence_hash)); // Remove from priority set, if it exists.
self.priority_set.remove(&priority_key);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
Some(block) Some(block)
} }
None => None, None => None,
...@@ -278,7 +281,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -278,7 +281,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// ///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`. /// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "debug", skip(self), fields(sequence_hash = ?sequence_hash))] #[instrument(level = "debug", skip(self), fields(sequence_hash = ?sequence_hash))]
pub fn match_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> { pub fn match_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, L, M>> {
self.take_with_sequence_hash(sequence_hash) self.take_with_sequence_hash(sequence_hash)
} }
...@@ -299,7 +302,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -299,7 +302,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub fn match_sequence_hashes( pub fn match_sequence_hashes(
&mut self, &mut self,
sequence_hashes: Vec<SequenceHash>, sequence_hashes: Vec<SequenceHash>,
) -> Vec<Block<S, M>> { ) -> Vec<Block<S, L, M>> {
let total_hashes = sequence_hashes.len(); let total_hashes = sequence_hashes.len();
let mut matched_blocks = Vec::with_capacity(total_hashes); let mut matched_blocks = Vec::with_capacity(total_hashes);
...@@ -332,7 +335,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -332,7 +335,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// A vector containing the blocks ([`Block<T, M>`]) that were successfully matched and taken. /// A vector containing the blocks ([`Block<T, M>`]) that were successfully matched and taken.
/// The vector may be shorter than `token_blocks` if not all corresponding hashes were found. /// The vector may be shorter than `token_blocks` if not all corresponding hashes were found.
#[instrument(level = "debug", skip(self, token_blocks), fields(num_token_blocks = token_blocks.len()))] #[instrument(level = "debug", skip(self, token_blocks), fields(num_token_blocks = token_blocks.len()))]
pub fn match_token_blocks(&mut self, token_blocks: &[TokenBlock]) -> Vec<Block<S, M>> { pub fn match_token_blocks(&mut self, token_blocks: &[TokenBlock]) -> Vec<Block<S, L, M>> {
let total_blocks = token_blocks.len(); let total_blocks = token_blocks.len();
let mut matched_blocks = Vec::with_capacity(total_blocks); let mut matched_blocks = Vec::with_capacity(total_blocks);
...@@ -375,52 +378,27 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -375,52 +378,27 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// and [`lookup_map`] (i.e., a key exists in the set but not the map). This indicates /// and [`lookup_map`] (i.e., a key exists in the set but not the map). This indicates
/// a bug in the pool's internal logic. /// a bug in the pool's internal logic.
#[instrument(level = "debug", skip(self))] #[instrument(level = "debug", skip(self))]
pub fn acquire_free_block(&mut self) -> Option<Block<S, M>> { pub fn acquire_free_block(&mut self) -> Option<Block<S, L, M>> {
// First try uninitialized blocks - these are often part of sequences // First try uninitialized blocks - these are often part of sequences
// that have been arranged in the correct order // that have been arranged in the correct order
if let Some(mut block) = self.uninitialized_set.pop_front() { if let Some(mut block) = self.uninitialized_set.pop_front() {
tracing::trace!("Acquired uninitialized block"); tracing::trace!("Acquired uninitialized block");
self.return_tick += 1; self.return_tick += 1;
block.metadata_on_acquired(self.return_tick); block.metadata_on_acquired(self.return_tick);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
return Some(block); return Some(block);
} }
// if we have blocks in the leaf set, pop the first (it's sorted by priority) // if we have blocks in the priority set, pop the first (it's sorted by priority)
// a fatal error will occur if the block is not found in the lookup map // a fatal error will occur if the block is not found in the lookup map
if let Some(key) = self.leaf_set.pop_first() { if let Some(key) = self.priority_set.pop_first() {
tracing::trace!("Acquired priority/registered block map; resetting block"); tracing::trace!("Acquired priority/registered block map; resetting block");
match self.lookup_map.remove(&key.sequence_hash()) { match self.lookup_map.remove(&key.sequence_hash()) {
Some(mut block) => { Some(mut block) => {
if let Some(children) = self.parent_children.get(&key.sequence_hash()) {
panic!(
"Block has {} inactive children, but should have none.",
children.len()
);
}
if let Ok(Some(parent)) = block.parent_sequence_hash() {
let is_leaf = match self.parent_children.get_mut(&parent) {
Some(children) => {
children.remove(&key.sequence_hash());
children.is_empty()
}
None => true,
};
if is_leaf {
self.parent_children.remove(&parent);
if let Some(parent_block) = self.lookup_map.get(&parent) {
self.leaf_set.insert(PriorityKey::new(
parent_block.metadata().clone(),
parent,
));
}
}
}
block.reset(); block.reset();
self.return_tick += 1; self.return_tick += 1;
block.metadata_on_acquired(self.return_tick); block.metadata_on_acquired(self.return_tick);
self.available_blocks.fetch_sub(1, Ordering::Relaxed);
Some(block) Some(block)
} }
None => { None => {
...@@ -457,7 +435,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -457,7 +435,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
pub fn acquire_free_blocks( pub fn acquire_free_blocks(
&mut self, &mut self,
count: usize, count: usize,
) -> Result<Vec<Block<S, M>>, BlockPoolError> { ) -> Result<Vec<Block<S, L, M>>, BlockPoolError> {
if count == 0 { if count == 0 {
return Ok(Vec::new()); return Ok(Vec::new());
} }
...@@ -529,13 +507,48 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -529,13 +507,48 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
Ok(blocks) Ok(blocks)
} }
/// Resets the pool to its initial state.
///
/// This function will acquire all blocks, which will reset their state, then return them.
///
/// A [`Result`] containing `Ok(())` if the reset was successful, otherwise an error.
pub fn reset(&mut self) -> Result<(), BlockPoolError> {
let total_blocks = self.total_blocks.load(Ordering::Relaxed);
let available_blocks = self.available_blocks.load(Ordering::Relaxed);
if total_blocks != available_blocks {
return Err(BlockPoolError::ResetError(format!(
"total blocks: {}, available blocks: {}",
total_blocks, available_blocks
)));
}
let blocks = self.acquire_free_blocks(total_blocks as usize)?;
for block in blocks.into_iter() {
self.return_block(block);
}
Ok(())
}
/// Returns the [`PoolStatus`] of the pool.
pub fn status(&self) -> (usize, usize) {
let inactive_blocks = self.priority_set.len();
let empty_blocks = self.uninitialized_set.len();
(inactive_blocks, empty_blocks)
}
} }
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use crate::{ use crate::{
block_manager::{ block_manager::{
block::{registry::BlockRegistry, state::CompleteState, Blocks, PrivateBlockExt}, block::{
locality::Local, registry::BlockRegistry, state::CompleteState, Blocks,
PrivateBlockExt,
},
events::NullEventManager, events::NullEventManager,
layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder}, layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder},
storage::tests::{NullDeviceAllocator, NullDeviceStorage}, storage::tests::{NullDeviceAllocator, NullDeviceStorage},
...@@ -650,7 +663,7 @@ pub(crate) mod tests { ...@@ -650,7 +663,7 @@ pub(crate) mod tests {
tokens: Tokens, tokens: Tokens,
block_size: u32, block_size: u32,
async_runtime: Handle, async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> { ) -> Vec<Block<NullDeviceStorage, Local, TestMetadata>> {
let (token_blocks, _partial_token_block) = let (token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts(); tokens.into_sequence(block_size, None).into_parts();
let num_blocks = token_blocks.len(); let num_blocks = token_blocks.len();
...@@ -681,7 +694,7 @@ pub(crate) mod tests { ...@@ -681,7 +694,7 @@ pub(crate) mod tests {
pub fn create_block_pool( pub fn create_block_pool(
num_blocks: usize, num_blocks: usize,
) -> InactiveBlockPool<NullDeviceStorage, TestMetadata> { ) -> InactiveBlockPool<NullDeviceStorage, Local, TestMetadata> {
let mut pool = InactiveBlockPool::new(); let mut pool = InactiveBlockPool::new();
let blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let blocks = create_block_collection(num_blocks).into_blocks().unwrap();
pool.add_blocks(blocks); pool.add_blocks(blocks);
...@@ -692,9 +705,9 @@ pub(crate) mod tests { ...@@ -692,9 +705,9 @@ pub(crate) mod tests {
pub fn acquire_blocks( pub fn acquire_blocks(
tokens: Tokens, tokens: Tokens,
block_size: u32, block_size: u32,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>, pool: &mut InactiveBlockPool<NullDeviceStorage, Local, TestMetadata>,
async_runtime: Handle, async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) { ) -> (Vec<Block<NullDeviceStorage, Local, TestMetadata>>, usize) {
let (mut token_blocks, _partial_token_block) = let (mut token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts(); tokens.into_sequence(block_size, None).into_parts();
...@@ -764,6 +777,10 @@ pub(crate) mod tests { ...@@ -764,6 +777,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10); assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let tokens = create_token_sequence(&[1, 2, 3, 4]); let tokens = create_token_sequence(&[1, 2, 3, 4]);
...@@ -776,11 +793,19 @@ pub(crate) mod tests { ...@@ -776,11 +793,19 @@ pub(crate) mod tests {
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 0); assert_eq!(matched_block_count, 0);
assert_eq!(pool.available_blocks(), 8); assert_eq!(pool.available_blocks(), 8);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
pool.return_blocks(blocks); pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10); assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let (blocks, matched_block_count) = acquire_blocks( let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(), tokens.clone(),
...@@ -791,11 +816,19 @@ pub(crate) mod tests { ...@@ -791,11 +816,19 @@ pub(crate) mod tests {
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 2); assert_eq!(matched_block_count, 2);
assert_eq!(pool.available_blocks(), 8); assert_eq!(pool.available_blocks(), 8);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
pool.return_blocks(blocks); pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10); assert_eq!(pool.available_blocks(), 10);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
let blocks = pool.acquire_free_blocks(10).unwrap(); let blocks = pool.acquire_free_blocks(10).unwrap();
for block in &blocks { for block in &blocks {
...@@ -828,6 +861,10 @@ pub(crate) mod tests { ...@@ -828,6 +861,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2); assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2); assert_eq!(pool.available_blocks(), 2);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
// Match the blocks in sequence // Match the blocks in sequence
let matched = pool.match_sequence_hashes(hashes.clone()); let matched = pool.match_sequence_hashes(hashes.clone());
...@@ -835,6 +872,10 @@ pub(crate) mod tests { ...@@ -835,6 +872,10 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2); assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 0); assert_eq!(pool.available_blocks(), 0);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
// Validate the blocks are in the correct order and match the sequence hashes // Validate the blocks are in the correct order and match the sequence hashes
assert_eq!(matched[0].sequence_hash().unwrap(), hashes[0]); assert_eq!(matched[0].sequence_hash().unwrap(), hashes[0]);
...@@ -845,5 +886,9 @@ pub(crate) mod tests { ...@@ -845,5 +886,9 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 2); assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2); assert_eq!(pool.available_blocks(), 2);
assert_eq!(
pool.available_blocks_counter().load(Ordering::Relaxed),
pool.available_blocks()
);
} }
} }
...@@ -20,10 +20,13 @@ use crate::block_manager::{ ...@@ -20,10 +20,13 @@ use crate::block_manager::{
use super::*; use super::*;
impl<S: Storage, M: BlockMetadata> State<S, M> { use active::ActiveBlockPool;
fn new( use inactive::InactiveBlockPool;
impl<S: Storage, L: LocalityProvider + 'static, M: BlockMetadata> State<S, L, M> {
pub fn new(
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>, return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
global_registry: GlobalRegistry, global_registry: GlobalRegistry,
async_runtime: Handle, async_runtime: Handle,
metrics: Arc<PoolMetrics>, metrics: Arc<PoolMetrics>,
...@@ -38,10 +41,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -38,10 +41,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
} }
} }
async fn handle_priority_request( pub async fn handle_priority_request(
&mut self, &mut self,
req: PriorityRequest<S, M>, req: PriorityRequest<S, L, M>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>, return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) { ) {
match req { match req {
PriorityRequest::AllocateBlocks(req) => { PriorityRequest::AllocateBlocks(req) => {
...@@ -52,8 +55,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -52,8 +55,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
} }
} }
PriorityRequest::RegisterBlocks(req) => { PriorityRequest::RegisterBlocks(req) => {
let (blocks, resp_tx) = req.dissolve(); let ((blocks, duplication_setting), resp_tx) = req.dissolve();
let immutable_blocks = self.register_blocks(blocks, return_rx).await; let immutable_blocks = self
.register_blocks(blocks, duplication_setting, return_rx)
.await;
if resp_tx.send(immutable_blocks).is_err() { if resp_tx.send(immutable_blocks).is_err() {
tracing::error!("failed to send response to register blocks"); tracing::error!("failed to send response to register blocks");
} }
...@@ -61,14 +66,37 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -61,14 +66,37 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
PriorityRequest::MatchSequenceHashes(req) => { PriorityRequest::MatchSequenceHashes(req) => {
let (sequence_hashes, resp_tx) = req.dissolve(); let (sequence_hashes, resp_tx) = req.dissolve();
let immutable_blocks = self.match_sequence_hashes(sequence_hashes, return_rx).await; let immutable_blocks = self.match_sequence_hashes(sequence_hashes, return_rx).await;
if resp_tx.send(immutable_blocks).is_err() { if resp_tx.send(Ok(immutable_blocks)).is_err() {
tracing::error!("failed to send response to match sequence hashes"); tracing::error!("failed to send response to match sequence hashes");
} }
} }
PriorityRequest::TouchBlocks(req) => {
let (sequence_hashes, resp_tx) = req.dissolve();
self.touch_blocks(&sequence_hashes, return_rx).await;
if resp_tx.send(Ok(())).is_err() {
tracing::error!("failed to send response to touch blocks");
}
}
PriorityRequest::Reset(req) => {
let (_req, resp_tx) = req.dissolve();
let result = self.inactive.reset();
if resp_tx.send(result).is_err() {
tracing::error!("failed to send response to reset");
}
}
PriorityRequest::ReturnBlock(req) => {
let (returnable_blocks, resp_tx) = req.dissolve();
for block in returnable_blocks {
self.return_block(block);
}
if resp_tx.send(Ok(())).is_err() {
tracing::error!("failed to send response to return block");
}
}
} }
} }
fn handle_control_request(&mut self, req: ControlRequest<S, M>) { pub fn handle_control_request(&mut self, req: ControlRequest<S, L, M>) {
match req { match req {
ControlRequest::AddBlocks(blocks) => { ControlRequest::AddBlocks(blocks) => {
let (blocks, resp_rx) = blocks.dissolve(); let (blocks, resp_rx) = blocks.dissolve();
...@@ -77,10 +105,25 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -77,10 +105,25 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
tracing::error!("failed to send response to add blocks"); tracing::error!("failed to send response to add blocks");
} }
} }
ControlRequest::Status(req) => {
let (_, resp_rx) = req.dissolve();
if resp_rx.send(Ok(self.status())).is_err() {
tracing::error!("failed to send response to status");
}
}
ControlRequest::ResetBlocks(req) => {
let (sequence_hashes, resp_rx) = req.dissolve();
if resp_rx
.send(Ok(self.try_reset_blocks(&sequence_hashes)))
.is_err()
{
tracing::error!("failed to send response to reset blocks");
}
}
} }
} }
fn handle_return_block(&mut self, block: Block<S, M>) { pub fn handle_return_block(&mut self, block: Block<S, L, M>) {
self.return_block(block); self.return_block(block);
} }
...@@ -89,8 +132,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -89,8 +132,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
async fn wait_for_returned_block( async fn wait_for_returned_block(
&mut self, &mut self,
sequence_hash: SequenceHash, sequence_hash: SequenceHash,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>, return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Block<S, M> { ) -> Block<S, L, M> {
while let Some(block) = return_rx.recv().await { while let Some(block) = return_rx.recv().await {
if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash) if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash)
{ {
...@@ -105,7 +148,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -105,7 +148,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
pub fn allocate_blocks( pub fn allocate_blocks(
&mut self, &mut self,
count: usize, count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> { ) -> Result<Vec<MutableBlock<S, L, M>>, BlockPoolError> {
let available_blocks = self.inactive.available_blocks() as usize; let available_blocks = self.inactive.available_blocks() as usize;
if available_blocks < count { if available_blocks < count {
...@@ -135,11 +178,15 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -135,11 +178,15 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
Ok(blocks) Ok(blocks)
} }
#[tracing::instrument(level = "debug", skip_all, fields(blocks = ?blocks))]
pub async fn register_blocks( pub async fn register_blocks(
&mut self, &mut self,
blocks: Vec<MutableBlock<S, M>>, blocks: Vec<MutableBlock<S, L, M>>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>, duplication_setting: BlockRegistrationDuplicationSetting,
) -> Result<Vec<ImmutableBlock<S, M>>, BlockPoolError> { return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Result<Vec<ImmutableBlock<S, L, M>>, BlockPoolError> {
assert!(!blocks.is_empty(), "no blocks to register");
let expected_len = blocks.len(); let expected_len = blocks.len();
let mut immutable_blocks = Vec::new(); let mut immutable_blocks = Vec::new();
...@@ -151,44 +198,81 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -151,44 +198,81 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
// If the block is already registered, acquire a clone of the immutable block // If the block is already registered, acquire a clone of the immutable block
if let Some(immutable) = self.active.match_sequence_hash(sequence_hash) { if let Some(immutable) = self.active.match_sequence_hash(sequence_hash) {
let immutable = if duplication_setting
== BlockRegistrationDuplicationSetting::Allowed
{
immutable.with_duplicate(block.into()).expect("incompatible immutable block; only primary should be returned from match_sequence_hash")
} else {
// immediate return the block to the pool if duplicates are disabled
if let Some(blocks) = block.try_take_block(private::PrivateToken) {
self.inactive.return_blocks(blocks);
}
immutable
};
immutable_blocks.push(immutable); immutable_blocks.push(immutable);
continue; continue;
} }
let mut offload = true; let mut offload = true;
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) let (mutable, duplicate) =
{ if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) {
assert!(matches!(raw_block.state(), BlockState::Registered(_, _))); // We already have a match, so our block is a duplicate.
MutableBlock::new(raw_block, self.return_tx.clone()) assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
} else { (
// Attempt to register the block MutableBlock::new(raw_block, self.return_tx.clone()),
// On the very rare chance that the block is registered, but in the process of being returned, Some(block),
// we will wait for it to be returned and then register it. )
let result = block.register(&mut self.registry); } else {
// Attempt to register the block
match result { // On the very rare chance that the block is registered, but in the process of being returned,
Ok(handle) => { // we will wait for it to be returned and then register it.
// Only create our publish handle if this block is new, and not transfered. let result = block.register(&mut self.registry);
if let Some(handle) = handle {
publish_handles.take_handle(handle); match result {
Ok(handle) => {
// Only create our publish handle if this block is new, and not transfered.
if let Some(handle) = handle {
publish_handles.take_handle(handle);
}
(block, None)
}
Err(BlockRegistrationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned
// Return the original block as the primary, and the block we passed in as the duplicate.
offload = false;
let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await;
(
MutableBlock::new(raw_block, self.return_tx.clone()),
Some(block),
)
}
Err(e) => {
return Err(BlockPoolError::FailedToRegisterBlock(e.to_string()));
} }
block
} }
Err(BlockRegistrationError::BlockAlreadyRegistered(_)) => { };
// Block is already registered, wait for it to be returned
offload = false; let mut immutable = self.active.register(mutable)?;
let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await; match duplication_setting {
MutableBlock::new(raw_block, self.return_tx.clone()) BlockRegistrationDuplicationSetting::Allowed => {
if let Some(duplicate) = duplicate {
immutable = immutable
.with_duplicate(duplicate.into())
.expect("incompatible immutable block; only primary should be returned from ActiveBlockPool::register");
} }
Err(e) => { }
return Err(BlockPoolError::FailedToRegisterBlock(e.to_string())); BlockRegistrationDuplicationSetting::Disabled => {
if let Some(block) = duplicate {
if let Some(raw_blocks) = block.try_take_block(private::PrivateToken) {
self.inactive.return_blocks(raw_blocks);
}
} }
} }
}; }
let immutable = self.active.register(mutable)?;
if offload { if offload {
if let Some(priority) = immutable.metadata().offload_priority() { if let Some(priority) = immutable.metadata().offload_priority() {
...@@ -211,8 +295,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -211,8 +295,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
async fn match_sequence_hashes( async fn match_sequence_hashes(
&mut self, &mut self,
sequence_hashes: Vec<SequenceHash>, sequence_hashes: Vec<SequenceHash>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>, return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) -> Vec<ImmutableBlock<S, M>> { ) -> Vec<ImmutableBlock<S, L, M>> {
let mut immutable_blocks = Vec::new(); let mut immutable_blocks = Vec::new();
for sequence_hash in &sequence_hashes { for sequence_hash in &sequence_hashes {
if !self.registry.is_registered(*sequence_hash) { if !self.registry.is_registered(*sequence_hash) {
...@@ -245,7 +329,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -245,7 +329,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self let immutable = self
.active .active
.register(mutable) .register(mutable)
.expect("unable to register block; should ever happen"); .expect("unable to register block; should never happen");
immutable_blocks.push(immutable); immutable_blocks.push(immutable);
} }
...@@ -260,8 +344,31 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -260,8 +344,31 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
immutable_blocks immutable_blocks
} }
async fn touch_blocks(
&mut self,
sequence_hashes: &[SequenceHash],
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, L, M>>,
) {
for sequence_hash in sequence_hashes {
if !self.registry.is_registered(*sequence_hash) {
break;
}
let block = if let Some(block) = self.inactive.match_sequence_hash(*sequence_hash) {
block
} else if self.active.match_sequence_hash(*sequence_hash).is_none() {
self.wait_for_returned_block(*sequence_hash, return_rx)
.await
} else {
continue;
};
self.inactive.return_block(block);
}
}
/// Returns a block to the inactive pool /// Returns a block to the inactive pool
pub fn return_block(&mut self, mut block: Block<S, M>) { pub fn return_block(&mut self, mut block: Block<S, L, M>) {
self.active.remove(&mut block); self.active.remove(&mut block);
self.inactive.return_block(block); self.inactive.return_block(block);
} }
...@@ -269,111 +376,41 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -269,111 +376,41 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
fn publisher(&self) -> Publisher { fn publisher(&self) -> Publisher {
Publisher::new(self.event_manager.clone()) Publisher::new(self.event_manager.clone())
} }
}
impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> { fn status(&self) -> BlockPoolStatus {
#[allow(clippy::too_many_arguments)] let active = self.active.status();
pub fn new( let (inactive, empty) = self.inactive.status();
event_manager: Arc<dyn EventManager>, BlockPoolStatus {
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>, active_blocks: active,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>, inactive_blocks: inactive,
cancel_token: CancellationToken, empty_blocks: empty,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
metrics: Arc<PoolMetrics>,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, M>::new(
event_manager,
return_tx,
global_registry,
async_runtime,
metrics.clone(),
);
tracing::debug!(count = blocks.len(), "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
Self {
priority_rx,
ctrl_rx,
cancel_token,
state,
return_rx,
metrics,
} }
} }
pub async fn step(&mut self) -> bool { fn try_reset_blocks(&mut self, sequence_hashes: &[SequenceHash]) -> ResetBlocksResponse {
tokio::select! { let mut reset_blocks = Vec::new();
biased; let mut not_found = Vec::new();
let mut not_reset = Vec::new();
Some(priority_req) = self.priority_rx.recv(), if !self.priority_rx.is_closed() => { for sequence_hash in sequence_hashes {
self.metrics.gauge("priority_request_queue_size").set(self.priority_rx.len() as i64); if !self.registry.is_registered(*sequence_hash) {
self.state.handle_priority_request(priority_req, &mut self.return_rx).await; not_found.push(*sequence_hash);
} continue;
Some(req) = self.ctrl_rx.recv(), if !self.ctrl_rx.is_closed() => {
self.metrics.gauge("control_request_queue_size").set(self.ctrl_rx.len() as i64);
self.state.handle_control_request(req);
}
Some(block) = self.return_rx.recv() => {
self.metrics.gauge("return_block_queue_size").set(self.return_rx.len() as i64);
self.state.handle_return_block(block);
} }
_ = self.cancel_token.cancelled() => { if let Some(mut block) = self.inactive.match_sequence_hash(*sequence_hash) {
return false; reset_blocks.push(*sequence_hash);
block.reset();
self.inactive.return_block(block);
} else {
not_reset.push(*sequence_hash);
} }
} }
true ResetBlocksResponse {
reset_blocks,
not_found,
not_reset,
}
} }
} }
// pub(crate) async fn progress_engine<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// mut priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// mut ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let (return_tx, mut return_rx) = tokio::sync::mpsc::unbounded_channel();
// let mut state = State::<S, M>::new(event_manager, return_tx);
// loop {
// tokio::select! {
// biased;
// Some(priority_req) = priority_rx.recv(), if !priority_rx.is_closed() => {
// state.handle_priority_request(priority_req, &mut return_rx).await;
// }
// Some(req) = ctrl_rx.recv(), if !ctrl_rx.is_closed() => {
// state.handle_control_request(req);
// }
// Some(block) = return_rx.recv() => {
// state.handle_return_block(block);
// }
// _ = cancel_token.cancelled() => {
// break;
// }
// }
// }
// }
// pub(crate) async fn progress_engine_v2<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let mut progress_engine =
// ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token);
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// }
...@@ -13,190 +13,255 @@ ...@@ -13,190 +13,255 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
mod local;
mod logical;
mod resources;
use crate::block_manager::block::{factory::IntoBlocks, MutableBlock};
use crate::block_manager::locality::LogicalResources;
use crate::block_manager::offload::request::BlockResult;
use super::*; use super::*;
use super::offload::OffloadManager; // use super::offload::OffloadManager;
use super::{ use super::{
block::{Block, GlobalRegistry, ImmutableBlock}, block::{
factory::LocalBlockDataFactory, locality::LocalityProvider, Block, GlobalRegistry,
ImmutableBlock,
},
config::NixlOptions, config::NixlOptions,
events::{EventManager, NullEventManager}, events::{EventManager, NullEventManager},
metrics::{BlockManagerMetrics, PoolMetrics}, metrics::BlockManagerMetrics,
offload::OffloadManager,
}; };
use derive_getters::Dissolve;
use std::sync::Arc; use std::sync::Arc;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::oneshot;
#[allow(dead_code)] pub(crate) struct Resources {
pub struct KvBlockManagerState<Metadata: BlockMetadata> { pub worker_id: WorkerID,
worker_id: WorkerID, pub cancellation_token: CancellationToken,
cancellation_token: CancellationToken, pub async_rt_handle: Handle,
nixl_agent: Arc<Option<NixlAgent>>, // nixl agent/backends for the block manager
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>, pub nixl_agent: Arc<Option<NixlAgent>>,
#[expect(dead_code)]
pub nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Option<Arc<BlockPool<DiskStorage, Metadata>>>, // registry for blocks across all storage types
host_pool: Option<Arc<BlockPool<PinnedStorage, Metadata>>>, pub global_registry: GlobalRegistry,
device_pool: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
local_block_set: NixlBlockSet, // event manager for block manager events
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>, pub event_manager: Arc<dyn EventManager>,
offload_manager: Arc<OffloadManager<Metadata>>, // metrics for the block manager
pub metrics: Arc<BlockManagerMetrics>,
// config for the block manager
pub config: KvBlockManagerConfig,
} }
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { #[allow(dead_code)]
pub fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> { pub struct KvBlockManagerState<Locality: LocalityProvider, Metadata: BlockMetadata> {
config resources: Arc<Resources>,
.runtime
.validate()
.context("Validating runtime config")?;
config.model.validate().context("Validating model config")?; disk_pool: Option<Arc<dyn BlockPool<DiskStorage, Locality, Metadata>>>,
host_pool: Option<Arc<dyn BlockPool<PinnedStorage, Locality, Metadata>>>,
device_pool: Option<Arc<dyn BlockPool<DeviceStorage, Locality, Metadata>>>,
let worker_id = config.runtime.worker_id; local_block_set: NixlBlockSet,
let cancellation_token = config.runtime.cancellation_token; remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
offload_manager: Arc<OffloadManager<Locality, Metadata>>,
}
// Create a map of NIXL backends impl<Locality: LocalityProvider, Metadata: BlockMetadata> KvBlockManagerState<Locality, Metadata> {
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new(); pub fn disk(&self) -> Option<&dyn BlockPool<DiskStorage, Locality, Metadata>> {
self.disk_pool.as_ref().map(|pool| pool.as_ref())
}
let global_registry = GlobalRegistry::default(); pub fn host(&self) -> Option<&dyn BlockPool<PinnedStorage, Locality, Metadata>> {
self.host_pool.as_ref().map(|pool| pool.as_ref())
}
let metrics = BlockManagerMetrics::new(&config.runtime.metrics_registry)?; pub fn device(&self) -> Option<&dyn BlockPool<DeviceStorage, Locality, Metadata>> {
self.device_pool.as_ref().map(|pool| pool.as_ref())
}
let event_manager = config pub fn worker_id(&self) -> WorkerID {
.event_manager self.resources.worker_id
.clone() }
.unwrap_or_else(|| NullEventManager::new());
// Create a NIXL agent if NIXL is enabled and instantiate requested backends pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
// TODO: Build a map of NIXL backends to block pools/sets &self,
let nixl_agent = Arc::new(match config.runtime.nixl { block: &ImmutableBlock<S, Locality, Metadata>,
NixlOptions::Enabled => { priority: u64,
tracing::debug!("Creating NIXL agent"); ) -> Result<()> {
let agent = NixlAgent::new(&worker_id.to_string())?; self.offload_manager.offload(block, priority).await?;
tracing::debug!("Creating NIXL backends"); Ok(())
}
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") { pub fn onboard_blocks<S: Storage + 'static>(
let backend = agent.create_backend("UCX", &ucx_params)?; &self,
nixl_backends.insert("UCX".to_string(), Arc::new(backend)); blocks: Vec<ImmutableBlock<S, Locality, Metadata>>,
} else { targets: Option<Vec<MutableBlock<DeviceStorage, Locality, Metadata>>>,
tracing::warn!("No UCX plugin found; will not create UCX backend"); ) -> oneshot::Receiver<BlockResult<DeviceStorage, Locality, Metadata>> {
} self.offload_manager.onboard(blocks, targets)
}
}
if config.disk_layout.is_some() { impl<R: LogicalResources, Metadata: BlockMetadata>
if let Ok((_, gds_params)) = agent.get_plugin_params("GDS") { KvBlockManagerState<locality::Logical<R>, Metadata>
let backend = agent.create_backend("GDS", &gds_params)?; {
nixl_backends.insert("GDS".to_string(), Arc::new(backend)); pub async fn new(config: KvBlockManagerConfig, logical_resources: R) -> Result<Arc<Self>> {
} else { let mut resources = Resources::new(config)?;
tracing::warn!("No GDS plugin found; will not create GDS backend"); let block_data_factories =
} logical::LogicalBlockFactories::new(&mut resources, logical_resources)?;
}
let (disk_factory, host_factory, device_factory) = block_data_factories.dissolve();
let (disk_pool, disk_blocks) = match disk_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
}
};
Some(agent) let (host_pool, host_blocks) = match host_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "host")?;
(Some(pool), Some(blocks))
} }
NixlOptions::EnabledWithAgent(agent) => Some(agent), None => {
NixlOptions::Disabled => None, tracing::debug!("No host layout provided; will not allocate host blocks.");
}); (None, None)
}
};
// Initialize model-specific layout config. The layout_builder is incomplete at this point. let (device_pool, device_blocks) = match device_factory {
// We will clone this builder and apply the storage-specific configs to each clone in the Some(factory) => {
// following steps. let (pool, blocks) =
let model = &config.model; create_block_pool::<_, _, Metadata>(factory, &resources, "device")?;
let mut layout_builder = LayoutConfig::builder(); (Some(pool), Some(blocks))
}
layout_builder None => {
.num_layers(model.num_layers) tracing::debug!("No device layout provided; will not allocate device blocks.");
.outer_dim(model.outer_dim) (None, None)
.page_size(model.page_size) }
.inner_dim(model.inner_dim)
.dtype(model.dtype);
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
}; };
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout { let offload_manager = OffloadManager::new(
if nixl_agent.is_none() { disk_pool.clone(),
tracing::warn!("NIXL is disabled; will not allocate disk blocks."); host_pool.clone(),
device_pool.clone(),
resources.nixl_agent.clone(),
resources.async_rt_handle.clone(),
resources.metrics.clone(),
resources.cancellation_token.clone(),
)?;
let resources = Arc::new(resources);
let state = Arc::new(Self {
resources: resources.clone(),
disk_pool,
host_pool,
device_pool,
local_block_set: NixlBlockSet::new(resources.worker_id),
remote_block_sets: RwLock::new(HashMap::new()),
offload_manager,
});
if let Some(mut blocks) = disk_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state.disk_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state.host_pool.as_ref().unwrap().add_blocks(blocks).await?;
}
if let Some(mut blocks) = device_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.device_pool
.as_ref()
.unwrap()
.add_blocks(blocks)
.await?;
}
Ok(state)
}
}
// move into mod local
// move local block data factory into mod super::block
// create a method on locality to construct a block data factory from a layout builder and resources
// - this will allow us to use the locality abstraction to build our factories and block pools
impl<Metadata: BlockMetadata> KvBlockManagerState<locality::Local, Metadata> {
pub async fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
let mut resources = Resources::new(config)?;
let block_data_factories = local::LocalBlockDataFactories::new(&mut resources)?;
let (mut local_block_set, disk_factory, host_factory, device_factory) =
block_data_factories.dissolve();
let (disk_pool, disk_blocks) = match disk_factory {
Some(factory) => {
let (pool, blocks) =
create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
(Some(pool), Some(blocks))
}
None => {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None) (None, None)
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("disk"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
} }
} else {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(None, None)
}; };
// Create the host block pool if a host layout is provided let (host_pool, host_blocks) = match host_factory {
let (host_pool, host_blocks) = if let Some(config) = config.host_layout { Some(factory) => {
next_block_set_idx += 1; let (pool, blocks) =
tracing::debug!("Constructing host pool."); create_block_pool::<_, _, Metadata>(factory, &resources, "host")?;
let layout = (Some(pool), Some(blocks))
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?; }
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?); None => {
let (pool, blocks) = create_block_pool::<_, Metadata>( tracing::debug!("No disk layout provided; will not allocate disk blocks.");
layout, (None, None)
next_block_set_idx, }
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("host"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None)
}; };
// Create the device block pool if a device layout is provided let (device_pool, device_blocks) = match device_factory {
let (device_pool, device_blocks) = if let Some(config) = config.device_layout { Some(factory) => {
next_block_set_idx += 1; let (pool, blocks) =
tracing::debug!("Constructing device pool."); create_block_pool::<_, _, Metadata>(factory, &resources, "disk")?;
let layout = (Some(pool), Some(blocks))
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?; }
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?); None => {
let (pool, blocks) = create_block_pool::<_, Metadata>( tracing::debug!("No disk layout provided; will not allocate disk blocks.");
layout, (None, None)
next_block_set_idx, }
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
metrics.pool("device"),
Some(event_manager.clone()),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None)
}; };
// Finalize the local block set by adding NIXL metadata // Finalize the local block set by adding NIXL metadata
if let Some(nixl_agent) = nixl_agent.as_ref() { if let Some(nixl_agent) = resources.nixl_agent.as_ref() {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata."); tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
} }
...@@ -205,17 +270,16 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -205,17 +270,16 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
disk_pool.clone(), disk_pool.clone(),
host_pool.clone(), host_pool.clone(),
device_pool.clone(), device_pool.clone(),
nixl_agent.clone(), resources.nixl_agent.clone(),
async_rt_handle, resources.async_rt_handle.clone(),
metrics.clone(), resources.metrics.clone(),
cancellation_token.clone(), resources.cancellation_token.clone(),
)?; )?;
let resources = Arc::new(resources);
let state = Arc::new(Self { let state = Arc::new(Self {
worker_id, resources: resources.clone(),
cancellation_token,
nixl_agent,
nixl_backends,
disk_pool, disk_pool,
host_pool, host_pool,
device_pool, device_pool,
...@@ -229,12 +293,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -229,12 +293,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
block.set_manager(state.clone()); block.set_manager(state.clone());
}); });
state state.disk_pool.as_ref().unwrap().add_blocks(blocks).await?;
.disk_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
} }
if let Some(mut blocks) = host_blocks { if let Some(mut blocks) = host_blocks {
...@@ -242,12 +301,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -242,12 +301,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
block.set_manager(state.clone()); block.set_manager(state.clone());
}); });
state state.host_pool.as_ref().unwrap().add_blocks(blocks).await?;
.host_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
} }
if let Some(mut blocks) = device_blocks { if let Some(mut blocks) = device_blocks {
...@@ -258,9 +312,9 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -258,9 +312,9 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state state
.device_pool .device_pool
.as_ref() .as_ref()
.as_ref()
.unwrap() .unwrap()
.add_blocks_blocking(blocks)?; .add_blocks(blocks)
.await?;
} }
Ok(state) Ok(state)
...@@ -296,11 +350,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -296,11 +350,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
tracing::debug!("Importing remote blockset from worker {}", worker_id); tracing::debug!("Importing remote blockset from worker {}", worker_id);
assert_ne!( assert_ne!(
worker_id, self.worker_id, worker_id, self.resources.worker_id,
"Cannot import blockset from self" "Cannot import blockset from self"
); );
let agent = self let agent = self
.resources
.nixl_agent .nixl_agent
.as_ref() .as_ref()
.as_ref() .as_ref()
...@@ -417,91 +472,51 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -417,91 +472,51 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(blocks) Ok(blocks)
} }
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn worker_id(&self) -> WorkerID {
self.worker_id
}
pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
&self,
block: &ImmutableBlock<S, Metadata>,
priority: u64,
) -> Result<()> {
self.offload_manager.offload(block, priority).await?;
Ok(())
}
pub async fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
self.offload_manager.onboard(blocks).await
}
} }
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> { impl<Locality: LocalityProvider, Metadata: BlockMetadata> std::fmt::Debug
for KvBlockManagerState<Locality, Metadata>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "KvBlockManagerState") write!(f, "KvBlockManagerState")
} }
} }
fn create_layout<S: Storage + NixlRegisterableStorage>( // if let Some(storage) = config.storage {
mut builder: LayoutConfigBuilder, // let mut layout = layout.create_layout(config.layout_type, storage, false)?;
config: KvManagerLayoutConfig<S>, // if let Some(nixl_agent) = nixl_agent {
nixl_agent: Option<&NixlAgent>, // layout.nixl_register(nixl_agent, None)?;
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> { // }
let layout = builder.num_blocks(config.num_blocks).build()?; // return Ok(layout.into());
if let Some(storage) = config.storage { // }
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent { // if let Some(allocator) = config.allocator {
layout.nixl_register(nixl_agent, None)?; // let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
} // if let Some(nixl_agent) = nixl_agent {
return Ok(Arc::new(layout)); // layout.nixl_register(nixl_agent, None)?;
} // }
// return Ok(layout.into());
if let Some(allocator) = config.allocator { // }
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent { // anyhow::bail!("failed to create layout");
layout.nixl_register(nixl_agent, None)?; // }
}
return Ok(Arc::new(layout)); #[expect(clippy::type_complexity)]
} pub(crate) fn create_block_pool<S: Storage, L: LocalityProvider, M: BlockMetadata>(
factory: impl IntoBlocks<S, L>,
resources: &Resources,
pool_name: &str,
) -> Result<(Arc<dyn BlockPool<S, L, M>>, Vec<Block<S, L, M>>)> {
let pool = ManagedBlockPool::<S, L, M>::builder()
.cancel_token(resources.cancellation_token.clone())
.global_registry(resources.global_registry.clone())
.async_runtime(resources.async_rt_handle.clone())
.event_manager(resources.event_manager.clone())
.pool_metrics(resources.metrics.pool(pool_name))
.build()?;
anyhow::bail!("failed to create layout"); let blocks = factory.into_blocks()?;
Ok((Arc::new(pool), blocks))
} }
#[expect(clippy::type_complexity, clippy::too_many_arguments)] // Block state operations moved to block.rs for better organization and private field access
fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
layout: Arc<dyn NixlLayout<StorageType = S>>,
block_set_idx: usize,
cancellation_token: CancellationToken,
worker_id: WorkerID,
global_registry: GlobalRegistry,
async_runtime: Handle,
pool_metrics: Arc<PoolMetrics>,
event_manager: Option<Arc<dyn EventManager>>,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let event_manager = event_manager.unwrap_or_else(|| NullEventManager::new());
let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token)
.global_registry(global_registry)
.async_runtime(async_runtime)
.pool_metrics(pool_metrics)
.event_manager(event_manager)
.build()?;
Ok((pool, blocks))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
/// The local block factories for the block manager
///
/// This struct will construct the factories in a consistent order and can be
/// used as an intermediate step before creating the block pools.
///
/// This is useful for debugging and for testing.
#[derive(Dissolve)]
pub struct LocalBlockDataFactories {
block_set: NixlBlockSet,
disk_factory: Option<LocalBlockDataFactory<DiskStorage>>,
host_factory: Option<LocalBlockDataFactory<PinnedStorage>>,
device_factory: Option<LocalBlockDataFactory<DeviceStorage>>,
}
impl LocalBlockDataFactories {
/// Construct the local block factories
pub fn new(resources: &mut Resources) -> Result<Self> {
let mut block_set = NixlBlockSet::new(resources.worker_id);
let mut next_block_set_idx = 0;
let layout_builder = resources.layout_builder();
let device_factory = if let Some(config) = resources.config.device_layout.take() {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
} else {
None
};
let host_factory = if let Some(config) = resources.config.host_layout.take() {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
} else {
None
};
let disk_factory = if let Some(config) = resources.config.disk_layout.take() {
if resources.nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
None
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout = create_layout(
layout_builder.clone(),
config,
resources.nixl_agent.as_ref().as_ref(),
)?;
block_set.add_block_set(next_block_set_idx, layout.serialize()?);
Some(LocalBlockDataFactory::new(
layout,
next_block_set_idx,
resources.worker_id,
))
}
} else {
None
};
Ok(Self {
block_set,
disk_factory,
host_factory,
device_factory,
})
}
}
fn create_layout<S: Storage + NixlRegisterableStorage>(
mut builder: LayoutConfigBuilder,
config: KvManagerLayoutConfig<S>,
nixl_agent: Option<&NixlAgent>,
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> {
let layout = builder.num_blocks(config.num_blocks).build()?;
if let Some(_logical) = config.logical {
return Err(anyhow::anyhow!(
"Logical layouts are not supported by the local builder"
));
}
if let Some(storage) = config.storage {
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(layout.into());
}
if let Some(allocator) = config.allocator {
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(layout.into());
}
anyhow::bail!("failed to create layout");
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::block_manager::{block::factory::logical::LogicalBlockFactory, storage::StorageType};
/// The local block factories for the block manager
///
/// This struct will construct the factories in a consistent order and can be
/// used as an intermediate step before creating the block pools.
///
/// This is useful for debugging and for testing.
#[derive(Dissolve)]
pub struct LogicalBlockFactories<R: LogicalResources> {
disk_factory: Option<LogicalBlockFactory<DiskStorage, R>>,
host_factory: Option<LogicalBlockFactory<PinnedStorage, R>>,
device_factory: Option<LogicalBlockFactory<DeviceStorage, R>>,
}
impl<R: LogicalResources> LogicalBlockFactories<R> {
/// Construct the local block factories
pub fn new(resources: &mut Resources, logical_resources: R) -> Result<Self> {
let mut next_block_set_idx = 0;
let layout_builder = resources.layout_builder();
let logical_resources = Arc::new(logical_resources);
let device_factory = if let Some(config) = resources.config.device_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Device(0),
);
Some(factory)
} else {
None
};
let host_factory = if let Some(config) = resources.config.host_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Pinned,
);
Some(factory)
} else {
None
};
let disk_factory = if let Some(config) = resources.config.disk_layout.take() {
next_block_set_idx += 1;
let mut builder = layout_builder.clone();
let config = Arc::new(builder.num_blocks(config.num_blocks).build()?);
let factory = LogicalBlockFactory::new(
config,
next_block_set_idx,
resources.worker_id,
logical_resources.clone(),
StorageType::Disk(0),
);
Some(factory)
} else {
None
};
Ok(Self {
disk_factory,
host_factory,
device_factory,
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
impl Resources {
/// Create a new [`Resources`] instance
pub fn new(config: KvBlockManagerConfig) -> Result<Self> {
config
.runtime
.validate()
.context("Validating runtime config")?;
config.model.validate().context("Validating model config")?;
let worker_id = config.runtime.worker_id;
let cancellation_token = config.runtime.cancellation_token.clone();
let global_registry = GlobalRegistry::default();
let metrics = BlockManagerMetrics::new(&config.runtime.metrics_registry)?;
let event_manager = config
.event_manager
.clone()
.unwrap_or_else(|| NullEventManager::new());
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
let nixl_agent = Arc::new(match &config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends");
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") {
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
} else {
tracing::warn!("No UCX plugin found; will not create UCX backend");
}
if config.disk_layout.is_some() {
if let Ok((_, gds_mt_params)) = agent.get_plugin_params("GDS_MT") {
let backend = agent.create_backend("GDS_MT", &gds_mt_params)?;
nixl_backends.insert("GDS_MT".to_string(), Arc::new(backend));
} else {
tracing::warn!("No GDS_MT plugin found; will not create GDS_MT backend");
}
}
Some(agent)
}
NixlOptions::EnabledWithAgent(agent) => Some(agent.clone()),
NixlOptions::Disabled => None,
});
let async_rt_handle = match &config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
Ok(Self {
worker_id,
cancellation_token,
async_rt_handle,
nixl_agent,
nixl_backends,
global_registry,
event_manager,
metrics,
config,
})
}
/// Create a new [`LayoutConfigBuilder`] with the model configuration
pub fn layout_builder(&self) -> LayoutConfigBuilder {
let mut layout_builder = LayoutConfig::builder();
let model = &self.config.model;
layout_builder
.num_layers(model.num_layers)
.outer_dim(model.outer_dim)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype_width_bytes(model.dtype_width_bytes);
layout_builder
}
}
...@@ -77,14 +77,15 @@ ...@@ -77,14 +77,15 @@
//! - [`StorageMemset`] - Memory initialization operations //! - [`StorageMemset`] - Memory initialization operations
//! - [`StorageAllocator`] - Factory for creating storage instances //! - [`StorageAllocator`] - Factory for creating storage instances
pub mod arena;
pub mod cuda; pub mod cuda;
pub mod disk; pub mod disk;
pub mod nixl; pub mod nixl;
pub mod torch;
pub mod arena;
pub use cuda::*; pub use cuda::*;
pub use disk::*; pub use disk::*;
use torch::*;
use std::{ use std::{
alloc::{alloc_zeroed, dealloc, Layout}, alloc::{alloc_zeroed, dealloc, Layout},
...@@ -100,7 +101,7 @@ use thiserror::Error; ...@@ -100,7 +101,7 @@ use thiserror::Error;
pub type StorageResult<T> = std::result::Result<T, StorageError>; pub type StorageResult<T> = std::result::Result<T, StorageError>;
/// Represents the type of storage used for a block /// Represents the type of storage used for a block
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum StorageType { pub enum StorageType {
/// System memory /// System memory
System, System,
...@@ -112,7 +113,7 @@ pub enum StorageType { ...@@ -112,7 +113,7 @@ pub enum StorageType {
Pinned, Pinned,
/// Disk memory /// Disk memory
Disk, Disk(u64),
/// Remote memory accessible through NIXL /// Remote memory accessible through NIXL
Nixl, Nixl,
...@@ -193,6 +194,14 @@ pub trait Storage: Debug + Send + Sync + 'static { ...@@ -193,6 +194,14 @@ pub trait Storage: Debug + Send + Sync + 'static {
unsafe fn as_mut_ptr(&mut self) -> *mut u8; unsafe fn as_mut_ptr(&mut self) -> *mut u8;
} }
pub trait StorageTypeProvider {
type StorageType: Storage;
fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<Self::StorageType>()
}
}
/// Extension trait for storage types that support memory setting operations /// Extension trait for storage types that support memory setting operations
pub trait StorageMemset: Storage { pub trait StorageMemset: Storage {
/// Sets a region of memory to a specific value /// Sets a region of memory to a specific value
...@@ -524,3 +533,41 @@ pub mod tests { ...@@ -524,3 +533,41 @@ pub mod tests {
} }
} }
} }
// Comment out Nixl-related code for now
/*
pub trait NixlDescriptor: Storage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable>;
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable>;
}
impl NixlDescriptor for SystemStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
impl NixlDescriptor for PinnedStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
impl NixlDescriptor for DeviceStorage {
fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'_, BlockKind, IsImmutable> {
NixlMemoryDescriptor::new(self.as_ptr() as *const u8, self.size())
}
fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'_, BlockKind, IsMutable> {
NixlMemoryDescriptor::new_mut(self.as_mut_ptr() as *mut u8, self.size())
}
}
*/
...@@ -303,6 +303,17 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator { ...@@ -303,6 +303,17 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator {
} }
} }
/// An enum indicating the type of device storage.
/// This is needed to ensure ownership of memory is correctly handled.
/// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that
/// the torch tensor is not GCed until the [`DeviceStorage`] is dropped.
/// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`]
#[derive(Debug)]
enum DeviceStorageType {
Owned, // Memory that we allocated ourselves.
Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor.
}
/// CUDA device memory storage /// CUDA device memory storage
#[derive(Debug)] #[derive(Debug)]
pub struct DeviceStorage { pub struct DeviceStorage {
...@@ -310,6 +321,7 @@ pub struct DeviceStorage { ...@@ -310,6 +321,7 @@ pub struct DeviceStorage {
size: usize, size: usize,
ctx: Arc<CudaContext>, ctx: Arc<CudaContext>,
handles: RegistrationHandles, handles: RegistrationHandles,
_storage_type: DeviceStorageType,
} }
impl Local for DeviceStorage {} impl Local for DeviceStorage {}
...@@ -326,6 +338,35 @@ impl DeviceStorage { ...@@ -326,6 +338,35 @@ impl DeviceStorage {
size, size,
ctx: ctx.clone(), ctx: ctx.clone(),
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Owned,
})
}
pub fn new_from_torch(
ctx: &Arc<CudaContext>,
tensor: Arc<dyn TorchTensor>,
) -> Result<Self, StorageError> {
let device = tensor.device();
let TorchDevice::Cuda(device_id) = device else {
return Err(StorageError::InvalidConfig("Tensor is not CUDA!".into()));
};
if device_id != ctx.cu_device() as usize {
return Err(StorageError::InvalidConfig(
"Tensor is not on the same device as the context!".into(),
));
}
let data_ptr = tensor.data_ptr();
let size = tensor.size_bytes();
Ok(Self {
ptr: data_ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Torch { _tensor: tensor },
}) })
} }
...@@ -366,7 +407,14 @@ impl CudaContextProivder for DeviceStorage { ...@@ -366,7 +407,14 @@ impl CudaContextProivder for DeviceStorage {
impl Drop for DeviceStorage { impl Drop for DeviceStorage {
fn drop(&mut self) { fn drop(&mut self) {
self.handles.release(); self.handles.release();
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap(); match &self._storage_type {
DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
}
DeviceStorageType::Torch { _tensor } => {
// Do nothing. The torch storage is resposible for cleaning up itself.
}
}
} }
} }
...@@ -419,3 +467,100 @@ impl StorageAllocator<DeviceStorage> for DeviceAllocator { ...@@ -419,3 +467,100 @@ impl StorageAllocator<DeviceStorage> for DeviceAllocator {
DeviceStorage::new(&self.ctx, size) DeviceStorage::new(&self.ctx, size)
} }
} }
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockTensor {
device: TorchDevice,
data_ptr: u64,
size_bytes: usize,
}
impl MockTensor {
pub fn new(device: TorchDevice, data_ptr: u64, size_bytes: usize) -> Self {
Self {
device,
data_ptr,
size_bytes,
}
}
}
impl TorchTensor for MockTensor {
fn device(&self) -> TorchDevice {
self.device.clone()
}
fn data_ptr(&self) -> u64 {
self.data_ptr
}
fn size_bytes(&self) -> usize {
self.size_bytes
}
fn shape(&self) -> Vec<usize> {
vec![self.size_bytes]
}
fn stride(&self) -> Vec<usize> {
vec![1]
}
}
#[test]
fn test_device_storage_from_torch_valid_tensor() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage =
std::mem::ManuallyDrop::new(DeviceStorage::new(&ctx, size_bytes).unwrap());
let tensor = MockTensor::new(TorchDevice::Cuda(0), actual_storage.addr(), size_bytes);
let storage = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor)).unwrap();
assert_eq!(storage.size(), size_bytes);
assert_eq!(storage.storage_type(), StorageType::Device(0));
assert_eq!(storage.addr(), actual_storage.addr());
}
#[test]
fn test_device_storage_from_torch_cpu_tensor_fails() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(
TorchDevice::Other("cpu".to_string()),
actual_storage.addr(),
size_bytes,
);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
if let Err(StorageError::InvalidConfig(msg)) = result {
assert!(msg.contains("Tensor is not CUDA"));
} else {
panic!("Expected InvalidConfig error for CPU tensor");
}
}
#[test]
fn test_device_storage_wrong_device() {
let ctx = Cuda::device_or_create(0).expect("Failed to create CUDA context");
let size_bytes = 1024;
let actual_storage = DeviceStorage::new(&ctx, size_bytes).unwrap();
let tensor = MockTensor::new(TorchDevice::Cuda(1), actual_storage.addr(), size_bytes);
let result = DeviceStorage::new_from_torch(&ctx, Arc::new(tensor));
assert!(result.is_err());
}
}
...@@ -17,16 +17,21 @@ use super::*; ...@@ -17,16 +17,21 @@ use super::*;
use core::ffi::c_char; use core::ffi::c_char;
use nix::fcntl::{fallocate, FallocateFlags}; use nix::fcntl::{fallocate, FallocateFlags};
use nix::unistd::unlink;
use std::ffi::CStr;
use std::ffi::CString; use std::ffi::CString;
use std::fs::File; use std::path::Path;
use std::os::unix::io::{AsRawFd, FromRawFd};
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
#[derive(Debug)] #[derive(Debug)]
pub struct DiskStorage { pub struct DiskStorage {
file: File, fd: u64,
file_name: String, file_name: String,
size: usize, size: usize,
handles: RegistrationHandles, handles: RegistrationHandles,
unlinked: bool,
} }
impl Local for DiskStorage {} impl Local for DiskStorage {}
...@@ -37,7 +42,17 @@ impl DiskStorage { ...@@ -37,7 +42,17 @@ impl DiskStorage {
// We need to open our file with some special flags that aren't supported by the tempfile crate. // We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags. // Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let template = CString::new("/tmp/dynamo-kvbm-disk-cache-XXXXXX").unwrap(); let specified_dir =
std::env::var(DISK_CACHE_KEY).unwrap_or_else(|_| DEFAULT_DISK_CACHE_DIR.to_string());
let file_path = Path::new(&specified_dir).join("dynamo-kvbm-disk-cache-XXXXXX");
if !file_path.exists() {
std::fs::create_dir_all(file_path.parent().unwrap()).unwrap();
}
tracing::debug!("Allocating disk cache file at {}", file_path.display());
let template = CString::new(file_path.to_str().unwrap()).unwrap();
let mut template_bytes = template.into_bytes_with_nul(); let mut template_bytes = template.into_bytes_with_nul();
let raw_fd = unsafe { let raw_fd = unsafe {
...@@ -50,45 +65,63 @@ impl DiskStorage { ...@@ -50,45 +65,63 @@ impl DiskStorage {
) )
}; };
let file = unsafe { File::from_raw_fd(raw_fd) }; let file_name = CStr::from_bytes_with_nul(template_bytes.as_slice())
let file_name = String::from_utf8_lossy(&template_bytes) .unwrap()
.trim_end_matches("\0") .to_str()
.map_err(|e| {
StorageError::AllocationFailed(format!("Failed to read temp file name: {}", e))
})?
.to_string(); .to_string();
file.set_len(size as u64).map_err(|_| {
StorageError::AllocationFailed("Failed to set temp file size".to_string())
})?;
// File::set_len() only updates the metadata of the file, it does not allocate the underlying storage.
// We need to use fallocate to actually allocate the storage and create the blocks on disk. // We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(file.as_raw_fd(), FallocateFlags::empty(), 0, size as i64).map_err(|_| { fallocate(raw_fd, FallocateFlags::empty(), 0, size as i64).map_err(|e| {
StorageError::AllocationFailed("Failed to allocate temp file".to_string()) StorageError::AllocationFailed(format!("Failed to allocate temp file: {}", e))
})?; })?;
Ok(Self { Ok(Self {
file, fd: raw_fd as u64,
file_name, file_name,
size, size,
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
unlinked: false,
}) })
} }
pub fn fd(&self) -> u64 { pub fn fd(&self) -> u64 {
self.file.as_raw_fd() as u64 self.fd
}
/// Unlink our temp file.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
pub fn unlink(&mut self) -> Result<(), StorageError> {
if self.unlinked {
return Ok(());
}
self.unlinked = true;
unlink(self.file_name.as_str()).map_err(|e| {
StorageError::AllocationFailed(format!("Failed to unlink temp file: {}", e))
})
}
pub fn unlinked(&self) -> bool {
self.unlinked
} }
} }
impl Drop for DiskStorage { impl Drop for DiskStorage {
// TODO: How robust is this actually?
fn drop(&mut self) { fn drop(&mut self) {
self.handles.release(); self.handles.release();
std::fs::remove_file(self.file_name.clone()).unwrap(); let _ = self.unlink();
} }
} }
impl Storage for DiskStorage { impl Storage for DiskStorage {
fn storage_type(&self) -> StorageType { fn storage_type(&self) -> StorageType {
StorageType::Disk StorageType::Disk(self.fd())
} }
fn addr(&self) -> u64 { fn addr(&self) -> u64 {
......
...@@ -156,7 +156,7 @@ impl StorageType { ...@@ -156,7 +156,7 @@ impl StorageType {
StorageType::Device(_) => MemType::Vram, StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown, StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown, StorageType::Null => MemType::Unknown,
StorageType::Disk => MemType::File, StorageType::Disk(_) => MemType::File,
} }
} }
} }
...@@ -169,6 +169,15 @@ impl RegistationHandle for NixlRegistrationHandle { ...@@ -169,6 +169,15 @@ impl RegistationHandle for NixlRegistrationHandle {
} }
} }
fn handle_nixl_register<S: NixlRegisterableStorage>(
storage: &mut S,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(storage, opt_args)?);
storage.register("nixl", handle)
}
/// Extension to the [`RegisterableStorage`] trait for NIXL-compatible storage. /// Extension to the [`RegisterableStorage`] trait for NIXL-compatible storage.
pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized { pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized {
/// Register the storage with the NIXL agent. /// Register the storage with the NIXL agent.
...@@ -177,9 +186,7 @@ pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized ...@@ -177,9 +186,7 @@ pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized
agent: &NixlAgent, agent: &NixlAgent,
opt_args: Option<&OptArgs>, opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> { ) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(self, opt_args)?); handle_nixl_register(self, agent, opt_args)
// Assuming PinnedStorage has `handles: RegistrationHandles`
self.register("nixl", handle)
} }
/// Check if the storage is registered with the NIXL agent. /// Check if the storage is registered with the NIXL agent.
...@@ -379,7 +386,23 @@ impl NixlDescriptor for DeviceStorage { ...@@ -379,7 +386,23 @@ impl NixlDescriptor for DeviceStorage {
} }
impl NixlAccessible for DiskStorage {} impl NixlAccessible for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {} impl NixlRegisterableStorage for DiskStorage {
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
if self.unlinked() {
return Err(StorageError::AllocationFailed(
"Disk storage has already been unlinked. GDS registration will fail.".to_string(),
));
}
handle_nixl_register(self, agent, opt_args)?;
self.unlink()?;
Ok(())
}
}
impl MemoryRegion for DiskStorage { impl MemoryRegion for DiskStorage {
unsafe fn as_ptr(&self) -> *const u8 { unsafe fn as_ptr(&self) -> *const u8 {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
...@@ -22,6 +22,7 @@ where ...@@ -22,6 +22,7 @@ where
} }
/// A generic recorder for events that streams directly to a JSONL file /// A generic recorder for events that streams directly to a JSONL file
#[derive(Debug)]
pub struct Recorder<T> { pub struct Recorder<T> {
/// A sender for events that can be cloned and shared with producers /// A sender for events that can be cloned and shared with producers
event_tx: mpsc::Sender<T>, event_tx: mpsc::Sender<T>,
......
...@@ -87,6 +87,12 @@ impl From<&[Token]> for Tokens { ...@@ -87,6 +87,12 @@ impl From<&[Token]> for Tokens {
} }
} }
impl From<Vec<usize>> for Tokens {
fn from(tokens: Vec<usize>) -> Self {
Tokens(tokens.into_iter().map(|t| t as u32).collect())
}
}
impl From<Vec<i32>> for Tokens { impl From<Vec<i32>> for Tokens {
/// Converts `Vec<i32>` to `Tokens`, casting each `i32` to `u32`. /// Converts `Vec<i32>` to `Tokens`, casting each `i32` to `u32`.
fn from(tokens: Vec<i32>) -> Self { fn from(tokens: Vec<i32>) -> Self {
...@@ -460,6 +466,11 @@ impl TokenBlock { ...@@ -460,6 +466,11 @@ impl TokenBlock {
pub fn parent_sequence_hash(&self) -> Option<SequenceHash> { pub fn parent_sequence_hash(&self) -> Option<SequenceHash> {
self.parent_sequence_hash self.parent_sequence_hash
} }
/// Returns the number of tokens in the block.
pub fn block_size(&self) -> usize {
self.tokens.0.len()
}
} }
/// Represents a sequence of tokens, segmented into fixed-size, hashed blocks. /// Represents a sequence of tokens, segmented into fixed-size, hashed blocks.
...@@ -481,6 +492,7 @@ pub struct TokenBlockSequence { ...@@ -481,6 +492,7 @@ pub struct TokenBlockSequence {
blocks: Vec<TokenBlock>, blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock, current_block: PartialTokenBlock,
salt_hash: SaltHash, salt_hash: SaltHash,
block_size: usize,
} }
impl TokenBlockSequence { impl TokenBlockSequence {
...@@ -507,6 +519,7 @@ impl TokenBlockSequence { ...@@ -507,6 +519,7 @@ impl TokenBlockSequence {
blocks, blocks,
current_block, current_block,
salt_hash, salt_hash,
block_size: block_size as usize,
} }
} }
...@@ -545,14 +558,12 @@ impl TokenBlockSequence { ...@@ -545,14 +558,12 @@ impl TokenBlockSequence {
tokens_to_append = self.current_block.push_tokens(available_tokens); tokens_to_append = self.current_block.push_tokens(available_tokens);
// Check if the current block *became* full after pushing tokens // Check if the current block *became* full after pushing tokens
if self.current_block.remaining() == 0 && !tokens_to_append.is_empty() { if self.current_block.remaining() == 0 {
// If it became full AND there are still more tokens to append, // If it became full AND there are still more tokens to append,
// commit it now so the next loop iteration starts with a fresh block. // commit it now so the next loop iteration starts with a fresh block.
let new_block = self.current_block.commit()?; let new_block = self.current_block.commit()?;
self.blocks.push(new_block); self.blocks.push(new_block);
} }
// If it became full and there are NO more tokens, the loop will exit,
// and the block remains partial but full, ready for the next append/commit.
} }
let end_block_index = self.blocks.len(); let end_block_index = self.blocks.len();
...@@ -708,6 +719,13 @@ impl TokenBlockSequence { ...@@ -708,6 +719,13 @@ impl TokenBlockSequence {
self.truncate(len) self.truncate(len)
} }
/// Resets the sequence to the initial state.
pub fn reset(&mut self) {
self.blocks.clear();
self.current_block =
PartialTokenBlock::create_sequence_root(self.block_size as u32, self.salt_hash);
}
/// Removes the last token from the sequence and returns it, or [`None`] if it is empty. /// Removes the last token from the sequence and returns it, or [`None`] if it is empty.
/// ///
/// This operation is analogous to `Vec::pop`. /// This operation is analogous to `Vec::pop`.
...@@ -779,6 +797,11 @@ impl TokenBlockSequence { ...@@ -779,6 +797,11 @@ impl TokenBlockSequence {
(self.blocks, self.current_block) (self.blocks, self.current_block)
} }
/// Returns the block size used for this sequence.
pub fn block_size(&self) -> usize {
self.block_size
}
/// Returns the [`SaltHash`] used for this sequence. /// Returns the [`SaltHash`] used for this sequence.
pub fn salt_hash(&self) -> SaltHash { pub fn salt_hash(&self) -> SaltHash {
self.salt_hash self.salt_hash
...@@ -791,6 +814,38 @@ impl TokenBlockSequence { ...@@ -791,6 +814,38 @@ impl TokenBlockSequence {
(self.blocks.len() * block_size) + self.current_block.len() (self.blocks.len() * block_size) + self.current_block.len()
} }
/// Extract the token with the range
pub fn tokens_at(&self, range: Range<usize>) -> Tokens {
let total = self.total_tokens();
// Validate range - return empty tokens for invalid ranges
if range.start > range.end || range.end > total {
return Tokens::default();
}
// Handle empty range
if range.is_empty() {
return Tokens::default();
}
let mut result = Vec::with_capacity(range.len());
for i in range {
if i < self.blocks.len() * self.block_size {
// Token is in a completed block
let block_index = i / self.block_size;
let token_index = i % self.block_size;
result.push(self.blocks[block_index].tokens()[token_index]);
} else {
// Token is in the current partial block
let current_block_index = i - (self.blocks.len() * self.block_size);
result.push(self.current_block.tokens()[current_block_index]);
}
}
Tokens::from(result)
}
/// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block. /// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block.
/// ///
/// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally. /// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally.
...@@ -857,6 +912,7 @@ impl TokenBlockSequence { ...@@ -857,6 +912,7 @@ impl TokenBlockSequence {
blocks, blocks,
current_block, current_block,
salt_hash, salt_hash,
block_size: block_size as usize,
} }
} }
} }
...@@ -1109,6 +1165,15 @@ mod tests { ...@@ -1109,6 +1165,15 @@ mod tests {
Some(SEQ_HASH_5_8) Some(SEQ_HASH_5_8)
); );
// Test tokens_at across blocks and partial block
assert_eq!(seq_multi.tokens_at(0..4).as_ref(), &[1, 2, 3, 4]); // First complete block
assert_eq!(seq_multi.tokens_at(4..8).as_ref(), &[5, 6, 7, 8]); // Second complete block
assert_eq!(seq_multi.tokens_at(8..9).as_ref(), &[9]); // Current partial block
assert_eq!(seq_multi.tokens_at(2..6).as_ref(), &[3, 4, 5, 6]); // Spanning blocks
assert_eq!(seq_multi.tokens_at(6..9).as_ref(), &[7, 8, 9]); // Spanning to partial
assert_eq!(seq_multi.tokens_at(5..5).as_ref(), &[0u32; 0]); // Empty range
assert_eq!(seq_multi.tokens_at(10..15).as_ref(), &[0u32; 0]); // Out of bounds
// No salt hash // No salt hash
let seq_no_salt = create_test_sequence(&[1, 2, 3, 4, 5], 4, None); let seq_no_salt = create_test_sequence(&[1, 2, 3, 4, 5], 4, None);
assert_eq!(seq_no_salt.salt_hash(), 0); assert_eq!(seq_no_salt.salt_hash(), 0);
...@@ -1142,22 +1207,22 @@ mod tests { ...@@ -1142,22 +1207,22 @@ mod tests {
assert_eq!(sequence.current_block().tokens.as_ref(), &[9, 10, 11]); assert_eq!(sequence.current_block().tokens.as_ref(), &[9, 10, 11]);
// Append token 12 - should complete block 2 (index 2) // Append token 12 - should complete block 2 (index 2)
// This will also commit block 2
let completed_idx = sequence.append(12).unwrap(); let completed_idx = sequence.append(12).unwrap();
assert_eq!(completed_idx, None); // Lazy commit: extend returns None assert_eq!(completed_idx, Some(2));
assert_eq!(sequence.blocks().len(), 2); // Block 2 not added yet assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.current_block.tokens.as_ref(), &[9, 10, 11, 12]); // Current block is now full assert_eq!(sequence.current_block.tokens.as_ref(), &[0u32; 0]);
assert_eq!(sequence.current_block.remaining(), 0); assert_eq!(sequence.current_block.remaining(), 4);
assert_eq!( assert_eq!(
sequence.current_block().parent_sequence_hash, sequence.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8) Some(SEQ_HASH_9_12)
); // Still linked to block 1 ); // Still linked to block 1
// Append token 13 - should not complete a block // Append token 13 - should not complete a block
// NOW appending 13 should first commit block 2, then add 13 to the new current
let completed_idx_13 = sequence.append(13).unwrap(); let completed_idx_13 = sequence.append(13).unwrap();
assert_eq!(completed_idx_13, Some(2)); // Block 2 (index 2) was completed by this append assert_eq!(completed_idx_13, None);
assert_eq!(sequence.blocks.len(), 3); // Now 3 blocks committed assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.blocks[2].tokens().as_ref(), &[9, 10, 11, 12]); // Verify committed block 2 assert_eq!(sequence.blocks[2].tokens().as_ref(), &[9, 10, 11, 12]);
assert_eq!(sequence.blocks[2].sequence_hash(), SEQ_HASH_9_12); assert_eq!(sequence.blocks[2].sequence_hash(), SEQ_HASH_9_12);
assert_eq!(sequence.current_block.tokens.as_ref(), &[13]); // New current block has 13 assert_eq!(sequence.current_block.tokens.as_ref(), &[13]); // New current block has 13
assert_eq!(sequence.current_block.remaining(), 3); assert_eq!(sequence.current_block.remaining(), 3);
...@@ -1180,16 +1245,17 @@ mod tests { ...@@ -1180,16 +1245,17 @@ mod tests {
assert_eq!(seq1.blocks.len(), 0); assert_eq!(seq1.blocks.len(), 0);
assert_eq!(seq1.current_block.tokens.as_ref(), &[1, 2]); assert_eq!(seq1.current_block.tokens.as_ref(), &[1, 2]);
assert_eq!(seq1.current_block.remaining(), 2); assert_eq!(seq1.current_block.remaining(), 2);
assert_eq!(seq1.current_block.parent_sequence_hash, None); // Still the root block
// Case 2: Extend exactly block size // Case 2: Extend exactly block size
let mut seq2 = create_test_sequence(&[], block_size, salt_hash); let mut seq2 = create_test_sequence(&[], block_size, salt_hash);
let tokens2 = Tokens::from(vec![1, 2, 3, 4]); let tokens2 = Tokens::from(vec![1, 2, 3, 4]);
let completed2 = seq2.extend(tokens2).unwrap(); let completed2 = seq2.extend(tokens2).unwrap();
assert_eq!(completed2, None); // Block is full but not committed yet assert_eq!(completed2, Some(0..1));
assert_eq!(seq2.blocks.len(), 0); // No blocks committed assert_eq!(seq2.blocks.len(), 1);
assert_eq!(seq2.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full assert_eq!(seq2.current_block.tokens.as_ref(), &[0u32; 0]); // Current block is empty
assert_eq!(seq2.current_block.remaining(), 0); assert_eq!(seq2.current_block.remaining(), 4);
assert_eq!(seq2.current_block.parent_sequence_hash, None); // Still the root block assert_eq!(seq2.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Still the root block
// Case 3: Extend more than block size, less than two blocks // Case 3: Extend more than block size, less than two blocks
let mut seq3 = create_test_sequence(&[], block_size, salt_hash); let mut seq3 = create_test_sequence(&[], block_size, salt_hash);
...@@ -1206,13 +1272,13 @@ mod tests { ...@@ -1206,13 +1272,13 @@ mod tests {
let mut seq4 = create_test_sequence(&[], block_size, salt_hash); let mut seq4 = create_test_sequence(&[], block_size, salt_hash);
let tokens4 = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); let tokens4 = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
let completed4 = seq4.extend(tokens4).unwrap(); let completed4 = seq4.extend(tokens4).unwrap();
assert_eq!(completed4, Some(0..1)); // Only block 0 is committed assert_eq!(completed4, Some(0..2)); // Only block 0 is committed
assert_eq!(seq4.blocks.len(), 1); // Only 1 block committed assert_eq!(seq4.blocks.len(), 2); // Only 1 block committed
assert_eq!(seq4.current_block.tokens.as_ref(), &[5, 6, 7, 8]); // Current block holds the second block's tokens assert_eq!(seq4.current_block.tokens.as_ref(), &[0u32; 0]);
assert_eq!(seq4.current_block.remaining(), 0); // Current block is full assert_eq!(seq4.current_block.remaining(), 4);
assert_eq!(seq4.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]); assert_eq!(seq4.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq4.blocks[0].sequence_hash(), SEQ_HASH_1_4); assert_eq!(seq4.blocks[0].sequence_hash(), SEQ_HASH_1_4);
assert_eq!(seq4.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Parent is the first block assert_eq!(seq4.current_block.parent_sequence_hash, Some(SEQ_HASH_5_8)); // Parent is the first block
// Case 5: Extend multiple times, completing blocks across calls // Case 5: Extend multiple times, completing blocks across calls
let mut seq5 = create_test_sequence(&[], block_size, salt_hash); let mut seq5 = create_test_sequence(&[], block_size, salt_hash);
...@@ -1252,12 +1318,18 @@ mod tests { ...@@ -1252,12 +1318,18 @@ mod tests {
let mut seq7 = create_test_sequence(&[1, 2], block_size, salt_hash); let mut seq7 = create_test_sequence(&[1, 2], block_size, salt_hash);
let tokens7 = Tokens::from(vec![3, 4]); let tokens7 = Tokens::from(vec![3, 4]);
let completed7 = seq7.extend(tokens7).unwrap(); let completed7 = seq7.extend(tokens7).unwrap();
assert_eq!(completed7, None); // Block is full but not committed yet assert_eq!(completed7, Some(0..1)); // Block is full but not committed yet
assert_eq!(seq7.blocks.len(), 0); assert_eq!(seq7.blocks.len(), 1);
assert_eq!(seq7.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full assert_eq!(seq7.current_block.tokens.as_ref(), &[0u32; 0]); // Current block is full
assert_eq!(seq7.current_block.remaining(), 0); assert_eq!(seq7.current_block.remaining(), 4);
assert_eq!(seq7.total_tokens(), 4); assert_eq!(seq7.total_tokens(), 4);
assert_eq!(seq7.current_block.parent_sequence_hash, None); // Still the root block assert_eq!(seq7.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Still the root block
// Test tokens_at extraction
assert_eq!(seq7.tokens_at(0..2).as_ref(), &[1, 2]);
assert_eq!(seq7.tokens_at(1..3).as_ref(), &[2, 3]);
assert_eq!(seq7.tokens_at(0..4).as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq7.tokens_at(2..2).as_ref(), &[0u32; 0]); // Empty range
} }
#[test] #[test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment