Unverified Commit 42ce6931 authored by Harrison Saturley-Hall's avatar Harrison Saturley-Hall Committed by GitHub
Browse files

feat: kv block manager (#965) (#1021)


Co-authored-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
parent cafc74eb
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use super::block::registry::RegistrationHandle;
/// The [EventManager] is not responsible for managing the history of the blocks, nor what
/// events have been published.
///
/// The [EventManager] is only responsible for issuing events on state changes. In this case,
/// there are two states:
///
/// - Store: a dynamo event plane message will be published which defines the registration/storing
/// of the block. Details include, but are not limited to, the sequence/prefix hash, the local block
/// hash, the sequence position of the block, the block size, and the storage location/class which
/// the block is stored in.
///
/// - Remove: a dynamo event plane message will be published which defines the removal of the block
/// from the cache. This messasge will include enough information to identify the block within a
/// storage hierarchy; minmally, the sequence hash and the storage location/class.
///
/// The [RegistrationHandle] associated from [EventManager::block_register] call is an RAII object
/// which will trigger a `Remove` event on being dropped.
pub trait EventManager: EventPublisher + EventReleaseManager + Send + Sync {
// fn register_block(&self, token_block: &TokenBlock) -> PublishHandle;
// fn publisher(&self) -> Publisher;
}
pub trait EventPublisher: Send + Sync {
fn publish(&self, handles: Vec<Arc<RegistrationHandle>>);
}
pub trait EventReleaseManager: Send + Sync {
fn block_release(&self, registration_handle: &RegistrationHandle);
}
/// A handle to a registered block.
///
/// Ensures that the register event published before the release event by
/// holding an [Arc] to the [RegistrationHandle], which by extension holds
/// issues the release event when dropped.
///
/// Ownership of the [PublishHandle] transferred to a [Publisher] object
/// which is responsible for coordinating the publication of multiple
/// registration events.
pub struct PublishHandle {
handle: Arc<RegistrationHandle>,
publisher: Option<Arc<dyn EventPublisher>>,
}
impl PublishHandle {
pub fn new(handle: RegistrationHandle, publisher: Arc<dyn EventPublisher>) -> Self {
let handle = Arc::new(handle);
let publisher = Some(publisher);
Self { handle, publisher }
}
pub fn remove_handle(&self) -> Arc<RegistrationHandle> {
self.handle.clone()
}
fn disarm(&mut self) {
self.publisher = None;
}
}
impl Drop for PublishHandle {
fn drop(&mut self) {
if let Some(publisher) = self.publisher.take() {
publisher.publish(vec![self.handle.clone()]);
}
}
}
/// Responsible for publishing multiple registration events.
///
/// Because [EventPublisher::publish] takes a list of shared [RegistrationHandles][RegistrationHandle]
/// this allows the [EventPublisher] logic to optimize the number of events published
/// by consoldiate multiple registration events with additional sequence logic.
///
/// The behavior of the [EventPublisher] is left entirely up to the the implementor.
#[derive(Clone)]
pub struct Publisher {
handles: Vec<Arc<RegistrationHandle>>,
publisher: Arc<dyn EventPublisher>,
}
impl Publisher {
pub fn new(publisher: Arc<dyn EventPublisher>) -> Self {
Self {
handles: Vec::new(),
publisher,
}
}
pub fn take_handle(&mut self, publish_handle: PublishHandle) -> Arc<RegistrationHandle> {
let handle = publish_handle.remove_handle();
self.handles.push(handle.clone());
let mut publish_handle = publish_handle;
publish_handle.disarm();
handle
}
pub fn publish(&mut self) {
let handles = std::mem::take(&mut self.handles);
if !handles.is_empty() {
self.publisher.publish(handles);
}
}
}
impl Drop for Publisher {
fn drop(&mut self) {
self.publish();
}
}
// Implementation notes:
//
// - Removable events are per blocks. I think we will want to leverage a task to collect drop/remove
// events so that we can batch them together.
//
// - Registration events are can be batched by the nature of the [EventManager::register_blocks] call.
pub struct NullEventManager;
impl NullEventManager {
pub fn new() -> Arc<Self> {
Arc::new(Self {})
}
}
impl EventManager for NullEventManager {}
impl EventPublisher for NullEventManager {
fn publish(&self, _handles: Vec<Arc<RegistrationHandle>>) {}
}
impl EventReleaseManager for NullEventManager {
fn block_release(&self, _registration_handle: &RegistrationHandle) {}
}
#[cfg(test)]
pub mod tests {
use crate::tokens::SequenceHash;
use super::*;
#[derive(Debug, PartialEq, Eq)]
pub enum EventType {
Register(SequenceHash),
Remove(SequenceHash),
}
pub struct MockEventManager {
tx: tokio::sync::mpsc::UnboundedSender<Vec<EventType>>,
}
impl MockEventManager {
pub fn new() -> (
Arc<Self>,
tokio::sync::mpsc::UnboundedReceiver<Vec<EventType>>,
) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(Arc::new(Self { tx }), rx)
}
pub fn publisher(self: &Arc<Self>) -> Publisher {
Publisher::new(self.clone())
}
}
impl EventManager for MockEventManager {}
impl EventPublisher for MockEventManager {
fn publish(&self, handles: Vec<Arc<RegistrationHandle>>) {
let events = handles
.iter()
.map(|handle| EventType::Register(handle.sequence_hash()))
.collect::<Vec<_>>();
self.tx.send(events).unwrap();
}
}
impl EventReleaseManager for MockEventManager {
fn block_release(&self, registration_handle: &RegistrationHandle) {
let events = vec![EventType::Remove(registration_handle.sequence_hash())];
self.tx.send(events).unwrap();
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::{
kv_router::{
indexer::RouterEvent,
protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData,
KvCacheStoreData, KvCacheStoredBlockData, LocalBlockHash,
},
KV_EVENT_SUBJECT,
},
tokens::BlockHash,
};
use derive_getters::{Dissolve, Getters};
use dynamo_runtime::traits::events::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
raise, Result,
};
use std::sync::Arc;
use tokio::sync::mpsc;
pub enum DynamoPublisher {
Component(Component),
Namespace(Namespace),
}
impl DynamoPublisher {
pub async fn publish(&self, event: RouterEvent) -> Result<()> {
match self {
DynamoPublisher::Component(component) => {
component.publish(KV_EVENT_SUBJECT, &event).await
}
DynamoPublisher::Namespace(namespace) => {
namespace.publish(KV_EVENT_SUBJECT, &event).await
}
}
}
}
struct EventChannel {
tx: mpsc::UnboundedSender<Event>,
}
impl EventReleaseManager for EventChannel {
// Generalize sequence_hash
fn block_release(&self, sequence_hash: SequenceHash) {
if self.tx.send(Event::RemoveSingle(sequence_hash)).is_err() {
tracing::warn!("Failed to send remove block event");
}
}
}
pub struct NatsEventManager {
event_channel: Arc<EventChannel>,
}
impl NatsEventManager {
// todo - generalize identifier
pub async fn new(publisher: DynamoPublisher, worker_identifier: u64) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
let state = NatsEventsManagerState {
rx,
publisher,
worker_identifier,
};
tokio::spawn(progress_engine(state));
Self {
event_channel: Arc::new(EventChannel { tx }),
}
}
}
impl std::fmt::Debug for NatsEventManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NatsEventManager")
}
}
impl EventManager for NatsEventManager {
fn register_block(&self, token_block: &TokenBlock) -> Result<RegistrationHandle> {
let event = Event::StoreSingle(RegisterBlockEvent {
block_hash: LocalBlockHash(token_block.block_hash()),
sequence_hash: ExternalSequenceBlockHash(token_block.sequence_hash()),
parent_hash: token_block
.parent_sequence_hash()
.map(ExternalSequenceBlockHash),
});
if self.event_channel.tx.send(event).is_err() {
tracing::warn!("Failed to send store block event");
raise!("Failed to send store block event");
}
Ok(RegistrationHandle {
sequence_hash: token_block.sequence_hash(),
release_manager: Some(self.event_channel.clone()),
})
}
fn register_blocks(&self, token_blocks: &[TokenBlock]) -> Result<Vec<RegistrationHandle>> {
let event = Event::StoreMultiple(RegisterBlocksEvent {
hashes: token_blocks
.iter()
.map(|block| {
(
LocalBlockHash(block.block_hash()),
ExternalSequenceBlockHash(block.sequence_hash()),
)
})
.collect(),
parent_hash: token_blocks
.first()
.and_then(|block| block.parent_sequence_hash().map(ExternalSequenceBlockHash)),
});
let handles = token_blocks
.iter()
.map(|block| RegistrationHandle {
sequence_hash: block.sequence_hash(),
release_manager: Some(self.event_channel.clone()),
})
.collect();
if self.event_channel.tx.send(event).is_err() {
tracing::warn!("Failed to send store block event");
raise!("Failed to send store block event");
}
Ok(handles)
}
}
#[derive(Dissolve)]
struct NatsEventsManagerState {
rx: mpsc::UnboundedReceiver<Event>,
publisher: DynamoPublisher,
worker_identifier: WorkerIdentifier,
}
async fn progress_engine(state: NatsEventsManagerState) {
let (mut rx, publisher, worker_identifier) = state.dissolve();
let mut event_id = 0;
while let Some(event) = rx.recv().await {
match event {
Event::StoreSingle(event) => {
let store_data = KvCacheStoreData {
blocks: vec![KvCacheStoredBlockData {
block_hash: event.sequence_hash,
tokens_hash: event.block_hash,
}],
parent_hash: event.parent_hash,
};
let data = KvCacheEventData::Stored(store_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish store event");
}
}
Event::StoreMultiple(event) => {
let store_data = KvCacheStoreData {
blocks: event
.hashes
.iter()
.map(|(local_hash, external_hash)| KvCacheStoredBlockData {
block_hash: *external_hash,
tokens_hash: *local_hash,
})
.collect(),
parent_hash: event.parent_hash,
};
let data = KvCacheEventData::Stored(store_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish store event");
}
}
Event::RemoveSingle(sequence_hash) => {
let remove_data = KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(sequence_hash)],
};
let data = KvCacheEventData::Removed(remove_data);
let event = KvCacheEvent { event_id, data };
let event = RouterEvent::new(worker_identifier as i64, event);
if publisher.publish(event).await.is_err() {
tracing::warn!("Failed to publish remove event");
}
}
}
event_id += 1;
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![deny(missing_docs)]
//! # Block Layout Management 🧱
//!
//! This module is responsible for defining and managing the memory layout of data blocks.
//! It provides the foundational traits and concrete implementations for how blocks,
//! composed of multiple layers and pages, are arranged within a given [`Storage`].
//! The primary goal is to abstract the complexities of memory organization, including
//! contiguity, strides, and alignment, to ensure efficient data access and manipulation.
//!
//! ## Core Concepts
//!
//! ### 1. Layout Traits
//! The module defines a set of traits to ensure a consistent interface across different layout strategies:
//! - [`BlockLayout`]: The central trait that combines configuration and lookup capabilities. It specifies the
//! associated [`StorageType`].
//! - [`BlockLayoutConfig`]: Provides metadata about the layout, such as the number of blocks, layers, page size,
//! and data type.
//! - [`BlockLayoutLookup`]: Offers methods to retrieve the memory address and size of a specific memory region
//! (page) within the layout.
//!
//! ### 2. Layout Configuration
//! The [`LayoutConfig`] struct is used to define the parameters of a block layout, including:
//! - `num_blocks`: Total number of blocks.
//! - `num_layers`: Number of layers per block.
//! - `page_size`: Size of each page (often corresponds to a dimension like sequence length or number of tokens).
//! - `inner_dim`: The inner dimension of the data (e.g., hidden size).
//! - `alignment`: Required memory alignment for certain operations or hardware. Must be a power of 2.
//! - `dtype`: The data type ([`DType`]) of the elements stored.
//!
//! This configuration is validated to ensure consistency and correctness (e.g., alignment must be a power of 2).
//!
//! ### 3. Concrete Layouts
//! Currently, the primary implemented layout is:
//! - [`FullyContiguous<S>`]: Represents a layout where all blocks and their constituent layers are stored sequentially
//! in a single contiguous memory region provided by the generic storage `S`. It handles potential alignment
//! requirements by calculating a `base_offset` within the provided storage and adjusting strides between blocks if
//! necessary.
//!
//! ### 4. Strides and Alignment
//! The layout calculations meticulously handle strides between layers and blocks. For instance, in [`FullyContiguousConfig`]:
//! - `layer_stride_in_bytes`: The size of one memory region (page).
//! - `natural_block_stride`: The size of one block if there were no additional alignment padding between blocks.
//! - `block_stride_in_bytes`: The actual stride between the start of consecutive blocks, potentially larger than
//! `natural_block_stride` to meet `alignment` requirements.
//! - `base_offset`: An offset applied from the start of the allocated [`Storage`] to ensure the first block's
//! data begins at an aligned address.
//!
//! The function `align_up` is a utility to ensure values are aligned to the nearest multiple of a power-of-2 alignment.
//!
//! ### 5. Storage Interaction
//! Layouts are tightly coupled with the [`Storage`] trait from the `super::storage` module.
//! The [`BlockLayout::allocate`] method uses a [`StorageAllocator`] to obtain the necessary memory,
//! calculating the required size including any padding for alignment.
//!
//! ### 6. Error Handling
//! Operations within this module can result in [`LayoutError`], which covers issues like invalid configuration, validation errors, or out-of-bounds indexing.
//!
//! ## Usage Example
//!
//! ```rust
//! use dynamo_llm::block_manager::layout::{
//! LayoutConfig, FullyContiguous, BlockLayout, BlockLayoutLookup, BlockLayoutConfig,
//! };
//! use dynamo_llm::block_manager::storage::{SystemAllocator, StorageType};
//! use dynamo_llm::common::dtype::DType;
//!
//! // Define the layout configuration
//! let config = LayoutConfig::builder()
//! .num_blocks(10)
//! .num_layers(4)
//! .page_size(16)
//! .inner_dim(128)
//! .dtype(DType::FP16)
//! .build()
//! .unwrap();
//!
//!
//! // Allocate a FullyContiguous layout using a SystemAllocator
//! let allocator = SystemAllocator;
//! let layout = FullyContiguous::allocate(config, &allocator).unwrap();
//!
//! // Access layout properties
//! assert_eq!(layout.num_blocks(), 10);
//! assert_eq!(layout.storage_type(), StorageType::System);
//!
//! // Get the address of a specific page
//! let addr = layout.memory_region_addr(0, 0).unwrap();
//! println!("Address of block 0, layer 0: {}", addr);
//! ```
//!
//! ## NIXL Integration
//! This module also includes a submodule `nixl` ([`crate::block_manager::layout::nixl`])
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! layouts to be registered and serialized for use in distributed environments.
pub mod nixl;
use thiserror::Error;
use crate::block_manager::storage::{Storage, StorageAllocator};
use crate::common::dtype::DType;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use validator::Validate;
use super::storage::StorageType;
/// Errors that can occur during layout operations
#[derive(Debug, Error)]
#[allow(missing_docs)]
pub enum LayoutError {
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Validation failed: {0}")]
ValidationError(#[from] validator::ValidationErrors),
#[error("Invalid block index: {0}")]
InvalidBlockIndex(usize),
#[error("Invalid layer index: {0}")]
InvalidLayerIndex(usize),
#[error("Operation failed: {0}")]
OperationFailed(String),
#[error("Serialization error: {0}")]
SerdeError(#[from] serde_json::Error),
}
/// Storage pattern for layers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LayoutType {
/// All layers are contiguous in memory [n_layers, ...]
FullyContiguous,
// /// Each layer is stored separately with a common stride between blocks
// /// in different layers
// LayerContiguousWithCommonStride,
// /// Each layer is stored separately with no guaranteed stride
// LayerContiguousWithSeparateStride,
// /// Each page is stored separately with no guaranteed stride
// PageContiguousWithSeparateStride,
// /// NullLayout
// /// Used for testing and debugging
// Null,
}
/// Core trait for block layouts
pub trait BlockLayout:
BlockLayoutConfig + BlockLayoutLookup + Send + Sync + std::fmt::Debug
{
/// The type of storage this layout uses
type StorageType: Storage;
/// Get the memory regions for all blocks and layers
fn storage(&self) -> Vec<&Self::StorageType>;
/// Get the mutable memory regions for all blocks and layers
fn storage_mut(&mut self) -> Vec<&mut Self::StorageType>;
/// Storage type for the layout
fn storage_type(&self) -> StorageType;
}
/// Configuration for block layouts
pub trait BlockLayoutConfig: std::fmt::Debug {
/// Returns the layout type
fn layout_type(&self) -> LayoutType;
/// Returns the total number of blocks this layout manages
fn num_blocks(&self) -> usize;
/// Returns the number of layers per block
fn num_layers(&self) -> usize;
/// Returns the size of each block in bytes
fn page_size(&self) -> usize;
/// Returns the inner dimension size
fn inner_dim(&self) -> usize;
}
/// Trait for looking up memory regions in a block layout
pub trait BlockLayoutLookup {
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError>;
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_size(&self) -> usize;
}
/// Configuration for block layouts
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize)]
pub struct LayoutConfig {
/// Number of blocks
#[validate(range(min = 1))]
pub num_blocks: usize,
/// Number of layers
#[validate(range(min = 1))]
pub num_layers: usize,
/// Page size
#[validate(range(min = 1))]
pub page_size: usize,
/// Inner dimension
#[validate(range(min = 1))]
pub inner_dim: usize,
/// Alignment
#[validate(custom(function = "validate_power_of_2"))]
#[builder(default = "1")]
pub alignment: usize,
/// Data type
#[builder(default = "DType::FP16")]
pub dtype: DType,
}
impl LayoutConfig {
/// Builder for LayoutConfig
pub fn builder() -> LayoutConfigBuilder {
LayoutConfigBuilder::default()
}
}
/// Validation function for Option<usize> to check if it's Some(power_of_2).
fn validate_power_of_2(alignment: usize) -> Result<(), validator::ValidationError> {
if !alignment.is_power_of_two() {
// Return validation error if alignment is not a power of 2
return Err(validator::ValidationError::new(
"alignment_must_be_power_of_2",
));
}
// Passes validation if alignment is a power of 2
Ok(())
}
/// Helper to align a value up to the nearest multiple of alignment.
/// Alignment must be a power of 2.
fn align_up(value: usize, alignment: usize) -> usize {
(value + alignment - 1) & !(alignment - 1)
}
/// Internal struct to hold calculated layout dimensions specific to FullyContiguous.
// Module-level, but only used internally by FullyContiguous
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct FullyContiguousConfig {
inner: LayoutConfig,
memory_region_size: usize,
layer_stride_in_bytes: usize,
natural_block_stride: usize,
block_stride_in_bytes: usize, // Aligned if necessary
layout_data_bytes: usize, // Size of the layout data itself (post base offset)
}
impl FullyContiguousConfig {
/// Calculates the core dimensions based on the configuration.
/// Returns an error if the configuration is invalid.
fn new(config: LayoutConfig) -> Result<Self, LayoutError> {
// Validate first, propagating errors via `?`
config.validate()?;
let alignment = config.alignment;
let memory_region_size = config.page_size * config.inner_dim * config.dtype.size_in_bytes();
let layer_stride_in_bytes = memory_region_size;
let natural_block_stride = config.num_layers * layer_stride_in_bytes;
let block_stride_in_bytes = if alignment > 1 {
align_up(natural_block_stride, alignment)
} else {
natural_block_stride
};
let layout_data_bytes =
(config.num_blocks - 1) * block_stride_in_bytes + natural_block_stride;
Ok(Self {
inner: config,
memory_region_size,
layer_stride_in_bytes,
natural_block_stride,
block_stride_in_bytes,
layout_data_bytes,
})
}
/// Calculate the total number of bytes required for allocation, including initial alignment padding.
/// Panics if the provided configuration is invalid.
pub fn required_allocation_size(&self) -> usize {
let initial_padding = if self.inner.alignment > 1 {
self.inner.alignment - 1
} else {
0
};
self.layout_data_bytes + initial_padding
}
}
impl BlockLayoutConfig for FullyContiguousConfig {
fn layout_type(&self) -> LayoutType {
LayoutType::FullyContiguous
}
fn num_blocks(&self) -> usize {
self.inner.num_blocks
}
fn num_layers(&self) -> usize {
self.inner.num_layers
}
fn page_size(&self) -> usize {
self.inner.page_size
}
fn inner_dim(&self) -> usize {
self.inner.inner_dim
}
}
/// Contiguous memory layout where all blocks and layers are sequential
#[derive(Debug)]
pub struct FullyContiguous<S: Storage> {
/// Configuration for the layout
config: FullyContiguousConfig,
/// Storage for the layoutk
storage: S,
/// Storage type for the layout
storage_type: StorageType,
// Offset from storage.addr() to the aligned start of block 0
base_offset: usize,
}
impl<S: Storage> FullyContiguous<S> {
/// Create a new contiguous layout using the provided configuration and pre-allocated storage.
/// Performs validation and calculates strides/offsets.
#[instrument(level = "debug", skip(storage), fields(config = ?config))]
pub fn new(config: LayoutConfig, storage: Vec<S>) -> Result<Self, LayoutError> {
// Calculate dimensions, which includes validation.
let config = FullyContiguousConfig::new(config)?;
if storage.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous layout requires exactly one storage region".to_string(),
));
}
let mut storage = storage;
let storage = storage.remove(0);
let storage_type = storage.storage_type();
let provided_size = storage.size();
let storage_addr = storage.addr();
let alignment = config.inner.alignment;
// Calculate base offset needed to align the start of block 0
let base_offset = if alignment > 1 {
align_up(storage_addr as usize, alignment) - storage_addr as usize
} else {
0
};
let total_required_size_with_offset = base_offset + config.layout_data_bytes;
tracing::debug!(
provided_size,
total_required_size_with_offset,
base_offset,
required_layout_data_bytes = config.layout_data_bytes,
alignment,
"Validating storage size with base offset and alignment"
);
// Validate storage size fits the configuration *with base offset and alignment*
if provided_size < total_required_size_with_offset {
tracing::warn!(
provided_size,
total_required_size_with_offset,
"Storage size too small for aligned layout including base offset"
);
return Err(LayoutError::InvalidConfig(format!(
"Storage size {} is less than required size {} (including base offset for alignment)",
provided_size,
total_required_size_with_offset
)));
}
tracing::debug!(
config.memory_region_size,
config.layer_stride_in_bytes,
config.block_stride_in_bytes,
config.natural_block_stride,
alignment = config.inner.alignment,
base_offset,
"Calculated layout strides (aligned)"
);
Ok(Self {
config,
storage,
storage_type,
base_offset,
})
}
/// Internal constructor used for reconstruction from serialized parts.
/// Assumes the provided config, storage, and base_offset are consistent
/// and skips size/alignment validation against the storage.
pub(crate) fn new_internal(
config: FullyContiguousConfig,
storage: S,
base_offset: usize,
storage_type: StorageType,
) -> Result<Self, LayoutError> {
// Basic check: Ensure the storage address matches expectations based on offset if possible?
// Maybe not strictly necessary if we trust the serialized data.
Ok(Self {
config,
storage,
storage_type,
base_offset,
})
}
/// Allocate storage using the provided allocator and create a new FullyContiguous layout.
///
/// Calculates the required size based on the configuration, allocates the storage
/// (including potential padding for initial alignment), and then constructs the
/// `FullyContiguous` layout instance.
///
/// # Type Parameters
///
/// * `A`: The type of the storage allocator, implementing `StorageAllocator<S>`.
///
/// # Arguments
///
/// * `config` - The layout configuration.
/// * `allocator` - A reference to the storage allocator.
///
/// # Returns
///
/// A `Result` containing the new `FullyContiguous<S>` instance or an error if allocation
/// or layout creation fails.
#[instrument(level = "debug", skip(allocator), fields(config = ?config))]
pub fn allocate(
config: LayoutConfig,
allocator: &dyn StorageAllocator<S>,
) -> Result<Self, LayoutError> {
// Calculate total bytes needed. Propagate error if config is invalid.
let config = FullyContiguousConfig::new(config)?;
let bytes_to_allocate = config.required_allocation_size();
tracing::debug!(
bytes_to_allocate,
alignment = config.inner.alignment,
"Calculated storage size for allocation (with alignment padding)"
);
let storage = allocator.allocate(bytes_to_allocate).map_err(|e| {
LayoutError::OperationFailed(format!("Storage allocation failed: {}", e))
})?;
tracing::debug!(
allocated_size = storage.size(),
allocated_addr = storage.addr(),
"Storage allocated successfully"
);
// Pass the config by value as Self::new takes ownership
Self::new(config.inner, vec![storage])
}
}
impl<S: Storage> BlockLayout for FullyContiguous<S> {
type StorageType = S;
fn storage(&self) -> Vec<&Self::StorageType> {
vec![&self.storage]
}
fn storage_mut(&mut self) -> Vec<&mut Self::StorageType> {
vec![&mut self.storage]
}
fn storage_type(&self) -> StorageType {
self.storage_type.clone()
}
}
impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
fn layout_type(&self) -> LayoutType {
LayoutType::FullyContiguous
}
fn num_blocks(&self) -> usize {
self.config.inner.num_blocks
}
fn num_layers(&self) -> usize {
self.config.inner.num_layers
}
fn page_size(&self) -> usize {
self.config.inner.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner.inner_dim
}
}
impl<S: Storage> BlockLayoutLookup for FullyContiguous<S> {
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError> {
if block_idx >= self.num_blocks() {
return Err(LayoutError::InvalidBlockIndex(block_idx));
}
if layer_idx >= self.num_layers() {
return Err(LayoutError::InvalidLayerIndex(layer_idx));
}
// Start from the aligned base address
let aligned_start_addr = self.storage.addr() + self.base_offset as u64;
// Calculate offset relative to the aligned start using stored config
let block_offset = block_idx * self.config.block_stride_in_bytes;
let layer_offset = layer_idx * self.config.layer_stride_in_bytes;
let final_addr = aligned_start_addr + block_offset as u64 + layer_offset as u64;
Ok(final_addr)
}
fn memory_region_size(&self) -> usize {
// Access via stored dims
self.config.memory_region_size
}
}
#[allow(missing_docs)]
#[cfg(test)]
pub mod tests {
use super::*;
use crate::block_manager::storage::tests::{NullDeviceAllocator, NullDeviceStorage};
use crate::block_manager::storage::{StorageType, SystemAllocator};
use crate::common::dtype::DType;
use dynamo_runtime::logging::init as init_logging;
const NUM_BLOCKS: usize = 7;
const NUM_LAYERS: usize = 5;
const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 13;
const DTYPE: DType = DType::FP32; // Example dtype
/// Helper function to calculate expected memory offset
fn calculate_expected_offset(
base_addr: u64,
block_idx: usize,
layer_idx: usize,
block_stride: usize,
layer_stride: usize,
) -> u64 {
base_addr + (block_idx * block_stride + layer_idx * layer_stride) as u64
}
// Updated setup_layout: Calculates size internally, uses default alignment for simplicity in non-alignment tests.
pub fn setup_layout(
alignment: Option<usize>, // Option to override default alignment
) -> Result<FullyContiguous<NullDeviceStorage>, LayoutError> {
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: alignment.unwrap_or(1),
dtype: DTYPE,
};
FullyContiguous::allocate(config, &NullDeviceAllocator)
}
#[test]
fn test_fc_creation_invalid_alignment() {
let config = LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(NUM_LAYERS)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.alignment(3)
.build()
.unwrap();
assert!(config.validate().is_err());
}
#[test]
fn test_fc_creation_success() {
// Setup with default (None) alignment
let layout_result = setup_layout(None);
assert!(
layout_result.is_ok(),
"Layout creation failed: {:?}",
layout_result.err()
);
}
#[test]
fn test_fc_creation_insufficient_storage() {
init_logging();
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: 1,
dtype: DTYPE,
};
// Calculate correct size needed
let fc_config = FullyContiguousConfig::new(config.clone()).unwrap();
let required_size = fc_config.required_allocation_size();
let storage = NullDeviceStorage::new((required_size - 1) as u64);
let layout_result = FullyContiguous::new(config, vec![storage]);
assert!(layout_result.is_err());
match layout_result.err().unwrap() {
LayoutError::InvalidConfig(_) => {} // Expected error
e => panic!("Expected InvalidConfig error, got {:?}", e),
}
}
#[test]
fn test_fc_accessor_methods() {
let layout = setup_layout(None).expect("Layout setup failed");
assert_eq!(layout.num_blocks(), NUM_BLOCKS);
assert_eq!(layout.num_layers(), NUM_LAYERS);
assert_eq!(layout.page_size(), PAGE_SIZE);
assert_eq!(layout.inner_dim(), INNER_DIM);
}
#[test]
fn test_fc_memory_region_size() {
let layout = setup_layout(None).expect("Layout setup failed");
let expected_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(layout.memory_region_size(), expected_region_size);
}
#[test]
fn test_fc_offset_calculation() {
let layout = setup_layout(None).expect("Layout setup failed");
let dims = layout.config.clone();
let block_stride = dims.block_stride_in_bytes;
let layer_stride = dims.layer_stride_in_bytes;
let base_addr = layout.storage.addr() + layout.base_offset as u64;
// Test first block, first layer
let expected_offset_0_0 =
calculate_expected_offset(base_addr, 0, 0, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(0, 0).unwrap(),
expected_offset_0_0
);
// Test first block, last layer
let last_layer_idx = NUM_LAYERS - 1;
let expected_offset_0_last =
calculate_expected_offset(base_addr, 0, last_layer_idx, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(0, last_layer_idx).unwrap(),
expected_offset_0_last
);
// Test last block, first layer
let last_block_idx = NUM_BLOCKS - 1;
let expected_offset_last_0 =
calculate_expected_offset(base_addr, last_block_idx, 0, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(last_block_idx, 0).unwrap(),
expected_offset_last_0
);
// Test last block, last layer
let expected_offset_last_last = calculate_expected_offset(
base_addr,
last_block_idx,
last_layer_idx,
block_stride,
layer_stride,
);
assert_eq!(
layout
.memory_region_addr(last_block_idx, last_layer_idx)
.unwrap(),
expected_offset_last_last
);
// Test intermediate block/layer
let mid_block_idx = NUM_BLOCKS / 2;
let mid_layer_idx = NUM_LAYERS / 2;
let expected_offset_mid_mid = calculate_expected_offset(
base_addr,
mid_block_idx,
mid_layer_idx,
block_stride,
layer_stride,
);
assert_eq!(
layout
.memory_region_addr(mid_block_idx, mid_layer_idx)
.unwrap(),
expected_offset_mid_mid
);
}
#[test]
fn test_fc_invalid_block_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(NUM_BLOCKS, 0); // Index == num_blocks (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
LayoutError::InvalidBlockIndex(NUM_BLOCKS)
));
}
#[test]
fn test_fc_invalid_layer_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(0, NUM_LAYERS); // Index == num_layers (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
LayoutError::InvalidLayerIndex(NUM_LAYERS)
));
}
#[test]
fn test_fc_allocation_system() {
init_logging();
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: 1,
dtype: DTYPE,
};
let allocator = SystemAllocator;
let layout_result = FullyContiguous::allocate(config, &allocator);
assert!(layout_result.is_ok());
let layout = layout_result.unwrap();
// Basic checks on the allocated layout
assert_eq!(layout.num_blocks(), NUM_BLOCKS);
assert_eq!(layout.num_layers(), NUM_LAYERS);
assert_eq!(layout.page_size(), PAGE_SIZE);
assert_eq!(layout.inner_dim(), INNER_DIM);
assert_eq!(layout.storage.storage_type(), StorageType::System);
assert_eq!(
layout.storage.size(),
layout.config.required_allocation_size()
);
assert_eq!(
layout.storage.size(),
NUM_BLOCKS * NUM_LAYERS * PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes()
);
}
#[test]
fn test_fc_alignment() {
init_logging();
const ALIGNMENT: usize = 256; // Must be power of 2
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: ALIGNMENT,
dtype: DTYPE,
};
// Calculate expected size needed *for the data layout itself*
let memory_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(memory_region_size, 208);
let natural_block_stride = NUM_LAYERS * memory_region_size;
assert_eq!(natural_block_stride, 1040);
let aligned_block_stride = align_up(natural_block_stride, ALIGNMENT);
assert_eq!(aligned_block_stride, 1280);
// Calculate the expected *allocated* size (data + initial padding)
let fc_config = FullyContiguousConfig::new(config.clone()).unwrap();
let expected_allocated_size = fc_config.required_allocation_size();
// Use allocate method
let allocator = SystemAllocator;
let layout_result = FullyContiguous::allocate(config.clone(), &allocator);
assert!(
layout_result.is_ok(),
"Allocation failed: {:?}",
layout_result.err()
);
let layout = layout_result.unwrap();
// Verify total *allocated* size matches expectation
assert_eq!(
layout.storage.size(),
expected_allocated_size,
"Allocated storage size mismatch"
);
assert_eq!(
layout.config.block_stride_in_bytes, aligned_block_stride,
"Stored block stride mismatch"
);
// Check alignment of block starts
let addr_block_0 = layout
.memory_region_addr(0, 0)
.expect("Failed to get addr block 0");
let addr_block_1 = layout
.memory_region_addr(1, 0)
.expect("Failed to get addr block 1");
let addr_block_2 = layout
.memory_region_addr(2, 0)
.expect("Failed to get addr block 2");
// All blocks should now be aligned due to base_offset adjustment
assert_eq!(
addr_block_0 % ALIGNMENT as u64,
0,
"Block 0 start address is not aligned"
);
assert_eq!(
addr_block_1 % ALIGNMENT as u64,
0,
"Block 1 start address is not aligned"
);
assert_eq!(
addr_block_2 % ALIGNMENT as u64,
0,
"Block 2 start address is not aligned"
);
// Verify the difference matches the aligned stride
assert_eq!(
addr_block_1 - addr_block_0,
aligned_block_stride as u64,
"Stride between block 0 and 1 mismatch"
);
assert_eq!(
addr_block_2 - addr_block_1,
aligned_block_stride as u64,
"Stride between block 1 and 2 mismatch"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # NIXL Integration for Block Layouts 🤝
//!
//! This module extends the core block layout functionalities defined in the parent `layout` module
//! with [NIXL](http://github.com/ai-dynamo/nixl) specific capabilities. It enables block layouts,
//! whose underlying storage is NIXL-registerable, to be registered with a NIXL agent and
//! serialized into a format suitable for sharing and reconstruction in distributed environments.
//!
//! ## Key Features & Components
//!
//! ### 1. NIXL-Specific Layout Traits
//! - [`NixlLayout`]: An umbrella trait that augments a [`BlockLayout`]. It requires the layout's
//! associated `StorageType` to implement [`NixlRegisterableStorage`]. This trait provides the
//! `nixl_register` method to register all underlying storage regions of the layout with a NIXL agent.
//! - [`BlockLayoutNixlStorage`]: A trait implemented by layouts to provide NIXL-specific memory
//! information like `mem_type` and `device_id` directly from the layout structure, typically
//! derived from its underlying storage.
//! - [`ToSerializedNixlBlockLayout`]: Implemented by layouts that can be converted into a
//! [`SerializedNixlBlockLayout`]. This involves capturing the layout configuration and the NIXL
//! descriptors of its storage.
//!
//! ### 2. Serializable NIXL Layout
//! - [`SerializedNixlBlockLayout`]: A struct that holds the serialized representation (as `Vec<u8>`)
//! of a NIXL-compatible block layout. It can be deserialized to reconstruct the layout, typically
//! on a remote node, assuming the described NIXL memory regions are accessible.
//! - `NixlBlockLayoutKinds`: An internal enum used during serialization to differentiate between
//! different types of layouts (e.g., `FullyContiguous`).
//! - `SerializableNixlLayout<C>`: An internal generic struct that captures the configuration (`C`),
//! base offset, NIXL storage descriptors, and storage type for a specific layout kind.
//!
//! ### 3. Integration with Core Layouts
//! The module provides implementations of these NIXL traits for concrete layout types from the
//! parent module, such as [`FullyContiguous`]. For example:
//! - `FullyContiguous<S>` (where `S:` [`NixlRegisterableStorage`]) implements [`NixlLayout`], allowing
//! its storage to be registered.
//! - It also implements [`ToSerializedNixlBlockLayout`], enabling its configuration and NIXL storage
//! descriptors to be serialized.
//!
//! ### 4. Layout Creation and Allocation Extensions
//! The [`LayoutConfig`] from the parent module is extended with methods like:
//! - `create_layout`: To create a NIXL-aware layout from existing NIXL-registerable storage.
//! - `allocate_layout`: To allocate storage using a NIXL-registerable storage allocator and then
//! create the NIXL-aware layout.
//!
//! ## Usage Flow
//!
//! 1. **Create/Allocate Layout**: A block layout (e.g., [`FullyContiguous`]) is created or allocated,
//! ensuring its underlying storage is NIXL-compatible (e.g., using [`SystemStorage`] that implements
//! [`NixlRegisterableStorage`]).
//! 2. **Register with NIXL**: The [`nixl_register`] method from the [`NixlLayout`] trait is called on the
//! layout instance with a [`NixlAgent`].
//! 3. **Serialize**: The [`serialize`] method from [`ToSerializedNixlBlockLayout`] is used to get a
//! [`SerializedNixlBlockLayout`].
//! 4. **Transmit**: The [`SerializedNixlBlockLayout`] (or its byte representation) is sent to another
//! process/node.
//! 5. **Deserialize**: On the receiving end, [`SerializedNixlBlockLayout::deserialize`] is called to
//! reconstruct an `Arc<dyn BlockLayout<StorageType = NixlStorage>>`. This reconstructed layout now
//! refers to the remote NIXL memory regions.
//!
//! ```rust
//! use dynamo_llm::block_manager::layout::{LayoutConfig, LayoutType};
//! use dynamo_llm::block_manager::layout::nixl::{NixlLayout, ToSerializedNixlBlockLayout, SerializedNixlBlockLayout};
//! use dynamo_llm::block_manager::storage::nixl::NixlAgent;
//! use dynamo_llm::block_manager::storage::PinnedAllocator; // Assuming PinnedStorage is NixlRegisterable
//! use std::sync::Arc;
//!
//! // Configuration
//! let config = LayoutConfig::builder()
//! .num_blocks(10)
//! .num_layers(2)
//! .page_size(4)
//! .inner_dim(13)
//! .build().unwrap();
//!
//! // 1. Allocate a NIXL-compatible layout
//! let allocator = Arc::new(PinnedAllocator::new().unwrap()); // PinnedAllocator provides NixlRegisterable PinnedStorage
//! let mut layout = config.allocate_layout(LayoutType::FullyContiguous, allocator).unwrap();
//!
//! // 2. Register with NIXL Agent
//! let agent = NixlAgent::new("my_agent").unwrap();
//! layout.nixl_register(&agent, None).unwrap();
//!
//! // 3. Serialize the layout
//! let serialized_layout = layout.serialize().unwrap();
//!
//! // 4. (Transmit serialized_layout to another process)
//!
//! // 5. Deserialize on the other end
//! let reconstructed_layout = SerializedNixlBlockLayout::deserialize(&serialized_layout).unwrap();
//! println!("Reconstructed layout refers to storage type: {:?}", reconstructed_layout.storage_type());
//! ```
//!
//! This module effectively bridges the local layout definitions with the requirements of distributed memory management via NIXL.
use crate::block_manager::storage::StorageType;
use super::{BlockLayout, BlockLayoutConfig, LayoutConfig, LayoutError, LayoutType};
use super::super::storage::{
nixl::{MemType, NixlAgent, NixlRegisterableStorage, NixlStorage, OptArgs},
Storage, StorageAllocator,
};
use super::{FullyContiguous, FullyContiguousConfig};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Extends [BlockLayout] with NIXL-specific methods for registering with an NIXL agent.
pub trait NixlLayout: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout {
/// Register the layout with an NIXL agent
///
/// This will register all the individual memory regions associated with the [BlockLayout].
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()>;
}
/// Trait for providing NIXL-specific memory information
pub trait BlockLayoutNixlStorage {
/// Returns the memory type of the storage
fn mem_type(&self) -> MemType;
/// Returns the device ID of the storage
fn device_id(&self) -> u64;
}
// Umbrella impl for all BlockLayout types that are NixlRegisterableStorage
impl<T> NixlLayout for T
where
T: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout + ?Sized, // Implement for any T that is BlockLayout (potentially unsized)
T::StorageType: NixlRegisterableStorage, // T's associated StorageType must be NixlStorage
{
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()> {
for storage in self.storage_mut() {
storage.nixl_register(agent, opt_args)?;
}
Ok(())
}
}
impl LayoutConfig {
/// Create a new NIXL-aware layout from existing NIXL-registerable storage.
pub fn create_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
storage: Vec<S>,
) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
match layout_type {
LayoutType::FullyContiguous => FullyContiguous::new(self.clone(), storage),
}
}
/// Allocate a new NIXL-aware layout using a NIXL-registerable storage allocator.
pub fn allocate_layout<S: Storage + NixlRegisterableStorage>(
&self,
layout_type: LayoutType,
allocator: Arc<dyn StorageAllocator<S>>,
) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
match layout_type {
LayoutType::FullyContiguous => {
FullyContiguous::allocate(self.clone(), allocator.as_ref())
}
}
}
}
/// Trait to convert a BlockLayout instance into its NIXL-specific serializable representation.
pub trait ToSerializedNixlBlockLayout: BlockLayout<StorageType: NixlRegisterableStorage> {
/// Converts the layout into a serializable format, ensuring it's backed by NIXL storage.
/// Returns an error if the layout is not backed by storage providing NIXL descriptors.
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError>;
}
/// Serializable representation of a BlockLayout backed by NIXL storage.
#[derive(Serialize, Deserialize, Clone)]
pub struct SerializedNixlBlockLayout(Vec<u8>);
/// Enum representing the serializable state of different BlockLayout types
/// specifically when backed by NIXL-compatible storage.
#[derive(Serialize, Deserialize, Debug, Clone)]
enum NixlBlockLayoutKinds {
FullyContiguous(SerializableNixlLayout<FullyContiguousConfig>),
// Add variants for other layout types here
}
/// Serializable representation of FullyContiguous layout backed by NIXL storage.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct SerializableNixlLayout<C: BlockLayoutConfig> {
config: C,
base_offset: usize,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
}
impl<C> SerializableNixlLayout<C>
where
C: BlockLayoutConfig + Serialize + for<'de> Deserialize<'de> + Clone + std::fmt::Debug,
{
/// Create a new SerializableNixlLayout
fn new(
config: C,
base_offset: usize,
storage_descriptors: Vec<NixlStorage>,
storage_type: StorageType,
) -> Self {
Self {
config,
base_offset,
storage_descriptors,
storage_type,
}
}
}
impl<S: NixlRegisterableStorage> ToSerializedNixlBlockLayout for FullyContiguous<S> {
fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError> {
// Use accessors added previously
let config = self.config.clone();
let base_offset = self.base_offset;
let storages = self.storage();
if storages.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
// FullyContiguous uses a Vec<Storage>, but should only contain one element.
let storage_instance = storages.first().ok_or_else(|| {
LayoutError::OperationFailed("FullyContiguous requires one storage element".to_string())
})?;
let storage_descriptors =
unsafe { storage_instance.as_nixl_descriptor() }.ok_or_else(|| {
LayoutError::OperationFailed(
"Storage does not provide NIXL descriptors for serialization".to_string(),
)
})?;
let serializable_data = SerializableNixlLayout::new(
config,
base_offset,
vec![storage_descriptors],
self.storage_type(),
);
let nixl_block_layout = NixlBlockLayoutKinds::FullyContiguous(serializable_data);
Ok(SerializedNixlBlockLayout(serde_json::to_vec(
&nixl_block_layout,
)?))
}
}
impl SerializedNixlBlockLayout {
/// Reconstructs a dynamic BlockLayout trait object backed by NixlStorage
/// from the serialized layout information.
/// Assumes the NixlStorage regions described within already exist and are valid.
pub fn deserialize(
&self,
) -> Result<Arc<dyn BlockLayout<StorageType = NixlStorage>>, LayoutError> {
let nixl_block_layout: NixlBlockLayoutKinds = serde_json::from_slice(&self.0)?;
match nixl_block_layout {
NixlBlockLayoutKinds::FullyContiguous(config) => {
if config.storage_descriptors.len() != 1 {
return Err(LayoutError::InvalidConfig(
"FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
.to_string(),
));
}
// Clone the single NixlStorage descriptor to become the storage instance
let storage = config.storage_descriptors[0].clone();
// Use the internal constructor which skips allocation checks
let layout = FullyContiguous::new_internal(
config.config.clone(),
storage, // Pass the NixlStorage instance
config.base_offset,
config.storage_type,
)?;
Ok(Arc::new(layout))
} // Handle other variants when added...
}
}
}
impl<S> BlockLayoutNixlStorage for FullyContiguous<S>
where
S: Storage + NixlRegisterableStorage,
{
fn mem_type(&self) -> MemType {
self.storage.mem_type()
}
fn device_id(&self) -> u64 {
self.storage.device_id()
}
}
#[cfg(test)]
mod tests {
use super::super::*;
use super::*;
use crate::block_manager::storage::SystemAllocator;
use dynamo_runtime::logging::init as init_logging;
#[test]
fn test_nixl_layout() {
init_logging();
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.page_size(4)
.inner_dim(13)
.build()
.unwrap();
config.validate().unwrap();
let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let agent = NixlAgent::new("test").unwrap();
tracing::info!("Registering layout");
layout.nixl_register(&agent, None).unwrap();
tracing::info!("Layout registered");
let local_storage_type = layout.storage_type();
let serialized = layout.serialize().unwrap();
let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
println!("Nixl layout: {:?}", remote_layout);
let remote_storage_type = remote_layout.storage_type();
assert_eq!(local_storage_type, remote_storage_type);
drop(layout);
tracing::info!("Layout dropped");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Cache Block Pool Management
//!
//! This module provides the primary [`BlockPool`] structure for managing KV cache blocks.
//! It orchestrates the allocation, registration, and reuse of blocks by coordinating
//! between an [`ActiveBlockPool`] and an [`InactiveBlockPool`].
//!
//! ## Core Components:
//!
//! - **[`BlockPool`]**: The main entry point for interacting with the block management system.
//! It holds the shared state containing both active and inactive pools.
//! - **[`ActiveBlockPool`]**: Manages blocks that are currently associated with active sequences.
//! It primarily uses weak references to track these blocks, allowing them to be potentially
//! reclaimed by the inactive pool if no strong references remain.
//! - **[`InactiveBlockPool`]**: Manages blocks that are not currently in active use. It supports
//! block reuse by matching sequence hashes and employs a priority-based eviction strategy
//! for acquiring free blocks.
//! - **[`BlockRegistry`]**: Manages the registration of blocks that have transitioned from the
//! Complete to Registered state.
//! - **[`MutableBlock`]**: Represents a uniquely owned block, typically obtained from allocation.
//! It allows modification and is returned to the inactive pool upon being dropped.
//! - **[`ImmutableBlock`]**: Represents a shared, immutable reference to a block, usually after
//! it has been registered or matched. Ensures that multiple sequences can reference the
//! same underlying block data.
//!
//! ## Workflow:
//!
//! 1. Blocks are initially added to the [`BlockPool`] via [`BlockPool::add_blocks`], populating the
//! [`InactiveBlockPool`].
//! 2. Sequences request blocks via [`BlockPool::allocate_blocks`], which attempts to acquire them
//! from the [`InactiveBlockPool`]. This returns [`MutableBlock`]s.
//! 3. Once a [`MutableBlock`] is filled and ready, it's registered using [`BlockPool::register_block`].
//! This process checks the both the [`ActiveBlockPool`] and the [`InactiveBlockPool`] for existing blocks
//! with the same content hash. It returns an [`ImmutableBlock`] representing the canonical block
//! (either the one provided or an existing one).
//! 4. Sequences can also try to reuse blocks directly using [`BlockPool::match_sequence_hash`], which
//! checks both the active and inactive pools.
//! 5. When an [`ImmutableBlock`] is no longer needed by any sequence (its `Arc` count drops to zero),
//! the underlying [`MutableBlock`] (if it still exists via the weak reference in the active pool)
//! can eventually be returned to the [`InactiveBlockPool`] when its final strong reference (the `Arc`
//! within `ImmutableBlock`) is dropped.
//! 6. Dropped [`MutableBlock`]s are automatically returned to the [`InactiveBlockPool`].
mod active;
mod inactive;
mod priority_key;
mod state;
use active::ActiveBlockPool;
use derive_builder::Builder;
use derive_getters::Dissolve;
use inactive::InactiveBlockPool;
use priority_key::PriorityKey;
pub use super::block::{ImmutableBlock, MutableBlock};
use super::block::{
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata,
};
use super::events::{EventManager, NullEventManager};
use super::storage::Storage;
use crate::tokens::{SequenceHash, TokenBlock};
use std::{
collections::{BTreeSet, HashMap, VecDeque},
sync::{Arc, Weak},
};
use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result;
#[derive(Debug, thiserror::Error)]
pub enum BlockPoolError {
#[error("Block is not complete")]
BlockNotComplete,
#[error("Not enough blocks available, requested: {0}, available: {1}")]
NotEnoughBlocksAvailable(usize, usize),
#[error("Invalid MutableBlock: {0}")]
InvalidMutableBlock(String),
#[error("Failed to register block: {0}")]
FailedToRegisterBlock(String),
#[error("Progress engine shutdown")]
ProgressEngineShutdown,
#[error(transparent)]
BlockError(#[from] BlockError),
}
#[derive(Builder, Dissolve)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default = "NullEventManager::new()")]
event_manager: Arc<dyn EventManager>,
#[builder(default = "CancellationToken::new()")]
cancel_token: CancellationToken,
#[builder(default)]
blocks: Vec<Block<S, M>>,
}
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve();
tracing::info!("building block pool");
let pool = BlockPool::new(event_manager, cancel_token, blocks);
Ok(pool)
}
}
/// Manages the blocks in a specific storage backenda
pub struct BlockPool<S: Storage, M: BlockMetadata> {
priority_tx: tokio::sync::mpsc::UnboundedSender<PriorityRequest<S, M>>,
ctrl_tx: tokio::sync::mpsc::UnboundedSender<ControlRequest<S, M>>,
}
impl<S: Storage, M: BlockMetadata> Clone for BlockPool<S, M> {
fn clone(&self) -> Self {
Self {
priority_tx: self.priority_tx.clone(),
ctrl_tx: self.ctrl_tx.clone(),
}
}
}
#[derive(Dissolve)]
struct Unary<Req, Resp> {
request: Req,
response_tx: oneshot::Sender<Resp>,
}
impl<Req, Resp> Unary<Req, Resp> {
fn make_request(request: Req) -> (Self, oneshot::Receiver<Resp>) {
let (response_tx, response_rx) = oneshot::channel();
(
Self {
request,
response_tx,
},
response_rx,
)
}
}
type UnaryResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>;
type ImmutableBlocksResult<S, M> = Result<Vec<ImmutableBlock<S, M>>, BlockPoolError>;
pub type MutableBlocks<S, M> = Vec<MutableBlock<S, M>>;
pub type ImmutableBlocks<S, M> = Vec<ImmutableBlock<S, M>>;
enum PriorityRequest<S: Storage, M: BlockMetadata> {
AllocateBlocks(Unary<usize, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>),
RegisterBlocks(Unary<MutableBlocks<S, M>, Result<ImmutableBlocks<S, M>, BlockPoolError>>),
MatchSequenceHashes(Unary<Vec<SequenceHash>, Vec<ImmutableBlock<S, M>>>),
}
enum ControlRequest<S: Storage, M: BlockMetadata> {
AddBlocks(Unary<Vec<Block<S, M>>, ()>),
}
impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
pub fn builder() -> BlockPoolArgsBuilder<S, M> {
BlockPoolArgsBuilder::default()
}
/// Creates a new [`BlockPool`] with the given [`EventManager`].
///
/// The pool starts empty and requires blocks to be added via [`add_blocks`].
///
/// # Arguments
///
/// * `event_manager` - An [`Arc<dyn EventManager>`] used for publishing block registration/removal events.
///
/// # Returns
///
/// A new [`BlockPool`] instance.
fn new(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
) -> Self {
let (pool, progress_engine) =
Self::with_progress_engine(event_manager, cancel_token, blocks);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// tracing::debug!("starting progress engine");
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// });
let thread_name = format!("block-pool-{}", short_type_name::<S>());
std::thread::Builder::new()
.name(thread_name)
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for block pool progress engine");
runtime.block_on(async move {
let mut progress_engine = progress_engine;
tracing::debug!("starting progress engine");
while progress_engine.step().await {
tracing::trace!("progress engine step");
}
});
})
.expect("Failed to spawn block pool progress engine thread");
pool
}
fn with_progress_engine(
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
) -> (Self, ProgressEngine<S, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine =
ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token, blocks);
(
Self {
priority_tx,
ctrl_tx,
},
progress_engine,
)
}
/// Adds a vector of [`Block`]s to the [`InactiveBlockPool`].
///
/// These blocks are typically created from a [`super::block::Blocks`]
/// and represent the initial set of available cache blocks.
/// Blocks added this way are initially reset.
///
/// # Arguments
///
/// * `blocks` - A [`Vec<Block<S, M>>`] to add to the inactive pool.
#[expect(dead_code)]
pub(crate) async fn add_blocks(&self, blocks: Vec<Block<S, M>>) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Blocking version of [`BlockPool::add_blocks`].
pub(crate) fn add_blocks_blocking(
&self,
blocks: Vec<Block<S, M>>,
) -> Result<(), BlockPoolError> {
self._add_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn _add_blocks(&self, blocks: Vec<Block<S, M>>) -> UnaryResponse<()> {
let (req, resp_rx) = Unary::<_, ()>::make_request(blocks);
self.ctrl_tx
.send(ControlRequest::AddBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
Ok(resp_rx)
}
/// Attempts to allocate a specified number of free blocks from the [`InactiveBlockPool`].
///
/// Blocks acquired this way are returned as [`MutableBlock`]s, granting unique ownership
/// and allowing modification. Dropping a [`MutableBlock`] automatically returns it
/// to the [`InactiveBlockPool`].
///
/// # Arguments
///
/// * `count` - The number of blocks to allocate.
///
/// # Returns
///
/// A [`Result`] containing:
/// - `Ok(Vec<MutableBlock<S, M>>)`: If successful, a vector of allocated mutable blocks.
/// - `Err(BlockPoolError)`: If not enough blocks are available in the inactive pool.
pub async fn allocate_blocks(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Blocking version of [`BlockPool::allocate_blocks`].
pub fn allocate_blocks_blocking(
&self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
self._allocate_blocks(count)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn _allocate_blocks(
&self,
count: usize,
) -> UnaryResponse<Result<Vec<MutableBlock<S, M>>, BlockPoolError>> {
// Create the request
let (req, resp_rx) =
Unary::<_, Result<Vec<MutableBlock<S, M>>, BlockPoolError>>::make_request(count);
// Issue the request
self.priority_tx
.send(PriorityRequest::AllocateBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Registers a vector of [`MutableBlock`]s (presumably after filling them) with the pool,
/// making them available for sharing via the [`ActiveBlockPool`].
///
/// This function checks if any of the blocks have the same sequence hash as an existing block
/// in the active pool. If so, it returns an [`ImmutableBlock`] pointing to the existing block,
/// and the provided `block` is implicitly dropped (returned to the [`InactiveBlockPool`]).
pub async fn register_blocks(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> ImmutableBlocksResult<S, M> {
self._register_blocks(blocks)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
/// Blocking version of [`BlockPool::register_blocks`].
pub fn register_blocks_blocking(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> ImmutableBlocksResult<S, M> {
self._register_blocks(blocks)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
}
fn _register_blocks(
&self,
blocks: Vec<MutableBlock<S, M>>,
) -> UnaryResponse<ImmutableBlocksResult<S, M>> {
// Make the request
let (req, resp_rx) = Unary::<_, ImmutableBlocksResult<S, M>>::make_request(blocks);
// Issue the request
self.priority_tx
.send(PriorityRequest::RegisterBlocks(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
/// Attempts to match the given [`SequenceHash`] to an existing block, checking
/// both the active and inactive pools.
///
/// Checks the [`ActiveBlockPool`] first. If a valid strong reference exists, it returns
/// an [`ImmutableBlock`] cloned from it. If the weak reference exists but is stale,
/// it's removed.
///
/// If not found in the active pool, it checks the [`InactiveBlockPool`]. If found there,
/// the block is moved to the active pool (tracked by a weak reference) and returned
/// as a new [`ImmutableBlock`].
///
/// # Arguments
///
/// * `sequence_hash` - The [`SequenceHash`] to look for.
///
/// # Returns
///
/// An [`Option<ImmutableBlock<S, M>>`] containing the shared block if found, otherwise `None`.
pub async fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> {
self._match_sequence_hashes(sequence_hashes)?
.await
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
/// Blocking version of [`BlockPool::match_sequence_hashes`].
pub fn match_sequence_hashes_blocking(
&self,
sequence_hashes: &[SequenceHash],
) -> ImmutableBlocksResult<S, M> {
self._match_sequence_hashes(sequence_hashes)?
.recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)
}
fn _match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> UnaryResponse<Vec<ImmutableBlock<S, M>>> {
// Create the request
let (req, resp_rx) =
Unary::<_, Vec<ImmutableBlock<S, M>>>::make_request(sequence_hashes.into());
// Issue the request
self.priority_tx
.send(PriorityRequest::MatchSequenceHashes(req))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Await a response
Ok(resp_rx)
}
}
struct State<S: Storage, M: BlockMetadata> {
active: ActiveBlockPool<S, M>,
inactive: InactiveBlockPool<S, M>,
registry: BlockRegistry,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
event_manager: Arc<dyn EventManager>,
}
struct ProgressEngine<S: Storage, M: BlockMetadata> {
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken,
state: State<S, M>,
return_rx: tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
}
#[cfg(test)]
mod tests {
use crate::block_manager::block::BlockExt;
use super::super::block::{BasicMetadata, Blocks};
use super::super::layout::tests::setup_layout;
use super::*;
/// Helper method to build a [`BlockPool`] with a [`ProgressEngine`] for unit testing
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
fn build_with_progress_engine(
self,
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve();
let (pool, progress_engine) =
BlockPool::with_progress_engine(event_manager, cancel_token, blocks);
Ok((pool, progress_engine))
}
}
#[tokio::test]
async fn test_block_pool_state() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (_pool, mut progress) = BlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
drop(blocks);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
let mut blocks = progress.state.allocate_blocks(1).unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
let mut block = blocks.pop().unwrap();
block.init_sequence(1337).unwrap();
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
assert!(block.add_token(5).is_err());
}
#[tokio::test]
async fn test_block_pool() {
let layout = setup_layout(None).unwrap();
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
let (pool, mut progress) = BlockPool::builder()
.blocks(blocks)
.build_with_progress_engine()
.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 7);
let pool_clone = pool.clone();
let allocate_1_block =
tokio::spawn(async move { pool_clone.allocate_blocks(1).await.unwrap() });
progress.step().await;
let blocks = allocate_1_block.await.unwrap();
assert_eq!(progress.state.inactive.available_blocks(), 6);
assert_eq!(blocks.len(), 1);
// drop the single block
drop(blocks);
// check before and after the progress engine step
assert_eq!(progress.state.inactive.available_blocks(), 6);
progress.step().await;
assert_eq!(progress.state.inactive.available_blocks(), 7);
}
#[test]
fn test_block_pool_blocking() {
const EXPECTED_SEQUENCE_HASH: u64 = 14643705804678351452;
// Create a new layout
let layout = setup_layout(None).unwrap();
// Create the Blocks
let blocks = Blocks::<_, BasicMetadata>::new(layout, 42, 0)
.unwrap()
.into_blocks()
.unwrap();
// Create the BlockPool and add the blocks
let pool = BlockPool::builder().blocks(blocks).build().unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
let matched_blocks = pool
.match_sequence_hashes_blocking(&[EXPECTED_SEQUENCE_HASH])
.unwrap();
assert_eq!(matched_blocks.len(), 0);
// Allocate a single block from the pool
let mut mutable_blocks = pool.allocate_blocks_blocking(1).unwrap();
assert_eq!(mutable_blocks.len(), 1);
let mut block = mutable_blocks.pop().unwrap();
// Initialize the sequence on the block with a salt hash
block.init_sequence(1337).unwrap();
// Add some tokens to the block - our page_size is 4
block.add_token(1).unwrap();
block.add_token(2).unwrap();
block.add_token(3).unwrap();
block.add_token(4).unwrap();
// Should fail because we don't have space in the block
assert!(block.add_token(5).is_err());
// Commit the block - this will generate a sequence hash
// This will put the block in a Complete state
block.commit().unwrap();
assert!(block.state().is_complete()); // perhaps renamed to Commited
let sequence_hash = block.sequence_hash().unwrap();
assert_eq!(sequence_hash, EXPECTED_SEQUENCE_HASH);
// Register the block
// We provide a mutable block to the register_blocks function
// This will take ownership of the block and return an immutable block
let mut immutable_blocks = pool.register_blocks_blocking(vec![block]).unwrap();
let block = immutable_blocks.pop().unwrap();
assert!(block.state().is_registered());
assert_eq!(block.sequence_hash().unwrap(), sequence_hash);
// Dropping the immutable block should return the block to the pool
// However, the block should remain in the BlockPool as an inactive block until it is reused
// or promoted back to an immutable block by being matched with a sequence hash
drop(block);
// Get the list of ImmutableBlocks that match the sequence hash
let matched = pool
.match_sequence_hashes_blocking(&[sequence_hash])
.unwrap();
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].sequence_hash().unwrap(), sequence_hash);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
/// Manages active blocks being used by sequences
pub struct ActiveBlockPool<S: Storage, M: BlockMetadata> {
pub(super) map: HashMap<SequenceHash, Weak<MutableBlock<S, M>>>,
}
impl<S: Storage, M: BlockMetadata> ActiveBlockPool<S, M> {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn register(
&mut self,
block: MutableBlock<S, M>,
) -> Result<ImmutableBlock<S, M>, BlockPoolError> {
if !block.state().is_registered() {
return Err(BlockPoolError::InvalidMutableBlock(
"block is not registered".to_string(),
));
}
let sequence_hash = block.sequence_hash().map_err(|_| {
BlockPoolError::InvalidMutableBlock("block has no sequence hash".to_string())
})?;
let shared = Arc::new(block);
match self.map.entry(sequence_hash) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
let weak = entry.get();
if let Some(arc) = weak.upgrade() {
Ok(ImmutableBlock::new(arc))
} else {
// Weak reference is no longer alive, update it in the map
entry.insert(Arc::downgrade(&shared));
Ok(ImmutableBlock::new(shared))
}
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(Arc::downgrade(&shared));
Ok(ImmutableBlock::new(shared))
}
}
}
pub fn remove(&mut self, block: &mut Block<S, M>) {
if let Ok(sequence_hash) = block.sequence_hash() {
if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(_arc) = weak.upgrade() {
block.reset();
return;
}
self.map.remove(&sequence_hash);
}
}
}
pub fn match_sequence_hash(
&mut self,
sequence_hash: SequenceHash,
) -> Option<ImmutableBlock<S, M>> {
if let Some(weak) = self.map.get(&sequence_hash) {
if let Some(arc) = weak.upgrade() {
Some(ImmutableBlock::new(arc))
} else {
// Weak reference is no longer alive, remove it from the map
self.map.remove(&sequence_hash);
None
}
} else {
None
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::block_manager::block::BlockState;
use super::*;
use tracing::instrument;
#[derive(Default)]
pub struct InactiveBlockPool<S: Storage, M: BlockMetadata> {
// Direct lookup by sequence_hash
lookup_map: HashMap<SequenceHash, Block<S, M>>,
// Ordered by timestamp (oldest first)
priority_set: BTreeSet<PriorityKey<M>>,
// Fully Uninitialized
uninitialized_set: VecDeque<Block<S, M>>,
// Return Tick
return_tick: u64,
// Total blocks
total_blocks: u64,
}
impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
/// Creates a new, empty [`InactiveBlockPool`].
///
/// # Returns
///
/// A new instance of [`InactiveBlockPool`].
pub(crate) fn new() -> Self {
Self {
lookup_map: HashMap::new(),
priority_set: BTreeSet::new(),
uninitialized_set: VecDeque::new(),
return_tick: 0,
total_blocks: 0,
}
}
/// Returns the total number of blocks managed by this pool (both available and acquired).
///
/// # Returns
///
/// The total block count as a [`u64`].
pub fn total_blocks(&self) -> u64 {
self.total_blocks
}
/// Returns the number of blocks currently available in the pool.
///
/// This is calculated dynamically based on the blocks in the [`uninitialized_set`]
/// and the [`lookup_map`].
///
/// # Returns
///
/// The available block count as a [`u64`].
pub fn available_blocks(&self) -> u64 {
self.uninitialized_set.len() as u64 + self.lookup_map.len() as u64
}
/// Inserts a block into the pool using its sequence hash for potential reuse.
///
/// If an entry with the same priority key already exists in the [`priority_set`],
/// the block is reset and moved to the [`uninitialized_set`].
/// If an entry with the same sequence hash already exists in the [`lookup_map`]
/// (but not the priority set - indicating an inconsistency), the block is reset
/// and moved to the [`uninitialized_set`].
/// Otherwise, the block is added to both the [`lookup_map`] and the [`priority_set`].
///
/// # Arguments
///
/// * `block` - The block to insert ([`Block<T, M>`]).
/// * `sequence_hash` - The sequence hash associated with the block's content ([`SequenceHash`]).
#[instrument(level = "trace", skip(self, block), fields(sequence_hash = ?sequence_hash))]
fn insert_with_sequence_hash(&mut self, block: Block<S, M>, sequence_hash: SequenceHash) {
let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
if self.priority_set.contains(&priority_key) {
tracing::trace!("multiple entries with the same priority key, resetting block and inserting into uninitialized set");
let mut block = block;
block.reset();
self.uninitialized_set.push_back(block);
} else if let std::collections::hash_map::Entry::Vacant(e) =
self.lookup_map.entry(sequence_hash)
{
tracing::trace!("inserting block to map and priority set");
self.priority_set.insert(priority_key);
e.insert(block);
} else {
tracing::trace!("multiple entries in lookup map with the same sequence hash, inserting into uninitialized set");
let mut block = block;
block.reset();
self.uninitialized_set.push_back(block);
}
}
/// Internal helper to insert a block into the appropriate internal collection
/// based on its current state.
///
/// - [`BlockState::Reset`], [`BlockState::Partial`], [`BlockState::Complete`] states result in the block being reset and added
/// to the `uninitialized_set`.
/// - [`BlockState::Registered`] state results in the block being added via [`insert_with_sequence_hash`].
///
/// # Arguments
///
/// * `block` - The block to insert ([`Block<S, M>`]).
#[instrument(level = "trace", skip(self, block), fields(block_state = ?block.state()))]
fn insert(&mut self, block: Block<S, M>) {
tracing::trace!("Inserting block into available pool");
// If we already have an entry for this sequence hash or the block is reset,
// we need to move it to the uninitialized set
match block.state() {
BlockState::Reset => {
self.uninitialized_set.push_back(block);
}
BlockState::Partial(_) => {
let mut block = block;
block.reset();
self.uninitialized_set.push_back(block);
}
BlockState::Complete(_) => {
let mut block = block;
block.reset();
self.uninitialized_set.push_back(block);
}
BlockState::Registered(state) => {
let sequence_hash = state.sequence_hash();
self.insert_with_sequence_hash(block, sequence_hash);
}
}
}
/// Adds multiple blocks to the pool.
///
/// Each block is reset before being inserted. The total block count is updated.
///
/// # Arguments
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks(&mut self, blocks: Vec<Block<S, M>>) {
let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool");
for (i, mut block) in blocks.into_iter().enumerate() {
tracing::trace!(current = i + 1, total = count, "Processing block");
block.reset();
self.insert(block);
}
self.total_blocks += count as u64;
}
/// Adds multiple blocks to the pool.
///
/// The state of the blocks are not reset.
///
/// # Arguments
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to add.
#[instrument(level = "debug", skip(self, blocks))]
pub fn add_blocks_with_state(&mut self, blocks: Vec<Block<S, M>>) {
let count = blocks.len();
tracing::debug!(count, "Adding blocks to pool");
self.total_blocks += count as u64;
// self.available_blocks += count as u64;
self.return_blocks(blocks);
}
/// Returns a single block to the pool.
///
/// Increments the internal return tick, updates the block's metadata,
/// and inserts the block back into the appropriate internal collection.
///
/// # Arguments
///
/// * `block` - The block ([`Block<S, M>`]) to return.
#[instrument(level = "debug", skip(self, block))]
pub fn return_block(&mut self, mut block: Block<S, M>) {
// increment the return tick
self.return_tick += 1;
// update the metadata
block.metadata_on_returned(self.return_tick);
// insert the block into the pool
self.insert(block);
// self.available_blocks += 1;
}
/// Returns multiple blocks to the pool.
///
/// Iterates through the blocks in reverse order (tail to head) and calls
/// `return_block` for each one.
///
/// # Arguments
///
/// * `blocks` - A vector of blocks ([`Block<T, M>`]) to return.
#[instrument(level = "debug", skip(self, blocks))]
pub fn return_blocks(&mut self, blocks: Vec<Block<S, M>>) {
let count = blocks.len();
tracing::debug!(count, "Returning blocks to pool");
// return the block to the pool from tail to head
for (i, block) in blocks.into_iter().rev().enumerate() {
tracing::trace!(current = i + 1, total = count, "Returning block");
// Note: return_block has its own instrumentation
self.return_block(block);
}
}
/// Attempts to remove and return a block associated with the given sequence hash
/// from the [`lookup_map`] and [`priority_set`].
///
/// # Arguments
///
/// * `sequence_hash` - The sequence hash ([`SequenceHash`]) of the block to take.
///
/// # Returns
///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "trace", skip(self), fields(sequence_hash = ?sequence_hash))]
fn take_with_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> {
match self.lookup_map.remove(&sequence_hash) {
Some(block) => {
// Remove from priority set
let priority_key = PriorityKey::new(block.metadata().clone(), sequence_hash);
self.priority_set.remove(&priority_key);
Some(block)
}
None => None,
}
}
/// Attempts to find and take a block matching the given sequence hash.
///
/// This is a convenience wrapper around `take_with_sequence_hash`.
///
/// # Arguments
///
/// * `sequence_hash` - The sequence hash ([`SequenceHash`]) to match.
///
/// # Returns
///
/// An [`Option<Block<S, M>>`] containing the block if found, otherwise `None`.
#[instrument(level = "debug", skip(self), fields(sequence_hash = ?sequence_hash))]
pub fn match_sequence_hash(&mut self, sequence_hash: SequenceHash) -> Option<Block<S, M>> {
self.take_with_sequence_hash(sequence_hash)
}
/// Attempts to find and take multiple blocks matching a sequence of hashes.
///
/// Iterates through the provided hashes and takes blocks using `take_with_sequence_hash`.
/// Stops if a hash is not found.
///
/// # Arguments
///
/// * `sequence_hashes` - A vector of sequence hashes ([`SequenceHash`]) to match.
///
/// # Returns
///
/// A vector containing the blocks ([`Block<T, M>`]) that were successfully matched and taken.
/// The vector may be shorter than `sequence_hashes` if not all hashes were found.
#[instrument(level = "debug", skip(self, sequence_hashes), fields(num_hashes = sequence_hashes.len()))]
pub fn match_sequence_hashes(
&mut self,
sequence_hashes: Vec<SequenceHash>,
) -> Vec<Block<S, M>> {
let total_hashes = sequence_hashes.len();
let mut matched_blocks = Vec::with_capacity(total_hashes);
for (i, hash) in sequence_hashes.into_iter().enumerate() {
tracing::trace!(current = i + 1, total = total_hashes, sequence_hash = ?hash, "Attempting to match sequence hash");
// Note: take_with_sequence_hash has its own instrumentation
if let Some(block) = self.take_with_sequence_hash(hash) {
tracing::trace!(current = i + 1, total = total_hashes, sequence_hash = ?hash, "Matched sequence hash");
matched_blocks.push(block);
} else {
tracing::trace!(current = i + 1, total = total_hashes, sequence_hash = ?hash, "Sequence hash not found, stopping match");
break;
}
}
matched_blocks
}
/// Attempts to find and take multiple blocks matching a sequence of `TokenBlock`s.
///
/// Extracts sequence hashes from the [`TokenBlock`]s and calls [`take_with_sequence_hash`].
/// Stops if a hash is not found.
///
/// # Arguments
///
/// * `token_blocks` - A slice of [`TokenBlock`]s to match.
///
/// # Returns
///
/// A vector containing the blocks ([`Block<T, M>`]) that were successfully matched and taken.
/// The vector may be shorter than `token_blocks` if not all corresponding hashes were found.
#[instrument(level = "debug", skip(self, token_blocks), fields(num_token_blocks = token_blocks.len()))]
pub fn match_token_blocks(&mut self, token_blocks: &[TokenBlock]) -> Vec<Block<S, M>> {
let total_blocks = token_blocks.len();
let mut matched_blocks = Vec::with_capacity(total_blocks);
tracing::debug!("Attempting to match {} token blocks", total_blocks);
for (i, token_block) in token_blocks.iter().enumerate() {
let sequence_hash = token_block.sequence_hash();
tracing::trace!(sequence_hash = ?sequence_hash, "Attempting to match token block hash {}/{}", i + 1, total_blocks);
if let Some(block) = self.take_with_sequence_hash(sequence_hash) {
tracing::trace!(sequence_hash = ?sequence_hash, "Matched token block hash");
matched_blocks.push(block);
} else {
tracing::trace!(sequence_hash = ?sequence_hash, "Token block hash not found, stopping match");
break;
}
}
tracing::debug!(
"Matched {} of {} token blocks",
matched_blocks.len(),
total_blocks
);
matched_blocks
}
/// Acquires a single free block from the pool.
///
/// Prioritizes blocks from the [`uninitialized_set`] first, then takes the
/// lowest priority block from the [`priority_set`] (and [`lookup_map`]).
/// If a block is taken from the priority set, it is reset.
///
/// # Returns
///
/// An [`Option<Block<T, M>>`] containing a free block if available, otherwise `None`.
///
/// # Panics
///
/// This function can panic if there is an inconsistency between the [`priority_set`]
/// and [`lookup_map`] (i.e., a key exists in the set but not the map). This indicates
/// a bug in the pool's internal logic.
#[instrument(level = "debug", skip(self))]
pub fn acquire_free_block(&mut self) -> Option<Block<S, M>> {
// First try uninitialized blocks - these are often part of sequences
// that have been arranged in the correct order
if let Some(mut block) = self.uninitialized_set.pop_front() {
tracing::trace!("Acquired uninitialized block");
self.return_tick += 1;
block.metadata_on_acquired(self.return_tick);
return Some(block);
}
// if we have blocks in the priority set, pop the first (it's sorted by priority)
// a fatal error will occur if the block is not found in the lookup map
if let Some(key) = self.priority_set.pop_first() {
tracing::trace!("Acquired priority/registered block map; resetting block");
match self.lookup_map.remove(&key.sequence_hash()) {
Some(mut block) => {
block.reset();
self.return_tick += 1;
block.metadata_on_acquired(self.return_tick);
Some(block)
}
None => {
panic!(
"Block from priority set not found in lookup map! Inconsistency detected."
);
}
}
} else {
// No blocks available in either set
None
}
}
/// Acquires a specified number of free blocks from the pool.
///
/// Checks if enough blocks are available and then calls [`acquire_free_block`] repeatedly.
///
/// # Arguments
///
/// * `count` - The number of free blocks to acquire.
///
/// # Returns
///
/// A [`Result`] containing:
/// - `Ok(Vec<Block<T, M>>)`: A vector of the acquired blocks if successful.
/// - `Err(BlockPoolError::InsufficientBlocksAvailable)`: If the requested number
/// of blocks is not available, or if an inconsistency occurred during acquisition.
///
/// # Panics
///
/// This function can panic if [`acquire_free_block`] panics due to internal inconsistencies.
#[instrument(level = "debug", skip(self))]
pub fn acquire_free_blocks(
&mut self,
count: usize,
) -> Result<Vec<Block<S, M>>, BlockPoolError> {
if count == 0 {
return Ok(Vec::new());
}
let mut blocks = Vec::with_capacity(count);
let available_now = self.uninitialized_set.len() + self.lookup_map.len();
tracing::debug!(
available_now,
requested = count,
"Attempting to acquire free blocks"
);
if count > available_now {
tracing::debug!(
available_now,
requested = count,
"Insufficient blocks available"
);
return Err(BlockPoolError::NotEnoughBlocksAvailable(
count,
available_now,
));
}
for i in 0..count {
tracing::trace!(current = i + 1, total = count, "Acquiring free block");
// Directly call the logic in acquire_free_block
// Note: acquire_free_block has its own instrumentation
if let Some(block) = self.acquire_free_block() {
blocks.push(block);
} else {
// This should not happen if the initial check passed and there are no concurrent modifications.
// If it does, it indicates an inconsistency or a logic error.
tracing::error!(
requested = count,
acquired = blocks.len(),
available_at_start = available_now,
current_available = self.uninitialized_set.len() + self.lookup_map.len(),
"Insufficient blocks during acquisition loop despite initial check."
);
// Return the blocks acquired so far, or handle as an error.
// For now, we break and return what we have, but decrementing 'available_blocks'
// needs to account for the actual number acquired.
// Consider returning an error or panicking in debug.
break;
}
}
let acquired_count = blocks.len();
tracing::debug!(
acquired_count,
requested = count,
"Finished acquiring blocks"
);
// Check if we got the requested number of blocks
if acquired_count != count {
// This path is taken if the loop broke early due to unexpected `None` from acquire_free_block
// Return an error indicating partial success or failure
// Depending on the desired behavior, you might return the partial list
// or a more specific error.
// For consistency with the original check, let's return an error if count wasn't met.
return Err(BlockPoolError::NotEnoughBlocksAvailable(
count,
blocks.len(),
));
}
Ok(blocks)
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::{
block_manager::{
block::{registry::BlockRegistry, state::CompleteState, Blocks, PrivateBlockExt},
events::NullEventManager,
layout::{BlockLayout, FullyContiguous, LayoutConfigBuilder},
storage::tests::{NullDeviceAllocator, NullDeviceStorage},
},
tokens::{Token, Tokens},
};
use super::*;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
pub struct TestMetadata {
priority: u32,
returned_tick: u64,
acquired_tick: u64,
}
impl BlockMetadata for TestMetadata {
fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick;
}
fn on_returned(&mut self, tick: u64) {
self.returned_tick = tick;
}
fn reset_metadata(&mut self) {
self.priority = 0;
}
}
type TestPriorityKey = PriorityKey<TestMetadata>;
fn make_priority_key(
priority: u32,
returned_tick: u64,
sequence_hash: SequenceHash,
) -> TestPriorityKey {
TestPriorityKey::new(
TestMetadata {
priority,
returned_tick,
acquired_tick: 0,
},
sequence_hash,
)
}
#[test]
fn test_priority_key_ord() {
let mut map = BTreeSet::new();
let hash1 = SequenceHash::from(1u64);
let hash2 = SequenceHash::from(2u64);
let hash3 = SequenceHash::from(3u64);
map.insert(make_priority_key(0, 2, hash1));
map.insert(make_priority_key(1, 1, hash2));
map.insert(make_priority_key(0, 3, hash3));
// Test popping from the map to verify ordering
let first_key = map.pop_first().unwrap();
assert_eq!(first_key.metadata().priority, 0);
assert_eq!(first_key.metadata().returned_tick, 2);
assert_eq!(first_key.sequence_hash(), hash1);
let second_key = map.pop_first().unwrap();
assert_eq!(second_key.metadata().priority, 0);
assert_eq!(second_key.metadata().returned_tick, 3);
assert_eq!(second_key.sequence_hash(), hash3);
let third_key = map.pop_first().unwrap();
assert_eq!(third_key.metadata().priority, 1);
assert_eq!(third_key.metadata().returned_tick, 1);
assert_eq!(third_key.sequence_hash(), hash2);
// Map should now be empty
assert!(map.is_empty());
}
// Helper function to create a sequence of tokens
pub fn create_token_sequence(values: &[u32]) -> Tokens {
let tokens: Vec<Token> = values.iter().map(|&v| Token::from(v)).collect();
Tokens::from(tokens)
}
/// Creates a block collection with the given number of blocks.
pub fn create_block_collection(
num_blocks: usize,
) -> Blocks<impl BlockLayout<StorageType = NullDeviceStorage>, TestMetadata> {
let config = LayoutConfigBuilder::default()
.num_blocks(num_blocks)
.num_layers(61)
.page_size(16)
.inner_dim(576)
.build()
.unwrap();
let layout = FullyContiguous::allocate(config, &NullDeviceAllocator)
.expect("Failed to allocate layout/storage");
Blocks::<_, TestMetadata>::new(layout, 42, 0).unwrap()
}
/// Creates a vector of Blocks from a token sequence and block size.
/// Each block is initialized to the Complete state and then Registered.
pub fn create_blocks(
tokens: Tokens,
block_size: usize,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
let (token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
let num_blocks = token_blocks.len();
if num_blocks == 0 {
return Vec::new();
}
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager);
// Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one.
for (block, token_block) in blocks.iter_mut().zip(token_blocks.into_iter()) {
assert!(block.state().is_reset()); // Start with empty blocks
block.update_state(BlockState::Complete(CompleteState::new(token_block)));
block
.register(&mut registry)
.expect("Failed to register block in test helper");
assert!(block.state().is_registered()); // Ensure registration worked
}
blocks
}
pub fn create_block_pool(
num_blocks: usize,
) -> InactiveBlockPool<NullDeviceStorage, TestMetadata> {
let mut pool = InactiveBlockPool::new();
let blocks = create_block_collection(num_blocks).into_blocks().unwrap();
pool.add_blocks(blocks);
pool
}
pub fn acquire_blocks(
tokens: Tokens,
block_size: usize,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
let (mut token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
let total_complete_blocks = token_blocks.len();
// this will match the token_blocks to any matching blocks in the inactive pool
// these blocks have the same sequence hash as the token_blocks, thus no updates are needed
let mut matched_blocks = pool.match_token_blocks(&token_blocks);
let matched_block_count = matched_blocks.len();
let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager);
// all matched blocks should be in the complete or registered state
for block in &mut matched_blocks {
assert!(block.state().is_registered());
}
// drain the matched blocks from the token_blocks
token_blocks.drain(0..matched_block_count);
assert_eq!(
token_blocks.len() + matched_blocks.len(),
total_complete_blocks
);
// try to acquire the remaining blocks
let mut unmatched_blocks = pool.acquire_free_blocks(token_blocks.len()).unwrap();
assert_eq!(unmatched_blocks.len(), token_blocks.len());
for unmatched in &unmatched_blocks {
assert!(unmatched.state().is_reset());
}
for (unmatched, token_block) in unmatched_blocks.iter_mut().zip(token_blocks.into_iter()) {
assert!(unmatched.state().is_reset());
unmatched.update_state(BlockState::Complete(CompleteState::new(token_block)));
unmatched.register(&mut registry).unwrap();
assert!(unmatched.state().is_registered());
}
let mut blocks = matched_blocks;
blocks.extend(unmatched_blocks);
(blocks, matched_block_count)
}
#[test]
fn test_block_pool_lifecycle() {
dynamo_runtime::logging::init();
const PAGE_SIZE: usize = 2;
let mut pool = create_block_pool(10);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
let blocks = pool.acquire_free_blocks(10).unwrap();
assert_eq!(blocks.len(), 10);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 0);
pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
let tokens = create_token_sequence(&[1, 2, 3, 4]);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool);
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 0);
assert_eq!(pool.available_blocks(), 8);
pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool);
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 2);
assert_eq!(pool.available_blocks(), 8);
pool.return_blocks(blocks);
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
let blocks = pool.acquire_free_blocks(10).unwrap();
for block in &blocks {
assert!(block.state().is_reset());
}
}
#[test]
fn test_basic_sequence_matching() {
let mut pool = InactiveBlockPool::new();
// Create a sequence of 4 tokens split into blocks of 2
let sequence = create_token_sequence(&[1, 2, 3, 4]);
let blocks = create_blocks(sequence, 2);
assert_eq!(blocks.len(), 2);
// Match the blocks in sequence
let hashes: Vec<_> = blocks
.iter()
.map(|b| {
b.sequence_hash()
.expect("Block should have a sequence hash in this test")
})
.collect();
// Insert blocks into pool
pool.add_blocks_with_state(blocks);
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
// Match the blocks in sequence
let matched = pool.match_sequence_hashes(hashes.clone());
assert_eq!(matched.len(), 2);
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 0);
// Validate the blocks are in the correct order and match the sequence hashes
assert_eq!(matched[0].sequence_hash().unwrap(), hashes[0]);
assert_eq!(matched[1].sequence_hash().unwrap(), hashes[1]);
// Return blocks in reverse order (tail to root)
pool.return_blocks(matched);
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PriorityKey<M: BlockMetadata> {
metadata: M,
sequence_hash: SequenceHash,
}
impl<M: BlockMetadata> PriorityKey<M> {
pub(crate) fn new(metadata: M, sequence_hash: SequenceHash) -> Self {
Self {
metadata,
sequence_hash,
}
}
pub fn sequence_hash(&self) -> SequenceHash {
self.sequence_hash
}
#[allow(dead_code)]
pub fn metadata(&self) -> &M {
&self.metadata
}
#[allow(dead_code)]
pub fn update_metadata(&mut self, metadata: M) {
self.metadata = metadata;
}
}
// customize ord and partial ord for to store first by priority (lowest to highest),
// then by return_tick (lowest to highest)
impl<M: BlockMetadata> PartialOrd for PriorityKey<M> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<M: BlockMetadata> Ord for PriorityKey<M> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.metadata
.cmp(&other.metadata)
.then(self.sequence_hash.cmp(&other.sequence_hash))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::block_manager::{
block::{registry::BlockRegistationError, BlockState, PrivateBlockExt},
events::Publisher,
};
use super::*;
impl<S: Storage, M: BlockMetadata> State<S, M> {
fn new(
event_manager: Arc<dyn EventManager>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
) -> Self {
Self {
active: ActiveBlockPool::new(),
inactive: InactiveBlockPool::new(),
registry: BlockRegistry::new(event_manager.clone()),
return_tx,
event_manager,
}
}
async fn handle_priority_request(
&mut self,
req: PriorityRequest<S, M>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) {
match req {
PriorityRequest::AllocateBlocks(req) => {
let (count, resp_tx) = req.dissolve();
let blocks = self.allocate_blocks(count);
if resp_tx.send(blocks).is_err() {
tracing::error!("failed to send response to allocate blocks");
}
}
PriorityRequest::RegisterBlocks(req) => {
let (blocks, resp_tx) = req.dissolve();
let immutable_blocks = self.register_blocks(blocks, return_rx).await;
if resp_tx.send(immutable_blocks).is_err() {
tracing::error!("failed to send response to register blocks");
}
}
PriorityRequest::MatchSequenceHashes(req) => {
let (sequence_hashes, resp_tx) = req.dissolve();
let immutable_blocks = self.match_sequence_hashes(sequence_hashes, return_rx).await;
if resp_tx.send(immutable_blocks).is_err() {
tracing::error!("failed to send response to match sequence hashes");
}
}
}
}
fn handle_control_request(&mut self, req: ControlRequest<S, M>) {
match req {
ControlRequest::AddBlocks(blocks) => {
let (blocks, resp_rx) = blocks.dissolve();
self.inactive.add_blocks(blocks);
if resp_rx.send(()).is_err() {
tracing::error!("failed to send response to add blocks");
}
}
}
}
fn handle_return_block(&mut self, block: Block<S, M>) {
self.return_block(block);
}
/// We have a strong guarantee that the block will be returned to the pool in the near future.
/// The caller must take ownership of the block
async fn wait_for_returned_block(
&mut self,
sequence_hash: SequenceHash,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Block<S, M> {
while let Some(block) = return_rx.recv().await {
if matches!(block.state(), BlockState::Registered(handle) if handle.sequence_hash() == sequence_hash)
{
return block;
}
self.handle_return_block(block);
}
unreachable!("this should be unreachable");
}
pub fn allocate_blocks(
&mut self,
count: usize,
) -> Result<Vec<MutableBlock<S, M>>, BlockPoolError> {
let available_blocks = self.inactive.available_blocks() as usize;
if available_blocks < count {
tracing::debug!(
"not enough blocks available, requested: {}, available: {}",
count,
available_blocks
);
return Err(BlockPoolError::NotEnoughBlocksAvailable(
count,
available_blocks,
));
}
let mut blocks = Vec::with_capacity(count);
for _ in 0..count {
if let Some(block) = self.inactive.acquire_free_block() {
blocks.push(MutableBlock::new(block, self.return_tx.clone()));
}
}
Ok(blocks)
}
pub async fn register_blocks(
&mut self,
blocks: Vec<MutableBlock<S, M>>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Result<Vec<ImmutableBlock<S, M>>, BlockPoolError> {
let expected_len = blocks.len();
let mut immutable_blocks = Vec::new();
// raii object that will collect all the publish handles and publish them when the object is dropped
let mut publish_handles = self.publisher();
for mut block in blocks.into_iter() {
let sequence_hash = block.sequence_hash()?;
// If the block is already registered, acquire a clone of the immutable block
if let Some(immutable) = self.active.match_sequence_hash(sequence_hash) {
immutable_blocks.push(immutable);
continue;
}
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{
assert!(matches!(raw_block.state(), BlockState::Registered(_)));
MutableBlock::new(raw_block, self.return_tx.clone())
} else {
// Attempt to register the block
// On the very rare chance that the block is registered, but in the process of being returned,
// we will wait for it to be returned and then register it.
let result = block.register(&mut self.registry);
match result {
Ok(handle) => {
publish_handles.take_handle(handle);
block
}
Err(BlockRegistationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned
let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await;
MutableBlock::new(raw_block, self.return_tx.clone())
}
Err(e) => {
return Err(BlockPoolError::FailedToRegisterBlock(e.to_string()));
}
}
};
let immutable = self.active.register(mutable)?;
immutable_blocks.push(immutable);
}
assert_eq!(immutable_blocks.len(), expected_len);
Ok(immutable_blocks)
}
async fn match_sequence_hashes(
&mut self,
sequence_hashes: Vec<SequenceHash>,
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Vec<ImmutableBlock<S, M>> {
let mut immutable_blocks = Vec::new();
for sequence_hash in sequence_hashes {
if !self.registry.is_registered(sequence_hash) {
return immutable_blocks;
}
// the block is registered, so to get it from either the:
// 1. active pool
// 2. inactive pool
// 3. return channel
if let Some(immutable) = self.active.match_sequence_hash(sequence_hash) {
immutable_blocks.push(immutable);
continue;
}
let raw_block =
if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) {
raw_block
} else {
self.wait_for_returned_block(sequence_hash, return_rx).await
};
// this assert allows us to skip the error checking on the active pool registration step
assert!(matches!(raw_block.state(), BlockState::Registered(_)));
let mutable = MutableBlock::new(raw_block, self.return_tx.clone());
let immutable = self
.active
.register(mutable)
.expect("unable to register block; should ever happen");
immutable_blocks.push(immutable);
}
immutable_blocks
}
/// Returns a block to the inactive pool
pub fn return_block(&mut self, mut block: Block<S, M>) {
self.active.remove(&mut block);
self.inactive.return_block(block);
}
fn publisher(&self) -> Publisher {
Publisher::new(self.event_manager.clone())
}
}
impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
pub fn new(
event_manager: Arc<dyn EventManager>,
priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, M>::new(event_manager, return_tx);
tracing::debug!(count = blocks.len(), "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
Self {
priority_rx,
ctrl_rx,
cancel_token,
state,
return_rx,
}
}
pub async fn step(&mut self) -> bool {
tokio::select! {
biased;
Some(priority_req) = self.priority_rx.recv(), if !self.priority_rx.is_closed() => {
self.state.handle_priority_request(priority_req, &mut self.return_rx).await;
}
Some(req) = self.ctrl_rx.recv(), if !self.ctrl_rx.is_closed() => {
self.state.handle_control_request(req);
}
Some(block) = self.return_rx.recv() => {
self.state.handle_return_block(block);
}
_ = self.cancel_token.cancelled() => {
return false;
}
}
true
}
}
// pub(crate) async fn progress_engine<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// mut priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// mut ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let (return_tx, mut return_rx) = tokio::sync::mpsc::unbounded_channel();
// let mut state = State::<S, M>::new(event_manager, return_tx);
// loop {
// tokio::select! {
// biased;
// Some(priority_req) = priority_rx.recv(), if !priority_rx.is_closed() => {
// state.handle_priority_request(priority_req, &mut return_rx).await;
// }
// Some(req) = ctrl_rx.recv(), if !ctrl_rx.is_closed() => {
// state.handle_control_request(req);
// }
// Some(block) = return_rx.recv() => {
// state.handle_return_block(block);
// }
// _ = cancel_token.cancelled() => {
// break;
// }
// }
// }
// }
// pub(crate) async fn progress_engine_v2<S: Storage, M: BlockMetadata>(
// event_manager: Arc<dyn EventManager>,
// priority_rx: tokio::sync::mpsc::UnboundedReceiver<PriorityRequest<S, M>>,
// ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
// cancel_token: CancellationToken,
// ) {
// let mut progress_engine =
// ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token);
// while progress_engine.step().await {
// tracing::trace!("progress engine step");
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use super::{block::Block, config::NixlOptions};
use cudarc::driver::CudaStream;
use std::sync::Arc;
pub struct TransferContext {
nixl_agent: Option<NixlAgent>,
stream: Arc<CudaStream>,
}
impl TransferContext {
pub fn new(nixl_agent: Option<NixlAgent>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream }
}
pub fn nixl_agent(&self) -> Option<&NixlAgent> {
self.nixl_agent.as_ref()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
#[allow(dead_code)]
pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID,
cancellation_token: CancellationToken,
nixl_agent: Option<NixlAgent>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
host_pool: Option<BlockPool<PinnedStorage, Metadata>>,
device_pool: Option<BlockPool<DeviceStorage, Metadata>>,
local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
}
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
pub fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
config
.runtime
.validate()
.context("Validating runtime config")?;
config.model.validate().context("Validating model config")?;
let worker_id = config.runtime.worker_id;
let cancellation_token = config.runtime.cancellation_token;
// Create a map of NIXL backends
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = match config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends");
let (_ucx_mem_list1, ucx_params) = agent.get_plugin_params("UCX")?;
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
Some(agent)
}
NixlOptions::EnabledWithAgent(agent) => Some(agent),
NixlOptions::Disabled => None,
};
// Initialize model-specific layout config. The layout_builder is incomplete at this point.
// We will clone this builder and apply the storage-specific configs to each clone in the
// following steps.
let model = &config.model;
let mut layout_builder = LayoutConfig::builder();
layout_builder
.num_layers(model.num_layers)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype(model.dtype);
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
// Create the host block pool if a host layout is provided
let (host_pool, host_blocks) = if let Some(config) = config.host_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Some(pool), Some(blocks))
} else {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None)
};
// Create the device block pool if a device layout is provided
let (device_pool, device_blocks) = if let Some(config) = config.device_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Some(pool), Some(blocks))
} else {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None)
};
// Finalize the local block set by adding NIXL metadata
if let Some(nixl_agent) = &nixl_agent {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
}
let state = Arc::new(Self {
worker_id,
cancellation_token,
nixl_agent,
nixl_backends,
host_pool,
device_pool,
local_block_set,
remote_block_sets: RwLock::new(HashMap::new()),
});
if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.host_pool
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
if let Some(mut blocks) = device_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.device_pool
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
Ok(state)
}
/// Exports the local blockset configuration as a serialized object.
pub fn export_local_blockset(&self) -> Result<SerializedNixlBlockSet> {
SerializedNixlBlockSet::try_from(&self.local_block_set)
.context("Failed to serialize local blockset")
}
/// Imports a remote blockset configuration from a serialized object.
// TODO: NIXL will validate the every descriptor list against the memory registration list for
// a given agent; this is can be an expensive operation. To avoid this, NIXL offers the ability
// to generate "partial pre-validated (PPV)" descriptor lists. However, to support per-block and per-layer
// PPV lists we will need as many as `num_layers + 1` PPV lists per block:
// - one for representing the entire block
// - one for representing each layer individually
//
// A deeper dive into the performance impact of PPV lists is required to determine if this is
// the best approach.
//
// If PPV are valuable, it might be beneficial to lazily instantiate PPV lists when they are
// needed; alternatively, we could generate the entire PPV list for each block at import time.
pub fn import_remote_blockset(
&self,
serialized_blockset: SerializedNixlBlockSet,
) -> Result<()> {
let remote = NixlBlockSet::try_from(serialized_blockset)
.context("Failed to deserialize remote blockset")?;
let (block_sets, metadata, worker_id) = remote.dissolve();
tracing::debug!("Importing remote blockset from worker {}", worker_id);
assert_ne!(
worker_id, self.worker_id,
"Cannot import blockset from self"
);
let agent = self
.nixl_agent
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NIXL agent not initialized"))?;
let mut remote_block_sets = self.remote_block_sets.write().unwrap();
if remote_block_sets.contains_key(&worker_id) {
anyhow::bail!(
"Worker ID {} already exists; cannot update remote blockset",
worker_id
);
}
let mut inner_map = HashMap::new();
for (block_set_idx, block_set_layout) in block_sets {
// Deserialize the individual layout and create RemoteBlocks
let remote_blocks =
RemoteBlocks::from_serialized(block_set_layout.clone(), block_set_idx, worker_id)?;
// check the storage type of the remote blocks
let layout = remote_blocks.layout();
let storage = layout.storage();
let storage = storage
.first()
.ok_or_else(|| anyhow::anyhow!("No storage found in remote blockset"))?;
match storage.mem_type() {
MemType::Dram => {
tracing::trace!(block_set_idx, "Detected Host/DRAM remote descriptor");
}
MemType::Vram => {
tracing::trace!(block_set_idx, "Detected GPU/Device/VRAM remote descriptor");
}
_ => {
tracing::warn!(
block_set_idx,
"Detected unknown remote descriptor; skipping blockset..."
);
continue;
}
}
inner_map.insert(block_set_idx, remote_blocks);
}
let agent_id = agent
.load_remote_md(&metadata)
.context("Loading remote metadata")?;
// try to convert the agent_id (String) to a WorkerID (u64)
let agent_id: WorkerID =
agent_id // Assuming agent_id is String here
.parse() // Parse the String into u64 (WorkerID)
.context("Failed to parse agent ID string into WorkerID (u64)")?;
assert_eq!(agent_id, worker_id, "Mismatch with remote worker ID");
remote_block_sets.insert(worker_id, inner_map);
Ok(())
}
/// Get a [`Vec<RemoteBlock<IsImmutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_immutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsImmutable>>> {
// no checks - we can always create an immutable remote block even if the bds is mutable
self.get_remote_blocks::<IsImmutable>(bds)
}
/// Get a [`Vec<RemoteBlock<IsMutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_mutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsMutable>>> {
if bds.mutability() == BlockMutability::Mutable {
self.get_remote_blocks::<IsMutable>(bds)
} else {
anyhow::bail!("Cannot get mutable remote blocks for immutable block descriptor set");
}
}
/// Generate a [`Vec<RemoteBlock>`] from a [`BlockDescriptorList`]
fn get_remote_blocks<M: MutabilityKind>(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<M>>> {
// Get a read lock on the remote block sets
let remote_block_sets = self.remote_block_sets.read().unwrap();
// validate we have loaded a remote blockset for the worker and the specific block_set_idx
let remote_blocks = remote_block_sets
.get(&bds.worker_id())
.and_then(|map| map.get(&bds.block_set_idx()))
.ok_or_else(|| {
anyhow::anyhow!(
"No remote blockset found for worker {} and block_set_idx {}",
bds.worker_id(),
bds.block_set_idx()
)
})?;
// Iterate through indices, call .block() for each, and collect results.
// The collect::<Result<...>>() handles potential errors from .block()
let blocks: Vec<block::nixl::RemoteBlock<M>> = bds
.block_indices()
.iter()
.map(|block_idx| remote_blocks.block(*block_idx))
.collect::<Result<Vec<_>, _>>()?;
Ok(blocks)
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref()
}
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref()
}
pub fn worker_id(&self) -> WorkerID {
self.worker_id
}
}
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "KvBlockManagerState")
}
}
fn create_layout<S: Storage + NixlRegisterableStorage>(
mut builder: LayoutConfigBuilder,
config: KvManagerLayoutConfig<S>,
nixl_agent: Option<&NixlAgent>,
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> {
let layout = builder.num_blocks(config.num_blocks).build()?;
if let Some(storage) = config.storage {
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
if let Some(allocator) = config.allocator {
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
anyhow::bail!("failed to create layout");
}
#[expect(clippy::type_complexity)]
fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
layout: Arc<dyn NixlLayout<StorageType = S>>,
block_set_idx: usize,
cancellation_token: CancellationToken,
worker_id: WorkerID,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token)
.build()?;
Ok((pool, blocks))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![deny(missing_docs)]
//! # Storage Management
//!
//! This module provides a unified interface for managing different types of memory storage used in the block manager.
//! It handles various memory types including system memory, pinned memory, device memory, and remote storage through NIXL.
//!
//! ## Core Concepts
//!
//! ### Storage Types
//! The module defines [`Storage`] trait which is implemented for all storage types. The primary module provide a
//! [`Storage`] implementation for system memory via [`SystemStorage`].
//!
//! CUDA support is provided via the [`cuda`] module.
//! NIXL support is provided via the [`nixl`] module.
//!
//! ### Memory Registration
//! Storage objects can be registered with external libraries (like NIXL) through the [`RegisterableStorage`] trait.
//! This registration process:
//! - Creates a registration handle that ties the external library's state to the storage's lifetime
//! - Ensures proper cleanup through the [`Drop`] implementation of [`RegistrationHandles`]
//! - Provides a safe way to manage external library resources
//!
//! ### Safety and Performance
//! The module emphasizes:
//! - Memory safety through proper lifetime management
//! - Thread safety with appropriate trait bounds
//! - Performance optimization for different memory types
//! - Automatic resource cleanup
//!
//! ## Usage
//!
//! Storage objects are typically created through their respective allocators:
//! ```rust
//! use dynamo_llm::block_manager::storage::{SystemAllocator, StorageAllocator};
//!
//! let system_allocator = SystemAllocator::default();
//! let storage = system_allocator.allocate(1024).unwrap();
//! ```
//!
//! For registering with external libraries:
//! ```rust
//! use dynamo_llm::block_manager::storage::{
//! PinnedAllocator, StorageAllocator,
//! nixl::NixlRegisterableStorage
//! };
//! use nixl_sys::Agent as NixlAgent;
//!
//! // Create a NIXL agent
//! let agent = NixlAgent::new("my_agent").unwrap();
//!
//! let mut storage = PinnedAllocator::default().allocate(1024).unwrap();
//! storage.nixl_register(&agent, None).unwrap();
//! ```
//!
//! ## Implementation Details
//!
//! The module uses several key traits to provide a unified interface:
//! - [`Storage`] - Core trait for memory access
//! - [`RegisterableStorage`] - Support for external library registration
//! - [`StorageMemset`] - Memory initialization operations
//! - [`StorageAllocator`] - Factory for creating storage instances
pub mod cuda;
pub mod nixl;
pub use cuda::*;
use std::{
alloc::{alloc_zeroed, dealloc, Layout},
collections::HashMap,
fmt::Debug,
ptr::NonNull,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Result type for storage operations
pub type StorageResult<T> = std::result::Result<T, StorageError>;
/// Represents the type of storage used for a block
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageType {
/// System memory
System,
/// CUDA device memory
Device(u32),
/// CUDA page-locked host memory
Pinned,
/// Remote memory accessible through NIXL
Nixl,
/// Null storage
Null,
}
/// A block that is local to the current worker
pub trait Local {}
/// A block that is remote to the current worker
pub trait Remote {}
/// Marker trait for [`Storage`] types that can be accessed by the standard
/// mechanisms of the system, e.g. `memcpy`, `memset`, etc.
pub trait SystemAccessible: Storage {}
/// Errors that can occur during storage operations
#[derive(Debug, Error)]
#[allow(missing_docs)]
pub enum StorageError {
#[error("Storage allocation failed: {0}")]
AllocationFailed(String),
#[error("Storage not accessible: {0}")]
NotAccessible(String),
#[error("Invalid storage configuration: {0}")]
InvalidConfig(String),
#[error("Storage operation failed: {0}")]
OperationFailed(String),
#[error("Registration key already exists: {0}")]
RegistrationKeyExists(String),
#[error("Handle not found for key: {0}")]
HandleNotFound(String),
#[error("CUDA error: {0}")]
CudaError(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")]
NixlError(#[from] nixl_sys::NixlError),
}
/// Core storage trait that provides access to memory regions
pub trait Storage: Debug + Send + Sync + 'static {
/// Returns the type of storage
fn storage_type(&self) -> StorageType;
/// Returns the address of the storage
fn addr(&self) -> u64;
/// Returns the total size of the storage in bytes
fn size(&self) -> usize;
/// Get a raw pointer to the storage
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the storage is dropped
/// - Access patterns respect the storage's thread safety model
unsafe fn as_ptr(&self) -> *const u8;
/// Get a raw mutable pointer to the storage
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the storage is dropped
/// - No other references exist while the pointer is in use
/// - Access patterns respect the storage's thread safety model
unsafe fn as_mut_ptr(&mut self) -> *mut u8;
}
/// Extension trait for storage types that support memory setting operations
pub trait StorageMemset: Storage {
/// Sets a region of memory to a specific value
///
/// # Arguments
/// * `value` - The value to set (will be truncated to u8)
/// * `offset` - Offset in bytes from the start of the storage
/// * `size` - Number of bytes to set
///
/// # Safety
/// The caller must ensure:
/// - offset + size <= self.size()
/// - No other references exist to the memory region being set
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
}
/// Registerable storage is a [Storage] that can be associated with one or more
/// [RegistationHandle]s.
///
/// The core concept here is that the storage might be registered with a library
/// like NIXL or some other custom library which might make some system calls on
/// viritual addresses of the storage.
///
/// Before the [Storage] is dropped, the [RegistationHandle]s should be released.
///
/// The behavior is enforced via the [Drop] implementation for [RegistrationHandles].
pub trait RegisterableStorage: Storage + Send + Sync + 'static {
/// Register a handle with a key
/// If a handle with the same key already exists, an error is returned
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError>;
/// Check if a handle is registered with a key
fn is_registered(&self, key: &str) -> bool;
/// Get a reference to the registration handle for a key
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle>;
}
/// Designed to be implemented by any type that can be used as a handle to a
/// [RegisterableStorage].
///
/// See [RegisterableStorage] for more details.
pub trait RegistationHandle: std::any::Any + Send + Sync + 'static {
/// Release the [RegistationHandle].
/// This should be called when the external registration of this storage
/// is no longer needed.
///
/// Note: All [RegistrationHandle]s should be explicitly released before
/// the [Storage] is dropped.
fn release(&mut self);
}
/// A collection of [RegistrationHandle]s for a [RegisterableStorage].
///
/// This is used to ensure that all [RegistrationHandle]s are explicitly released
/// before the [RegisterableStorage] is dropped.
#[derive(Default)]
pub struct RegistrationHandles {
handles: HashMap<String, Box<dyn RegistationHandle>>,
}
impl RegistrationHandles {
/// Create a new [RegistrationHandles] instance
pub fn new() -> Self {
Self {
handles: HashMap::new(),
}
}
/// Register a handle with a key
/// If a handle with the same key already exists, an error is returned
pub fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
let key = key.to_string();
if self.handles.contains_key(&key) {
return Err(StorageError::RegistrationKeyExists(key));
}
self.handles.insert(key, handle);
Ok(())
}
/// Release all handles
fn release(&mut self) {
for handle in self.handles.values_mut() {
handle.release();
}
self.handles.clear();
}
/// Check if a handle is registered with a key
fn is_registered(&self, key: &str) -> bool {
self.handles.contains_key(key)
}
/// Get a reference to the registration handle for a key
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.get(key).map(|h| h.as_ref())
}
}
impl std::fmt::Debug for RegistrationHandles {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RegistrationHandles {{ count: {:?} }}",
self.handles.len()
)
}
}
impl Drop for RegistrationHandles {
fn drop(&mut self) {
if !self.handles.is_empty() {
panic!("RegistrationHandles dropped with {} handles remaining; RegistrationHandles::release() needs to be explicitly called", self.handles.len());
}
}
}
/// Trait for types that can allocate specific Storage implementations.
pub trait StorageAllocator<S: Storage>: Send + Sync {
/// Allocate storage of the specific type `S` with the given size in bytes.
fn allocate(&self, size: usize) -> Result<S, StorageError>;
}
/// System memory storage implementation using pinned memory
#[derive(Debug)]
pub struct SystemStorage {
ptr: NonNull<u8>,
layout: Layout,
len: usize,
handles: RegistrationHandles,
}
unsafe impl Send for SystemStorage {}
unsafe impl Sync for SystemStorage {}
impl Local for SystemStorage {}
impl SystemAccessible for SystemStorage {}
impl SystemStorage {
/// Create a new system storage with the given size
///
/// # Safety
/// This function allocates memory that will be freed when the SystemStorage is dropped.
pub fn new(size: usize) -> Result<Self, StorageError> {
// Create layout for the allocation, ensuring proper alignment
let layout =
Layout::array::<u8>(size).map_err(|e| StorageError::AllocationFailed(e.to_string()))?;
// Allocate zeroed memory
let ptr = unsafe {
NonNull::new(alloc_zeroed(layout))
.ok_or_else(|| StorageError::AllocationFailed("memory allocation failed".into()))?
};
Ok(Self {
ptr,
layout,
len: size,
handles: RegistrationHandles::new(),
})
}
}
impl Drop for SystemStorage {
fn drop(&mut self) {
self.handles.release();
unsafe {
dealloc(self.ptr.as_ptr(), self.layout);
}
}
}
impl Storage for SystemStorage {
fn storage_type(&self) -> StorageType {
StorageType::System
}
fn addr(&self) -> u64 {
self.ptr.as_ptr() as u64
}
fn size(&self) -> usize {
self.len
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl StorageMemset for SystemStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.ptr.as_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
impl RegisterableStorage for SystemStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
/// Allocator for SystemStorage
#[derive(Debug, Default, Clone, Copy)]
pub struct SystemAllocator;
impl StorageAllocator<SystemStorage> for SystemAllocator {
fn allocate(&self, size: usize) -> Result<SystemStorage, StorageError> {
SystemStorage::new(size)
}
}
#[allow(missing_docs)]
pub mod tests {
use super::*;
#[derive(Debug)]
pub struct NullDeviceStorage {
size: u64,
}
impl NullDeviceStorage {
pub fn new(size: u64) -> Self {
Self { size }
}
}
impl Storage for NullDeviceStorage {
fn storage_type(&self) -> StorageType {
StorageType::Null
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size as usize
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
pub struct NullDeviceAllocator;
impl StorageAllocator<NullDeviceStorage> for NullDeviceAllocator {
fn allocate(&self, size: usize) -> Result<NullDeviceStorage, StorageError> {
Ok(NullDeviceStorage::new(size as u64))
}
}
#[derive(Debug)]
pub struct NullHostStorage {
size: u64,
}
impl NullHostStorage {
pub fn new(size: u64) -> Self {
Self { size }
}
}
impl Storage for NullHostStorage {
fn storage_type(&self) -> StorageType {
StorageType::Null
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size as usize
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
pub struct NullHostAllocator;
impl StorageAllocator<NullHostStorage> for NullHostAllocator {
fn allocate(&self, size: usize) -> Result<NullHostStorage, StorageError> {
Ok(NullHostStorage::new(size as u64))
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # CUDA Storage Support
//!
//! This module provides CUDA-specific storage implementations for the block manager.
//! It is conditionally compiled based on the `cuda` feature flag.
//!
//! ## Features
//!
//! The following types are available when the `cuda` feature is enabled:
//! - [`PinnedStorage`] - Page-locked host memory for efficient GPU transfers
//! - [`DeviceStorage`] - Direct GPU memory allocation
//!
//! ## Storage Allocators
//!
//! The module provides allocators for each storage type:
//! - [`PinnedAllocator`] - Creates pinned host memory allocations
//! - [`DeviceAllocator`] - Creates device memory allocations
//!
//! ## CUDA Context Management
//!
//! The module provides a singleton [`Cuda`] type for managing CUDA contexts:
//! - Thread-safe context management
//! - Lazy initialization of device contexts
//! - Automatic cleanup of resources
//!
//! ## Usage
//!
//! ### Using Allocators
//! ```rust
//! use dynamo_llm::block_manager::storage::{DeviceAllocator, PinnedAllocator, StorageAllocator};
//!
//! // Create a pinned memory allocator
//! let pinned_allocator = PinnedAllocator::default();
//! let pinned_storage = pinned_allocator.allocate(1024).unwrap();
//!
//! // Create a device memory allocator for a specific device
//! let device_allocator = DeviceAllocator::new(1).unwrap(); // Use device 1
//! let device_storage = device_allocator.allocate(1024).unwrap();
//! ```
//!
//! ### Memory Operations
//! ```rust
//! use dynamo_llm::block_manager::storage::{
//! PinnedAllocator, StorageAllocator, Storage, StorageMemset
//! };
//!
//! // Initialize memory
//! let mut storage = PinnedAllocator::default().allocate(1024).unwrap();
//!
//! // Initialize memory
//! storage.memset(0, 0, 1024).unwrap();
//!
//! // Access memory through raw pointers (requires unsafe)
//! unsafe {
//! let ptr = storage.as_mut_ptr();
//! // Use the pointer...
//! }
//! ```
//!
//! ## Safety
//!
//! All CUDA operations are wrapped in safe Rust interfaces that ensure:
//! - Proper resource cleanup
//! - Thread safety
//! - Memory alignment requirements
//! - Error handling for CUDA operations
use super::*;
use std::{
collections::HashMap,
sync::{Arc, Mutex, OnceLock},
};
use cudarc::driver::{sys, CudaContext};
/// Trait for [Storage] types that can be accessed by CUDA
pub trait CudaAccessible: Storage {}
/// Trait for types that can provide a CUDA context.
pub trait CudaContextProivder {
/// Get a referene to the [`CudaContext`].
fn cuda_context(&self) -> &Arc<CudaContext>;
}
/// Singleton for managing CUDA contexts.
pub struct Cuda {
contexts: HashMap<usize, Arc<CudaContext>>,
}
impl Cuda {
// Private constructor
fn new() -> Self {
Self {
contexts: HashMap::new(),
}
}
/// Get a CUDA context for a specific device_id.
/// If the context does not exist, it will return None.
///
/// This will not lazily instantiate a context for a device. Use
/// [Cuda::get_or_init_device]
pub fn device(device_id: usize) -> Option<Arc<CudaContext>> {
Cuda::instance()
.lock()
.unwrap()
.get_existing_context(device_id)
}
/// Get or initialize a CUDA context for a specific device_id.
/// If the context does not exist, it will be created or fail.
///
/// This will lazily instantiate a context for a device. Use
/// [CudaContextManager::device] to get an existing context.
pub fn get_or_init_device(device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
Cuda::instance().lock().unwrap().get_context(device_id)
}
/// Check if a CUDA context exists for a specific device_id.
pub fn is_initialized(device_id: usize) -> bool {
Cuda::instance().lock().unwrap().has_context(device_id)
}
// Get the singleton instance
fn instance() -> &'static Mutex<Cuda> {
static INSTANCE: OnceLock<Mutex<Cuda>> = OnceLock::new();
INSTANCE.get_or_init(|| Mutex::new(Cuda::new()))
}
// Get or create a CUDA context for a specific device
fn get_context(&mut self, device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
// Check if we already have a context for this device
if let Some(ctx) = self.contexts.get(&device_id) {
return Ok(ctx.clone());
}
// Create a new context for this device
let ctx = CudaContext::new(device_id)?;
// Store the context
self.contexts.insert(device_id, ctx.clone());
Ok(ctx)
}
// Get a context if it exists, but don't create one
fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> {
self.contexts.get(&device_id).cloned()
}
// Check if a context exists for a device
fn has_context(&self, device_id: usize) -> bool {
self.contexts.contains_key(&device_id)
}
}
/// Pinned host memory storage using CUDA page-locked memory
#[derive(Debug)]
pub struct PinnedStorage {
ptr: u64,
size: usize,
handles: RegistrationHandles,
ctx: Arc<CudaContext>,
}
impl Local for PinnedStorage {}
impl SystemAccessible for PinnedStorage {}
impl CudaAccessible for PinnedStorage {}
impl PinnedStorage {
/// Create a new pinned storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
unsafe {
ctx.bind_to_thread().map_err(StorageError::CudaError)?;
let ptr = cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::CudaError)?;
let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(size < isize::MAX as usize);
let ptr = ptr as u64;
Ok(Self {
ptr,
size,
handles: RegistrationHandles::new(),
ctx: ctx.clone(),
})
}
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
self.handles.release();
unsafe { cudarc::driver::result::free_host(self.ptr as _) }.unwrap();
}
}
impl Storage for PinnedStorage {
fn storage_type(&self) -> StorageType {
StorageType::Pinned
}
fn addr(&self) -> u64 {
self.ptr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl CudaContextProivder for PinnedStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl RegisterableStorage for PinnedStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
impl StorageMemset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError> {
if offset + size > self.size {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = (self.ptr as *mut u8).add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
/// Allocator for PinnedStorage
pub struct PinnedAllocator {
ctx: Arc<CudaContext>,
}
impl Default for PinnedAllocator {
fn default() -> Self {
Self {
ctx: Cuda::get_or_init_device(0).expect("Failed to create CUDA context"),
}
}
}
impl PinnedAllocator {
/// Create a new pinned allocator
pub fn new() -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::get_or_init_device(0)?,
})
}
}
impl StorageAllocator<PinnedStorage> for PinnedAllocator {
fn allocate(&self, size: usize) -> Result<PinnedStorage, StorageError> {
PinnedStorage::new(&self.ctx, size)
}
}
/// CUDA device memory storage
#[derive(Debug)]
pub struct DeviceStorage {
ptr: u64,
size: usize,
ctx: Arc<CudaContext>,
handles: RegistrationHandles,
}
impl Local for DeviceStorage {}
impl CudaAccessible for DeviceStorage {}
impl DeviceStorage {
/// Create a new device storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
ctx.bind_to_thread().map_err(StorageError::CudaError)?;
let ptr =
unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::CudaError)? };
Ok(Self {
ptr,
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
})
}
/// Get the CUDA context
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Storage for DeviceStorage {
fn storage_type(&self) -> StorageType {
StorageType::Device(self.ctx.cu_device() as u32)
}
fn addr(&self) -> u64 {
self.ptr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl CudaContextProivder for DeviceStorage {
fn cuda_context(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
self.handles.release();
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap();
}
}
impl RegisterableStorage for DeviceStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
/// Allocator for DeviceStorage
pub struct DeviceAllocator {
ctx: Arc<CudaContext>,
}
impl Default for DeviceAllocator {
fn default() -> Self {
Self {
ctx: CudaContext::new(0).expect("Failed to create CUDA context"),
}
}
}
impl DeviceAllocator {
/// Create a new device allocator
pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self {
ctx: Cuda::get_or_init_device(device_id)?,
})
}
/// Get the CUDA context
pub fn ctx(&self) -> &Arc<CudaContext> {
&self.ctx
}
}
impl StorageAllocator<DeviceStorage> for DeviceAllocator {
fn allocate(&self, size: usize) -> Result<DeviceStorage, StorageError> {
DeviceStorage::new(&self.ctx, size)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # NIXL Storage Support
//!
//! This module provides NIXL-specific storage implementations and integration for the block manager.
//! It is conditionally compiled based on the `nixl` feature flag.
//!
//! ## Features
//!
//! The following functionality is available when the `nixl` feature is enabled:
//! - [`NixlStorage`] - Remote memory representation
//! - [`NixlRegisterableStorage`] - Trait for NIXL-compatible storage types
//! - Integration with the NIXL agent system for remote memory access
//!
//! ## Memory Registration
//!
//! The module extends the core storage types with NIXL registration capabilities:
//! - Automatic registration handle management
//! - Memory type mapping between storage and NIXL types
//! - Device ID tracking for GPU memory
//!
//! ## Usage
//!
//! ```rust
//! use dynamo_llm::block_manager::storage::{
//! PinnedAllocator, StorageAllocator,
//! nixl::NixlRegisterableStorage
//! };
//! use nixl_sys::Agent as NixlAgent;
//!
//! // Create a NIXL agent
//! let agent = NixlAgent::new("my_agent").unwrap();
//!
//! // Create storage using an allocator
//! let pinned_allocator = PinnedAllocator::default();
//! let mut storage = pinned_allocator.allocate(1024).unwrap();
//!
//! // Initially no NIXL descriptors are available
//! assert!(unsafe { storage.as_nixl_descriptor() }.is_none());
//!
//! // Register with NIXL
//! storage.nixl_register(&agent, None).unwrap();
//!
//! // Now we can get NIXL descriptors
//! // NIXL descriptors are not owned by the storage, so we need to access them
//! // through an unsafe method.
//! if let Some(nixl_desc) = unsafe { storage.as_nixl_descriptor() } {
//! // Use NIXL memory region
//! println!("NIXL memory at addr: {}", nixl_desc.addr());
//! println!("Memory type: {:?}", nixl_desc.mem_type());
//! println!("Device ID: {}", nixl_desc.device_id());
//! }
//! ```
//!
//! ## Safety
//!
//! The module ensures safe interaction with NIXL by:
//! - Managing registration lifetimes
//! - Validating memory types and device IDs
//! - Providing type-safe interfaces for remote memory access
//! - Automatic cleanup of NIXL resources
pub use nixl_sys::{
Agent as NixlAgent, MemType, MemoryRegion, NixlDescriptor, OptArgs,
RegistrationHandle as NixlRegistrationHandle,
};
use derive_getters::Getters;
use serde::{Deserialize, Serialize};
use super::{
CudaContextProivder, DeviceStorage, PinnedStorage, RegistationHandle, RegisterableStorage,
Remote, Storage, StorageError, StorageType, SystemStorage,
};
/// Marker trait for storage types that can be accessed by NIXL.
///
/// This trait is different from [`NixlRegisterableStorage`] which has further restrictions
/// that the [`Storage`] must be [`RegisterableStorage`].
///
/// Remote memory described by [`NixlStorage`] is [`NixlAccessible`] but is not [`NixlRegisterableStorage`]
/// due to the fact it represents memory that is registered to another NIXL agent.
pub trait NixlAccessible {}
impl StorageType {
/// Get the NIXL memory type for a given storage type.
pub fn nixl_mem_type(&self) -> MemType {
match self {
StorageType::System => MemType::Dram,
StorageType::Pinned => MemType::Dram,
StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown,
}
}
/// Get the NIXL device ID for a given storage type.
pub fn nixl_device_id(&self) -> u64 {
match self {
StorageType::System => 0,
StorageType::Pinned => 0,
StorageType::Device(id) => *id as u64,
StorageType::Nixl => 0,
StorageType::Null => 0,
}
}
}
impl RegistationHandle for NixlRegistrationHandle {
fn release(&mut self) {
if let Err(e) = self.deregister() {
tracing::error!("Failed to deregister Nixl storage: {}", e);
}
}
}
/// Extension to the [`RegisterableStorage`] trait for NIXL-compatible storage.
pub trait NixlRegisterableStorage: RegisterableStorage + NixlDescriptor + Sized {
/// Register the storage with the NIXL agent.
fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> Result<(), StorageError> {
let handle = Box::new(agent.register_memory(self, opt_args)?);
// Assuming PinnedStorage has `handles: RegistrationHandles`
self.register("nixl", handle)
}
/// Check if the storage is registered with the NIXL agent.
fn is_nixl_registered(&self) -> bool {
self.is_registered("nixl")
}
/// Get the NIXL agent name for the storage.
fn nixl_agent_name(&self) -> Option<String> {
// Get the registration handle associated with "nixl".
self.registration_handle("nixl")
// If a handle exists, attempt to downcast it.
.and_then(|handle_box| {
// Cast the trait object &dyn RegistationHandle to &dyn Any
// then attempt to downcast to the concrete NixlRegistrationHandle type.
// Note: This requires RegistationHandle: Any + 'static
(handle_box as &dyn std::any::Any)
.downcast_ref::<NixlRegistrationHandle>()
// If downcast succeeds, get the agent name.
.map(|nixl_handle| nixl_handle.agent_name())
})?
}
/// If the underlying storage is NIXL-compatible, return descriptions of the NIXL memory regions.
/// This is used for serialization/deserialization of NIXL-specific layouts.
///
/// # Safety
///
/// This function is unsafe because because ownership of the storage is not transferred.
unsafe fn as_nixl_descriptor(&self) -> Option<NixlStorage> {
if self.is_nixl_registered() {
Some(NixlStorage {
addr: self.addr(),
size: MemoryRegion::size(self),
mem_type: self.mem_type(),
device_id: self.device_id(),
})
} else {
None
}
}
}
/// NIXL-compatible storage
///
/// This object does not own any memory, it is meant to hold descriptions
/// of non-local/remote memory.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct NixlStorage {
addr: u64,
size: usize,
mem_type: MemType,
device_id: u64,
}
impl Remote for NixlStorage {}
impl NixlAccessible for NixlStorage {}
impl Storage for NixlStorage {
fn storage_type(&self) -> StorageType {
StorageType::Nixl
}
fn addr(&self) -> u64 {
self.addr
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
}
impl MemoryRegion for NixlStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl NixlDescriptor for NixlStorage {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
// SystemStorage
impl NixlRegisterableStorage for SystemStorage {}
impl MemoryRegion for SystemStorage {
unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
fn size(&self) -> usize {
self.len
}
}
impl NixlDescriptor for SystemStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
// PinnedStorage
impl NixlAccessible for PinnedStorage {}
impl NixlRegisterableStorage for PinnedStorage {}
impl MemoryRegion for PinnedStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for PinnedStorage {
fn mem_type(&self) -> MemType {
MemType::Dram
}
fn device_id(&self) -> u64 {
0
}
}
// DeviceStorage
impl NixlAccessible for DeviceStorage {}
impl NixlRegisterableStorage for DeviceStorage {}
impl MemoryRegion for DeviceStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DeviceStorage {
fn mem_type(&self) -> MemType {
MemType::Vram
}
fn device_id(&self) -> u64 {
CudaContextProivder::cuda_context(self).cu_device() as u64
}
}
......@@ -13,4 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod dtype;
pub mod versioned;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
/// Represents the data type of tensor elements
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {
FP8,
FP16,
BF16,
FP32,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl DType {
/// Get the size of the data type in bytes
pub fn size_in_bytes(&self) -> usize {
match self {
DType::FP32 => 4,
DType::FP16 => 2,
DType::BF16 => 2,
DType::FP8 => 1,
DType::U8 => 1,
DType::U16 => 2,
DType::U32 => 4,
DType::U64 => 8,
DType::I8 => 1,
DType::I16 => 2,
DType::I32 => 4,
DType::I64 => 8,
}
}
}
......@@ -42,7 +42,7 @@ use crate::{
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
tokens::Tokens,
tokens::TokenBlockSequence,
};
use dynamo_runtime::traits::events::EventSubscriber;
......@@ -149,14 +149,13 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let isl_tokens = request.tokens.len();
let block_size = self.block_size;
// Compute the block hashes in a blocking task
let local_block_hashes: Vec<LocalBlockHash> = tokio::task::spawn_blocking(move || {
Tokens::compute_block_hash(&request.tokens, block_size)
.into_iter()
.map(LocalBlockHash)
.collect()
})
.await?;
let (complete_blocks, _partial_block) =
TokenBlockSequence::split_tokens(&request.tokens, block_size, 1337_u64);
let local_block_hashes = complete_blocks
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
......
......@@ -27,7 +27,6 @@ pub mod http;
pub mod hub;
pub mod key_value_store;
pub mod kv_router;
mod local_model;
pub mod model_card;
pub mod model_type;
pub mod preprocessor;
......@@ -38,7 +37,8 @@ pub mod tokenizers;
pub mod tokens;
pub mod types;
mod local_model;
pub use local_model::LocalModel;
#[cfg(feature = "cuda_kv")]
pub mod kv;
#[cfg(feature = "block-manager")]
pub mod block_manager;
......@@ -13,20 +13,44 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::indexer::compute_hash;
#![allow(dead_code)]
//! Types and utilities for handling sequences of tokens, including block creation and hashing.
use bytemuck::cast_slice;
use derive_getters::{Dissolve, Getters};
use derive_getters::Dissolve;
use rayon::prelude::*;
use std::ops::Range;
/// A token is represented as a 32-bit unsigned integer.
pub type Token = u32;
/// A hash of the only the tokens within a block computed from [compute_hash].
/// A salt used for hashing, represented as a vector of bytes.
/// This might encode model architecture, weights, PEFT info, etc.
pub type Salt = Vec<u8>;
/// A 64-bit hash of the salt, computed using [`compute_hash_v2`] with a seed of 0.
/// Used as the initial seed for subsequent block hashes.
pub type SaltHash = u64;
/// A 64-bit hash computed only from the tokens within a single block.
/// It uses [`compute_hash_v2`] with the [`SaltHash`] as the seed.
pub type BlockHash = u64;
/// A sequence aware hash that combines the previous block's sequence hash with the current block's hash.
/// A 64-bit sequence-aware hash.
/// It combines the previous block's [`SequenceHash`] (or the [`SaltHash`] for the first block)
/// with the current block's [`BlockHash`] using [`compute_hash_v2`] and the [`SaltHash`] as the seed.
pub type SequenceHash = u64;
#[derive(Debug, Clone, Dissolve, Default)]
/// Computes a hash of the data using the given seed.
pub fn compute_hash_v2(data: &[u8], seed: u64) -> u64 {
xxhash_rust::xxh3::xxh3_64_with_seed(data, seed)
}
/// A collection of tokens, represented as a `Vec<Token>`.
///
/// Provides convenience methods for conversion and manipulation.
#[derive(Debug, Clone, Dissolve, Default, Eq)]
pub struct Tokens(Vec<Token>);
impl AsRef<[Token]> for Tokens {
......@@ -35,12 +59,6 @@ impl AsRef<[Token]> for Tokens {
}
}
impl AsMut<[Token]> for Tokens {
fn as_mut(&mut self) -> &mut [Token] {
&mut self.0
}
}
impl std::ops::Deref for Tokens {
type Target = [Token];
......@@ -49,28 +67,12 @@ impl std::ops::Deref for Tokens {
}
}
impl std::ops::DerefMut for Tokens {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl std::borrow::Borrow<[Token]> for Tokens {
fn borrow(&self) -> &[Token] {
&self.0
}
}
impl IntoIterator for Tokens {
type Item = Token;
type IntoIter = std::vec::IntoIter<Token>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl From<Vec<Token>> for Tokens {
fn from(tokens: Vec<Token>) -> Self {
Tokens(tokens)
......@@ -84,12 +86,14 @@ impl From<&[Token]> for Tokens {
}
impl From<Vec<i32>> for Tokens {
/// Converts `Vec<i32>` to `Tokens`, casting each `i32` to `u32`.
fn from(tokens: Vec<i32>) -> Self {
Tokens(tokens.into_iter().map(|t| t as u32).collect())
}
}
impl From<&[i32]> for Tokens {
/// Converts `&[i32]` to `Tokens`, casting each `i32` to `u32`.
fn from(tokens: &[i32]) -> Self {
Tokens(tokens.iter().map(|&t| t as u32).collect())
}
......@@ -101,53 +105,248 @@ impl From<Tokens> for Vec<Token> {
}
}
impl Tokens {
pub fn into_sequence(self, block_size: usize) -> TokenSequence {
TokenSequence::new(self, block_size)
// PartialEq implementations for comparing Tokens with Vec<Token> and &[Token]
// (Generated implementations are usually sufficient, but explicit ones can be clearer)
impl PartialEq<Vec<Token>> for Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
impl PartialEq<Tokens> for Vec<Token> {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0
}
}
pub fn compute_block_hash(tokens: &[Token], block_size: usize) -> Vec<BlockHash> {
tokens
.par_chunks_exact(block_size)
.map(|chunk| compute_hash(cast_slice(chunk)))
.collect()
impl PartialEq<[Token]> for Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
impl PartialEq<Tokens> for &[Token] {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0.as_slice()
}
}
impl PartialEq for Tokens {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
// Add PartialEq<&[T]> where T: Into<Token> + Copy could be more general,
// but specifically implementing for &[Token] is sufficient for the tests.
impl PartialEq<&[Token]> for Tokens {
fn eq(&self, other: &&[Token]) -> bool {
self.0.as_slice() == *other
}
}
impl Tokens {
/// Consumes the [`Tokens`] object and creates a [`TokenBlockSequence`].
///
/// The sequence is initialized with the provided tokens, splitting them into blocks
/// of the specified `block_size` using the given `salt_hash` (or 0 if `None`).
///
/// # Arguments
///
/// * `block_size` - The fixed size for each [`TokenBlock`].
/// * `salt_hash` - An optional [`SaltHash`] used as the base seed for hashing. Defaults to 0.
pub fn into_sequence(
self,
block_size: usize,
salt_hash: Option<SaltHash>,
) -> TokenBlockSequence {
TokenBlockSequence::new(self, block_size, salt_hash)
}
}
/// Errors that can occur during [`PartialTokenBlock`] operations.
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum TokenBlockError {
/// The operation could not be completed because the block is full.
#[error("TokenBlock is full")]
Full,
/// The operation requires a full block, but the block is incomplete.
#[error("TokenBlock is incomplete")]
Incomplete,
/// The operation could not be completed because the block is empty.
#[error("TokenBlock is empty")]
Empty,
/// The operation requires more tokens than are currently in the block.
#[error("TokenBlock has insufficient tokens")]
InsufficientTokens,
}
/// Represents a partially filled block of tokens within a sequence.
///
/// This structure accumulates tokens until it reaches the specified `block_size`,
/// at which point it can be [`commit`](PartialTokenBlock::commit)ted into a full [`TokenBlock`].
#[derive(Debug)] // No Clone: intended to be unique within a sequence
pub struct PartialTokenBlock {
tokens: Tokens,
block_size: usize,
salt_hash: SaltHash,
parent_sequence_hash: Option<SequenceHash>,
}
impl PartialTokenBlock {
/// Push a token onto the block, if the block is full, return a new [TokenBlock]
/// and reset the incomplete block
pub fn push_token(&mut self, token: Token) -> Option<TokenBlock> {
/// Creates the first partial block (root) for a new sequence.
///
/// # Arguments
///
/// * `block_size` - The fixed size for blocks in this sequence.
/// * `salt_hash` - The [`SaltHash`] for the sequence.
pub(crate) fn create_sequence_root(block_size: usize, salt_hash: SaltHash) -> Self {
Self {
tokens: Tokens::default(),
block_size,
salt_hash,
parent_sequence_hash: None, // Root has no parent
}
}
/// Attempts to push a single token onto the block.
///
/// # Arguments
///
/// * `token` - The [`Token`] to push.
///
/// # Returns
///
/// * `Ok(())` - If the token was successfully added.
/// * `Err(TokenBlockError::Full)` - If the block already contains `block_size` tokens.
pub(crate) fn push_token(&mut self, token: Token) -> Result<(), TokenBlockError> {
if self.tokens.0.len() >= self.block_size {
return Err(TokenBlockError::Full);
}
self.tokens.0.push(token);
if self.tokens.0.len() == self.block_size {
let block = std::mem::take(&mut self.tokens);
let block_hash = compute_hash(cast_slice(&block));
let sequence_hash = compute_hash(bytemuck::cast_slice(&[
self.parent_sequence_hash.unwrap_or_default(),
block_hash,
]));
Some(TokenBlock {
tokens: block,
sequence_hash,
block_hash,
parent_sequence_hash: self.parent_sequence_hash,
})
Ok(())
}
/// Attempts to push multiple tokens onto the block from a [`Tokens`] object.
///
/// Tokens are added until the block is full or all input tokens are consumed.
///
/// # Arguments
///
/// * `tokens` - The [`Tokens`] to push.
///
/// # Returns
///
/// A new [`Tokens`] object containing any tokens that did not fit,
/// if all tokens were added, the returned object will be empty.
pub(crate) fn push_tokens(&mut self, tokens: Tokens) -> Tokens {
let remaining_space = self.remaining();
if remaining_space == 0 {
return tokens; // Block is already full
}
if tokens.0.len() <= remaining_space {
// All tokens fit
self.tokens.0.extend(tokens.0);
Tokens::default() // No remaining tokens
} else {
None
// Only some tokens fit
let (to_add, remaining) = tokens.0.split_at(remaining_space);
self.tokens.0.extend_from_slice(to_add);
Tokens(remaining.to_vec()) // Return the leftover tokens
}
}
/// Attempts to remove the last token from the block.
///
/// # Returns
///
/// * `Ok(())` - If a token was successfully removed.
/// * `Err(TokenBlockError::Empty)` - If the block was already empty.
pub(crate) fn pop_token(&mut self) -> Result<(), TokenBlockError> {
if self.tokens.0.is_empty() {
return Err(TokenBlockError::Empty);
}
self.tokens.0.pop();
Ok(())
}
/// Attempts to remove the last `count` tokens from the block.
///
/// # Arguments
///
/// * `count` - The number of tokens to remove.
///
/// # Returns
///
/// * `Ok(())` - If the specified number of tokens were successfully removed.
/// * `Err(TokenBlockError::InsufficientTokens)` - If `count` is greater than the number of tokens in the block.
pub(crate) fn pop_tokens(&mut self, count: usize) -> Result<(), TokenBlockError> {
if self.tokens.0.len() < count {
return Err(TokenBlockError::InsufficientTokens);
}
self.tokens.0.truncate(self.tokens.0.len() - count);
Ok(())
}
/// Attempts to commit the current partial block into a full [`TokenBlock`].
///
/// This operation consumes the tokens within the partial block.
/// After a successful commit, this `PartialTokenBlock` instance is reset
/// to represent the *next* partial block in the sequence, inheriting the
/// sequence hash from the block just committed.
///
/// # Returns
///
/// * `Ok(TokenBlock)` - The newly created full [`TokenBlock`].
/// * `Err(TokenBlockError::Incomplete)` - If the block does not contain exactly `block_size` tokens.
pub(crate) fn commit(&mut self) -> Result<TokenBlock, TokenBlockError> {
if self.tokens.0.len() != self.block_size {
// Check for exact size match for committing
return Err(TokenBlockError::Incomplete);
}
// Take ownership of the tokens, leaving the internal tokens empty
let tokens = std::mem::take(&mut self.tokens);
let chunk = TokenBlockChunk::new(tokens, self.salt_hash);
let block = TokenBlock::from_chunk(chunk, self.parent_sequence_hash);
// Reset self to be the next block in the sequence
self.parent_sequence_hash = Some(block.sequence_hash());
// self.tokens is already empty due to mem::take
// self.block_size and self.salt_hash remain the same
Ok(block)
}
/// Returns the number of additional tokens required to fill the block.
pub fn remaining(&self) -> usize {
// Use saturating_sub to prevent underflow if len somehow exceeds block_size
self.block_size.saturating_sub(self.tokens.0.len())
}
/// Returns the number of tokens currently in the block.
pub fn len(&self) -> usize {
self.tokens.0.len()
}
/// Returns `true` if the block contains no tokens.
pub fn is_empty(&self) -> bool {
self.tokens.0.is_empty()
}
/// Returns a reference to the tokens currently in the block.
pub fn tokens(&self) -> &Tokens {
&self.tokens
}
}
// Deref allows treating &PartialTokenBlock like &Tokens for read-only access.
impl std::ops::Deref for PartialTokenBlock {
type Target = Tokens;
......@@ -156,262 +355,1149 @@ impl std::ops::Deref for PartialTokenBlock {
}
}
#[derive(Debug, Clone, Getters, Default)]
pub struct TokenBlock {
/// An intermediate structure holding a chunk of tokens destined to become a [`TokenBlock`].
///
/// This calculates the [`BlockHash`] but does not compute the final [`SequenceHash`],
/// allowing chunks to be processed independently (e.g., in parallel).
#[derive(Debug)] // No Clone: temporary intermediate value
struct TokenBlockChunk {
tokens: Tokens,
#[getter(copy)]
salt_hash: SaltHash,
block_hash: BlockHash,
}
#[getter(copy)]
sequence_hash: SequenceHash,
impl TokenBlockChunk {
/// Creates a new chunk from [`Tokens`], calculating the [`BlockHash`].
fn new(tokens: Tokens, salt_hash: SaltHash) -> Self {
let block_hash = compute_hash_v2(cast_slice(&tokens), salt_hash);
Self {
tokens,
salt_hash,
block_hash,
}
}
/// Creates a new chunk from a slice of `&[Token]`, calculating the [`BlockHash`].
fn from_tokens(tokens: &[Token], salt_hash: SaltHash) -> Self {
let block_hash = compute_hash_v2(cast_slice(tokens), salt_hash);
Self {
tokens: tokens.into(), // Converts slice to owned Tokens
salt_hash,
block_hash,
}
}
}
#[getter(copy)]
/// Represents a completed, immutable block of tokens with associated hashes.
///
/// Contains exactly `block_size` tokens and includes the [`SaltHash`], [`BlockHash`],
/// [`SequenceHash`], and optionally the parent's [`SequenceHash`].
#[derive(Debug, Clone, Default, PartialEq)] // Add PartialEq for tests
pub struct TokenBlock {
tokens: Tokens,
salt_hash: SaltHash,
block_hash: BlockHash,
sequence_hash: SequenceHash,
parent_sequence_hash: Option<SequenceHash>,
}
pub struct TokenSequence {
impl TokenBlock {
/// Creates a new [`PartialTokenBlock`] representing the block immediately following this one.
///
/// The new partial block will have the correct `parent_sequence_hash` set.
pub fn next_block(&self) -> PartialTokenBlock {
PartialTokenBlock {
tokens: Tokens::default(),
block_size: self.tokens.len(), // Should be == self.block_size
salt_hash: self.salt_hash,
parent_sequence_hash: Some(self.sequence_hash), // Link to this block
}
}
/// Finalizes a [`TokenBlock`] from a [`TokenBlockChunk`] and the parent's sequence hash.
///
/// This computes the final [`SequenceHash`] for the block.
fn from_chunk(chunk: TokenBlockChunk, parent_sequence_hash: Option<SequenceHash>) -> Self {
let sequence_hash = match parent_sequence_hash {
Some(parent) => {
// Combine parent sequence hash and current block hash
compute_hash_v2(cast_slice(&[parent, chunk.block_hash]), chunk.salt_hash)
}
None => {
// First block: sequence hash is just the block hash
chunk.block_hash
}
};
Self {
tokens: chunk.tokens,
salt_hash: chunk.salt_hash,
block_hash: chunk.block_hash,
sequence_hash,
parent_sequence_hash,
}
}
/// Returns a reference to the tokens in this block.
pub fn tokens(&self) -> &Tokens {
&self.tokens
}
/// Returns the salt hash used for this block's hashing.
pub fn salt_hash(&self) -> SaltHash {
self.salt_hash
}
/// Returns the hash of only the tokens within this block.
pub fn block_hash(&self) -> BlockHash {
self.block_hash
}
/// Returns the sequence-aware hash for this block.
pub fn sequence_hash(&self) -> SequenceHash {
self.sequence_hash
}
/// Returns the sequence hash of the preceding block, if any.
pub fn parent_sequence_hash(&self) -> Option<SequenceHash> {
self.parent_sequence_hash
}
}
/// Represents a sequence of tokens, segmented into fixed-size, hashed blocks.
///
/// This structure manages a series of completed [`TokenBlock`]s and one
/// [`PartialTokenBlock`] for accumulating incoming tokens.
/// It provides methods for appending tokens (`append`, `extend`), removing tokens
/// (`pop`, `truncate`, `unwind`), and accessing sequence information.
///
/// Hashing incorporates an initial [`SaltHash`] to ensure uniqueness across different
/// contexts (e.g., different models, PEFTs).
///
/// Key Hashes:
/// - [`BlockHash`]: Hash of tokens within a single block (seeded by [`SaltHash`]).
/// - [`SequenceHash`]: Hash combining the previous block's [`SequenceHash`] and the current
/// block's [`BlockHash`] (also seeded by [`SaltHash`]).
#[derive(Debug)]
pub struct TokenBlockSequence {
blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock,
salt_hash: SaltHash,
}
impl TokenSequence {
pub fn new(tokens: Tokens, block_size: usize) -> Self {
let (blocks, current_block) = Self::split_tokens(tokens, block_size);
impl TokenBlockSequence {
/// Creates a new [`TokenBlockSequence`] from an initial set of tokens.
///
/// The tokens are split into blocks of `block_size`. Any remaining tokens
/// form the initial `current_block`.
///
/// # Arguments
///
/// * `tokens` - The initial [`Tokens`] for the sequence.
/// * `block_size` - The fixed size for each [`TokenBlock`]. Must be greater than 0.
/// * `salt_hash` - An optional [`SaltHash`]. Defaults to 0 if `None`.
///
/// # Panics
///
/// Panics if `block_size` is 0.
pub fn new(tokens: Tokens, block_size: usize, salt_hash: Option<SaltHash>) -> Self {
assert!(block_size > 0, "block_size must be greater than 0");
let salt_hash = salt_hash.unwrap_or(0);
let (blocks, current_block) = Self::split_tokens(&tokens, block_size, salt_hash);
Self {
blocks,
current_block,
salt_hash,
}
}
pub fn push_token(&mut self, token: Token) -> Option<&TokenBlock> {
if let Some(block) = self.current_block.push_token(token) {
self.blocks.push(block);
self.blocks.last()
/// Extends the sequence with the given tokens, potentially completing multiple blocks.
///
/// This method processes all tokens from the input [`Tokens`] object.
/// If adding tokens causes one or more blocks to become full, they are committed
/// and added to the internal list of completed blocks.
///
/// # Arguments
///
/// * `tokens` - The [`Tokens`] object containing the tokens to extend the sequence with.
///
/// # Returns
///
/// * `Ok(Some(Range<usize>))` - The range of indices in the `blocks` vector corresponding
/// to the blocks completed during this `extend` operation.
/// * `Ok(None)` - If no blocks were completed.
/// * `Err(TokenBlockError)` - If an internal error occurs during commit.
pub fn extend(&mut self, tokens: Tokens) -> Result<Option<Range<usize>>, TokenBlockError> {
let start_block_index = self.blocks.len();
let mut tokens_to_append = tokens;
while !tokens_to_append.is_empty() {
let remaining_in_current = self.current_block.remaining();
if remaining_in_current == 0 {
// Current block is full, commit it first
let new_block = self.current_block.commit()?;
self.blocks.push(new_block);
// Continue loop to add tokens to the *new* current_block
}
// Push as many tokens as possible into the current (potentially new) block
let available_tokens = tokens_to_append;
tokens_to_append = self.current_block.push_tokens(available_tokens);
// Check if the current block *became* full after pushing tokens
if self.current_block.remaining() == 0 && !tokens_to_append.is_empty() {
// If it became full AND there are still more tokens to append,
// commit it now so the next loop iteration starts with a fresh block.
let new_block = self.current_block.commit()?;
self.blocks.push(new_block);
}
// If it became full and there are NO more tokens, the loop will exit,
// and the block remains partial but full, ready for the next append/commit.
}
let end_block_index = self.blocks.len();
if start_block_index == end_block_index {
Ok(None) // No blocks were completed
} else {
None
Ok(Some(start_block_index..end_block_index))
}
}
/// Appends a single token to the sequence.
///
/// If adding this token completes the current partial block, the block is committed,
/// and the index of the newly completed block is returned.
///
/// This method is equivalent to calling [`extend`] with a single-token [`Tokens`] object.
///
/// # Arguments
///
/// * `token` - The [`Token`] to append.
///
/// # Returns
///
/// * `Ok(Some(usize))` - The index of the block that was just completed.
/// * `Ok(None)` - No block was completed by adding this token.
/// * `Err(TokenBlockError)` - If an internal error occurs during processing.
pub fn append(&mut self, token: Token) -> Result<Option<usize>, TokenBlockError> {
// Create a single-token Tokens object
let tokens = Tokens::from(vec![token]);
// Call extend
let range_option = self.extend(tokens)?;
// Convert the range to Option<usize>
match range_option {
None => Ok(None),
Some(range) => {
// Since we only added one token, the range can only be empty or have one element.
// If it's not empty, it must be `n..(n+1)`.
assert_eq!(range.len(), 1, "Appending a single token completed more than one block, which should be impossible.");
Ok(Some(range.start))
}
}
}
/// Shortens the sequence, keeping the first `len` tokens and removing the rest.
///
/// If `len` is greater than the sequence's current length, this has no effect.
///
/// This operation is analogous to `Vec::truncate`.
/// It may involve removing tokens from the current partial block, removing entire
/// completed blocks, and adjusting the current partial block
/// to reflect the new end of the sequence.
///
/// # Arguments
///
/// * `len` - The number of tokens to keep.
///
/// # Returns
///
/// * `Ok(())` - If the sequence was successfully truncated.
/// * `Err(TokenBlockError::InsufficientTokens)` - This error should ideally not occur if `len`
/// is correctly checked against `total_tokens`, but the underlying `pop_tokens` might return it.
pub fn truncate(&mut self, len: usize) -> Result<(), TokenBlockError> {
let current_total_len = self.total_tokens();
if len >= current_total_len {
return Ok(()); // Nothing to truncate
}
let n = current_total_len - len; // Number of tokens to remove
// This inner block handles the actual removal logic based on `n` tokens to remove.
{
let current_len = self.current_block.len();
// Avoid division by zero if block_size is somehow 0 (though asserted in new)
let block_size = self.current_block.block_size.max(1);
if n <= current_len {
// Only need to pop from the current partial block
self.current_block.pop_tokens(n)?;
} else {
// Need to pop from full blocks as well
let tokens_to_pop_from_blocks = n - current_len;
// Calculate how many blocks are affected (including the one partially popped)
let num_blocks_to_affect = tokens_to_pop_from_blocks.div_ceil(block_size);
// Check if we need to pop more blocks than available (should be prevented by initial len check)
if num_blocks_to_affect > self.blocks.len() {
// This indicates an inconsistency between total_tokens() and internal state.
debug_assert!(
false,
"Truncate calculation error: trying to pop too many blocks."
);
return Err(TokenBlockError::InsufficientTokens);
}
// Determine the index of the block that will be the source for the new partial block
let source_block_index = self.blocks.len() - num_blocks_to_affect;
// Calculate how many tokens to keep from that source block
let num_full_blocks_completely_popped = num_blocks_to_affect - 1;
let num_tokens_to_pop_from_source_block =
tokens_to_pop_from_blocks - num_full_blocks_completely_popped * block_size;
let num_tokens_to_keep_in_new_partial =
block_size.saturating_sub(num_tokens_to_pop_from_source_block);
// Get the tokens for the new partial block
let new_partial_tokens = if num_tokens_to_keep_in_new_partial > 0 {
self.blocks[source_block_index].tokens().as_ref()
[..num_tokens_to_keep_in_new_partial]
.to_vec()
} else {
Vec::new()
};
// Truncate the blocks vector to remove popped blocks
self.blocks.truncate(source_block_index);
// Update the current_block state
self.current_block.tokens = Tokens(new_partial_tokens);
// Correctly set the parent hash based on the *new* last block
self.current_block.parent_sequence_hash =
self.blocks.last().map(|b| b.sequence_hash());
// salt_hash and block_size remain the same for current_block
}
}
Ok(())
}
/// Removes the last `count` tokens from the sequence.
///
/// This is a convenience method that calculates the required length and calls [`truncate`].
///
/// # Arguments
///
/// * `count` - The number of tokens to remove from the end.
///
/// # Returns
///
/// * `Ok(())` - If the tokens were successfully removed.
/// * `Err(TokenBlockError::InsufficientTokens)` - If `count` is greater than or equal to
/// the total number of tokens in the sequence.
pub fn unwind(&mut self, count: usize) -> Result<(), TokenBlockError> {
let current_total_len = self.total_tokens();
if count > current_total_len {
// Allow count == current_total_len, which truncates to 0.
return Err(TokenBlockError::InsufficientTokens);
}
// number of tokens remaining in the sequence after undoing the given count
let len = current_total_len - count;
self.truncate(len)
}
/// Removes the last token from the sequence and returns it, or [`None`] if it is empty.
///
/// This operation is analogous to `Vec::pop`.
///
/// # Returns
///
/// * `Some(Token)` - The last token, if the sequence was not empty.
/// * `None` - If the sequence was empty.
pub fn pop(&mut self) -> Option<Token> {
let current_total_len = self.total_tokens();
if current_total_len == 0 {
return None;
}
// Determine the last token. It must be in the current_block if current_block is not empty.
// If current_block is empty, it must be the last token of the last full block.
let last_token = if !self.current_block.tokens.is_empty() {
// Last token is in the partial block
*self
.current_block
.tokens
.last()
.expect("Current block checked for non-empty")
} else {
// Current block is empty, sequence is not. Must be in the last full block.
let last_block = self
.blocks
.last()
.expect("Sequence is not empty but has no blocks and empty current block?");
*last_block
.tokens()
.last()
.expect("Last block cannot be empty")
};
// Truncate the sequence by one element.
// We expect this to succeed since we know the length > 0.
match self.truncate(current_total_len - 1) {
Ok(_) => Some(last_token),
Err(_) => {
// This should be logically impossible if total_tokens() and truncate() are correct.
// Panic in debug, return None in release as a fallback, though it indicates a bug.
debug_assert!(
false,
"truncate failed unexpectedly after checking length in pop"
);
None
}
}
}
/// Returns a slice containing all the completed [`TokenBlock`]s in the sequence.
pub fn blocks(&self) -> &[TokenBlock] {
&self.blocks
}
/// Returns a reference to the last completed [`TokenBlock`] in the sequence, if any.
pub fn last_complete_block(&self) -> Option<&TokenBlock> {
self.blocks.last()
}
/// Returns a reference to the current [`PartialTokenBlock`] where new tokens are added.
pub fn current_block(&self) -> &PartialTokenBlock {
&self.current_block
}
/// Consumes the sequence and returns its parts: a `Vec` of completed blocks and the final partial block.
pub fn into_parts(self) -> (Vec<TokenBlock>, PartialTokenBlock) {
(self.blocks, self.current_block)
}
pub fn split_tokens(tokens: Tokens, block_size: usize) -> (Vec<TokenBlock>, PartialTokenBlock) {
// Use rayon's parallel iterator to process chunks in parallel
let mut blocks: Vec<TokenBlock> = tokens
/// Returns the [`SaltHash`] used for this sequence.
pub fn salt_hash(&self) -> SaltHash {
self.salt_hash
}
/// Returns the total number of tokens in the sequence (sum of tokens in all completed blocks
/// plus tokens in the current partial block).
pub fn total_tokens(&self) -> usize {
let block_size = self.current_block.block_size;
(self.blocks.len() * block_size) + self.current_block.len()
}
/// Splits a [`Tokens`] object into a vector of completed blocks and a final partial block.
///
/// This is primarily used internally by [`TokenBlockSequence::new`] but can be used externally.
///
/// # Arguments
///
/// * `tokens` - The [`Tokens`] to split.
/// * `block_size` - The size of each block.
/// * `salt_hash` - The [`SaltHash`] to use for hashing.
///
/// # Returns
///
/// A tuple containing `(Vec<TokenBlock>, PartialTokenBlock)`.
///
/// # Panics
///
/// Panics if `block_size` is 0.
pub fn split_tokens(
tokens: &[Token],
block_size: usize,
salt_hash: u64,
) -> (Vec<TokenBlock>, PartialTokenBlock) {
assert!(block_size > 0, "block_size must be greater than 0");
// Use Rayon for parallel computation of block chunks (hashes)
let chunks: Vec<TokenBlockChunk> = tokens
.as_ref()
.par_chunks_exact(block_size)
.map(|chunk| TokenBlock {
tokens: chunk.to_vec().into(),
sequence_hash: 0,
block_hash: compute_hash(cast_slice(chunk)),
parent_sequence_hash: None,
})
.map(|chunk| TokenBlockChunk::from_tokens(chunk, salt_hash))
.collect();
blocks[0].sequence_hash = blocks[0].block_hash;
let mut result_blocks = Vec::with_capacity(chunks.len());
let mut last_sequence_hash: Option<SequenceHash> = None;
// compute the sequence hash for each block
// this is the sequence hash of the previous block with the current block's hash
for i in 1..blocks.len() {
let previous_block = &blocks[i - 1];
let parent_sequence_hash = previous_block.sequence_hash;
let vals = &[parent_sequence_hash, blocks[i].block_hash];
blocks[i].sequence_hash = compute_hash(bytemuck::cast_slice(vals));
blocks[i].parent_sequence_hash = Some(parent_sequence_hash);
// Sequentially combine chunks to compute sequence hashes
for chunk in chunks {
let new_block = TokenBlock::from_chunk(chunk, last_sequence_hash);
last_sequence_hash = Some(new_block.sequence_hash());
result_blocks.push(new_block);
}
let remainder = tokens.chunks_exact(block_size).remainder();
// Handle any remaining tokens
let remainder = tokens.as_ref().chunks_exact(block_size).remainder();
let next_block = PartialTokenBlock {
let current_block = PartialTokenBlock {
tokens: remainder.into(),
block_size,
parent_sequence_hash: blocks.last().map(|b| b.sequence_hash),
salt_hash,
// Parent hash is the sequence hash of the last *full* block computed
parent_sequence_hash: last_sequence_hash,
};
(blocks, next_block)
(result_blocks, current_block)
}
}
impl PartialEq<Vec<Token>> for Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
#[cfg(test)]
mod tests {
use super::*;
use bytemuck::cast_slice;
// Helper to create a sequence for testing
fn create_test_sequence(
initial_tokens: &[Token],
block_size: usize,
salt_hash: Option<SaltHash>,
) -> TokenBlockSequence {
TokenBlockSequence::new(Tokens::from(initial_tokens), block_size, salt_hash)
}
}
impl PartialEq<Tokens> for Vec<Token> {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0
}
}
// Helper to get expected hashes (replace with actual calculated values if needed)
const TEST_SALT_HASH: SaltHash = 1337;
const HASH_1_4: BlockHash = 14643705804678351452; // hash([1,2,3,4], 1337)
const SEQ_HASH_1_4: SequenceHash = HASH_1_4;
const HASH_5_8: BlockHash = 16777012769546811212; // hash([5,6,7,8], 1337)
const SEQ_HASH_5_8: SequenceHash = 4945711292740353085; // hash([SEQ_HASH_1_4, HASH_5_8], 1337)
const HASH_9_12: BlockHash = 483935686894639516; // hash([9,10,11,12], 1337)
const SEQ_HASH_9_12: SequenceHash = 12583592247330656132; // hash([SEQ_HASH_5_8, HASH_9_12], 1337)
impl PartialEq<[Token]> for Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
#[test]
fn test_validate_hash_constants() {
let salt = TEST_SALT_HASH;
// Block 1: [1, 2, 3, 4]
let tokens_1_4 = &[1u32, 2, 3, 4];
let computed_hash_1_4 = compute_hash_v2(cast_slice(tokens_1_4), salt);
assert_eq!(computed_hash_1_4, HASH_1_4, "Mismatch for HASH_1_4");
// First block's sequence hash is its block hash
assert_eq!(computed_hash_1_4, SEQ_HASH_1_4, "Mismatch for SEQ_HASH_1_4");
// Block 2: [5, 6, 7, 8]
let tokens_5_8 = &[5u32, 6, 7, 8];
let computed_hash_5_8 = compute_hash_v2(cast_slice(tokens_5_8), salt);
assert_eq!(computed_hash_5_8, HASH_5_8, "Mismatch for HASH_5_8");
let computed_seq_hash_5_8 = compute_hash_v2(cast_slice(&[SEQ_HASH_1_4, HASH_5_8]), salt);
assert_eq!(
computed_seq_hash_5_8, SEQ_HASH_5_8,
"Mismatch for SEQ_HASH_5_8"
);
// Block 3: [9, 10, 11, 12]
let tokens_9_12 = &[9u32, 10, 11, 12];
let computed_hash_9_12 = compute_hash_v2(cast_slice(tokens_9_12), salt);
assert_eq!(computed_hash_9_12, HASH_9_12, "Mismatch for HASH_9_12");
let computed_seq_hash_9_12 = compute_hash_v2(cast_slice(&[SEQ_HASH_5_8, HASH_9_12]), salt);
assert_eq!(
computed_seq_hash_9_12, SEQ_HASH_9_12,
"Mismatch for SEQ_HASH_9_12"
);
}
}
impl PartialEq<Tokens> for &[Token] {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0.as_slice()
}
}
#[test]
fn test_tokens_from() {
let vec_u32: Vec<u32> = vec![1, 2, 3];
let tokens_u32: Tokens = vec_u32.clone().into();
assert_eq!(tokens_u32.0, vec_u32);
impl PartialEq<Vec<Token>> for &Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
let slice_u32: &[u32] = &[4, 5];
let tokens_slice_u32: Tokens = slice_u32.into();
assert_eq!(tokens_slice_u32.0, vec![4, 5]);
impl<'a> PartialEq<&'a Tokens> for Vec<Token> {
fn eq(&self, other: &&'a Tokens) -> bool {
*self == other.0
}
}
let vec_i32: Vec<i32> = vec![-1, 0, 1]; // Note: -1 becomes large u32
let tokens_i32: Tokens = vec_i32.into();
assert_eq!(tokens_i32.0, vec![u32::MAX, 0, 1]);
impl PartialEq<[Token]> for &Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
let slice_i32: &[i32] = &[100, 200];
let tokens_slice_i32: Tokens = slice_i32.into();
assert_eq!(tokens_slice_i32.0, vec![100, 200]);
impl<'a> PartialEq<&'a [Token]> for Tokens {
fn eq(&self, other: &&'a [Token]) -> bool {
self.0.as_slice() == *other
let into_vec: Vec<u32> = tokens_slice_i32.into();
assert_eq!(into_vec, vec![100, 200]);
}
}
impl PartialEq for Tokens {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
#[test]
fn test_tokens_equality() {
let tokens = Tokens::from(vec![1, 2, 3]);
assert_eq!(tokens, vec![1, 2, 3]);
assert_eq!(vec![1, 2, 3], tokens);
assert_eq!(tokens, &[1, 2, 3][..]);
assert_eq!(&[1, 2, 3][..], tokens);
assert_eq!(tokens, Tokens::from(vec![1, 2, 3]));
assert_ne!(tokens, Tokens::from(vec![1, 2, 4]));
}
}
impl Eq for Tokens {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokens_slice_operations() {
let tokens = Tokens(vec![1, 2, 3, 4, 5]);
// Test AsRef<[Token]>
let slice: &[Token] = tokens.as_ref();
assert_eq!(slice, &[1, 2, 3, 4, 5]);
// Test Deref
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0], 1);
assert_eq!(tokens[4], 5);
// Test iteration
let sum: u32 = tokens.iter().sum();
assert_eq!(sum, 15);
// Test slicing
let slice = &tokens[1..4];
assert_eq!(slice, &[2, 3, 4]);
fn test_tokens_deref_asref() {
let tokens = Tokens::from(vec![10, 20, 30]);
// Deref to &[Token]
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[1], 20);
let slice: &[Token] = &tokens;
assert_eq!(slice, &[10, 20, 30]);
// AsRef<[Token]>
let as_ref_slice: &[Token] = tokens.as_ref();
assert_eq!(as_ref_slice, &[10, 20, 30]);
// Borrow<[Token]>
let borrowed_slice: &[Token] = std::borrow::Borrow::borrow(&tokens);
assert_eq!(borrowed_slice, &[10, 20, 30]);
}
// Test Borrow
let borrowed: &[Token] = std::borrow::Borrow::borrow(&tokens);
assert_eq!(borrowed, &[1, 2, 3, 4, 5]);
#[test]
fn test_tokens_into_sequence() {
let tokens = Tokens::from(vec![1, 2, 3, 4, 5]);
let seq = tokens.into_sequence(3, Some(TEST_SALT_HASH));
assert_eq!(seq.blocks().len(), 1);
assert_eq!(seq.blocks[0].tokens().as_ref(), &[1, 2, 3]);
assert_eq!(seq.current_block().tokens().as_ref(), &[4, 5]);
assert_eq!(seq.salt_hash(), TEST_SALT_HASH);
}
// Test with functions that accept &[Token]
fn takes_slice(slice: &[Token]) -> usize {
slice.len()
}
#[test]
fn test_partial_block_ops() {
let mut partial = PartialTokenBlock::create_sequence_root(3, TEST_SALT_HASH);
assert_eq!(partial.len(), 0);
assert_eq!(partial.remaining(), 3);
assert!(partial.is_empty());
// Push tokens
assert!(partial.push_token(1).is_ok());
assert_eq!(partial.len(), 1);
assert_eq!(partial.remaining(), 2);
let remaining = partial.push_tokens(Tokens::from(vec![2, 3, 4]));
assert_eq!(partial.len(), 3);
assert_eq!(partial.remaining(), 0);
assert_eq!(remaining.as_ref(), &[4]); // Token 4 didn't fit
assert_eq!(partial.tokens().as_ref(), &[1, 2, 3]);
// Push when full
assert_eq!(partial.push_token(5), Err(TokenBlockError::Full));
let remaining_full = partial.push_tokens(Tokens::from(vec![5]));
assert_eq!(remaining_full.as_ref(), &[5]);
// Pop tokens
assert!(partial.pop_token().is_ok());
assert_eq!(partial.len(), 2);
assert_eq!(partial.tokens().as_ref(), &[1, 2]);
assert!(partial.pop_tokens(2).is_ok());
assert!(partial.is_empty());
// Pop when empty
assert_eq!(partial.pop_token(), Err(TokenBlockError::Empty));
assert_eq!(
partial.pop_tokens(1),
Err(TokenBlockError::InsufficientTokens)
);
// Commit incomplete
assert!(partial.push_token(10).is_ok());
assert_eq!(partial.commit(), Err(TokenBlockError::Incomplete));
// Commit complete
assert!(partial.push_token(11).is_ok());
assert!(partial.push_token(12).is_ok());
assert_eq!(partial.len(), 3);
let commit_result = partial.commit();
assert!(commit_result.is_ok());
let committed_block = commit_result.unwrap();
assert_eq!(committed_block.tokens().as_ref(), &[10, 11, 12]);
// Check state after commit (partial block is now the next one)
assert!(partial.is_empty());
assert_eq!(
partial.parent_sequence_hash,
Some(committed_block.sequence_hash())
);
assert_eq!(partial.block_size, 3);
}
assert_eq!(takes_slice(&tokens), 5);
#[test]
fn test_token_block_creation_and_hashes() {
let salt = TEST_SALT_HASH;
let tokens1 = Tokens::from(vec![1, 2, 3, 4]);
let chunk1 = TokenBlockChunk::new(tokens1.clone(), salt);
let block1 = TokenBlock::from_chunk(chunk1, None);
assert_eq!(block1.tokens(), &tokens1);
assert_eq!(block1.salt_hash(), salt);
assert_eq!(block1.parent_sequence_hash(), None);
assert_eq!(block1.block_hash(), HASH_1_4);
assert_eq!(block1.sequence_hash(), SEQ_HASH_1_4); // First block seq_hash == block_hash
let tokens2 = Tokens::from(vec![5, 6, 7, 8]);
let chunk2 = TokenBlockChunk::new(tokens2.clone(), salt);
let block2 = TokenBlock::from_chunk(chunk2, block1.parent_sequence_hash()); // Incorrect parent
// Sequence hash should differ if parent is wrong
assert_ne!(block2.sequence_hash(), SEQ_HASH_5_8);
let chunk2_correct = TokenBlockChunk::new(tokens2.clone(), salt);
let block2_correct = TokenBlock::from_chunk(chunk2_correct, Some(block1.sequence_hash()));
assert_eq!(block2_correct.tokens(), &tokens2);
assert_eq!(block2_correct.salt_hash(), salt);
assert_eq!(
block2_correct.parent_sequence_hash(),
Some(block1.sequence_hash())
);
assert_eq!(block2_correct.block_hash(), HASH_5_8);
assert_eq!(block2_correct.sequence_hash(), SEQ_HASH_5_8);
}
#[test]
fn test_tokens_conversions() {
// Test From<Vec<Token>> for Tokens
let vec = vec![1, 2, 3, 4, 5];
let tokens: Tokens = vec.clone().into();
assert_eq!(tokens.0, vec);
fn test_new_sequence() {
// Empty initial tokens
let seq_empty = create_test_sequence(&[], 4, Some(TEST_SALT_HASH));
assert!(seq_empty.blocks().is_empty());
assert!(seq_empty.current_block().is_empty());
assert_eq!(seq_empty.total_tokens(), 0);
assert_eq!(seq_empty.salt_hash(), TEST_SALT_HASH);
assert_eq!(seq_empty.current_block().parent_sequence_hash, None);
// Less than one block
let seq_partial = create_test_sequence(&[1, 2], 4, Some(TEST_SALT_HASH));
assert!(seq_partial.blocks().is_empty());
assert_eq!(seq_partial.current_block().tokens().as_ref(), &[1, 2]);
assert_eq!(seq_partial.total_tokens(), 2);
assert_eq!(seq_partial.current_block().parent_sequence_hash, None);
// Exactly one block
let seq_one_block = create_test_sequence(&[1, 2, 3, 4], 4, Some(TEST_SALT_HASH));
assert_eq!(seq_one_block.blocks().len(), 1);
assert!(seq_one_block.current_block().is_empty());
assert_eq!(seq_one_block.total_tokens(), 4);
assert_eq!(seq_one_block.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq_one_block.blocks[0].sequence_hash(), SEQ_HASH_1_4);
assert_eq!(
seq_one_block.current_block().parent_sequence_hash,
Some(SEQ_HASH_1_4)
);
// More than one block
let seq_multi = create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8, 9], 4, Some(TEST_SALT_HASH));
assert_eq!(seq_multi.blocks().len(), 2);
assert_eq!(seq_multi.current_block().tokens().as_ref(), &[9]);
assert_eq!(seq_multi.total_tokens(), 9);
assert_eq!(seq_multi.blocks[0].sequence_hash(), SEQ_HASH_1_4);
assert_eq!(seq_multi.blocks[1].sequence_hash(), SEQ_HASH_5_8);
assert_eq!(
seq_multi.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
);
// No salt hash
let seq_no_salt = create_test_sequence(&[1, 2, 3, 4, 5], 4, None);
assert_eq!(seq_no_salt.salt_hash(), 0);
assert_eq!(seq_no_salt.blocks().len(), 1);
assert_ne!(seq_no_salt.blocks[0].block_hash(), HASH_1_4); // Hash differs with salt 0
assert_eq!(seq_no_salt.current_block().tokens().as_ref(), &[5]);
}
// Test Into<Vec<Token>> for Tokens
let tokens = Tokens(vec![6, 7, 8, 9, 10]);
let vec: Vec<Token> = tokens.into();
assert_eq!(vec, vec![6, 7, 8, 9, 10]);
#[test]
#[should_panic]
fn test_new_sequence_zero_block_size() {
let _ = create_test_sequence(&[1], 0, None);
}
// Test From<&[Token]> for Tokens
let slice: &[Token] = &[11, 12, 13];
let tokens: Tokens = slice.into();
assert_eq!(tokens.0, vec![11, 12, 13]);
#[test]
fn test_append_single_token() {
let mut sequence =
create_test_sequence(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 4, Some(TEST_SALT_HASH));
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().tokens.len(), 2);
assert_eq!(sequence.current_block().tokens, vec![9, 10]);
assert_eq!(
sequence.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
);
// Append token 11 - should not complete a block
let completed_idx = sequence.append(11).unwrap();
assert_eq!(completed_idx, None);
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().tokens.as_ref(), &[9, 10, 11]);
// Append token 12 - should complete block 2 (index 2)
let completed_idx = sequence.append(12).unwrap();
assert_eq!(completed_idx, None); // Lazy commit: extend returns None
assert_eq!(sequence.blocks().len(), 2); // Block 2 not added yet
assert_eq!(sequence.current_block.tokens.as_ref(), &[9, 10, 11, 12]); // Current block is now full
assert_eq!(sequence.current_block.remaining(), 0);
assert_eq!(
sequence.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
); // Still linked to block 1
// Append token 13 - should not complete a block
// NOW appending 13 should first commit block 2, then add 13 to the new current
let completed_idx_13 = sequence.append(13).unwrap();
assert_eq!(completed_idx_13, Some(2)); // Block 2 (index 2) was completed by this append
assert_eq!(sequence.blocks.len(), 3); // Now 3 blocks committed
assert_eq!(sequence.blocks[2].tokens().as_ref(), &[9, 10, 11, 12]); // Verify committed block 2
assert_eq!(sequence.blocks[2].sequence_hash(), SEQ_HASH_9_12);
assert_eq!(sequence.current_block.tokens.as_ref(), &[13]); // New current block has 13
assert_eq!(sequence.current_block.remaining(), 3);
assert_eq!(
sequence.current_block.parent_sequence_hash,
Some(SEQ_HASH_9_12)
); // Linked to new block 2
}
// Test From<Vec<i32>> for Tokens
let i32_values = vec![100_i32, 200_i32, 300_i32];
let tokens: Tokens = i32_values.into();
assert_eq!(tokens.0, vec![100, 200, 300]);
#[test]
fn test_extend() {
let block_size = 4;
let salt_hash = Some(TEST_SALT_HASH);
// Case 1: Extend less than block size
let mut seq1 = create_test_sequence(&[], block_size, salt_hash);
let tokens1 = Tokens::from(vec![1, 2]);
let completed1 = seq1.extend(tokens1).unwrap();
assert_eq!(completed1, None); // No blocks completed
assert_eq!(seq1.blocks.len(), 0);
assert_eq!(seq1.current_block.tokens.as_ref(), &[1, 2]);
assert_eq!(seq1.current_block.remaining(), 2);
// Case 2: Extend exactly block size
let mut seq2 = create_test_sequence(&[], block_size, salt_hash);
let tokens2 = Tokens::from(vec![1, 2, 3, 4]);
let completed2 = seq2.extend(tokens2).unwrap();
assert_eq!(completed2, None); // Block is full but not committed yet
assert_eq!(seq2.blocks.len(), 0); // No blocks committed
assert_eq!(seq2.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full
assert_eq!(seq2.current_block.remaining(), 0);
assert_eq!(seq2.current_block.parent_sequence_hash, None); // Still the root block
// Case 3: Extend more than block size, less than two blocks
let mut seq3 = create_test_sequence(&[], block_size, salt_hash);
let tokens3 = Tokens::from(vec![1, 2, 3, 4, 5, 6]);
let completed3 = seq3.extend(tokens3).unwrap();
assert_eq!(completed3, Some(0..1)); // Block at index 0 completed
assert_eq!(seq3.blocks.len(), 1);
assert_eq!(seq3.current_block.tokens.as_ref(), &[5, 6]); // Partial block has remainder
assert_eq!(seq3.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq3.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4));
assert_eq!(seq3.current_block.remaining(), 2);
// Case 4: Extend exactly two blocks
let mut seq4 = create_test_sequence(&[], block_size, salt_hash);
let tokens4 = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
let completed4 = seq4.extend(tokens4).unwrap();
assert_eq!(completed4, Some(0..1)); // Only block 0 is committed
assert_eq!(seq4.blocks.len(), 1); // Only 1 block committed
assert_eq!(seq4.current_block.tokens.as_ref(), &[5, 6, 7, 8]); // Current block holds the second block's tokens
assert_eq!(seq4.current_block.remaining(), 0); // Current block is full
assert_eq!(seq4.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq4.blocks[0].sequence_hash(), SEQ_HASH_1_4);
assert_eq!(seq4.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4)); // Parent is the first block
// Case 5: Extend multiple times, completing blocks across calls
let mut seq5 = create_test_sequence(&[], block_size, salt_hash);
let tokens5a = Tokens::from(vec![1, 2]);
let completed5a = seq5.extend(tokens5a).unwrap();
assert_eq!(completed5a, None);
assert_eq!(seq5.blocks.len(), 0);
assert_eq!(seq5.current_block.tokens.as_ref(), &[1, 2]);
let tokens5b = Tokens::from(vec![3, 4, 5]);
let completed5b = seq5.extend(tokens5b).unwrap();
assert_eq!(completed5b, Some(0..1)); // Block at index 0 completed
assert_eq!(seq5.blocks.len(), 1);
assert_eq!(seq5.current_block.tokens.as_ref(), &[5]);
assert_eq!(seq5.blocks[0].tokens().as_ref(), &[1, 2, 3, 4]);
assert_eq!(seq5.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4));
assert_eq!(seq5.current_block.remaining(), 3);
let tokens5c = Tokens::from(vec![6, 7, 8, 9, 10]);
let completed5c = seq5.extend(tokens5c).unwrap();
assert_eq!(completed5c, Some(1..2)); // Block at index 1 completed
assert_eq!(seq5.blocks.len(), 2);
assert_eq!(seq5.current_block.tokens.as_ref(), &[9, 10]);
assert_eq!(seq5.blocks[1].tokens().as_ref(), &[5, 6, 7, 8]);
assert_eq!(seq5.current_block.parent_sequence_hash, Some(SEQ_HASH_5_8));
assert_eq!(seq5.current_block.remaining(), 2);
// Case 6: Extend empty tokens
let mut seq6 = create_test_sequence(&[1], block_size, salt_hash);
let completed6 = seq6.extend(Tokens::default()).unwrap();
assert_eq!(completed6, None);
assert_eq!(seq6.blocks.len(), 0);
assert_eq!(seq6.current_block.tokens.as_ref(), &[1]);
assert_eq!(seq6.total_tokens(), 1);
// Case 7: Extend fills current exactly, no remainder
let mut seq7 = create_test_sequence(&[1, 2], block_size, salt_hash);
let tokens7 = Tokens::from(vec![3, 4]);
let completed7 = seq7.extend(tokens7).unwrap();
assert_eq!(completed7, None); // Block is full but not committed yet
assert_eq!(seq7.blocks.len(), 0);
assert_eq!(seq7.current_block.tokens.as_ref(), &[1, 2, 3, 4]); // Current block is full
assert_eq!(seq7.current_block.remaining(), 0);
assert_eq!(seq7.total_tokens(), 4);
assert_eq!(seq7.current_block.parent_sequence_hash, None); // Still the root block
}
// Test From<&[i32]> for Tokens
let i32_slice: &[i32] = &[400_i32, 500_i32, 600_i32];
let tokens: Tokens = i32_slice.into();
assert_eq!(tokens.0, vec![400, 500, 600]);
#[test]
fn test_truncate() {
let block_size = 4;
let salt_hash = Some(TEST_SALT_HASH);
let initial_tokens = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // 10 tokens
// Case 1: Truncate within current block (len 9)
let mut seq1 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq1.truncate(9).is_ok());
assert_eq!(seq1.total_tokens(), 9);
assert_eq!(seq1.blocks().len(), 2);
assert_eq!(seq1.current_block().tokens.as_ref(), &[9]);
assert_eq!(
seq1.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
);
// Case 2: Truncate to exact block boundary (len 8)
let mut seq2 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq2.truncate(8).is_ok());
assert_eq!(seq2.total_tokens(), 8);
assert_eq!(seq2.blocks().len(), 2);
assert!(seq2.current_block().tokens.is_empty());
assert_eq!(
seq2.current_block().parent_sequence_hash,
Some(SEQ_HASH_5_8)
);
// Case 3: Truncate into last full block (len 7)
let mut seq3 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq3.truncate(7).is_ok());
assert_eq!(seq3.total_tokens(), 7);
assert_eq!(seq3.blocks().len(), 1); // Block [5,6,7,8] removed conceptually
assert_eq!(seq3.current_block().tokens.as_ref(), &[5, 6, 7]); // Kept 3 from [5,6,7,8]
assert_eq!(
seq3.current_block().parent_sequence_hash,
Some(SEQ_HASH_1_4)
); // Parent is hash of [1,2,3,4]
assert_eq!(seq3.blocks()[0].tokens().as_ref(), &[1, 2, 3, 4]);
// Case 4: Truncate removing full block(s) exactly (len 4)
let mut seq4 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq4.truncate(4).is_ok());
assert_eq!(seq4.total_tokens(), 4);
assert_eq!(seq4.blocks().len(), 1); // Block [5,6,7,8] removed
assert!(seq4.current_block().tokens.is_empty()); // New partial based on block [1,2,3,4]
assert_eq!(
seq4.current_block().parent_sequence_hash,
Some(SEQ_HASH_1_4)
);
assert_eq!(seq4.blocks()[0].tokens().as_ref(), &[1, 2, 3, 4]);
// Case 5: Truncate into first block (len 3)
let mut seq5 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq5.truncate(3).is_ok());
assert_eq!(seq5.total_tokens(), 3);
assert!(seq5.blocks().is_empty()); // Both blocks removed conceptually
assert_eq!(seq5.current_block().tokens.as_ref(), &[1, 2, 3]); // Kept 3 from [1,2,3,4]
assert_eq!(seq5.current_block().parent_sequence_hash, None); // No parent
// Case 6: Truncate to zero length (len 0)
let mut seq6 = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq6.truncate(0).is_ok());
assert_eq!(seq6.total_tokens(), 0);
assert!(seq6.blocks().is_empty());
assert!(seq6.current_block().tokens.is_empty());
assert_eq!(seq6.current_block().parent_sequence_hash, None);
// Case 7: Truncate to length greater than current (len 11)
let mut seq7 = create_test_sequence(initial_tokens, block_size, salt_hash);
let original_state = (seq7.blocks.clone(), seq7.current_block.tokens.clone()); // Clone for state check
assert!(seq7.truncate(11).is_ok()); // Should have no effect
assert_eq!(seq7.total_tokens(), 10);
assert_eq!(seq7.blocks, original_state.0);
assert_eq!(seq7.current_block.tokens, original_state.1);
// Case 8: Truncate to current length (len 10)
let mut seq8 = create_test_sequence(initial_tokens, block_size, salt_hash);
let original_state = (seq8.blocks.clone(), seq8.current_block.tokens.clone());
assert!(seq8.truncate(10).is_ok());
assert_eq!(seq8.total_tokens(), 10);
assert_eq!(seq8.blocks, original_state.0);
assert_eq!(seq8.current_block.tokens, original_state.1);
// Case 9: Truncate an empty sequence to 0
let mut seq9 = create_test_sequence(&[], block_size, salt_hash);
assert!(seq9.truncate(0).is_ok());
assert_eq!(seq9.total_tokens(), 0);
assert!(seq9.blocks().is_empty());
assert!(seq9.current_block().tokens.is_empty());
// Case 10: Truncate on exact block boundary when current is empty (len 4)
let tokens10 = &[1, 2, 3, 4, 5, 6, 7, 8]; // 8 tokens
let mut seq10 = create_test_sequence(tokens10, block_size, salt_hash);
assert_eq!(seq10.total_tokens(), 8);
assert!(seq10.current_block().is_empty());
assert!(seq10.truncate(4).is_ok()); // Remove block [5, 6, 7, 8]
assert_eq!(seq10.total_tokens(), 4);
assert_eq!(seq10.blocks().len(), 1);
assert!(seq10.current_block().tokens.is_empty());
assert_eq!(
seq10.current_block().parent_sequence_hash,
Some(SEQ_HASH_1_4)
);
// Case 11: Truncate into first block when current is empty (len 3)
let tokens11 = &[1, 2, 3, 4, 5, 6, 7, 8]; // 8 tokens
let mut seq11 = create_test_sequence(tokens11, block_size, salt_hash);
assert!(seq11.truncate(3).is_ok()); // Pop block [5,6,7,8] + 1 from [1,2,3,4]
assert_eq!(seq11.total_tokens(), 3);
assert!(seq11.blocks().is_empty());
assert_eq!(seq11.current_block().tokens.as_ref(), &[1, 2, 3]); // Kept 3 from [1,2,3,4]
assert_eq!(seq11.current_block().parent_sequence_hash, None);
}
#[test]
fn test_tokens_blocks() {
let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let sequence = TokenSequence::new(tokens, 4);
fn test_unwind() {
let block_size = 4;
let salt_hash = Some(TEST_SALT_HASH);
let initial_tokens = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // 10 tokens
// Unwind 0
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq.unwind(0).is_ok());
assert_eq!(seq.total_tokens(), 10);
// Unwind 1
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq.unwind(1).is_ok());
assert_eq!(seq.total_tokens(), 9);
assert_eq!(seq.current_block.tokens.as_ref(), &[9]);
// Unwind 3 (crosses boundary)
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq.unwind(3).is_ok());
assert_eq!(seq.total_tokens(), 7);
assert_eq!(seq.blocks.len(), 1);
assert_eq!(seq.current_block.tokens.as_ref(), &[5, 6, 7]);
// Unwind all (10)
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
assert!(seq.unwind(10).is_ok());
assert_eq!(seq.total_tokens(), 0);
assert!(seq.blocks.is_empty());
assert!(seq.current_block.is_empty());
// Unwind more than available (11)
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
assert_eq!(seq.unwind(11), Err(TokenBlockError::InsufficientTokens));
assert_eq!(seq.total_tokens(), 10); // State unchanged
// Unwind from empty
let mut seq_empty = create_test_sequence(&[], block_size, salt_hash);
assert_eq!(
seq_empty.unwind(1),
Err(TokenBlockError::InsufficientTokens)
);
}
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().len(), 2);
#[test]
fn test_pop() {
let block_size = 4;
let salt_hash = Some(TEST_SALT_HASH);
let initial_tokens = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // 10 tokens
let mut seq = create_test_sequence(initial_tokens, block_size, salt_hash);
// Pop 10
assert_eq!(seq.pop(), Some(10));
assert_eq!(seq.total_tokens(), 9);
assert_eq!(seq.current_block.tokens.as_ref(), &[9]);
assert_eq!(seq.blocks.len(), 2);
// Pop 9
assert_eq!(seq.pop(), Some(9));
assert_eq!(seq.total_tokens(), 8);
assert!(seq.current_block.is_empty());
assert_eq!(seq.blocks.len(), 2);
assert_eq!(seq.current_block.parent_sequence_hash, Some(SEQ_HASH_5_8));
// Pop 8 (crosses boundary)
assert_eq!(seq.pop(), Some(8));
assert_eq!(seq.total_tokens(), 7);
assert_eq!(seq.current_block.tokens.as_ref(), &[5, 6, 7]);
assert_eq!(seq.blocks.len(), 1);
assert_eq!(seq.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4));
// Pop remaining partial (7, 6, 5)
assert_eq!(seq.pop(), Some(7));
assert_eq!(seq.pop(), Some(6));
assert_eq!(seq.pop(), Some(5));
assert_eq!(seq.total_tokens(), 4);
assert!(seq.current_block.is_empty());
assert_eq!(seq.blocks.len(), 1);
assert_eq!(seq.current_block.parent_sequence_hash, Some(SEQ_HASH_1_4));
// Pop 4 (crosses boundary)
assert_eq!(seq.pop(), Some(4));
assert_eq!(seq.total_tokens(), 3);
assert_eq!(seq.current_block.tokens.as_ref(), &[1, 2, 3]);
assert!(seq.blocks.is_empty());
assert_eq!(seq.current_block.parent_sequence_hash, None);
// Pop 3, 2, 1
assert_eq!(seq.pop(), Some(3));
assert_eq!(seq.pop(), Some(2));
assert_eq!(seq.pop(), Some(1));
assert_eq!(seq.total_tokens(), 0);
assert!(seq.current_block.is_empty());
assert!(seq.blocks.is_empty());
// Pop from empty
assert_eq!(seq.pop(), None);
assert_eq!(seq.total_tokens(), 0);
}
assert_eq!(sequence.blocks()[0].tokens(), vec![1, 2, 3, 4]);
assert_eq!(sequence.blocks()[0].block_hash(), 14643705804678351452);
assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
println!("blocks[0]: {:?}", sequence.blocks()[0]);
#[test]
fn test_total_tokens() {
let block_size = 3;
let salt_hash = Some(TEST_SALT_HASH);
assert_eq!(sequence.blocks()[1].tokens(), vec![5, 6, 7, 8]);
assert_eq!(sequence.blocks()[1].block_hash(), 16777012769546811212);
assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
println!("blocks[1]: {:?}", sequence.blocks()[1]);
let mut seq = create_test_sequence(&[], block_size, salt_hash);
assert_eq!(seq.total_tokens(), 0);
assert_eq!(sequence.current_block().tokens(), vec![9, 10]);
seq.extend(Tokens::from(vec![1, 2])).unwrap();
assert_eq!(seq.total_tokens(), 2);
let mut sequence = sequence;
seq.append(3).unwrap(); // Completes block 0
assert_eq!(seq.total_tokens(), 3);
let new_block = sequence.push_token(11);
assert!(new_block.is_none());
assert_eq!(sequence.blocks().len(), 2);
seq.extend(Tokens::from(vec![4, 5, 6, 7])).unwrap(); // Completes block 1, partial [7]
assert_eq!(seq.total_tokens(), 7);
let new_block = sequence.push_token(12);
assert!(new_block.is_some());
assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.current_block().tokens().len(), 0);
println!("blocks[2]: {:?}", sequence.blocks()[2]);
seq.pop().unwrap(); // Removes 7
assert_eq!(seq.total_tokens(), 6);
let (blocks, mut current_block) = sequence.into_parts();
seq.truncate(4).unwrap(); // Keep [1,2,3,4]
assert_eq!(seq.total_tokens(), 4);
let new_block = current_block.push_token(13);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 1);
seq.unwind(2).unwrap(); // Keep [1,2]
assert_eq!(seq.total_tokens(), 2);
}
let new_block = current_block.push_token(14);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 2);
#[test]
fn test_push_tokens_partial_block() {
let mut partial = PartialTokenBlock::create_sequence_root(4, 1337);
let new_block = current_block.push_token(15);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 3);
let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let new_block = current_block.push_token(16);
assert!(new_block.is_some());
assert_eq!(blocks.len(), 3);
assert_eq!(current_block.tokens().len(), 0);
let remaining = partial.push_tokens(tokens);
assert_eq!(partial.tokens.len(), 4);
assert_eq!(remaining.len(), 6);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment