"examples/backends/trtllm/vscode:/vscode.git/clone" did not exist on "8bd37c96d6899b321730c2433c12fe5d1748b654"
Unverified Commit 76042959 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat(kvbm): add testing utilities to kvbm-logical and kvbm-physical (#6691)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 146eb3b4
......@@ -410,9 +410,9 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
dependencies = [
"aws-lc-sys",
"zeroize",
......@@ -420,9 +420,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
dependencies = [
"cc",
"cmake",
......@@ -2266,9 +2266,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "erased-serde"
version = "0.4.9"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e8918065695684b2b0702da20382d5ae6065cf3327bc2d6436bd49a71ce9f3"
checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec"
dependencies = [
"serde",
"serde_core",
......@@ -3575,9 +3575,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]]
name = "jiff"
version = "0.2.21"
version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3e3d65f018c6ae946ab16e80944b97096ed73c35b221d1c478a6c81d8f57940"
checksum = "819b44bc7c87d9117eb522f14d46e918add69ff12713c475946b0a29363ed1c2"
dependencies = [
"jiff-static",
"jiff-tzdb-platform",
......@@ -3590,9 +3590,9 @@ dependencies = [
[[package]]
name = "jiff-static"
version = "0.2.21"
version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a17c2b211d863c7fde02cbea8a3c1a439b98e109286554f2860bdded7ff83818"
checksum = "470252db18ecc35fd766c0891b1e3ec6cbbcd62507e85276c01bf75d8e94d4a1"
dependencies = [
"proc-macro2",
"quote",
......@@ -5862,12 +5862,9 @@ dependencies = [
[[package]]
name = "pxfm"
version = "0.1.27"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8"
dependencies = [
"num-traits",
]
checksum = "b5a041e753da8b807c9255f28de81879c78c876392ff2469cde94799b2896b9d"
[[package]]
name = "py_literal"
......@@ -7722,9 +7719,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.6.0"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c"
dependencies = [
"proc-macro2",
"quote",
......@@ -9524,9 +9521,9 @@ dependencies = [
[[package]]
name = "zlib-rs"
version = "0.6.2"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c745c48e1007337ed136dc99df34128b9faa6ed542d80a1c673cf55a6d7236c8"
checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513"
[[package]]
name = "zmij"
......
......@@ -380,9 +380,9 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
dependencies = [
"aws-lc-sys",
"zeroize",
......@@ -390,9 +390,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
dependencies = [
"cc",
"cmake",
......@@ -1869,9 +1869,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "erased-serde"
version = "0.4.9"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e8918065695684b2b0702da20382d5ae6065cf3327bc2d6436bd49a71ce9f3"
checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec"
dependencies = [
"serde",
"serde_core",
......@@ -6601,9 +6601,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.6.0"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c"
dependencies = [
"proc-macro2",
"quote",
......
......@@ -380,9 +380,9 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
dependencies = [
"aws-lc-sys",
"zeroize",
......@@ -390,9 +390,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
dependencies = [
"cc",
"cmake",
......@@ -1904,9 +1904,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "erased-serde"
version = "0.4.9"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e8918065695684b2b0702da20382d5ae6065cf3327bc2d6436bd49a71ce9f3"
checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec"
dependencies = [
"serde",
"serde_core",
......@@ -6659,9 +6659,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.6.0"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c"
dependencies = [
"proc-macro2",
"quote",
......
......@@ -3,9 +3,17 @@
//! Test manager creation helpers.
use crate::blocks::BlockMetadata;
use std::marker::PhantomData;
use std::sync::Arc;
use anyhow::Result;
use crate::SequenceHash;
use crate::blocks::{BlockMetadata, BlockRegistry};
use crate::events::EventsManager;
use crate::manager::{BlockManager, FrequencyTrackingCapacity};
use crate::registry::BlockRegistry;
use super::token_blocks;
/// Create a basic test manager with LRU backend.
pub fn create_test_manager<T: BlockMetadata>(block_count: usize) -> BlockManager<T> {
......@@ -39,3 +47,304 @@ pub fn create_test_manager_with_block_size<T: BlockMetadata>(
.build()
.expect("Should build manager")
}
/// Builder for creating test BlockRegistry with optional events integration.
///
/// # Example
///
/// ```ignore
/// // Simple registry
/// let registry = TestRegistryBuilder::new().build();
///
/// // With events manager
/// let events_manager = Arc::new(EventsManager::builder().build());
/// let registry = TestRegistryBuilder::new()
/// .events_manager(events_manager)
/// .build();
///
/// // With custom frequency tracking
/// let registry = TestRegistryBuilder::new()
/// .frequency_tracking(FrequencyTrackingCapacity::Large)
/// .build();
/// ```
#[derive(Default)]
pub struct TestRegistryBuilder {
events_manager: Option<Arc<EventsManager>>,
frequency_tracking: FrequencyTrackingCapacity,
}
impl TestRegistryBuilder {
/// Creates a new builder with default settings.
pub fn new() -> Self {
Self {
events_manager: None,
frequency_tracking: FrequencyTrackingCapacity::Medium,
}
}
/// Sets the events manager for distributed event coordination.
pub fn events_manager(mut self, manager: Arc<EventsManager>) -> Self {
self.events_manager = Some(manager);
self
}
/// Sets the frequency tracking capacity.
///
/// Default: Medium
pub fn frequency_tracking(mut self, capacity: FrequencyTrackingCapacity) -> Self {
self.frequency_tracking = capacity;
self
}
/// Builds the BlockRegistry.
pub fn build(self) -> BlockRegistry {
let mut builder =
BlockRegistry::builder().frequency_tracker(self.frequency_tracking.create_tracker());
if let Some(events_manager) = self.events_manager {
builder = builder.event_manager(events_manager);
}
builder.build()
}
}
/// Builder for creating test BlockManagers.
///
/// # Example
///
/// ```ignore
/// // Simple manager (creates its own registry)
/// let manager = TestManagerBuilder::<G1>::new()
/// .block_count(100)
/// .block_size(4)
/// .build();
///
/// // With explicit registry (for events integration)
/// let events_manager = Arc::new(EventsManager::builder().build());
/// let registry = TestRegistryBuilder::new()
/// .events_manager(events_manager.clone())
/// .build();
/// let manager = TestManagerBuilder::<G1>::new()
/// .block_count(100)
/// .block_size(4)
/// .registry(registry)
/// .build();
///
/// // Convenience: with events manager (creates registry internally)
/// let manager = TestManagerBuilder::<G1>::new()
/// .block_count(100)
/// .block_size(4)
/// .events_manager(events_manager)
/// .build();
/// ```
pub struct TestManagerBuilder<T: BlockMetadata> {
block_count: Option<usize>,
block_size: Option<usize>,
registry: Option<BlockRegistry>,
events_manager: Option<Arc<EventsManager>>,
frequency_tracking: FrequencyTrackingCapacity,
_phantom: PhantomData<T>,
}
impl<T: BlockMetadata> Default for TestManagerBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: BlockMetadata> TestManagerBuilder<T> {
/// Creates a new builder with default settings.
pub fn new() -> Self {
Self {
block_count: None,
block_size: None,
registry: None,
events_manager: None,
frequency_tracking: FrequencyTrackingCapacity::Medium,
_phantom: PhantomData,
}
}
/// Sets the number of blocks in the pool.
pub fn block_count(mut self, count: usize) -> Self {
self.block_count = Some(count);
self
}
/// Sets the tokens per block (must be power of 2, 1-1024).
pub fn block_size(mut self, size: usize) -> Self {
self.block_size = Some(size);
self
}
/// Sets the registry to use.
///
/// If not set, a registry will be created based on `frequency_tracking`
/// and `events_manager` settings.
pub fn registry(mut self, registry: BlockRegistry) -> Self {
self.registry = Some(registry);
self
}
/// Sets the events manager for distributed event coordination.
///
/// This is a convenience method that creates a registry with the events manager.
/// If you also call `registry()`, this setting is ignored.
pub fn events_manager(mut self, manager: Arc<EventsManager>) -> Self {
self.events_manager = Some(manager);
self
}
/// Sets the frequency tracking capacity for auto-created registry.
///
/// Ignored if `registry()` is called.
///
/// Default: Medium
pub fn frequency_tracking(mut self, capacity: FrequencyTrackingCapacity) -> Self {
self.frequency_tracking = capacity;
self
}
/// Builds the BlockManager.
///
/// # Panics
///
/// Panics if `block_count` or `block_size` are not set.
pub fn build(self) -> BlockManager<T> {
let block_count = self.block_count.expect("block_count is required");
let block_size = self.block_size.expect("block_size is required");
let registry = self.registry.unwrap_or_else(|| {
let mut builder =
TestRegistryBuilder::new().frequency_tracking(self.frequency_tracking);
if let Some(events_manager) = self.events_manager {
builder = builder.events_manager(events_manager);
}
builder.build()
});
BlockManager::<T>::builder()
.block_count(block_count)
.block_size(block_size)
.registry(registry)
.with_lru_backend()
.build()
.expect("Should build test manager")
}
}
/// Populate a BlockManager with token blocks and return their sequence hashes.
///
/// This function:
/// 1. Allocates blocks from the manager
/// 2. Completes them with provided token blocks
/// 3. Registers them
/// 4. Drops the immutable blocks (returns to inactive pool)
///
/// # Returns
/// Vec of sequence hashes for the registered blocks (in order)
pub fn populate_manager_with_blocks<T: BlockMetadata>(
manager: &BlockManager<T>,
token_blocks: &[dynamo_tokens::TokenBlock],
) -> Result<Vec<SequenceHash>> {
let blocks = manager
.allocate_blocks(token_blocks.len())
.ok_or_else(|| anyhow::anyhow!("Failed to allocate {} blocks", token_blocks.len()))?;
let complete_blocks: Vec<_> = blocks
.into_iter()
.zip(token_blocks.iter())
.map(|(block, token_block)| {
block
.complete(token_block)
.map_err(|e| anyhow::anyhow!("Failed to complete block: {:?}", e))
})
.collect::<Result<Vec<_>>>()?;
let seq_hashes: Vec<SequenceHash> = complete_blocks.iter().map(|b| b.sequence_hash()).collect();
let immutable_blocks = manager.register_blocks(complete_blocks);
// Drop immutable blocks - they return to inactive pool via RAII
drop(immutable_blocks);
Ok(seq_hashes)
}
/// Quick setup: create manager and populate with sequential token blocks.
///
/// # Arguments
/// * `block_count` - Number of blocks
/// * `block_size` - Tokens per block
/// * `start_token` - Starting token value for sequence
///
/// # Returns
/// (BlockManager, Vec<SequenceHash>)
pub fn create_and_populate_manager<T: BlockMetadata>(
block_count: usize,
block_size: usize,
start_token: u32,
registry: BlockRegistry,
) -> Result<(BlockManager<T>, Vec<SequenceHash>)> {
let manager = TestManagerBuilder::<T>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry)
.build();
let token_sequence = token_blocks::create_token_sequence(block_count, block_size, start_token);
let seq_hashes = populate_manager_with_blocks(&manager, token_sequence.blocks())?;
Ok((manager, seq_hashes))
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug)]
struct TestMetadata;
#[test]
fn test_create_test_manager() {
let manager = TestManagerBuilder::<TestMetadata>::new()
.block_count(100)
.block_size(16)
.build();
assert_eq!(manager.total_blocks(), 100);
assert_eq!(manager.block_size(), 16);
assert_eq!(manager.available_blocks(), 100);
}
#[test]
fn test_populate_manager_with_blocks() {
let manager = TestManagerBuilder::<TestMetadata>::new()
.block_count(50)
.block_size(4)
.build();
let token_seq = token_blocks::create_token_sequence(10, 4, 0);
let seq_hashes =
populate_manager_with_blocks(&manager, token_seq.blocks()).expect("Should populate");
assert_eq!(seq_hashes.len(), 10);
// Blocks should be in inactive pool after population
assert_eq!(manager.available_blocks(), 50);
}
#[test]
fn test_create_and_populate_manager() {
let registry = TestRegistryBuilder::new().build();
let (manager, hashes) = create_and_populate_manager::<TestMetadata>(32, 4, 100, registry)
.expect("Should create");
assert_eq!(hashes.len(), 32);
assert_eq!(manager.total_blocks(), 32);
assert_eq!(manager.available_blocks(), 32);
// Verify blocks can be matched
let matched = manager.match_blocks(&hashes);
assert_eq!(matched.len(), 32);
}
}
......@@ -55,10 +55,15 @@ pub const TEST_SALT: u64 = 42;
// ============================================================================
// pub items — usable by downstream crates via the testing feature
pub use managers::{create_test_manager, create_test_manager_with_block_size};
pub use managers::{
TestManagerBuilder, TestRegistryBuilder, create_and_populate_manager, create_test_manager,
create_test_manager_with_block_size, populate_manager_with_blocks,
};
pub use sequences::BlockSequenceBuilder;
pub use token_blocks::{
create_iota_token_block, create_test_token_block, sequential_tokens, tokens_for_id,
create_disjoint_sequences, create_iota_token_block, create_sequential_block,
create_test_token_block, create_token_block, create_token_sequence, default_request_salt_hash,
generate_sequence_hashes, sequential_tokens, tokens_for_id,
};
// pub(crate) items — internal helpers using internal types
......
......@@ -3,7 +3,9 @@
//! Token block creation helpers for tests.
use dynamo_tokens::{TokenBlock, TokenBlockSequence};
use dynamo_tokens::{TokenBlock, TokenBlockSequence, compute_hash_v2};
use crate::{KvbmSequenceHashProvider, SequenceHash};
use super::TEST_SALT;
......@@ -38,3 +40,175 @@ pub fn sequential_tokens(start: u32, count: usize) -> Vec<u32> {
pub fn tokens_for_id(id: u64) -> Vec<u32> {
vec![id as u32, (id + 1) as u32, (id + 2) as u32, (id + 3) as u32]
}
/// Compute the default salt hash for requests with no salt and no lora.
///
/// This matches the hash computed by `Request::new()` when salt=None and lora_name=None.
pub fn default_request_salt_hash() -> u64 {
// Matches Request::new() computation:
// SaltPayload { salt: None, lora_name: None } serializes to "{}"
compute_hash_v2(b"{}", 0)
}
/// Create a token block from a slice of tokens.
///
/// Uses the default request salt hash to match blocks created by
/// requests with no salt parameter.
pub fn create_token_block(tokens: &[u32]) -> TokenBlock {
let salt = default_request_salt_hash();
let token_sequence = TokenBlockSequence::from_slice(tokens, tokens.len() as u32, Some(salt));
if let Some(block) = token_sequence.blocks().first() {
block.clone()
} else {
let mut partial = token_sequence.into_parts().1;
partial.commit().expect("Should be able to commit")
}
}
/// Create a token block with sequential tokens starting from `start`.
///
/// # Arguments
/// * `start` - Starting token value
/// * `count` - Number of tokens to generate
pub fn create_sequential_block(start: u32, count: usize) -> TokenBlock {
let tokens: Vec<u32> = (start..start + count as u32).collect();
create_token_block(&tokens)
}
/// Create a token sequence with multiple blocks.
///
/// Uses the default request salt hash to match blocks created by
/// requests with no salt parameter.
///
/// # Arguments
/// * `num_blocks` - Number of blocks to create
/// * `block_size` - Tokens per block
/// * `start_token` - Starting token value
///
/// # Returns
/// A TokenBlockSequence containing the requested blocks.
pub fn create_token_sequence(
num_blocks: usize,
block_size: usize,
start_token: u32,
) -> TokenBlockSequence {
let salt = default_request_salt_hash();
let total_tokens = num_blocks * block_size;
let tokens: Vec<u32> = (start_token..start_token + total_tokens as u32).collect();
TokenBlockSequence::from_slice(&tokens, block_size as u32, Some(salt))
}
/// Generate sequence hashes from a token sequence.
pub fn generate_sequence_hashes(token_sequence: &TokenBlockSequence) -> Vec<SequenceHash> {
token_sequence
.blocks()
.iter()
.map(|block| block.kvbm_sequence_hash())
.collect()
}
/// Create multiple disjoint token sequences with gaps between them.
///
/// This is useful for testing contiguous subsequence detection, where you need
/// blocks at non-consecutive positions with gaps between them.
///
/// # Arguments
/// * `segments` - Vec of (num_blocks, start_token) pairs. Each segment creates
/// consecutive blocks starting at the given token.
/// * `block_size` - Tokens per block
///
/// # Returns
/// A tuple of (Vec<TokenBlock>, Vec<SequenceHash>) containing all blocks and
/// their hashes from all segments, sorted by position.
pub fn create_disjoint_sequences(
segments: Vec<(usize, u32)>,
block_size: usize,
) -> (Vec<TokenBlock>, Vec<SequenceHash>) {
let mut all_blocks = Vec::new();
let mut all_hashes = Vec::new();
for (num_blocks, start_token) in segments {
let token_sequence = create_token_sequence(num_blocks, block_size, start_token);
let blocks = token_sequence.blocks().to_vec();
let hashes = generate_sequence_hashes(&token_sequence);
all_blocks.extend(blocks);
all_hashes.extend(hashes);
}
// Sort by position to maintain order
let mut combined: Vec<_> = all_blocks.into_iter().zip(all_hashes).collect();
combined.sort_by_key(|(_, hash)| hash.position());
combined.into_iter().unzip()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_token_block() {
let tokens = vec![1, 2, 3, 4];
let block = create_token_block(&tokens);
assert_eq!(block.tokens().len(), 4);
}
#[test]
fn test_create_sequential_block() {
let block = create_sequential_block(100, 4);
assert_eq!(block.tokens().len(), 4);
}
#[test]
fn test_create_token_sequence() {
let sequence = create_token_sequence(10, 4, 0);
assert_eq!(sequence.blocks().len(), 10);
// Verify first block starts at token 0
let first_block = &sequence.blocks()[0];
assert_eq!(first_block.tokens().len(), 4);
}
#[test]
fn test_generate_sequence_hashes() {
let sequence = create_token_sequence(5, 4, 100);
let hashes = generate_sequence_hashes(&sequence);
assert_eq!(hashes.len(), 5);
// Verify hashes are unique
let unique_hashes: std::collections::HashSet<_> = hashes.iter().collect();
assert_eq!(unique_hashes.len(), 5);
}
#[test]
fn test_create_disjoint_sequences() {
// Create 3 segments with different token ranges
let segments = vec![
(2, 0), // 2 blocks starting at token 0
(2, 100), // 2 blocks starting at token 100
(3, 200), // 3 blocks starting at token 200
];
let block_size = 4;
let (blocks, hashes) = create_disjoint_sequences(segments, block_size);
// Should have 7 total blocks
assert_eq!(blocks.len(), 7);
assert_eq!(hashes.len(), 7);
// All hashes should be unique (different token content = different hashes)
let unique_hashes: std::collections::HashSet<_> = hashes.iter().collect();
assert_eq!(unique_hashes.len(), 7);
// Positions are relative within each segment's TokenBlockSequence
assert_eq!(hashes[0].position(), 0);
assert_eq!(hashes[1].position(), 0);
assert_eq!(hashes[2].position(), 0);
assert_eq!(hashes[3].position(), 1);
assert_eq!(hashes[4].position(), 1);
assert_eq!(hashes[5].position(), 1);
assert_eq!(hashes[6].position(), 2);
}
}
......@@ -35,6 +35,7 @@ validator = { workspace = true }
[features]
default = []
collectives = []
testing = []
testing-kvbm = []
testing-nixl-gds = []
......
......@@ -11,6 +11,9 @@ pub use transfer::{TransferConfig, TransferOptions};
pub use kvbm_common::BlockId;
pub type SequenceHash = kvbm_common::SequenceHash;
#[cfg(any(test, feature = "testing"))]
pub mod testing;
#[cfg(test)]
#[cfg(not(feature = "testing-kvbm"))]
mod sentinel {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Physical layout and transfer testing utilities.
mod physical;
pub use physical::*;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Physical layout and transfer testing utilities.
//!
//! This module provides reusable test infrastructure for:
//! - Creating physical layouts with various storage types
//! - Creating TransferManagers with UCX backend for RDMA tests
//! - Filling blocks with test patterns and computing checksums
//! - Verifying data integrity after transfers
use anyhow::Result;
use std::collections::HashMap;
use crate::BlockId;
use crate::{
layout::{BlockDimension, LayoutConfig, PhysicalLayout},
manager::{LayoutHandle, TransferManager},
transfer::{
BlockChecksum, FillPattern, NixlAgent, StorageKind, TransferCapabilities,
compute_block_checksums, compute_layer_checksums, fill_blocks, fill_layers,
},
};
// =============================================================================
// Flexible Backend Agent Builder
// =============================================================================
/// A NixlAgent wrapper that tracks which backends were successfully initialized.
///
/// This wrapper allows tests to check backend availability and conditionally
/// skip tests that require unavailable backends.
///
/// # Example
///
/// ```ignore
/// // Flexible - won't fail if UCX unavailable
/// let agent = TestAgentBuilder::new("test")
/// .try_backend("UCX")
/// .try_backend("POSIX") // Always available
/// .build()?;
///
/// // Check what's available
/// if !agent.has_backend("UCX") {
/// eprintln!("Skipping RDMA test - UCX unavailable");
/// return Ok(());
/// }
/// ```
pub struct TestAgent {
agent: NixlAgent,
available_backends: Vec<String>,
}
impl TestAgent {
/// Returns true if the specified backend was successfully initialized.
pub fn has_backend(&self, backend: &str) -> bool {
self.available_backends
.iter()
.any(|b| b.eq_ignore_ascii_case(backend))
}
/// Returns the list of successfully initialized backends.
pub fn available_backends(&self) -> &[String] {
&self.available_backends
}
/// Consumes self and returns the underlying NixlAgent.
pub fn into_nixl_agent(self) -> NixlAgent {
self.agent
}
/// Returns a reference to the underlying NixlAgent.
pub fn nixl_agent(&self) -> &NixlAgent {
&self.agent
}
}
impl std::ops::Deref for TestAgent {
type Target = NixlAgent;
fn deref(&self) -> &Self::Target {
&self.agent
}
}
/// Builder for TestAgent with flexible backend handling.
///
/// Unlike `NixlAgent::with_backends()` which fails if ANY backend is unavailable,
/// `TestAgentBuilder` allows graceful degradation by distinguishing between
/// required and optional backends.
///
/// # Example
///
/// ```ignore
/// // RDMA tests - UCX is required
/// let agent = TestAgentBuilder::new("rdma-test")
/// .require_backend("UCX") // Fails if unavailable
/// .try_backend("POSIX") // Optional
/// .build()?;
///
/// // Disk tests - POSIX only, no GDS requirement
/// let agent = TestAgentBuilder::new("disk-test")
/// .try_backend("POSIX")
/// .build()?;
/// ```
#[derive(Default)]
pub struct TestAgentBuilder {
name: Option<String>,
try_backends: Vec<String>,
required_backends: Vec<String>,
}
impl TestAgentBuilder {
/// Creates a new builder with the given agent name.
pub fn new(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
try_backends: Vec::new(),
required_backends: Vec::new(),
}
}
/// Attempts to add a backend. If unavailable, build() will still succeed
/// but `has_backend(name)` will return false.
pub fn try_backend(mut self, backend: impl Into<String>) -> Self {
self.try_backends.push(backend.into());
self
}
/// Requires a backend. If unavailable, build() will fail.
///
/// Use this for tests that cannot function without specific backends,
/// like RDMA tests requiring UCX.
pub fn require_backend(mut self, backend: impl Into<String>) -> Self {
self.required_backends.push(backend.into());
self
}
/// Builds the TestAgent.
///
/// # Errors
///
/// Returns an error if:
/// - Agent creation fails
/// - Any required backend fails to initialize
pub fn build(self) -> Result<TestAgent> {
let name = self
.name
.ok_or_else(|| anyhow::anyhow!("Agent name is required"))?;
let mut agent = NixlAgent::new(&name)?;
let mut available_backends = Vec::new();
// Initialize required backends first - fail on error
for backend in &self.required_backends {
let backend_upper = backend.to_uppercase();
agent.add_backend(&backend_upper).map_err(|e| {
anyhow::anyhow!(
"Required backend {} unavailable: {}. \
Use try_backend() if this backend is optional.",
backend_upper,
e
)
})?;
available_backends.push(backend_upper);
}
// Initialize optional backends - log warning but continue
for backend in &self.try_backends {
let backend_upper = backend.to_uppercase();
// Skip if already added as required
if available_backends
.iter()
.any(|b| b.eq_ignore_ascii_case(&backend_upper))
{
continue;
}
match agent.add_backend(&backend_upper) {
Ok(_) => {
tracing::debug!("Initialized optional backend: {}", backend_upper);
available_backends.push(backend_upper);
}
Err(e) => {
tracing::debug!(
"Optional backend {} unavailable: {} - continuing without it",
backend_upper,
e
);
}
}
}
Ok(TestAgent {
agent,
available_backends,
})
}
}
// =============================================================================
// Transfer Checksums Helper
// =============================================================================
/// Captures checksums for source blocks to enable verification after transfer.
///
/// This provides a cleaner pattern for the fill->transfer->verify workflow.
///
/// # Example
///
/// ```ignore
/// // Capture source checksums
/// let src = TransferChecksums::fill_and_capture(&src_layout, &src_ids, FillPattern::Sequential)?;
///
/// // ... execute transfer ...
///
/// // Verify destination matches
/// src.verify_against(&dst_layout, &dst_ids)?;
/// ```
pub struct TransferChecksums {
checksums: HashMap<BlockId, BlockChecksum>,
block_ids: Vec<BlockId>,
}
impl TransferChecksums {
/// Fill blocks with a pattern and capture their checksums.
pub fn fill_and_capture(
layout: &PhysicalLayout,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<Self> {
let checksums = fill_and_checksum(layout, block_ids, pattern)?;
Ok(Self {
checksums,
block_ids: block_ids.to_vec(),
})
}
/// Capture checksums without filling (for already-filled blocks).
pub fn capture(layout: &PhysicalLayout, block_ids: &[BlockId]) -> Result<Self> {
let checksums = crate::transfer::compute_block_checksums(layout, block_ids)?;
Ok(Self {
checksums,
block_ids: block_ids.to_vec(),
})
}
/// Returns the captured checksums.
pub fn checksums(&self) -> &HashMap<BlockId, BlockChecksum> {
&self.checksums
}
/// Returns the block IDs for which checksums were captured.
pub fn block_ids(&self) -> &[BlockId] {
&self.block_ids
}
/// Verify that destination blocks match the captured source checksums.
///
/// The destination block IDs must be the same length as source block IDs.
/// Comparison is done positionally: src_blocks[i] is compared with dst_ids[i].
pub fn verify_against(&self, dst_layout: &PhysicalLayout, dst_ids: &[BlockId]) -> Result<()> {
verify_checksums_by_position(&self.checksums, &self.block_ids, dst_layout, dst_ids)
}
/// Verify that destination blocks match using specific layers only.
pub fn verify_layers_against(
&self,
src_layout: &PhysicalLayout,
dst_layout: &PhysicalLayout,
dst_ids: &[BlockId],
layer_range: std::ops::Range<usize>,
) -> Result<()> {
let src_layer_checksums =
compute_layer_checksums(src_layout, &self.block_ids, layer_range.clone())?;
verify_layer_checksums_by_position(
&src_layer_checksums,
&self.block_ids,
dst_layout,
dst_ids,
layer_range,
)
}
}
// =============================================================================
// Layout Types
// =============================================================================
/// Layout kind for parameterized testing.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayoutKind {
/// Fully contiguous layout
FC,
/// Layer-wise (layer-separate) layout
LW,
}
/// Storage and layout specification for creating test layouts.
#[derive(Debug, Clone, Copy)]
pub struct LayoutSpec {
pub kind: LayoutKind,
pub storage: StorageKind,
}
impl LayoutSpec {
pub fn new(kind: LayoutKind, storage: StorageKind) -> Self {
Self { kind, storage }
}
}
/// Standard layout configuration for tests.
///
/// Uses standard dimensions suitable for most tests:
/// - 2 layers
/// - outer_dim=2 (K&V separate)
/// - page_size=16
/// - inner_dim=128
/// - dtype_width=2 (bf16)
pub fn standard_config(num_blocks: usize) -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(num_blocks)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.expect("standard config should build")
}
/// Create a custom layout configuration for RDMA tests.
///
/// # Arguments
/// * `num_blocks` - Number of blocks in the layout
/// * `num_layers` - Number of transformer layers
/// * `outer_dim` - Outer dimension (2 for K&V separate)
/// * `page_size` - Tokens per block/page
/// * `inner_dim` - Hidden dimension
/// * `dtype_width` - Data type width in bytes
pub fn custom_config(
num_blocks: usize,
num_layers: usize,
outer_dim: usize,
page_size: usize,
inner_dim: usize,
dtype_width: usize,
) -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(num_blocks)
.num_layers(num_layers)
.outer_dim(outer_dim)
.page_size(page_size)
.inner_dim(inner_dim)
.dtype_width_bytes(dtype_width)
.build()
.expect("custom config should build")
}
/// Create a test NIXL agent with no backends.
///
/// Use this for tests that don't require specific NIXL backends.
pub fn create_test_agent(name: &str) -> NixlAgent {
NixlAgent::new(name).expect("Failed to create agent")
}
/// Create a test NIXL agent with specific backends (strict - all must succeed).
///
/// # Arguments
/// * `name` - Agent name (must be unique for RDMA addressing)
/// * `backends` - List of backends to enable (e.g., &["UCX"])
pub fn create_test_agent_with_backends(name: &str, backends: &[&str]) -> Result<NixlAgent> {
NixlAgent::with_backends(name, backends)
}
/// Create a fully contiguous physical layout with the specified storage type.
pub fn create_fc_layout(
agent: NixlAgent,
storage_kind: StorageKind,
num_blocks: usize,
) -> PhysicalLayout {
create_fc_layout_with_config(agent, storage_kind, standard_config(num_blocks))
}
/// Create a fully contiguous physical layout with custom config.
pub fn create_fc_layout_with_config(
agent: NixlAgent,
storage_kind: StorageKind,
config: LayoutConfig,
) -> PhysicalLayout {
let builder = PhysicalLayout::builder(agent)
.with_config(config)
.fully_contiguous();
match storage_kind {
StorageKind::System => builder.allocate_system().build().unwrap(),
StorageKind::Pinned => builder.allocate_pinned(None).build().unwrap(),
StorageKind::Device(device_id) => builder.allocate_device(device_id).build().unwrap(),
StorageKind::Disk(_) => builder.allocate_disk(None).build().unwrap(),
}
}
/// Create a layer-separate physical layout with the specified storage type.
pub fn create_lw_layout(
agent: NixlAgent,
storage_kind: StorageKind,
num_blocks: usize,
) -> PhysicalLayout {
create_lw_layout_with_config(agent, storage_kind, standard_config(num_blocks))
}
/// Create a layer-separate physical layout with custom config.
pub fn create_lw_layout_with_config(
agent: NixlAgent,
storage_kind: StorageKind,
config: LayoutConfig,
) -> PhysicalLayout {
let builder = PhysicalLayout::builder(agent)
.with_config(config)
.layer_separate(BlockDimension::BlockIsFirstDim);
match storage_kind {
StorageKind::System => builder.allocate_system().build().unwrap(),
StorageKind::Pinned => builder.allocate_pinned(None).build().unwrap(),
StorageKind::Device(device_id) => builder.allocate_device(device_id).build().unwrap(),
StorageKind::Disk(_) => builder.allocate_disk(None).build().unwrap(),
}
}
/// Create a physical layout based on the specification.
pub fn create_layout(agent: NixlAgent, spec: LayoutSpec, num_blocks: usize) -> PhysicalLayout {
match spec.kind {
LayoutKind::FC => create_fc_layout(agent, spec.storage, num_blocks),
LayoutKind::LW => create_lw_layout(agent, spec.storage, num_blocks),
}
}
/// Create a physical layout based on specification with custom config.
pub fn create_layout_with_config(
agent: NixlAgent,
spec: LayoutSpec,
config: LayoutConfig,
) -> PhysicalLayout {
match spec.kind {
LayoutKind::FC => create_fc_layout_with_config(agent, spec.storage, config),
LayoutKind::LW => create_lw_layout_with_config(agent, spec.storage, config),
}
}
/// Create a TransferManager for testing.
///
/// # Arguments
/// * `agent` - NIXL agent (should have backends configured)
/// * `capabilities` - Optional transfer capabilities
pub fn create_transfer_manager(
agent: NixlAgent,
capabilities: Option<TransferCapabilities>,
) -> Result<TransferManager> {
TransferManager::builder()
.capabilities(capabilities.unwrap_or_default())
.nixl_agent(agent)
.cuda_device_id(0)
.build()
}
/// Create a TransferManager with UCX backend for RDMA tests.
///
/// # Arguments
/// * `agent_name` - Unique agent name for RDMA addressing
///
/// Note: The worker_id is derived from the event system. For explicit worker_id
/// control, use the TransferManager builder directly with a custom event system.
pub fn create_rdma_transfer_manager(agent_name: &str) -> Result<TransferManager> {
let agent = create_test_agent_with_backends(agent_name, &["UCX"])?;
TransferManager::builder()
.nixl_agent(agent)
.cuda_device_id(0)
.build()
}
/// Fill blocks and compute checksums.
///
/// This can only be called on System or Pinned layouts.
pub fn fill_and_checksum(
layout: &PhysicalLayout,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<HashMap<BlockId, BlockChecksum>> {
fill_blocks(layout, block_ids, pattern)?;
compute_block_checksums(layout, block_ids)
}
/// Fill specific layers and compute checksums.
///
/// This can only be called on System or Pinned layouts.
pub fn fill_layers_and_checksum(
layout: &PhysicalLayout,
block_ids: &[BlockId],
layer_range: std::ops::Range<usize>,
pattern: FillPattern,
) -> Result<HashMap<BlockId, BlockChecksum>> {
fill_layers(layout, block_ids, layer_range.clone(), pattern)?;
compute_layer_checksums(layout, block_ids, layer_range)
}
/// Verify that destination block checksums match the expected source checksums.
///
/// This function compares checksums in order, assuming the source and destination
/// block arrays have a 1:1 correspondence (src[i] was transferred to dst[i]).
pub fn verify_checksums_by_position(
src_checksums: &HashMap<BlockId, BlockChecksum>,
src_block_ids: &[BlockId],
dst_layout: &PhysicalLayout,
dst_block_ids: &[BlockId],
) -> Result<()> {
assert_eq!(
src_block_ids.len(),
dst_block_ids.len(),
"Source and destination block arrays must have same length"
);
let dst_checksums = compute_block_checksums(dst_layout, dst_block_ids)?;
for (src_id, dst_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_checksum = src_checksums
.get(src_id)
.unwrap_or_else(|| panic!("Missing source checksum for block {}", src_id));
let dst_checksum = dst_checksums
.get(dst_id)
.unwrap_or_else(|| panic!("Missing destination checksum for block {}", dst_id));
assert_eq!(
src_checksum, dst_checksum,
"Checksum mismatch: src[{}] != dst[{}]: {} != {}",
src_id, dst_id, src_checksum, dst_checksum
);
}
Ok(())
}
/// Verify checksums for specific layers.
pub fn verify_layer_checksums_by_position(
src_checksums: &HashMap<BlockId, BlockChecksum>,
src_block_ids: &[BlockId],
dst_layout: &PhysicalLayout,
dst_block_ids: &[BlockId],
layer_range: std::ops::Range<usize>,
) -> Result<()> {
assert_eq!(
src_block_ids.len(),
dst_block_ids.len(),
"Source and destination block arrays must have same length"
);
let dst_checksums = compute_layer_checksums(dst_layout, dst_block_ids, layer_range)?;
for (src_id, dst_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_checksum = src_checksums
.get(src_id)
.unwrap_or_else(|| panic!("Missing source checksum for block {}", src_id));
let dst_checksum = dst_checksums
.get(dst_id)
.unwrap_or_else(|| panic!("Missing destination checksum for block {}", dst_id));
assert_eq!(
src_checksum, dst_checksum,
"Checksum mismatch: src[{}] != dst[{}]: {} != {}",
src_id, dst_id, src_checksum, dst_checksum
);
}
Ok(())
}
/// Fill guard blocks and return their checksums for later verification.
///
/// Guard blocks are blocks adjacent to transfer destinations that should
/// remain unchanged during transfers.
pub fn create_guard_blocks(
layout: &PhysicalLayout,
guard_block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<HashMap<BlockId, BlockChecksum>> {
fill_blocks(layout, guard_block_ids, pattern)?;
compute_block_checksums(layout, guard_block_ids)
}
/// Verify that guard blocks remain unchanged after transfers.
pub fn verify_guard_blocks_unchanged(
layout: &PhysicalLayout,
guard_block_ids: &[BlockId],
expected_checksums: &HashMap<BlockId, BlockChecksum>,
) -> Result<()> {
let current_checksums = compute_block_checksums(layout, guard_block_ids)?;
for &block_id in guard_block_ids {
let expected = expected_checksums
.get(&block_id)
.unwrap_or_else(|| panic!("Missing expected checksum for guard block {}", block_id));
let current = current_checksums
.get(&block_id)
.unwrap_or_else(|| panic!("Missing current checksum for guard block {}", block_id));
if expected != current {
anyhow::bail!(
"Guard block {} was modified during transfer! Expected: {}, Got: {}",
block_id,
expected,
current
);
}
}
Ok(())
}
// =============================================================================
// TransferManager-based helpers (for registered layouts)
// =============================================================================
/// Fill blocks in a registered layout via TransferManager.
///
/// Accesses the internal registry directly (only available in-crate).
/// This can only be called on System or Pinned layouts.
pub fn fill_manager_blocks(
manager: &TransferManager,
handle: LayoutHandle,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<()> {
let registry = manager.registry().read().unwrap();
let layout = registry
.get_layout(handle)
.ok_or_else(|| anyhow::anyhow!("Layout not found: {:?}", handle))?;
fill_blocks(layout, block_ids, pattern)
}
/// Compute checksums for blocks in a registered layout.
///
/// Accesses the internal registry directly (only available in-crate).
pub fn compute_manager_checksums(
manager: &TransferManager,
handle: LayoutHandle,
block_ids: &[BlockId],
) -> Result<HashMap<BlockId, BlockChecksum>> {
let registry = manager.registry().read().unwrap();
let layout = registry
.get_layout(handle)
.ok_or_else(|| anyhow::anyhow!("Layout not found: {:?}", handle))?;
compute_block_checksums(layout, block_ids)
}
/// Fill blocks and compute checksums via TransferManager.
///
/// Accesses the internal registry directly (only available in-crate).
/// This can only be called on System or Pinned layouts.
pub fn fill_and_checksum_manager(
manager: &TransferManager,
handle: LayoutHandle,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<HashMap<BlockId, BlockChecksum>> {
let registry = manager.registry().read().unwrap();
let layout = registry
.get_layout(handle)
.ok_or_else(|| anyhow::anyhow!("Layout not found: {:?}", handle))?;
fill_blocks(layout, block_ids, pattern)?;
compute_block_checksums(layout, block_ids)
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
#[test]
fn test_create_fc_layout_system() {
let agent = create_test_agent("test_fc_system");
let layout = create_fc_layout(agent, StorageKind::System, 4);
assert!(layout.layout().as_ref().is_fully_contiguous());
}
#[test]
fn test_create_lw_layout_system() {
let agent = create_test_agent("test_lw_system");
let layout = create_lw_layout(agent, StorageKind::System, 4);
assert!(!layout.layout().as_ref().is_fully_contiguous());
}
#[test]
fn test_fill_and_checksum() {
let agent = create_test_agent("test_fill_checksum");
let layout = create_fc_layout(agent, StorageKind::System, 4);
let block_ids = vec![0, 1, 2];
let checksums = fill_and_checksum(&layout, &block_ids, FillPattern::Sequential).unwrap();
assert_eq!(checksums.len(), 3);
// Each block should have a unique checksum with sequential pattern
let values: Vec<_> = checksums.values().collect();
assert!(values[0] != values[1] || values[1] != values[2]);
}
#[test]
fn test_custom_config() {
let config = custom_config(32, 3, 2, 4, 64, 2);
assert_eq!(config.num_blocks, 32);
assert_eq!(config.num_layers, 3);
assert_eq!(config.outer_dim, 2);
assert_eq!(config.page_size, 4);
assert_eq!(config.inner_dim, 64);
assert_eq!(config.dtype_width_bytes, 2);
}
}
......@@ -191,9 +191,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-lc-rs"
version = "1.16.0"
version = "1.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9"
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
dependencies = [
"aws-lc-sys",
"zeroize",
......@@ -201,9 +201,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.37.1"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549"
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
dependencies = [
"cc",
"cmake",
......@@ -3627,9 +3627,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.6.0"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5"
checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c"
dependencies = [
"proc-macro2",
"quote",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment