Unverified Commit 3998fdcb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KVBM V2 Initial Migration (#3861)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent e64d2f09
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for the storage-next module.
use super::*;
#[test]
fn test_system_storage() {
let storage = SystemStorage::new(1024).unwrap();
assert_eq!(storage.size(), 1024);
assert_eq!(storage.storage_kind(), StorageKind::System);
assert!(storage.addr() != 0);
// Test that we can create multiple allocations
let storage2 = SystemStorage::new(2048).unwrap();
assert_eq!(storage2.size(), 2048);
assert_ne!(storage.addr(), storage2.addr());
}
#[test]
fn test_system_storage_zero_size() {
let result = SystemStorage::new(0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_disk_storage_temp() {
let storage = DiskStorage::new(4096).unwrap();
assert_eq!(storage.size(), 4096);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
// Disk storage is file-backed, so addr() returns 0 (no memory address)
assert_eq!(storage.addr(), 0);
assert!(storage.path().exists());
}
#[test]
fn test_disk_storage_at_path() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.bin");
let storage = DiskStorage::new_at(&path, 8192).unwrap();
assert_eq!(storage.size(), 8192);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
assert!(path.exists());
}
#[test]
fn test_type_erasure() {
let storage = SystemStorage::new(1024).unwrap();
let erased: OwnedMemoryRegion = erase_storage(storage);
assert_eq!(erased.size(), 1024);
assert_eq!(erased.storage_kind(), StorageKind::System);
}
#[test]
fn test_memory_descriptor() {
let desc = MemoryDescriptor::new(0x1000, 4096);
assert_eq!(desc.addr, 0x1000);
assert_eq!(desc.size, 4096);
}
#[cfg(feature = "testing-cuda")]
mod cuda_tests {
use super::*;
#[test]
fn test_pinned_storage() {
let storage = PinnedStorage::new(2048).unwrap();
assert_eq!(storage.size(), 2048);
assert_eq!(storage.storage_kind(), StorageKind::Pinned);
assert!(storage.addr() != 0);
}
#[test]
fn test_pinned_storage_zero_size() {
let storage = PinnedStorage::new(0);
assert!(storage.is_err());
assert!(matches!(
storage.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_device_storage() {
let storage = DeviceStorage::new(4096, 0).unwrap();
assert_eq!(storage.size(), 4096);
assert_eq!(storage.storage_kind(), StorageKind::Device(0));
assert!(storage.addr() != 0);
assert_eq!(storage.device_id(), 0);
}
#[test]
fn test_device_storage_zero_size() {
let result = DeviceStorage::new(0, 0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
}
// Tests for NIXL registration would require a real NIXL agent,
// so we'll skip those for now. In practice, you'd mock the agent
// or use integration tests.
#[cfg(feature = "testing-nixl")]
mod nixl_tests {
use super::super::registered::register_with_nixl;
use super::*;
use nixl_sys::Agent as NixlAgent;
// These tests would require a mock NIXL agent or real NIXL setup
// Placeholder for now
#[test]
fn test_nixl_registration() {
let pinned = PinnedStorage::new(2048).unwrap();
let agent = NixlAgent::new("test_agent").unwrap();
let registered = register_with_nixl(pinned, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
impl TorchDevice {
pub fn is_cuda(&self) -> bool {
matches!(self, TorchDevice::Cuda(_))
}
pub fn cuda_device_index(&self) -> Option<usize> {
match self {
TorchDevice::Cuda(index) => Some(*index),
TorchDevice::Other(_) => None,
}
}
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Typed builder for constructing [`PhysicalLayout`](crate::block_manager::v2::layout::PhysicalLayout)
//! instances with strongly-typed configuration, layout selection, and memory provisioning.
//!
//! The builder enforces the three steps required to materialize a physical layout:
//! 1. Provide a [`LayoutConfig`]
//! 2. Select a concrete layout (fully contiguous or layer separate)
//! 3. Specify memory backing (either by allocating or by supplying existing regions)
//!
//! NIXL registration is always enabled. Callers must provide a [`nixl_sys::Agent`], and any memory
//! supplied to the builder must implement [`NixlCompatible`].
use crate::block_manager::v2::physical::layout::physical::PhysicalLayout;
use super::{
BlockDimension, FullyContiguousLayout, LayerSeparateLayout, Layout, LayoutConfig, MemoryRegion,
physical::NixlMetadata,
};
use crate::block_manager::v2::memory::{
DiskStorage, NixlCompatible, NixlDescriptor, OffsetMemoryRegion, OwnedMemoryRegion,
RegisteredView, StorageKind, SystemStorage, register_with_nixl,
};
use anyhow::{Result, anyhow, bail};
#[allow(unused_imports)]
use nixl_sys::Agent as RawNixlAgent;
use nixl_sys::MemType;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::sync::Arc;
use crate::block_manager::v2::memory::{DeviceStorage, PinnedStorage};
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
const REGION_ALIGNMENT: usize = 512;
/// Layout selection exposed by the builder.
#[derive(Debug, Clone)]
pub enum LayoutKind {
FullyContiguous,
LayerSeparate { block_dim: BlockDimension },
}
/// Allocation strategies for builder-managed memory.
#[derive(Debug, Clone)]
enum AllocationKind {
System,
Pinned { numa_aware: bool },
Device { device_id: u32 },
Disk { path: Option<PathBuf> },
}
/// Memory provisioning plan (either provided regions or an allocation request).
#[derive(Debug, Clone)]
enum MemoryPlan {
Provided(Vec<MemoryEntry>),
Allocate(AllocationKind),
}
/// Memory tenancy captured during the build process.
#[derive(Debug, Clone)]
struct MemoryEntry {
region: OwnedMemoryRegion,
descriptor: Option<NixlDescriptor>,
}
impl MemoryEntry {
fn new(region: OwnedMemoryRegion, descriptor: Option<NixlDescriptor>) -> Self {
Self { region, descriptor }
}
fn ensure_registered(mut self) -> Result<Self> {
if self.descriptor.is_none() {
self.descriptor = self.region.nixl_descriptor();
}
#[cfg(not(test))]
{
// In production, require NIXL registration
if self.descriptor.is_none() {
bail!(
"memory region {} is not registered with NIXL",
self.region.addr()
);
}
}
// In test builds, allow None descriptors for local-only layouts
Ok(self)
}
}
/// Marker types for the builder state machine.
pub struct NoConfig;
pub struct HasConfig;
pub struct NoLayout;
pub struct HasLayout;
pub struct NoMemory;
pub struct HasMemory;
/// Default builder state type alias.
pub type PhysicalLayoutBuilderDefault = PhysicalLayoutBuilder<NoConfig, NoLayout, NoMemory>;
/// Typed builder enforcing configuration, layout selection, and memory provisioning phases.
pub struct PhysicalLayoutBuilder<C, L, M> {
agent: NixlAgent,
config: Option<LayoutConfig>,
layout_kind: Option<LayoutKind>,
memory_plan: Option<MemoryPlan>,
_config: PhantomData<C>,
_layout: PhantomData<L>,
_memory: PhantomData<M>,
}
impl PhysicalLayoutBuilder<NoConfig, NoLayout, NoMemory> {
/// Create a new builder in its initial state.
pub fn new(agent: NixlAgent) -> Self {
Self {
agent,
config: None,
layout_kind: None,
memory_plan: None,
_config: PhantomData,
_layout: PhantomData,
_memory: PhantomData,
}
}
}
impl<C, L, M> PhysicalLayoutBuilder<C, L, M> {
fn into_parts(
self,
) -> (
NixlAgent,
Option<LayoutConfig>,
Option<LayoutKind>,
Option<MemoryPlan>,
) {
(self.agent, self.config, self.layout_kind, self.memory_plan)
}
fn from_parts<C2, L2, M2>(
agent: NixlAgent,
config: Option<LayoutConfig>,
layout_kind: Option<LayoutKind>,
memory_plan: Option<MemoryPlan>,
) -> PhysicalLayoutBuilder<C2, L2, M2> {
PhysicalLayoutBuilder {
agent,
config,
layout_kind,
memory_plan,
_config: PhantomData,
_layout: PhantomData,
_memory: PhantomData,
}
}
}
impl<L, M> PhysicalLayoutBuilder<NoConfig, L, M> {
/// Attach the [`LayoutConfig`] required to size the layout and allocations.
pub fn with_config(self, config: LayoutConfig) -> PhysicalLayoutBuilder<HasConfig, L, M> {
let (agent, _config, layout_kind, memory_plan) = self.into_parts();
PhysicalLayoutBuilder::<HasConfig, L, M>::from_parts(
agent,
Some(config),
layout_kind,
memory_plan,
)
}
}
impl<M> PhysicalLayoutBuilder<HasConfig, NoLayout, M> {
/// Select the fully contiguous layout variant.
pub fn fully_contiguous(self) -> PhysicalLayoutBuilder<HasConfig, HasLayout, M> {
let (agent, config, _layout, memory_plan) = self.into_parts();
PhysicalLayoutBuilder::<HasConfig, HasLayout, M>::from_parts(
agent,
config,
Some(LayoutKind::FullyContiguous),
memory_plan,
)
}
/// Select the layer-separate layout variant with the provided block dimension ordering.
pub fn layer_separate(
self,
block_dim: BlockDimension,
) -> PhysicalLayoutBuilder<HasConfig, HasLayout, M> {
let (agent, config, _layout, memory_plan) = self.into_parts();
PhysicalLayoutBuilder::<HasConfig, HasLayout, M>::from_parts(
agent,
config,
Some(LayoutKind::LayerSeparate { block_dim }),
memory_plan,
)
}
}
impl PhysicalLayoutBuilder<HasConfig, HasLayout, NoMemory> {
fn set_memory_plan(
self,
plan: MemoryPlan,
) -> PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
let (agent, config, layout_kind, _memory) = self.into_parts();
PhysicalLayoutBuilder::<HasConfig, HasLayout, HasMemory>::from_parts(
agent,
config,
layout_kind,
Some(plan),
)
}
pub fn allocate_system(self) -> PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
self.set_memory_plan(MemoryPlan::Allocate(AllocationKind::System))
}
/// Allocate pinned (page-locked) host memory.
pub fn allocate_pinned(
self,
numa_aware: bool,
) -> PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
self.set_memory_plan(MemoryPlan::Allocate(AllocationKind::Pinned { numa_aware }))
}
/// Allocate device memory on the specified CUDA device (or the context device if `None`).
pub fn allocate_device(
self,
device_id: u32,
) -> PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
self.set_memory_plan(MemoryPlan::Allocate(AllocationKind::Device { device_id }))
}
/// Allocate disk-backed storage. When `path` is `None`, a temporary file is used.
pub fn allocate_disk(
self,
path: Option<PathBuf>,
) -> PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
self.set_memory_plan(MemoryPlan::Allocate(AllocationKind::Disk { path }))
}
/// Use existing NIXL-compatible memory regions supplied by the caller.
pub fn with_memory_regions<S>(
self,
regions: Vec<S>,
) -> Result<PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory>>
where
S: MemoryRegion + NixlCompatible + 'static,
{
let (agent, config, layout_kind, _memory) = self.into_parts();
let entries = register_existing_regions(&agent, regions)?;
Ok(
PhysicalLayoutBuilder::<HasConfig, HasLayout, HasMemory>::from_parts(
agent,
config,
layout_kind,
Some(MemoryPlan::Provided(entries)),
),
)
}
/// Use pre-registered memory regions (already wrapped in `Arc<dyn MemoryRegion>`).
///
/// All regions must already expose a NIXL descriptor.
pub fn with_registered_regions(
self,
regions: Vec<OwnedMemoryRegion>,
) -> Result<PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory>> {
let entries = regions
.into_iter()
.enumerate()
.map(|(index, region)| {
let descriptor = region.nixl_descriptor().ok_or_else(|| {
anyhow!(
"provided memory region at index {} is not NIXL registered",
index
)
})?;
Ok(MemoryEntry::new(region, Some(descriptor)))
})
.collect::<Result<Vec<_>>>()?;
let (agent, config, layout_kind, _memory) = self.into_parts();
Ok(
PhysicalLayoutBuilder::<HasConfig, HasLayout, HasMemory>::from_parts(
agent,
config,
layout_kind,
Some(MemoryPlan::Provided(entries)),
),
)
}
}
impl PhysicalLayoutBuilder<HasConfig, HasLayout, HasMemory> {
/// Finalize the builder, constructing the [`PhysicalLayout`].
pub fn build(self) -> Result<PhysicalLayout> {
let (agent, config, layout_kind, memory_plan) = self.into_parts();
let config = config.ok_or_else(|| anyhow!("layout config missing despite type state"))?;
let layout_kind =
layout_kind.ok_or_else(|| anyhow!("layout kind missing despite type state"))?;
let memory_plan =
memory_plan.ok_or_else(|| anyhow!("memory plan missing despite type state"))?;
let required_sizes = compute_allocation_sizes(&config, &layout_kind)?;
let entries = resolve_memory_plan(&agent, memory_plan, &required_sizes)?;
validate_memory_sizes(&entries, &required_sizes)?;
let kind = derive_storage_kind(&entries)?;
let metadata = derive_nixl_metadata(&agent, &entries)?;
let layout: Arc<dyn Layout> = match layout_kind {
LayoutKind::FullyContiguous => {
let entry = entries.first().ok_or_else(|| {
anyhow!("fully contiguous layout requires a single memory region")
})?;
let layout = FullyContiguousLayout::new(config.clone(), Arc::clone(&entry.region))?;
Arc::new(layout)
}
LayoutKind::LayerSeparate { block_dim } => {
let regions: Vec<OwnedMemoryRegion> = entries
.iter()
.map(|entry| Arc::clone(&entry.region))
.collect();
let layout = LayerSeparateLayout::new(config.clone(), regions, block_dim)?;
Arc::new(layout)
}
};
Ok(PhysicalLayout::new_local(layout, kind, metadata))
}
}
fn register_existing_regions<S>(agent: &NixlAgent, regions: Vec<S>) -> Result<Vec<MemoryEntry>>
where
S: MemoryRegion + NixlCompatible + 'static,
{
regions
.into_iter()
.map(|region| register_storage(region, agent))
.collect()
}
fn resolve_memory_plan(
agent: &NixlAgent,
plan: MemoryPlan,
sizes: &[usize],
) -> Result<Vec<MemoryEntry>> {
match plan {
MemoryPlan::Provided(entries) => {
if entries.len() != sizes.len() {
bail!(
"provided memory count ({}) does not match required allocations ({})",
entries.len(),
sizes.len()
);
}
entries
.into_iter()
.map(MemoryEntry::ensure_registered)
.collect()
}
MemoryPlan::Allocate(strategy) => allocate_regions(agent, strategy, sizes),
}
}
fn allocate_regions(
agent: &NixlAgent,
strategy: AllocationKind,
sizes: &[usize],
) -> Result<Vec<MemoryEntry>> {
if sizes.is_empty() {
return Ok(Vec::new());
}
let reserve_size = total_allocation_size(sizes, REGION_ALIGNMENT)?;
let base_entry = match strategy {
AllocationKind::System => allocate_system_entry(reserve_size, agent)?,
AllocationKind::Pinned { numa_aware } => {
allocate_pinned_entry(reserve_size, agent, numa_aware)?
}
AllocationKind::Device { device_id } => {
allocate_device_entry(reserve_size, agent, device_id)?
}
AllocationKind::Disk { path } => allocate_disk_entry(reserve_size, agent, path)?,
};
create_offset_entries(base_entry, sizes, REGION_ALIGNMENT)
}
fn allocate_system_entry(size: usize, agent: &NixlAgent) -> Result<MemoryEntry> {
let storage = SystemStorage::new(size)
.map_err(|e| anyhow!("failed to allocate system memory ({size} bytes): {e}"))?;
register_storage(storage, agent)
}
fn allocate_pinned_entry(size: usize, agent: &NixlAgent, _numa_aware: bool) -> Result<MemoryEntry> {
let storage = PinnedStorage::new(size)
.map_err(|e| anyhow!("failed to allocate pinned memory ({size} bytes): {e}"))?;
register_storage(storage, agent)
}
fn allocate_device_entry(size: usize, agent: &NixlAgent, device_id: u32) -> Result<MemoryEntry> {
let storage = DeviceStorage::new(size, device_id).map_err(|e| {
anyhow!("failed to allocate device memory ({size} bytes) on device {device_id}: {e}")
})?;
register_storage(storage, agent)
}
fn allocate_disk_entry(
size: usize,
agent: &NixlAgent,
path: Option<PathBuf>,
) -> Result<MemoryEntry> {
let storage = if let Some(path) = path {
DiskStorage::new_at(&path, size)
.map_err(|e| anyhow!("failed to allocate disk storage at {}: {e}", path.display()))?
} else {
DiskStorage::new(size).map_err(|e| anyhow!("failed to allocate disk storage: {e}"))?
};
register_storage(storage, agent)
}
// When testing, we allow unregistered layouts to help with test time. NIXL + UCX is very expensive to setup
// so we only use that backend when it's needed.
#[cfg(test)]
fn register_storage<S>(storage: S, agent: &NixlAgent) -> Result<MemoryEntry>
where
S: MemoryRegion + NixlCompatible + 'static,
{
let storage_kind = storage.storage_kind();
// Determine if registration is needed based on storage type and available backends
let should_register = match storage_kind {
StorageKind::System | StorageKind::Pinned => {
// System/Pinned memory needs UCX for remote transfers
agent.has_backend("UCX") || agent.has_backend("POSIX")
}
StorageKind::Device(_) => {
// Device memory needs UCX for remote transfers OR GDS for direct disk transfers
agent.has_backend("UCX") || agent.has_backend("GDS_MT")
}
StorageKind::Disk(_) => {
// Disk storage needs POSIX for regular I/O OR GDS for GPU direct I/O
agent.has_backend("POSIX") || agent.has_backend("GDS_MT")
}
};
if !should_register {
// Skip registration - only local non-NIXL transfers will be used
let region: OwnedMemoryRegion = Arc::new(storage);
return Ok(MemoryEntry::new(region, None));
}
// Register with NIXL using the appropriate backend
match register_with_nixl(storage, agent.raw_agent(), None) {
Ok(registered) => {
let descriptor = registered.descriptor();
let region: OwnedMemoryRegion = Arc::new(registered);
Ok(MemoryEntry::new(region, Some(descriptor)))
}
Err(_storage) => bail!("failed to register memory with NIXL agent {}", agent.name()),
}
}
// Production builds always register
#[cfg(not(test))]
fn register_storage<S>(storage: S, agent: &NixlAgent) -> Result<MemoryEntry>
where
S: MemoryRegion + NixlCompatible + 'static,
{
// Production builds always register for safety
match register_with_nixl(storage, agent.raw_agent(), None) {
Ok(registered) => {
let descriptor = registered.descriptor();
let region: OwnedMemoryRegion = Arc::new(registered);
Ok(MemoryEntry::new(region, Some(descriptor)))
}
Err(_storage) => bail!("failed to register memory with NIXL agent {}", agent.name()),
}
}
fn create_offset_entries(
base_entry: MemoryEntry,
sizes: &[usize],
alignment: usize,
) -> Result<Vec<MemoryEntry>> {
if sizes.is_empty() {
return Ok(Vec::new());
}
let base_region = base_entry.region;
let base_descriptor = base_entry.descriptor;
let base_addr = base_region.addr();
let base_len = base_region.size();
let mut entries = Vec::with_capacity(sizes.len());
let mut offset = 0usize;
for (index, &size) in sizes.iter().enumerate() {
let region = if index == 0 && offset == 0 && size == base_len && sizes.len() == 1 {
Arc::clone(&base_region)
} else {
let view = OffsetMemoryRegion::new(Arc::clone(&base_region), offset, size)
.map_err(|e| anyhow!("failed to create offset region: {e}"))?;
Arc::new(view) as OwnedMemoryRegion
};
let descriptor = base_descriptor
.as_ref()
.map(|descriptor| derive_descriptor(descriptor, offset, size))
.transpose()?;
entries.push(MemoryEntry::new(region, descriptor));
offset = offset
.checked_add(size)
.ok_or_else(|| anyhow!("offset computation overflow"))?;
if index + 1 < sizes.len() && alignment > 1 {
let current_addr = base_addr
.checked_add(offset)
.ok_or_else(|| anyhow!("address computation overflow"))?;
let aligned_addr = align_up(current_addr, alignment)?;
offset = aligned_addr
.checked_sub(base_addr)
.ok_or_else(|| anyhow!("alignment subtraction overflow"))?;
}
}
if offset > base_len {
bail!(
"allocated base region ({base_len} bytes) is insufficient for {offset} bytes with padding"
);
}
Ok(entries)
}
fn derive_descriptor(base: &NixlDescriptor, offset: usize, size: usize) -> Result<NixlDescriptor> {
let mut descriptor = base.clone();
descriptor.size = size;
if descriptor.mem_type != MemType::File {
descriptor.addr = descriptor
.addr
.checked_add(offset as u64)
.ok_or_else(|| anyhow!("descriptor address overflow"))?;
}
Ok(descriptor)
}
fn compute_allocation_sizes(config: &LayoutConfig, kind: &LayoutKind) -> Result<Vec<usize>> {
match kind {
LayoutKind::FullyContiguous => {
let factors = [
config.num_blocks,
config.num_layers,
config.outer_dim,
config.page_size,
config.inner_dim,
config.dtype_width_bytes,
];
let total = mul_chain(&factors)?;
Ok(vec![total])
}
LayoutKind::LayerSeparate { .. } => {
let factors = [
config.num_blocks,
config.outer_dim,
config.page_size,
config.inner_dim,
config.dtype_width_bytes,
];
let per_layer = mul_chain(&factors)?;
Ok(vec![per_layer; config.num_layers])
}
}
}
fn mul_chain(factors: &[usize]) -> Result<usize> {
factors.iter().try_fold(1usize, |acc, &value| {
acc.checked_mul(value)
.ok_or_else(|| anyhow!("allocation size overflow during layout computation"))
})
}
fn total_allocation_size(sizes: &[usize], alignment: usize) -> Result<usize> {
if sizes.is_empty() {
return Ok(0);
}
let mut total = *sizes
.first()
.ok_or_else(|| anyhow!("allocation requires at least one region"))?;
for size in sizes.iter().skip(1) {
total = total
.checked_add(*size)
.ok_or_else(|| anyhow!("allocation size overflow during aggregation"))?;
if alignment > 1 {
total = total
.checked_add(alignment - 1)
.ok_or_else(|| anyhow!("allocation alignment padding overflow"))?;
}
}
Ok(total)
}
fn align_up(value: usize, alignment: usize) -> Result<usize> {
if alignment <= 1 {
return Ok(value);
}
let remainder = value % alignment;
if remainder == 0 {
Ok(value)
} else {
value
.checked_add(alignment - remainder)
.ok_or_else(|| anyhow!("alignment overflow"))
}
}
fn validate_memory_sizes(entries: &[MemoryEntry], required: &[usize]) -> Result<()> {
for (entry, &required_size) in entries.iter().zip(required.iter()) {
if entry.region.size() < required_size {
bail!(
"memory region too small: required {} bytes, available {} bytes",
required_size,
entry.region.size()
);
}
}
Ok(())
}
fn derive_storage_kind(entries: &[MemoryEntry]) -> Result<StorageKind> {
let first = entries
.first()
.ok_or_else(|| anyhow!("no memory regions available to determine storage location"))?;
let first_kind = first.region.storage_kind();
for entry in entries.iter().skip(1) {
let kind = entry.region.storage_kind();
if kind != first_kind {
bail!(
"all memory regions must share the same storage location (found {:?} and {:?})",
first_kind,
kind
);
}
}
Ok(first_kind)
}
fn derive_nixl_metadata(agent: &NixlAgent, entries: &[MemoryEntry]) -> Result<NixlMetadata> {
// Try to find a descriptor from entries
let descriptor_opt = entries.iter().find_map(|entry| entry.descriptor.clone());
#[cfg(test)]
{
// In test builds, allow layouts without NIXL registration
// Use defaults for local-only transfers
if let Some(descriptor) = descriptor_opt {
Ok(NixlMetadata::new(
agent.name().to_string(),
descriptor.mem_type,
descriptor.device_id,
))
} else {
// Use placeholder metadata for unregistered layouts
let first_entry = entries
.first()
.ok_or_else(|| anyhow!("no memory entries"))?;
let storage_kind = first_entry.region.storage_kind();
let (mem_type, device_id) = match storage_kind {
StorageKind::System => (MemType::Dram, 0),
StorageKind::Pinned => (MemType::Dram, 0),
StorageKind::Device(id) => (MemType::Vram, id as u64),
StorageKind::Disk(id) => (MemType::File, id),
};
Ok(NixlMetadata::new(
agent.name().to_string(),
mem_type,
device_id,
))
}
}
#[cfg(not(test))]
{
let descriptor = descriptor_opt
.ok_or_else(|| anyhow!("memory entries missing NIXL registration metadata"))?;
Ok(NixlMetadata::new(
agent.name().to_string(),
descriptor.mem_type,
descriptor.device_id,
))
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::{BlockDimension, LayoutConfig};
use super::*;
use crate::block_manager::v2::memory::{MemoryRegion, OwnedMemoryRegion, StorageKind};
use nixl_sys::MemType;
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
struct TestRegisteredRegion {
data: Vec<u8>,
kind: StorageKind,
descriptor: NixlDescriptor,
}
impl TestRegisteredRegion {
fn new(size: usize, kind: StorageKind, mem_type: MemType, device_id: u64) -> Self {
let data = vec![0u8; size];
let addr = data.as_ptr() as u64;
let descriptor = NixlDescriptor {
addr,
size,
mem_type,
device_id,
};
Self {
data,
kind,
descriptor,
}
}
}
impl MemoryRegion for TestRegisteredRegion {
fn addr(&self) -> usize {
self.data.as_ptr() as usize
}
fn size(&self) -> usize {
self.data.len()
}
fn storage_kind(&self) -> StorageKind {
self.kind
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor.clone())
}
}
fn make_layout_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(2)
.num_layers(3)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap()
}
fn fully_contiguous_size(cfg: &LayoutConfig) -> usize {
cfg.num_blocks
* cfg.num_layers
* cfg.outer_dim
* cfg.page_size
* cfg.inner_dim
* cfg.dtype_width_bytes
}
fn per_layer_size(cfg: &LayoutConfig) -> usize {
cfg.num_blocks * cfg.outer_dim * cfg.page_size * cfg.inner_dim * cfg.dtype_width_bytes
}
#[test]
fn builds_fully_contiguous_from_registered_regions() {
let agent = NixlAgent::require_backends("builder-test-fully", &[])
.expect("failed to create wrapped agent");
let cfg = make_layout_config();
let required = fully_contiguous_size(&cfg);
let region = Arc::new(TestRegisteredRegion::new(
required,
StorageKind::System,
MemType::Dram,
0,
)) as OwnedMemoryRegion;
let physical = PhysicalLayoutBuilder::new(agent.clone())
.with_config(cfg.clone())
.fully_contiguous()
.with_registered_regions(vec![region])
.expect("registered regions accepted")
.build()
.expect("builder should succeed");
assert_eq!(physical.location(), StorageKind::System);
assert!(physical.layout().as_ref().is_fully_contiguous());
assert_eq!(physical.layout().config().num_blocks, cfg.num_blocks);
assert_eq!(physical.layout().config().num_layers, cfg.num_layers);
let metadata = physical.nixl_metadata();
assert_eq!(metadata.agent_name(), agent.name());
assert_eq!(metadata.mem_type(), MemType::Dram);
}
#[test]
fn builds_layer_separate_from_registered_regions() {
let agent = NixlAgent::require_backends("builder-test-layer", &[])
.expect("failed to create wrapped agent");
let cfg = make_layout_config();
let per_layer = per_layer_size(&cfg);
let regions: Vec<OwnedMemoryRegion> = (0..cfg.num_layers)
.map(|_| {
Arc::new(TestRegisteredRegion::new(
per_layer,
StorageKind::System,
MemType::Dram,
0,
)) as OwnedMemoryRegion
})
.collect();
let physical = PhysicalLayoutBuilder::new(agent.clone())
.with_config(cfg.clone())
.layer_separate(BlockDimension::BlockIsFirstDim)
.with_registered_regions(regions)
.expect("registered layer regions accepted")
.build()
.expect("builder should succeed");
assert_eq!(physical.location(), StorageKind::System);
assert!(!physical.layout().as_ref().is_fully_contiguous());
assert_eq!(physical.layout().config().num_layers, cfg.num_layers);
let metadata = physical.nixl_metadata();
assert_eq!(metadata.agent_name(), agent.name());
assert_eq!(metadata.mem_type(), MemType::Dram);
}
}
// fn context_device_id(ctx: &TransferContext) -> u32 {
// ctx.stream().context().ordinal() as u32
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
use super::InnerShape;
/// Configuration for block layouts
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize, PartialEq, Eq)]
pub struct LayoutConfig {
/// Number of blocks
#[validate(range(min = 1))]
pub num_blocks: usize,
/// Number of layers
#[validate(range(min = 1))]
pub num_layers: usize,
/// Number of outer dimensions
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
/// Page size
#[validate(range(min = 1))]
pub page_size: usize,
/// Inner dimension
#[validate(range(min = 1))]
pub inner_dim: usize,
/// Alignment
#[validate(custom(function = "validate_power_of_2"))]
#[builder(default = "1")]
pub alignment: usize,
/// Data type
#[validate(custom(function = "validate_dtype_width_bytes"))]
#[builder(default = "2")]
pub dtype_width_bytes: usize,
/// Inner shape format (NHD, HND, or Unknown)
#[builder(default = "InnerShape::Unknown")]
pub inner_shape: InnerShape,
}
impl LayoutConfig {
/// Builder for LayoutConfig
pub fn builder() -> LayoutConfigBuilder {
LayoutConfigBuilder::default()
}
pub fn required_bytes(&self) -> usize {
self.num_blocks
.saturating_mul(self.num_layers)
.saturating_mul(self.outer_dim)
.saturating_mul(self.page_size)
.saturating_mul(self.inner_dim)
.saturating_mul(self.dtype_width_bytes)
}
}
/// The first two dimensions of the tensor, `shape[0]` and `shape[1]`, one of those corresponds to the
/// block dimension, while the other corresponds to the outer dimension.
///
/// The outer dimension is typically:
/// - 1: MLA or K and V stored together,
/// - 2: K and V stored separately,
///
/// The block dimension tell us the number of blocks.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockDimension {
/// The block dimension is the first dimension of the tensor, `[n_blocks, outer_dim, inner_dim]`
BlockIsFirstDim,
/// The block dimension is the second dimension of the tensor, `[outer_dim, n_blocks, inner_dim]`
/// This is a replacement for v1's `outer_contiguous` is true.
BlockIsSecondDim,
}
/// Validation function for Option<usize> to check if it's Some(power_of_2).
pub fn validate_power_of_2(alignment: usize) -> Result<(), ValidationError> {
if !alignment.is_power_of_two() {
// Return validation error if alignment is not a power of 2
return Err(validator::ValidationError::new(
"alignment_must_be_power_of_2",
));
}
// Passes validation if alignment is a power of 2
Ok(())
}
pub fn validate_dtype_width_bytes(dtype_width_bytes: usize) -> Result<(), ValidationError> {
if !dtype_width_bytes.is_power_of_two() || !(2..=8).contains(&dtype_width_bytes) {
return Err(validator::ValidationError::new(
"dtype_width_bytes_must_be_power_of_two_and_less_than_8_bytes",
));
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Fully contiguous layout implementation.
//!
//! This layout stores all blocks in a single contiguous memory allocation
//! with the shape: [num_blocks, num_layers, outer_dim, page_size, inner_dim].
use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{BlockFormat, FullyContiguousDetails, LayoutTypeDetails};
use super::{Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion};
/// Fully contiguous layout where all blocks are in a single allocation.
#[derive(Debug)]
pub struct FullyContiguousLayout {
config: LayoutConfig,
/// Base address of the allocation
base_addr: usize,
/// Stride between blocks in bytes
block_stride: usize,
/// Stride between layers in bytes
layer_stride: usize,
/// Stride between outer dimensions in bytes
outer_stride: usize,
/// Size of each memory region (page) in bytes
region_size: usize,
/// Owned memory region backing this layout
memory: Arc<dyn MemoryRegion>,
/// Format of blocks in memory
block_format: BlockFormat,
}
impl FullyContiguousLayout {
/// Create a new fully contiguous layout.
///
/// # Arguments
/// * `config` - Layout configuration
/// * `memory` - Owned memory region that backs this layout
///
/// # Returns
/// A new FullyContiguousLayout instance
pub fn new(config: LayoutConfig, memory: Arc<dyn MemoryRegion>) -> Result<Self> {
config.validate()?;
let base_addr = memory.addr();
// Calculate strides
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let outer_stride = region_size;
let layer_stride = outer_stride * config.outer_dim;
let block_stride = layer_stride * config.num_layers;
// Validate that the memory region is large enough
let required_size = block_stride * config.num_blocks;
if memory.size() < required_size {
return Err(anyhow!(
"Memory region too small for layout. Required: {} bytes, got: {} bytes",
required_size,
memory.size()
));
}
Ok(Self {
config,
base_addr,
block_stride,
layer_stride,
outer_stride,
region_size,
memory,
block_format: BlockFormat::default(),
})
}
/// Create a new fully contiguous layout with a specific block format.
///
/// # Arguments
/// * `config` - Layout configuration
/// * `memory` - Owned memory region that backs this layout
/// * `block_format` - Format of blocks in memory
///
/// # Returns
/// A new FullyContiguousLayout instance
pub(crate) fn new_with_format(
config: LayoutConfig,
memory: Arc<dyn MemoryRegion>,
block_format: BlockFormat,
) -> Result<Self> {
let mut layout = Self::new(config, memory)?;
layout.block_format = block_format;
Ok(layout)
}
/// Get the block format.
pub fn block_format(&self) -> BlockFormat {
self.block_format
}
/// Calculate the address of a specific memory region.
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
Ok(self.base_addr
+ block_id * self.block_stride
+ layer_id * self.layer_stride
+ outer_id * self.outer_stride)
}
/// Get mutable reference to the memory Arc for NIXL registration.
pub fn memory_arc_mut(&mut self) -> &mut Arc<dyn MemoryRegion> {
&mut self.memory
}
}
impl Layout for FullyContiguousLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
std::slice::from_ref(&self.memory)
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
// Single contiguous allocation
vec![self.block_stride * self.config.num_blocks]
}
fn is_fully_contiguous(&self) -> bool {
true
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: self.block_format,
})
}
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_fully_contiguous_layout_creation() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_bytes = config.required_bytes();
assert_eq!(required_bytes, 10 * 4 * 2 * 16 * 128 * 2);
let memory = MockMemory::new(0x1000, required_bytes);
let layout = FullyContiguousLayout::new(config, memory).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.required_bytes();
let memory = MockMemory::new(0x1000, required_size);
let layout = FullyContiguousLayout::new(config.clone(), memory).unwrap();
// Test accessing specific memory regions
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
// Block 0, Layer 0, Outer 0
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
// Block 0, Layer 0, Outer 1
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
// Block 0, Layer 1, Outer 0
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + 2 * region_size);
assert_eq!(region.size, region_size);
// Block 1, Layer 0, Outer 0
let region = layout.memory_region(1, 0, 0).unwrap();
assert_eq!(
region.addr,
0x1000 + (config.outer_dim * config.num_layers * region_size)
);
assert_eq!(region.size, region_size);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests comparing v1 and v2 layout implementations.
//!
//! These tests validate that the new v2 layout system produces identical
//! memory region **addresses** as the proven v1 implementation.
//!
//! **Note on Size Differences**: V1's `memory_region()` returns `layer_stride` as the
//! size (covering all outer dimensions), while V2 returns `outer_stride` (single page).
//! This is an intentional API difference - V2 provides more granular access.
//! Therefore, these tests only compare addresses, not sizes.
#![cfg(test)]
use anyhow::Result;
use std::{any::Any, sync::Arc};
use crate::block_manager::{
layout::{
BlockDimension, BlockLayout, BlockLayoutConfig, GenericBlockLayout, LayoutConfig,
LayoutType,
tests::{setup_layer_separate_layout, setup_layout},
},
storage::{Storage, tests::NullDeviceStorage},
v2::storage::StorageKind,
};
use super::{
FullyContiguousLayout, LayerSeparateLayout, Layout, LayoutConfig as V2LayoutConfig,
MemoryRegion,
};
// Test constants matching v1 tests
const NUM_BLOCKS: usize = 7;
const NUM_LAYERS: usize = 5;
const OUTER_DIM: usize = 2;
const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 13;
const DTYPE_WIDTH_BYTES: usize = 4;
/// Wrapper to make v1 NullDeviceStorage compatible with v2 MemoryRegion trait.
#[derive(Debug)]
struct V1StorageWrapper {
storage: NullDeviceStorage,
}
impl MemoryRegion for V1StorageWrapper {
fn addr(&self) -> usize {
self.storage.addr() as usize
}
fn size(&self) -> usize {
self.storage.size()
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
}
/// Create v1 layout configuration
fn create_v1_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.alignment(1)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()
.unwrap()
}
/// Create v2 layout configuration (equivalent to v1)
fn create_v2_config() -> V2LayoutConfig {
create_v1_config()
}
#[test]
fn test_v1_v2_fully_contiguous_equivalence() -> Result<()> {
// Create v1 layout
let v1_layout = setup_layout(None)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let required_size =
NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={}",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={}",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify metadata
assert_eq!(v1_layout.num_blocks(), v2_layout.num_blocks());
assert_eq!(v1_layout.num_layers(), v2_layout.num_layers());
assert_eq!(v1_layout.outer_dim(), v2_layout.outer_dim());
assert_eq!(v1_layout.page_size(), v2_layout.page_size());
assert_eq!(v1_layout.inner_dim(), v2_layout.inner_dim());
Ok(())
}
#[test]
fn test_v1_v2_layer_separate_block_contiguous_equivalence() -> Result<()> {
// Create v1 layout (block contiguous = !outer_contiguous)
let v1_layout = setup_layer_separate_layout(None, BlockDimension::BlockIsFirstDim)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let per_layer_size = NUM_BLOCKS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..NUM_LAYERS)
.map(|_| {
Arc::new(V1StorageWrapper {
storage: NullDeviceStorage::new(per_layer_size as u64),
}) as Arc<dyn MemoryRegion>
})
.collect();
let v2_layout = LayerSeparateLayout::new(v2_config, memory, BlockDimension::BlockIsFirstDim)?;
// Verify metadata
assert_eq!(v1_layout.num_blocks(), v2_layout.num_blocks());
assert_eq!(v1_layout.num_layers(), v2_layout.num_layers());
assert_eq!(v1_layout.outer_dim(), v2_layout.outer_dim());
assert_eq!(v1_layout.page_size(), v2_layout.page_size());
assert_eq!(v1_layout.inner_dim(), v2_layout.inner_dim());
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={} (block_contiguous)",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={} (block_contiguous)",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify layout type
assert!(!v2_layout.is_fully_contiguous());
assert_eq!(
v1_layout.layout_type(),
LayoutType::LayerSeparate {
block_dim: BlockDimension::BlockIsFirstDim,
}
);
Ok(())
}
#[test]
fn test_v1_v2_layer_separate_outer_contiguous_equivalence() -> Result<()> {
// Create v1 layout (outer contiguous)
let v1_layout = setup_layer_separate_layout(None, BlockDimension::BlockIsSecondDim)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let per_layer_size = NUM_BLOCKS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..NUM_LAYERS)
.map(|_| {
Arc::new(V1StorageWrapper {
storage: NullDeviceStorage::new(per_layer_size as u64),
}) as Arc<dyn MemoryRegion>
})
.collect();
let v2_layout = LayerSeparateLayout::new(v2_config, memory, BlockDimension::BlockIsSecondDim)?;
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={} (outer_contiguous)",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={} (outer_contiguous)",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify layout type
assert!(!v2_layout.is_fully_contiguous());
assert_eq!(
v1_layout.layout_type(),
LayoutType::LayerSeparate {
block_dim: BlockDimension::BlockIsSecondDim,
}
);
Ok(())
}
#[test]
fn test_v1_v2_stride_calculations() -> Result<()> {
// Test with a specific pattern to verify stride calculations
let _v1_layout = setup_layout(None)?;
let v2_config = create_v2_config();
let required_size =
NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Calculate expected strides
let region_size = PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let outer_stride = region_size;
let layer_stride = outer_stride * OUTER_DIM;
let block_stride = layer_stride * NUM_LAYERS;
// Test stride consistency across blocks
for block_id in 0..NUM_BLOCKS - 1 {
let region_b0 = v2_layout.memory_region(block_id, 0, 0)?;
let region_b1 = v2_layout.memory_region(block_id + 1, 0, 0)?;
assert_eq!(
region_b1.addr - region_b0.addr,
block_stride,
"Block stride mismatch between blocks {} and {}",
block_id,
block_id + 1
);
}
// Test stride consistency across layers
for layer_id in 0..NUM_LAYERS - 1 {
let region_l0 = v2_layout.memory_region(0, layer_id, 0)?;
let region_l1 = v2_layout.memory_region(0, layer_id + 1, 0)?;
assert_eq!(
region_l1.addr - region_l0.addr,
layer_stride,
"Layer stride mismatch between layers {} and {}",
layer_id,
layer_id + 1
);
}
// Test stride consistency across outer dimensions
for outer_id in 0..OUTER_DIM - 1 {
let region_o0 = v2_layout.memory_region(0, 0, outer_id)?;
let region_o1 = v2_layout.memory_region(0, 0, outer_id + 1)?;
assert_eq!(
region_o1.addr - region_o0.addr,
outer_stride,
"Outer stride mismatch between outer dims {} and {}",
outer_id,
outer_id + 1
);
}
Ok(())
}
#[test]
fn test_v1_v2_edge_case_single_block() -> Result<()> {
// Test with minimal configuration: single block
let v1_config = LayoutConfig::builder()
.num_blocks(1)
.num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()
.unwrap();
let v1_layout = crate::block_manager::layout::FullyContiguous::allocate(
v1_config.clone(),
&crate::block_manager::storage::tests::NullDeviceAllocator,
)?;
let v2_config = v1_config.clone();
let required_size = 1 * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare the single block across all layers and outer dims
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(0, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(0, layer_id, outer_id)?;
assert_eq!(v1_region.addr(), v2_region.addr);
assert_eq!(v1_region.size(), v2_region.size);
}
}
Ok(())
}
#[test]
fn test_v1_v2_edge_case_single_layer() -> Result<()> {
// Test with minimal configuration: single layer
let v1_config = LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(1)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()?;
let v1_layout = crate::block_manager::layout::FullyContiguous::allocate(
v1_config.clone(),
&crate::block_manager::storage::tests::NullDeviceAllocator,
)?;
let v2_config = v1_config.clone();
let required_size = NUM_BLOCKS * 1 * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare the single layer across all blocks and outer dims
for block_id in 0..NUM_BLOCKS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, 0, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, 0, outer_id)?;
assert_eq!(v1_region.addr(), v2_region.addr);
assert_eq!(v1_region.size(), v2_region.size);
}
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Layer-separate layout implementation.
//!
//! This layout stores each layer in its own allocation, which is the typical
//! vLLM layout. Each layer can be either block-contiguous or outer-contiguous:
//! - Block-contiguous: [num_blocks, outer_dim, page_size, inner_dim]
//! - Outer-contiguous: [outer_dim, num_blocks, page_size, inner_dim]
use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{LayerSeparateDetails, LayoutTypeDetails};
use super::{
BlockDimension, Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion,
};
/// Layer-separate layout where each layer has its own allocation.
#[derive(Debug)]
pub struct LayerSeparateLayout {
config: LayoutConfig,
/// Base addresses for each layer
layer_base_addrs: Vec<usize>,
/// Whether the outer dimension is contiguous (vs block dimensionl
block_dim: BlockDimension,
/// Stride between blocks in bytes
block_stride: usize,
/// Stride between outer dimensions in bytes
outer_stride: usize,
/// Size of each memory region (page) in bytes
region_size: usize,
/// Owned memory regions backing this layout (one per layer)
memory_regions: Vec<Arc<dyn MemoryRegion>>,
}
impl LayerSeparateLayout {
/// Create a new layer-separate layout.
///
/// # Arguments
/// - `config` - Layout configuration
/// - `memory` - Vector of owned memory regions (one per layer)
/// - `outer_contiguous` - If true, outer dimension is contiguous with the inner dimension, i.e. (num_blocks, outer_dim, ...);
/// if false, block dimension is contiguous with the inner dimension, i.e. (outer_dim, num_blocks, ...).
///
/// # Returns
/// A new LayerSeparateLayout instance
pub fn new(
config: LayoutConfig,
memory: Vec<Arc<dyn MemoryRegion>>,
block_dim: BlockDimension,
) -> Result<Self> {
config.validate()?;
if memory.len() != config.num_layers {
return Err(anyhow!(
"Memory region count ({}) must match num_layers ({})",
memory.len(),
config.num_layers
));
}
// Calculate strides
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let (block_stride, outer_stride) = if block_dim == BlockDimension::BlockIsSecondDim {
// Layout: [outer_dim, num_blocks, page_size, inner_dim]
let block_stride = region_size;
let outer_stride = block_stride * config.num_blocks;
(block_stride, outer_stride)
} else {
// Layout: [num_blocks, outer_dim, page_size, inner_dim]
let outer_stride = region_size;
let block_stride = outer_stride * config.outer_dim;
(block_stride, outer_stride)
};
// Extract base addresses and validate sizes
let mut layer_base_addrs = Vec::with_capacity(config.num_layers);
let required_size = config.num_blocks * config.outer_dim * region_size;
for (i, mem) in memory.iter().enumerate() {
if mem.size() < required_size {
return Err(anyhow!(
"Memory region {} too small for layout. Required: {} bytes, got: {} bytes",
i,
required_size,
mem.size()
));
}
layer_base_addrs.push(mem.addr());
}
Ok(Self {
config,
layer_base_addrs,
block_dim,
block_stride,
outer_stride,
region_size,
memory_regions: memory,
})
}
/// Calculate the address of a specific memory region.
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
let base_addr = self.layer_base_addrs[layer_id];
let offset = block_id * self.block_stride + outer_id * self.outer_stride;
Ok(base_addr + offset)
}
pub fn block_dim(&self) -> BlockDimension {
self.block_dim
}
/// Get mutable reference to the memory regions for NIXL registration.
pub fn memory_regions_mut(&mut self) -> &mut [Arc<dyn MemoryRegion>] {
&mut self.memory_regions
}
}
impl Layout for LayerSeparateLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
&self.memory_regions
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
// One allocation per layer
let per_layer_size = self.config.num_blocks * self.config.outer_dim * self.region_size;
vec![per_layer_size; self.config.num_layers]
}
fn is_fully_contiguous(&self) -> bool {
false
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: self.block_dim,
})
}
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_layer_separate_block_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
assert_eq!(layout.required_allocations().len(), 4);
}
#[test]
fn test_layer_separate_outer_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsSecondDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 2 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..2)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
// Test accessing specific memory regions
let region_size = 16 * 128 * 2;
// Block 0, Layer 0, Outer 0 - should be at layer 0's base address
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
// Block 0, Layer 1, Outer 0 - should be at layer 1's base address
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + per_layer_size);
assert_eq!(region.size, region_size);
// Block 0, Layer 0, Outer 1 - should be offset within layer 0
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Decoupled layout system for block management.
//!
//! This module provides a simplified layout abstraction that:
//! - Maps block IDs to physical memory regions (address + size)
//! - Decouples memory regions from storage type information
//! - Specifies allocation requirements without performing allocation
//! - Uses trait objects for memory ownership
pub(crate) mod builder;
mod config;
mod fully_contiguous;
mod layer_separate;
mod physical;
mod serialize;
mod validation;
#[cfg(test)]
pub(super) mod tests;
// #[cfg(test)]
// mod integration_tests;
pub use builder::{LayoutKind, PhysicalLayoutBuilder};
pub use config::{BlockDimension, LayoutConfig};
pub use fully_contiguous::FullyContiguousLayout;
pub use layer_separate::LayerSeparateLayout;
pub use physical::{NixlMetadata, PhysicalLayout};
pub use serialize::{
BlockFormat, FullyContiguousDetails, LayerSeparateDetails, LayoutDescriptor, LayoutTypeDetails,
};
pub use validation::{TensorFormat, validate_tensor_shapes, validate_tensor_strides};
// mod registration;
// pub use registration::{RegisteredLayout, RegisteredStorageMetadata, RegistrationManager};
use anyhow::Result;
use serde::{Deserialize, Serialize};
pub use crate::block_manager::v2::memory::{MemoryDescriptor, MemoryRegion, OwnedMemoryRegion};
/// Core layout trait for mapping block IDs to memory regions.
///
/// Layouts specify how KV cache blocks are organized in memory without
/// performing allocation themselves. They provide:
/// - Memory region lookup for specific blocks
/// - Allocation requirements for external allocators
/// - Metadata about block organization
pub trait Layout: Send + Sync + std::fmt::Debug {
/// Get the configuration for this layout.
fn config(&self) -> &LayoutConfig;
/// Get the root memory regions backing this layout.
///
/// These regions correspond to the concrete allocations that store the layout's data.
/// Implementations that derive memory procedurally can return an empty slice.
fn memory_regions(&self) -> &[OwnedMemoryRegion];
/// Get memory regions for a specific block_id, layer_id, outer_id.
///
/// Returns a [MemoryRegion] for the continuous region specified by the given block_id,
/// layer_id, outer_id.
///
/// # Arguments
/// * `block_id` - The ID of the block to query (0..num_blocks)
/// * `layer_id` - The ID of the layer to query (0..num_layers)
/// * `outer_id` - The ID of the outer dimension to query (0..outer_dim)
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor>;
/// Get the allocation requirements for this layout.
///
/// Returns a vector of allocation sizes needed to back this layout.
/// For fully contiguous layouts, this will be a single size.
/// For layer-separate layouts, this will contain one size per layer.
///
/// # Returns
/// Vector of allocation sizes in bytes.
fn required_allocations(&self) -> Vec<usize>;
/// Check if this layout uses fully contiguous memory.
///
/// Fully contiguous layouts have all blocks in a single allocation,
/// which enables certain optimizations.
fn is_fully_contiguous(&self) -> bool;
/// Get the total number of blocks in this layout.
fn num_blocks(&self) -> usize;
/// Get the number of layers per block.
fn num_layers(&self) -> usize;
/// Get the outer dimension size.
///
/// In typical KV cache layouts, this is often 2 (for K and V),
/// but can be 1 for architectures like MLA.
fn outer_dim(&self) -> usize;
/// Get the page size (often corresponds to block size in tokens).
fn page_size(&self) -> usize;
/// Get the inner dimension size.
///
/// This is typically the hidden size divided by tensor parallel size.
fn inner_dim(&self) -> usize;
/// Get the data type width in bytes.
fn dtype_width_bytes(&self) -> usize;
/// Get serialization details for this layout type.
///
/// This provides the layout-type-specific information needed to serialize
/// and reconstruct the layout on a remote node.
fn serialization_details(&self) -> serialize::LayoutTypeDetails;
}
/// Inner shape format for tensor layout
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum InnerShape {
/// Unknown shape - fallback when we can't determine the format
Unknown,
/// NHD format: [block_size, num_heads, head_dim]
/// Common for attention layers where N=tokens, H=heads, D=dimension
NHD,
/// HND format: [num_heads, block_size, head_dim]
/// Alternative layout with heads first
HND,
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Physical layout types that combine abstract layouts with storage location metadata.
use super::{
FullyContiguousLayout, LayerSeparateLayout, Layout, MemoryDescriptor,
builder::{PhysicalLayoutBuilder, PhysicalLayoutBuilderDefault},
serialize::{LayoutDescriptor, LayoutTypeDetails},
};
use crate::block_manager::v2::memory::{MemoryRegion, StorageKind};
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::sync::Arc;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
/// Runtime representation of a layout with its physical storage location.
///
/// A `PhysicalLayout` wraps an abstract [`Layout`] with information about where
/// its memory physically resides (GPU, host, disk) and whether it's local or remote.
/// This enables the transfer system to select appropriate copy strategies and build
/// NIXL transfer descriptors.
#[derive(Debug, Clone)]
pub struct PhysicalLayout {
/// The abstract layout defining memory organization
layout: Arc<dyn Layout>,
/// Physical storage location (System, Device, Pinned, Disk)
location: StorageKind,
/// NIXL registration metadata
nixl_metadata: NixlMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NixlMetadata {
agent_name: String,
mem_type: nixl_sys::MemType,
device_id: u64,
}
impl NixlMetadata {
pub fn new(agent_name: String, mem_type: nixl_sys::MemType, device_id: u64) -> Self {
Self {
agent_name,
mem_type,
device_id,
}
}
pub fn agent_name(&self) -> &str {
&self.agent_name
}
pub fn mem_type(&self) -> nixl_sys::MemType {
self.mem_type
}
pub fn device_id(&self) -> u64 {
self.device_id
}
}
impl PhysicalLayout {
/// Create a typed builder that enforces NIXL registration.
pub fn builder(agent: NixlAgent) -> PhysicalLayoutBuilderDefault {
PhysicalLayoutBuilder::new(agent)
}
/// Create a new local physical layout.
///
/// # Arguments
/// * `layout` - The abstract layout to wrap
/// * `location` - Where the layout's memory resides
pub(crate) fn new_local(
layout: Arc<dyn Layout>,
location: StorageKind,
nixl_metadata: NixlMetadata,
) -> Self {
Self {
layout,
location,
nixl_metadata,
}
}
// /// Create a new remote physical layout from a descriptor.
// ///
// /// # Arguments
// /// * `layout` - The abstract layout to wrap
// /// * `location` - Where the layout's memory resides (on remote node)
// /// * `remote_agent` - Name of the NIXL agent on the remote node
// pub fn new_remote(
// layout: Arc<dyn Layout>,
// location: StorageKind,
// remote_agent: String,
// ) -> Self {
// let metadata = NixlMetadata::new(
// remote_agent.clone(),
// location.to_nixl_mem_type(),
// location.device_id(),
// );
// let registrations = vec![RegisteredStorageMetadata::new(
// metadata.agent_name().to_string(),
// location,
// )];
// Self {
// layout,
// location,
// locality: Locality::Remote(remote_agent),
// nixl_metadata: Some(metadata),
// registered: registrations,
// }
// }
/// Get the underlying layout.
pub fn layout(&self) -> &Arc<dyn Layout> {
&self.layout
}
/// Get the storage location.
pub fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata.
pub fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get a memory region with location information.
///
/// # Arguments
/// * `block_id` - Block identifier
/// * `layer_id` - Layer identifier
/// * `outer_id` - Outer dimension identifier
pub fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
self.layout.memory_region(block_id, layer_id, outer_id)
}
/// Serialize this physical layout for transmission to remote nodes.
///
/// This converts the runtime `PhysicalLayout` into a `LayoutDescriptor` that
/// contains all information needed to reconstruct the layout on a remote node,
/// including layout configuration, memory descriptors, NIXL metadata, and
/// layout-type-specific details.
///
/// # Returns
/// A serializable representation of this layout
pub fn to_descriptor(&self) -> Result<LayoutDescriptor> {
// Extract memory descriptors
let memory_descriptors = self
.layout
.memory_regions()
.iter()
.map(|region| MemoryDescriptor {
addr: region.addr(),
size: region.size(),
})
.collect();
// Get layout type details from the layout itself
let layout_type_details = self.layout.serialization_details();
Ok(LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: self.layout.config().clone(),
location: self.location,
nixl_metadata: self.nixl_metadata.clone(),
memory_descriptors,
layout_type_details,
})
}
/// Reconstruct a physical layout from serialized data received from a remote node.
///
/// This creates a new `PhysicalLayout` from a `LayoutDescriptor`. The reconstructed
/// layout will have memory descriptors that point to the remote node's memory,
/// allowing NIXL to build RDMA descriptors for remote access.
///
/// # Arguments
/// * `serialized` - Serialized layout data from a remote node
///
/// # Returns
/// A new `PhysicalLayout` representing the remote layout
///
/// # Note
/// The memory regions in the reconstructed layout are not valid for local access;
/// they represent remote memory addresses and are used to build NIXL transfer descriptors.
pub fn from_descriptor(serialized: LayoutDescriptor) -> Result<Self> {
// Validate version
if serialized.version > LayoutDescriptor::CURRENT_VERSION {
return Err(anyhow!(
"Unsupported serialization version: {}. Maximum supported: {}",
serialized.version,
LayoutDescriptor::CURRENT_VERSION
));
}
// Create remote memory regions from descriptors
let remote_regions: Vec<Arc<dyn MemoryRegion>> = serialized
.memory_descriptors
.iter()
.map(|desc| {
Arc::new(RemoteMemoryDescriptor {
addr: desc.addr,
size: desc.size,
storage_kind: serialized.location,
}) as Arc<dyn MemoryRegion>
})
.collect();
// Reconstruct the layout based on type
let layout: Arc<dyn Layout> = match serialized.layout_type_details {
LayoutTypeDetails::FullyContiguous(details) => {
if remote_regions.len() != 1 {
return Err(anyhow!(
"FullyContiguous layout requires exactly 1 memory region, got {}",
remote_regions.len()
));
}
let layout = FullyContiguousLayout::new_with_format(
serialized.layout_config.clone(),
remote_regions[0].clone(),
details.block_format,
)?;
Arc::new(layout)
}
LayoutTypeDetails::LayerSeparate(details) => {
if remote_regions.len() != serialized.layout_config.num_layers {
return Err(anyhow!(
"LayerSeparate layout requires {} memory regions (one per layer), got {}",
serialized.layout_config.num_layers,
remote_regions.len()
));
}
let layout = LayerSeparateLayout::new(
serialized.layout_config.clone(),
remote_regions,
details.block_dim,
)?;
Arc::new(layout)
}
};
Ok(Self {
layout,
location: serialized.location,
nixl_metadata: serialized.nixl_metadata,
})
}
}
/// A memory region that represents remote memory addresses.
///
/// This type is used when reconstructing layouts from serialized data.
/// The addresses are not valid for local access but can be used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug)]
struct RemoteMemoryDescriptor {
addr: usize,
size: usize,
storage_kind: StorageKind,
}
impl MemoryRegion for RemoteMemoryDescriptor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.storage_kind
}
fn as_any(&self) -> &dyn Any {
self
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Serialization types for physical layouts.
//!
//! This module provides types for serializing and deserializing physical layouts
//! so they can be transmitted to remote nodes and reconstructed there for RDMA operations.
use super::physical::NixlMetadata;
use super::{BlockDimension, LayoutConfig};
use crate::block_manager::v2::memory::{MemoryDescriptor, StorageKind};
use anyhow::Result;
use serde::{Deserialize, Serialize};
/// Format of blocks in a fully contiguous layout.
///
/// This enum describes how the blocks are organized and formatted in memory.
/// Currently only `Operational` is supported, but future variants may include
/// different compression schemes or memory layouts.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockFormat {
/// Standard operational format - blocks are stored in their normal, uncompressed form.
Operational,
}
impl Default for BlockFormat {
fn default() -> Self {
Self::Operational
}
}
/// Details specific to fully contiguous layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullyContiguousDetails {
/// Format of the blocks in memory
pub block_format: BlockFormat,
}
/// Details specific to layer-separate layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerSeparateDetails {
/// Block dimension ordering (block-first or block-second)
pub block_dim: BlockDimension,
}
/// Layout-type-specific details.
///
/// This enum captures the information that differs between layout types
/// and is needed to reconstruct the layout on a remote node.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayoutTypeDetails {
/// Fully contiguous layout details
FullyContiguous(FullyContiguousDetails),
/// Layer-separate layout details
LayerSeparate(LayerSeparateDetails),
}
/// Serializable representation of a physical layout.
///
/// This structure contains all information needed to reconstruct a layout
/// on a remote node, including:
/// - Layout configuration (dimensions, sizes, etc.)
/// - Storage location and NIXL metadata
/// - Memory descriptors for all regions
/// - Layout-type-specific details
///
/// The serialized form can be transmitted over the network and used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayoutDescriptor {
/// Serialization format version (for future compatibility)
pub version: u32,
/// Layout configuration
pub layout_config: LayoutConfig,
/// Storage location
pub location: StorageKind,
/// NIXL metadata from the source node
pub nixl_metadata: NixlMetadata,
/// Memory descriptors for all regions backing this layout
pub memory_descriptors: Vec<MemoryDescriptor>,
/// Layout-type-specific details
pub layout_type_details: LayoutTypeDetails,
}
impl LayoutDescriptor {
/// Current serialization version
pub const CURRENT_VERSION: u32 = 1;
/// Serialize this layout to a JSON string.
///
/// # Returns
/// JSON string representation of the layout
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON: {}", e))
}
/// Serialize this layout to JSON bytes.
///
/// # Returns
/// UTF-8 encoded JSON bytes
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
serde_json::to_vec(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON bytes: {}", e))
}
/// Deserialize a layout from a JSON string.
///
/// # Arguments
/// * `json` - JSON string representation
///
/// # Returns
/// Deserialized layout
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON: {}", e))
}
/// Deserialize a layout from JSON bytes.
///
/// # Arguments
/// * `bytes` - UTF-8 encoded JSON bytes
///
/// # Returns
/// Deserialized layout
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON bytes: {}", e))
}
/// Get the layout configuration.
pub fn layout_config(&self) -> &LayoutConfig {
&self.layout_config
}
/// Get the storage location.
pub fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata from the source node.
pub fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get the memory descriptors.
pub fn memory_descriptors(&self) -> &[MemoryDescriptor] {
&self.memory_descriptors
}
/// Get the layout type details.
pub fn layout_type_details(&self) -> &LayoutTypeDetails {
&self.layout_type_details
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
#[test]
fn test_block_format_default() {
assert_eq!(BlockFormat::default(), BlockFormat::Operational);
}
#[test]
fn test_serialized_layout_json_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), nixl_sys::MemType::Dram, 0),
memory_descriptors: vec![MemoryDescriptor::new(0x1000, 4096)],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
}),
};
// Test to_json/from_json
let json = layout.to_json().unwrap();
let deserialized = LayoutDescriptor::from_json(&json).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.layout_config, layout.layout_config);
assert_eq!(deserialized.location, layout.location);
assert_eq!(
deserialized.nixl_metadata.agent_name(),
layout.nixl_metadata.agent_name()
);
assert_eq!(deserialized.memory_descriptors.len(), 1);
}
#[test]
fn test_serialized_layout_json_bytes_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), nixl_sys::MemType::Vram, 5),
memory_descriptors: vec![
MemoryDescriptor::new(0x1000, 2048),
MemoryDescriptor::new(0x2000, 2048),
],
layout_type_details: LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsFirstDim,
}),
};
// Test to_json_bytes/from_json_bytes
let bytes = layout.to_json_bytes().unwrap();
let deserialized = LayoutDescriptor::from_json_bytes(&bytes).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.nixl_metadata.device_id(), 5);
assert_eq!(deserialized.memory_descriptors.len(), 2);
}
#[test]
fn test_fully_contiguous_details_serialization() {
let details = LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::FullyContiguous(d) => {
assert_eq!(d.block_format, BlockFormat::Operational);
}
_ => panic!("Expected FullyContiguous variant"),
}
}
#[test]
fn test_layer_separate_details_serialization() {
let details = LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsSecondDim,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::LayerSeparate(d) => {
assert_eq!(d.block_dim, BlockDimension::BlockIsSecondDim);
}
_ => panic!("Expected LayerSeparate variant"),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for layout serialization.
//!
//! These tests verify the complete serialization and deserialization flow,
//! ensuring that layouts can be transmitted to remote nodes and reconstructed
//! with all necessary metadata intact.
use crate::block_manager::v2::memory::{
MemoryRegion, NixlDescriptor, OwnedMemoryRegion, StorageKind,
};
use crate::block_manager::v2::physical::layout::physical::PhysicalLayout;
use crate::block_manager::v2::physical::layout::{BlockDimension, LayoutConfig, LayoutDescriptor};
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
use std::any::Any;
use std::sync::Arc;
// Simple mock implementation for testing
#[derive(Debug)]
pub struct MockMemory {
addr: usize,
size: usize,
}
impl MockMemory {
pub fn new(addr: usize, size: usize) -> Arc<Self> {
Arc::new(Self { addr, size })
}
}
impl MemoryRegion for MockMemory {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
}
/// Mock memory region for testing serialization
#[derive(Debug)]
struct TestMemoryRegion {
addr: usize,
size: usize,
kind: StorageKind,
descriptor: NixlDescriptor,
}
impl TestMemoryRegion {
fn new(addr: usize, size: usize, kind: StorageKind) -> Arc<Self> {
Arc::new(Self {
addr,
size,
kind,
descriptor: NixlDescriptor {
addr: addr as u64,
size,
mem_type: nixl_sys::MemType::Dram,
device_id: 0,
},
})
}
}
impl MemoryRegion for TestMemoryRegion {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.kind
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor.clone())
}
}
fn make_test_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
#[test]
fn test_fully_contiguous_layout_serialization_roundtrip() {
let agent = NixlAgent::require_backends("test-fc-serialize", &[])
.expect("failed to create wrapped agent");
let config = make_test_config();
// Calculate required size
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
// Create test memory region
let memory = TestMemoryRegion::new(0x10000, required_size, StorageKind::System);
let regions = vec![memory as OwnedMemoryRegion];
// Build physical layout
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.fully_contiguous()
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize to LayoutDescriptor
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize layout");
// Verify serialized data
assert_eq!(serialized.version, LayoutDescriptor::CURRENT_VERSION);
assert_eq!(serialized.layout_config, config);
assert_eq!(serialized.location, StorageKind::System);
assert_eq!(serialized.memory_descriptors.len(), 1);
assert_eq!(serialized.memory_descriptors[0].addr, 0x10000);
assert_eq!(serialized.memory_descriptors[0].size, required_size);
// Serialize to JSON
let json = serialized.to_json().expect("failed to serialize to JSON");
assert!(json.contains("\"version\":1"));
assert!(json.contains("\"num_blocks\":10"));
// Deserialize from JSON
let deserialized = LayoutDescriptor::from_json(&json).expect("failed to deserialize from JSON");
// Verify deserialized matches original
assert_eq!(deserialized.version, serialized.version);
assert_eq!(deserialized.layout_config, serialized.layout_config);
assert_eq!(deserialized.location, serialized.location);
assert_eq!(
deserialized.memory_descriptors.len(),
serialized.memory_descriptors.len()
);
// Reconstruct layout from serialized data
let reconstructed =
PhysicalLayout::from_descriptor(deserialized).expect("failed to reconstruct layout");
// Verify reconstructed layout has same configuration
assert_eq!(reconstructed.layout().config(), &config);
assert_eq!(reconstructed.location(), StorageKind::System);
assert_eq!(reconstructed.layout().num_blocks(), 10);
assert_eq!(reconstructed.layout().num_layers(), 4);
assert!(reconstructed.layout().is_fully_contiguous());
}
#[test]
fn test_layer_separate_layout_serialization_roundtrip() {
let agent = NixlAgent::require_backends("test-ls-serialize", &[])
.expect("failed to create wrapped agent");
let config = make_test_config();
// Calculate per-layer size
let per_layer_size = config.num_blocks
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
// Create memory regions (one per layer)
let regions: Vec<OwnedMemoryRegion> = (0..config.num_layers)
.map(|i| {
TestMemoryRegion::new(
0x10000 + i * per_layer_size,
per_layer_size,
StorageKind::System,
) as OwnedMemoryRegion
})
.collect();
// Build physical layout
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.layer_separate(BlockDimension::BlockIsFirstDim)
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize to LayoutDescriptor
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize layout");
// Verify serialized data
assert_eq!(serialized.version, LayoutDescriptor::CURRENT_VERSION);
assert_eq!(serialized.layout_config, config);
assert_eq!(serialized.memory_descriptors.len(), 4); // One per layer
// Verify memory descriptors
for (i, desc) in serialized.memory_descriptors.iter().enumerate() {
assert_eq!(desc.addr, 0x10000 + i * per_layer_size);
assert_eq!(desc.size, per_layer_size);
}
// Serialize to JSON bytes
let json_bytes = serialized
.to_json_bytes()
.expect("failed to serialize to JSON bytes");
// Deserialize from JSON bytes
let deserialized = LayoutDescriptor::from_json_bytes(&json_bytes)
.expect("failed to deserialize from JSON bytes");
// Verify deserialized matches original
assert_eq!(deserialized.version, serialized.version);
assert_eq!(deserialized.layout_config, serialized.layout_config);
assert_eq!(
deserialized.memory_descriptors.len(),
serialized.memory_descriptors.len()
);
// Reconstruct layout from serialized data
let reconstructed =
PhysicalLayout::from_descriptor(deserialized).expect("failed to reconstruct layout");
// Verify reconstructed layout has same configuration
assert_eq!(reconstructed.layout().config(), &config);
assert_eq!(reconstructed.location(), StorageKind::System);
assert_eq!(reconstructed.layout().num_blocks(), 10);
assert_eq!(reconstructed.layout().num_layers(), 4);
assert!(!reconstructed.layout().is_fully_contiguous());
}
#[test]
fn test_memory_region_calculation_after_deserialization() {
let agent = NixlAgent::require_backends("test-memory-calc", &[])
.expect("failed to create wrapped agent");
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
let memory = TestMemoryRegion::new(0x1000, required_size, StorageKind::System);
let regions = vec![memory as OwnedMemoryRegion];
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.fully_contiguous()
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize and deserialize
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize");
let reconstructed = PhysicalLayout::from_descriptor(serialized).expect("failed to reconstruct");
// Verify memory region calculations
let region = reconstructed
.memory_region(0, 0, 0)
.expect("failed to get memory region");
assert_eq!(region.addr, 0x1000);
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
assert_eq!(region.size, region_size);
// Test different block/layer/outer indices
let region = reconstructed
.memory_region(1, 1, 1)
.expect("failed to get memory region");
// Address should be: base + block_stride + layer_stride + outer_stride
let layer_stride = config.outer_dim * region_size;
let block_stride = config.num_layers * layer_stride;
let expected_addr = 0x1000 + block_stride + layer_stride + region_size;
assert_eq!(region.addr, expected_addr);
}
#[test]
fn test_version_check_on_deserialization() {
let config = make_test_config();
// Calculate required size for fully contiguous layout
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
let mut serialized = LayoutDescriptor {
version: 999, // Future version
layout_config: config.clone(),
location: StorageKind::System,
nixl_metadata: crate::block_manager::v2::physical::layout::physical::NixlMetadata::new(
"test".to_string(),
nixl_sys::MemType::Dram,
0,
),
memory_descriptors: vec![],
layout_type_details:
crate::block_manager::v2::physical::layout::LayoutTypeDetails::FullyContiguous(
crate::block_manager::v2::physical::layout::FullyContiguousDetails {
block_format:
crate::block_manager::v2::physical::layout::BlockFormat::Operational,
},
),
};
// Should fail with unsupported version
let result = PhysicalLayout::from_descriptor(serialized.clone());
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported serialization version")
);
// Should succeed with supported version
serialized.version = LayoutDescriptor::CURRENT_VERSION;
serialized.memory_descriptors = vec![crate::block_manager::v2::memory::MemoryDescriptor::new(
0x1000,
required_size,
)];
let result = PhysicalLayout::from_descriptor(serialized);
if let Err(ref e) = result {
eprintln!("Error during deserialization: {}", e);
}
assert!(
result.is_ok(),
"Expected successful deserialization, got error: {:?}",
result.err()
);
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tensor validation utilities for layout creation.
use anyhow::{Result, anyhow};
use std::sync::Arc;
use crate::block_manager::v2::memory::TorchTensor;
/// Format of tensor layout (for future TP translation).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorFormat {
/// NHD format: [N, H, D] where N=block_size, H=heads, D=hidden
NHD,
/// HND format: [H, N, D] where H=heads, N=block_size, D=hidden
HND,
/// Unknown or ambiguous format
Unknown,
}
/// Validate tensor strides and detect format.
///
/// This function checks that tensor strides are monotonically decreasing,
/// which ensures tensor-contiguous layout. The stride validation is flexible
/// at the inner dimension boundary to accommodate different layouts.
///
/// Additionally, it attempts to detect whether the layout is NHD or HND format,
/// which is important for future tensor parallel (TP) translation.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The detected tensor format (NHD, HND, or Unknown)
pub fn validate_tensor_strides(tensors: &[Arc<dyn TorchTensor>]) -> Result<TensorFormat> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let mut format = TensorFormat::Unknown;
for tensor in tensors {
let stride = tensor.stride();
let shape = tensor.shape();
if stride.len() < 2 {
return Err(anyhow!(
"Tensor must have at least 2 dimensions, got stride: {:?}",
stride
));
}
// Check monotonic decreasing stride
// Note: We're flexible at the combined inner dimension boundary as per requirements
let mut prev_stride = usize::MAX;
for (i, &current_stride) in stride.iter().enumerate() {
if current_stride > prev_stride {
return Err(anyhow!(
"Tensor strides must be monotonically decreasing (until inner dimension). \
Got stride: {:?} at position {}",
stride,
i
));
}
prev_stride = current_stride;
}
// Attempt to detect NHD vs HND format based on shape and stride patterns
// This is a heuristic and may need refinement based on actual usage
if shape.len() >= 3 {
// If the first dimension stride is smaller than the second, likely HND
// If the first dimension stride is larger than the second, likely NHD
if stride[0] < stride[1] {
format = TensorFormat::HND;
} else if stride[0] > stride[1] {
format = TensorFormat::NHD;
}
}
}
Ok(format)
}
/// Validate that all tensors have consistent shapes.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The common shape shared by all tensors
pub fn validate_tensor_shapes(tensors: &[Arc<dyn TorchTensor>]) -> Result<Vec<usize>> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(anyhow!(
"All tensors must have the same shape. Expected {:?}, got {:?}",
first_shape,
tensor.shape()
));
}
}
Ok(first_shape)
}
#[allow(dead_code)]
pub fn determine_compressed_shape(shape: &[usize]) -> usize {
shape.iter().product()
}
#[cfg(test)]
mod tests {
// Note: These tests would require mock TorchTensor implementations
// which we can add if needed for testing infrastructure
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Layout handle type encoding worker ID and layout ID.
use bincode::{Decode, Encode};
/// Unique handle for a layout combining worker_id and layout_id.
///
/// The handle encodes:
/// - Bits 0-63: worker_id (u64)
/// - Bits 64-79: layout_id (u16)
/// - Bits 80-127: Reserved (48 bits, currently unused)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encode, Decode)]
pub struct LayoutHandle(u128);
impl LayoutHandle {
/// Create a new layout handle from worker_id and layout_id.
///
/// # Arguments
/// * `worker_id` - Unique identifier for the worker (0-63 bits)
/// * `layout_id` - Layout identifier within the worker (64-79 bits)
pub fn new(worker_id: u64, layout_id: u16) -> Self {
let handle = (worker_id as u128) | ((layout_id as u128) << 64);
Self(handle)
}
/// Extract the worker_id from this handle.
pub fn worker_id(&self) -> u64 {
(self.0 & 0xFFFF_FFFF_FFFF_FFFF) as u64
}
/// Extract the layout_id from this handle.
pub fn layout_id(&self) -> u16 {
((self.0 >> 64) & 0xFFFF) as u16
}
/// Get the raw u128 value.
pub fn as_u128(&self) -> u128 {
self.0
}
/// Create a handle from a raw u128 value.
pub fn from_u128(value: u128) -> Self {
Self(value)
}
}
impl std::fmt::Display for LayoutHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LayoutHandle(worker={}, layout={})",
self.worker_id(),
self.layout_id()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handle_encoding() {
let worker_id = 0x1234_5678_9ABC_DEF0u64;
let layout_id = 0x4242u16;
let handle = LayoutHandle::new(worker_id, layout_id);
assert_eq!(handle.worker_id(), worker_id);
assert_eq!(handle.layout_id(), layout_id);
}
#[test]
fn test_handle_roundtrip() {
let handle = LayoutHandle::new(42, 100);
let raw = handle.as_u128();
let restored = LayoutHandle::from_u128(raw);
assert_eq!(handle, restored);
assert_eq!(restored.worker_id(), 42);
assert_eq!(restored.layout_id(), 100);
}
#[test]
fn test_handle_max_values() {
let max_worker = u64::MAX;
let max_layout = u16::MAX;
let handle = LayoutHandle::new(max_worker, max_layout);
assert_eq!(handle.worker_id(), max_worker);
assert_eq!(handle.layout_id(), max_layout);
}
#[test]
fn test_handle_bincode_roundtrip() {
let handle = LayoutHandle::new(999, 42);
let encoded = bincode::encode_to_vec(handle, bincode::config::standard()).unwrap();
let (decoded, _): (LayoutHandle, _) =
bincode::decode_from_slice(&encoded, bincode::config::standard()).unwrap();
assert_eq!(handle, decoded);
}
#[test]
fn test_handle_display() {
let handle = LayoutHandle::new(123, 456);
let display = format!("{}", handle);
assert!(display.contains("123"));
assert!(display.contains("456"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Local layout wrapper with handle and metadata.
use std::ops::Deref;
use super::handle::LayoutHandle;
use crate::block_manager::v2::physical::layout::PhysicalLayout;
/// A local physical layout with an assigned handle.
///
/// This wraps a `PhysicalLayout` that exists on the local worker,
/// associating it with a unique handle that combines the worker_id
/// and a locally-assigned layout_id.
///
/// This type is cheap to clone as `PhysicalLayout` contains `Arc` internally.
#[derive(Debug, Clone)]
pub struct LocalLayout {
handle: LayoutHandle,
layout: PhysicalLayout,
}
#[allow(dead_code)]
impl LocalLayout {
/// Create a new local layout.
///
/// # Arguments
/// * `handle` - Unique handle for this layout
/// * `layout` - The physical layout
pub fn new(handle: LayoutHandle, layout: PhysicalLayout) -> Self {
Self { handle, layout }
}
/// Get the handle for this layout.
pub fn handle(&self) -> LayoutHandle {
self.handle
}
/// Get a reference to the physical layout.
pub fn layout(&self) -> &PhysicalLayout {
&self.layout
}
/// Get the worker_id from the handle.
pub fn worker_id(&self) -> u64 {
self.handle.worker_id()
}
/// Get the layout_id from the handle.
pub fn layout_id(&self) -> u16 {
self.handle.layout_id()
}
/// Consume this local layout and return the physical layout.
pub fn into_layout(self) -> PhysicalLayout {
self.layout
}
}
impl Deref for LocalLayout {
type Target = PhysicalLayout;
fn deref(&self) -> &Self::Target {
&self.layout
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
use crate::block_manager::v2::physical::layout::{LayoutConfig, PhysicalLayout};
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
fn create_test_agent(name: &str) -> NixlAgent {
NixlAgent::require_backends(name, &[]).expect("failed to create wrapped agent")
}
fn make_test_layout() -> PhysicalLayout {
let agent = create_test_agent("test-local");
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
PhysicalLayout::builder(agent)
.with_config(config)
.fully_contiguous()
.allocate_system()
.build()
.unwrap()
}
#[test]
fn test_local_layout_creation() {
let handle = LayoutHandle::new(42, 100);
let layout = make_test_layout();
let local = LocalLayout::new(handle, layout);
assert_eq!(local.handle(), handle);
assert_eq!(local.worker_id(), 42);
assert_eq!(local.layout_id(), 100);
}
#[test]
fn test_local_layout_into_layout() {
let handle = LayoutHandle::new(1, 2);
let layout = make_test_layout();
let local = LocalLayout::new(handle, layout);
let _recovered = local.into_layout();
// Successfully consumed and returned the layout
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Serialization types for exporting/importing layout metadata with NIXL integration.
use super::handle::LayoutHandle;
use crate::block_manager::v2::physical::layout::LayoutDescriptor;
use anyhow::Result;
use bincode::{Decode, Encode};
use bytes::Bytes;
/// Worker identification combining worker_id and NIXL agent name.
#[derive(Debug, Clone, Encode, Decode, PartialEq, Eq)]
pub struct WorkerAddress {
/// Unique identifier for this worker
pub worker_id: u64,
/// NIXL agent name on this worker
pub nixl_agent_name: String,
}
impl WorkerAddress {
/// Create a new worker address.
pub fn new(worker_id: u64, nixl_agent_name: String) -> Self {
Self {
worker_id,
nixl_agent_name,
}
}
}
/// Local layout descriptor with its assigned handle from the TransportManager.
#[derive(Debug, Clone, Encode, Decode)]
pub struct LocalLayoutDescriptor {
/// Unique handle for this layout
pub handle: LayoutHandle,
/// Serialized layout data (uses Serde, bridged via bincode)
#[bincode(with_serde)]
pub layout: LayoutDescriptor,
}
impl LocalLayoutDescriptor {
/// Create a new serialized layout with handle.
pub fn new(handle: LayoutHandle, layout: LayoutDescriptor) -> Self {
Self { handle, layout }
}
}
/// The set of [`LocalLayoutDescriptor`] that are RDMA enabled. This object packages the detail
/// about the layouts and the NIXL RDMA metadata required to reconstruct the layouts and access
/// the memory via NIXL RDMA.
#[derive(Debug, Encode, Decode)]
pub struct RdmaLayoutDescriptors {
/// Worker identification
pub worker_address: WorkerAddress,
/// Exported NIXL metadata from nixl_sys::Agent::get_local_md()
pub nixl_metadata: Vec<u8>,
/// Serialized layouts (handle + layout data)
pub layouts: Vec<LocalLayoutDescriptor>,
}
/// Managed memory metadata package for export/import.
///
/// This is the wire format for transmitting layout metadata between workers.
/// It contains everything needed to reconstruct remote layouts and load their
/// NIXL registration data.
pub struct SerializedLayout(Bytes);
impl SerializedLayout {
/// Pack metadata into a serialized form.
///
/// # Arguments
/// * `worker_address` - Worker identification
/// * `nixl_metadata` - NIXL metadata blob from get_local_md()
/// * `layouts` - Vector of layouts with handles to export
///
/// # Returns
/// Packed metadata ready for transmission
pub fn pack(
worker_address: WorkerAddress,
nixl_metadata: Vec<u8>,
layouts: Vec<LocalLayoutDescriptor>,
) -> Result<Self> {
let inner = RdmaLayoutDescriptors {
worker_address,
nixl_metadata,
layouts,
};
let bytes = bincode::encode_to_vec(&inner, bincode::config::standard())
.map_err(|e| anyhow::anyhow!("failed to encode managed memory metadata: {}", e))?;
Ok(Self(Bytes::from(bytes)))
}
/// Unpack metadata from serialized form.
///
/// # Returns
/// Unpacked metadata structure
pub fn unpack(&self) -> Result<RdmaLayoutDescriptors> {
let (inner, _) = bincode::decode_from_slice(&self.0, bincode::config::standard())
.map_err(|e| anyhow::anyhow!("failed to decode managed memory metadata: {}", e))?;
Ok(inner)
}
/// Get the raw bytes.
pub fn as_bytes(&self) -> &Bytes {
&self.0
}
/// Create from raw bytes.
pub fn from_bytes(bytes: Bytes) -> Self {
Self(bytes)
}
/// Get the size in bytes.
pub fn len(&self) -> usize {
self.0.len()
}
/// Check if empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl std::fmt::Debug for SerializedLayout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SerializedLayout")
.field("size_bytes", &self.len())
.finish()
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
use crate::block_manager::v2::memory::{MemoryDescriptor, StorageKind};
use crate::block_manager::v2::physical::layout::{
BlockFormat, FullyContiguousDetails, LayoutConfig, LayoutDescriptor, LayoutTypeDetails,
NixlMetadata,
};
fn make_test_serialized_layout() -> LayoutDescriptor {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
LayoutDescriptor {
version: 1,
layout_config: config,
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test".to_string(), nixl_sys::MemType::Dram, 0),
memory_descriptors: vec![MemoryDescriptor::new(0x1000, 4096)],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
}),
}
}
#[test]
fn test_worker_address() {
let addr = WorkerAddress::new(42, "test_agent".to_string());
assert_eq!(addr.worker_id, 42);
assert_eq!(addr.nixl_agent_name, "test_agent");
}
#[test]
fn test_serialized_layout_with_handle() {
let handle = LayoutHandle::new(1, 2);
let layout = make_test_serialized_layout();
let with_handle = LocalLayoutDescriptor::new(handle, layout);
assert_eq!(with_handle.handle, handle);
}
#[test]
fn test_metadata_pack_unpack() {
let worker_address = WorkerAddress::new(100, "worker_100".to_string());
let nixl_metadata = vec![1, 2, 3, 4, 5];
let layouts = vec![LocalLayoutDescriptor::new(
LayoutHandle::new(100, 1),
make_test_serialized_layout(),
)];
let packed =
SerializedLayout::pack(worker_address.clone(), nixl_metadata.clone(), layouts).unwrap();
assert!(!packed.is_empty());
assert!(!packed.is_empty());
let unpacked = packed.unpack().unwrap();
assert_eq!(unpacked.worker_address, worker_address);
assert_eq!(unpacked.nixl_metadata, nixl_metadata);
assert_eq!(unpacked.layouts.len(), 1);
assert_eq!(unpacked.layouts[0].handle.worker_id(), 100);
assert_eq!(unpacked.layouts[0].handle.layout_id(), 1);
}
#[test]
fn test_metadata_multiple_layouts() {
let worker_address = WorkerAddress::new(200, "worker_200".to_string());
let nixl_metadata = vec![10, 20, 30];
let layouts = vec![
LocalLayoutDescriptor::new(LayoutHandle::new(200, 1), make_test_serialized_layout()),
LocalLayoutDescriptor::new(LayoutHandle::new(200, 2), make_test_serialized_layout()),
LocalLayoutDescriptor::new(LayoutHandle::new(200, 3), make_test_serialized_layout()),
];
let packed =
SerializedLayout::pack(worker_address, nixl_metadata, layouts.clone()).unwrap();
let unpacked = packed.unpack().unwrap();
assert_eq!(unpacked.layouts.len(), 3);
for (i, layout) in unpacked.layouts.iter().enumerate() {
assert_eq!(layout.handle.worker_id(), 200);
assert_eq!(layout.handle.layout_id(), (i + 1) as u16);
}
}
#[test]
fn test_metadata_from_bytes() {
let worker_address = WorkerAddress::new(42, "test".to_string());
let nixl_metadata = vec![1, 2, 3];
let layouts = vec![];
let packed = SerializedLayout::pack(worker_address, nixl_metadata, layouts).unwrap();
let bytes = packed.as_bytes().clone();
let restored = SerializedLayout::from_bytes(bytes);
let unpacked = restored.unpack().unwrap();
assert_eq!(unpacked.worker_address.worker_id, 42);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport manager for local and remote physical layouts with transfer execution.
mod handle;
mod local;
mod metadata;
mod remote;
pub use handle::LayoutHandle;
pub use metadata::{SerializedLayout, WorkerAddress};
pub(crate) use local::LocalLayout;
pub(crate) use metadata::LocalLayoutDescriptor;
pub(crate) use remote::RemoteLayout;
use crate::block_manager::v2::memory::StorageKind;
use crate::block_manager::v2::physical::layout::PhysicalLayout;
use crate::block_manager::v2::physical::transfer::TransferContext;
use crate::block_manager::v2::physical::transfer::context::TransferCompleteNotification;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
use crate::block_manager::v2::physical::transfer::options::TransferOptions;
use anyhow::{Result, anyhow, bail};
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Arc, RwLock};
/// Public entry point for layout and transfer management.
///
/// TransportManager combines layout registration/metadata management with
/// transfer execution capabilities, providing a unified API for:
/// - Registering local layouts and obtaining handles
/// - Exporting/importing layout metadata for remote workers
/// - Executing transfers between layouts using handles
/// - Managing CUDA, NIXL, and other execution resources
#[derive(Clone)]
pub struct TransportManager {
registry: Arc<RwLock<LayoutRegistry>>,
context: Arc<TransferContext>,
}
impl TransportManager {
/// Create a new TransportManager builder.
///
/// The builder configures the worker ID, NIXL agent, CUDA device,
/// and other execution parameters before creating the manager.
///
/// # Example
/// ```ignore
/// let manager = TransportManager::builder()
/// .worker_id(0) // NIXL agent name defaults to "worker-0"
/// .nixl_backend("ucx") // Optional: defaults to UCX from env
/// .cuda_device_id(0)
/// .build()?;
///
/// // Or with custom agent name:
/// let manager = TransportManager::builder()
/// .worker_id(0)
/// .nixl_agent_name("custom-agent")
/// .build()?;
/// ```
pub fn builder() -> crate::block_manager::v2::physical::transfer::context::TransferConfigBuilder
{
TransferContext::builder()
}
/// Create a TransportManager from a built TransferContext.
///
/// This is used internally by the builder to wrap the context
/// and create the associated registry.
pub(crate) fn from_context(context: TransferContext) -> Self {
let worker_id = context.worker_id();
let nixl_agent = context.nixl_agent().clone();
let registry = Arc::new(RwLock::new(LayoutRegistry::new(nixl_agent, worker_id)));
Self {
registry,
context: Arc::new(context),
}
}
// ===== Layout Registration and Metadata Management =====
/// Register a local physical layout and return a unique handle.
///
/// This registers the layout with the embedded memory manager, assigning
/// it a unique handle that can be used for handle-based transfers.
///
/// # Arguments
/// * `layout` - Physical layout to register
///
/// # Returns
/// Unique handle for the registered layout
///
/// # Errors
/// Returns an error if layout IDs are exhausted (u16::MAX reached)
pub fn register_layout(&self, layout: PhysicalLayout) -> Result<LayoutHandle> {
self.registry.write().unwrap().register_local(layout)
}
/// Export layout metadata for transmission to remote workers.
///
/// This exports all registered local layouts along with NIXL metadata
/// needed for remote memory registration.
///
/// # Returns
/// Packed metadata ready for transmission to remote workers
pub fn export_metadata(&self) -> Result<SerializedLayout> {
self.registry.read().unwrap().export_metadata()
}
/// Import remote layout metadata.
///
/// This loads NIXL metadata and reconstructs physical layouts from a remote
/// worker's exported metadata.
///
/// # Arguments
/// * `metadata` - Packed metadata from remote worker
///
/// # Returns
/// Vector of handles for the imported remote layouts
///
/// # Errors
/// Returns an error if the remote worker was already loaded or if metadata
/// loading/reconstruction fails
pub fn import_metadata(&self, metadata: SerializedLayout) -> Result<Vec<LayoutHandle>> {
self.registry.write().unwrap().import_metadata(metadata)
}
// ===== Handle-Based Transfer API =====
/// Transfer complete blocks between layouts using handles.
///
/// This function copies entire blocks (all layers and outer dimensions) between
/// the source and destination layouts identified by their handles. The transfer
/// strategy (memcpy, CUDA, NIXL) is automatically selected based on storage locations.
///
/// The lock on the registry is held only briefly during layout lookup,
/// then released before executing the actual transfer.
///
/// # Arguments
/// * `src_handle` - Handle to source layout
/// * `src_blocks` - Source block IDs to transfer
/// * `dst_handle` - Handle to destination layout
/// * `dst_blocks` - Destination block IDs to transfer
///
/// # Returns
/// A notification handle that can be awaited for transfer completion
///
/// # Errors
/// Returns an error if:
/// - Either handle is invalid
/// - Block IDs are out of bounds
/// - Transfer execution fails
pub fn execute_transfer(
&self,
src_handle: LayoutHandle,
src_blocks: &[usize],
dst_handle: LayoutHandle,
dst_blocks: &[usize],
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
// Clone layouts inside the lock, then drop lock before transfer
let (src_layout, dst_layout) = {
let registry = self.registry.read().unwrap();
let src = registry
.get_layout(src_handle)
.ok_or_else(|| anyhow!("invalid source handle: {}", src_handle))?
.clone(); // Cheap: just Arc refcount bump
let dst = registry
.get_layout(dst_handle)
.ok_or_else(|| anyhow!("invalid destination handle: {}", dst_handle))?
.clone();
(src, dst)
}; // Lock released here
// Execute transfer with no lock held
super::transfer::executor::execute_transfer(
&src_layout,
&dst_layout,
src_blocks,
dst_blocks,
options,
&self.context,
)
}
// ===== Query Methods =====
/// Get the worker ID for this manager.
pub fn worker_id(&self) -> u64 {
self.context.worker_id()
}
/// Get handles for all locally registered layouts.
pub fn get_local_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().local_handles()
}
/// Get handles for all imported remote layouts.
pub fn get_remote_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().remote_handles()
}
// ===== Internal Methods for Testing =====
/// Get the internal transfer context (for testing only).
pub fn context(&self) -> &Arc<TransferContext> {
&self.context
}
/// Get the H2D stream (for testing only).
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn h2d_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.h2d_stream()
}
/// Get the D2H stream (for testing only).
#[cfg(all(test, feature = "testing-cuda"))]
#[allow(dead_code)]
pub(crate) fn d2h_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.d2h_stream()
}
/// Get the CUDA context (for testing only).
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn cuda_context(&self) -> &std::sync::Arc<cudarc::driver::CudaContext> {
self.context.cuda_context()
}
/// Register a CUDA event for completion (for testing only).
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn register_cuda_event(
&self,
event: cudarc::driver::CudaEvent,
) -> TransferCompleteNotification {
self.context.register_cuda_event(event)
}
}
/// Internal registry for local and remote physical layouts with NIXL integration.
///
/// The LayoutRegistry handles:
/// - Registering local layouts with unique handles
/// - Exporting local layout metadata for remote access
/// - Importing remote layout metadata and reconstructing layouts
/// - Managing NIXL metadata for RDMA operations
#[derive(Debug)]
pub(crate) struct LayoutRegistry {
/// NIXL agent for memory registration
nixl_agent: NixlAgent,
/// Worker ID for this manager
worker_id: u64,
/// Next layout ID to assign (monotonically increasing)
next_layout_id: AtomicU16,
/// Local layouts registered on this worker
local_layouts: HashMap<LayoutHandle, LocalLayout>,
/// Remote layouts imported from other workers
remote_layouts: HashMap<LayoutHandle, RemoteLayout>,
/// Set of loaded remote workers (agent_name, worker_id) to prevent duplicates
loaded_remotes: HashSet<(String, u64)>,
}
#[expect(dead_code)]
impl LayoutRegistry {
/// Create a new layout manager.
///
/// # Arguments
/// * `nixl_agent` - NIXL agent for memory registration
/// * `worker_id` - Unique identifier for this worker
pub(crate) fn new(nixl_agent: NixlAgent, worker_id: u64) -> Self {
Self {
nixl_agent,
worker_id,
next_layout_id: AtomicU16::new(0),
local_layouts: HashMap::new(),
remote_layouts: HashMap::new(),
loaded_remotes: HashSet::new(),
}
}
/// Register a local physical layout.
///
/// # Arguments
/// * `layout` - Physical layout to register
///
/// # Returns
/// Unique handle for the registered layout
///
/// # Errors
/// Returns an error if layout IDs are exhausted (u16::MAX reached)
pub(crate) fn register_local(&mut self, layout: PhysicalLayout) -> Result<LayoutHandle> {
// Get next layout ID
let layout_id = self.next_layout_id.fetch_add(1, Ordering::SeqCst);
if layout_id == u16::MAX {
bail!("Layout ID overflow: maximum number of layouts (65535) reached");
}
// Create handle
let handle = LayoutHandle::new(self.worker_id, layout_id);
// Wrap in LocalLayout
let local_layout = LocalLayout::new(handle, layout);
// Store
self.local_layouts.insert(handle, local_layout);
Ok(handle)
}
/// Export local layout metadata for transmission to remote workers.
///
/// This exports:
/// - NIXL agent metadata for remote memory registration
/// - All host and device layouts (disk layouts are excluded)
/// - Worker address information
///
/// # Returns
/// Packed metadata ready for transmission
pub(crate) fn export_metadata(&self) -> Result<SerializedLayout> {
// Get NIXL metadata from agent
let nixl_metadata = self
.nixl_agent
.get_local_md()
.map_err(|e| anyhow!("failed to get NIXL local metadata: {:?}", e))?;
// Create worker address
let worker_address = WorkerAddress::new(self.worker_id, self.nixl_agent.name().to_string());
// Filter and serialize layouts (only host and device, skip disk)
let mut serialized_layouts = Vec::new();
for (handle, local_layout) in &self.local_layouts {
let location = local_layout.layout().location();
// Only export host and device layouts
if matches!(
location,
StorageKind::System | StorageKind::Device(_) | StorageKind::Pinned
) {
let serialized = local_layout
.layout()
.to_descriptor()
.map_err(|e| anyhow!("failed to serialize layout {}: {}", handle, e))?;
serialized_layouts.push(LocalLayoutDescriptor::new(*handle, serialized));
}
}
// Pack into managed metadata
SerializedLayout::pack(worker_address, nixl_metadata, serialized_layouts)
}
/// Import remote layout metadata.
///
/// This:
/// - Validates the remote worker hasn't been loaded already
/// - Loads NIXL metadata into the agent
/// - Reconstructs physical layouts from serialized data
/// - Stores them as remote layouts
///
/// # Arguments
/// * `metadata` - Packed metadata from remote worker
///
/// # Returns
/// Vector of handles for the imported layouts
///
/// # Errors
/// Returns an error if:
/// - The remote worker was already loaded
/// - NIXL metadata loading fails
/// - Agent name mismatch after loading
/// - Layout reconstruction fails
pub(crate) fn import_metadata(
&mut self,
metadata: SerializedLayout,
) -> Result<Vec<LayoutHandle>> {
// Unpack metadata
let inner = metadata.unpack()?;
// Validate not already loaded
let remote_key = (
inner.worker_address.nixl_agent_name.clone(),
inner.worker_address.worker_id,
);
if self.loaded_remotes.contains(&remote_key) {
bail!(
"Remote worker already loaded: {} (worker_id={})",
remote_key.0,
remote_key.1
);
}
// Load NIXL metadata
let returned_agent_name = self
.nixl_agent
.load_remote_md(&inner.nixl_metadata)
.map_err(|e| anyhow!("failed to load remote NIXL metadata: {:?}", e))?;
// Verify agent name matches
if returned_agent_name != inner.worker_address.nixl_agent_name {
bail!(
"Agent name mismatch: expected '{}', got '{}'",
inner.worker_address.nixl_agent_name,
returned_agent_name
);
}
// Reconstruct layouts
let mut imported_handles = Vec::new();
for serialized_with_handle in inner.layouts {
let handle = serialized_with_handle.handle;
let layout = PhysicalLayout::from_descriptor(serialized_with_handle.layout)
.map_err(|e| anyhow!("failed to reconstruct layout {}: {}", handle, e))?;
let remote_layout = RemoteLayout::new(handle, layout);
self.remote_layouts.insert(handle, remote_layout);
imported_handles.push(handle);
}
// Mark remote as loaded
self.loaded_remotes.insert(remote_key);
Ok(imported_handles)
}
/// Get a local layout by handle.
pub(crate) fn get_local(&self, handle: LayoutHandle) -> Option<&LocalLayout> {
self.local_layouts.get(&handle)
}
/// Get a remote layout by handle.
pub(crate) fn get_remote(&self, handle: LayoutHandle) -> Option<&RemoteLayout> {
self.remote_layouts.get(&handle)
}
/// Get a layout by handle (either local or remote).
///
/// # Returns
/// Returns a reference to the PhysicalLayout if found
pub(crate) fn get_layout(&self, handle: LayoutHandle) -> Option<&PhysicalLayout> {
self.local_layouts
.get(&handle)
.map(|l| l.layout())
.or_else(|| self.remote_layouts.get(&handle).map(|r| r.layout()))
}
/// Check if a handle refers to a local layout.
pub(crate) fn is_local(&self, handle: LayoutHandle) -> bool {
self.local_layouts.contains_key(&handle)
}
/// Check if a handle refers to a remote layout.
pub(crate) fn is_remote(&self, handle: LayoutHandle) -> bool {
self.remote_layouts.contains_key(&handle)
}
/// Get the number of local layouts.
pub(crate) fn local_count(&self) -> usize {
self.local_layouts.len()
}
/// Get the number of remote layouts.
pub(crate) fn remote_count(&self) -> usize {
self.remote_layouts.len()
}
/// Get the worker ID for this manager.
pub(crate) fn worker_id(&self) -> u64 {
self.worker_id
}
/// Get all local layout handles.
pub(crate) fn local_handles(&self) -> Vec<LayoutHandle> {
self.local_layouts.keys().copied().collect()
}
/// Get all remote layout handles.
pub(crate) fn remote_handles(&self) -> Vec<LayoutHandle> {
self.remote_layouts.keys().copied().collect()
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
use crate::block_manager::v2::physical::layout::LayoutConfig;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
fn make_test_agent(name: &str) -> NixlAgent {
NixlAgent::require_backends(name, &[]).expect("failed to create wrapped agent")
}
fn make_test_layout(agent: &NixlAgent) -> PhysicalLayout {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_system()
.build()
.unwrap()
}
#[test]
fn test_manager_creation() {
let agent = make_test_agent("test-manager");
let manager = LayoutRegistry::new(agent, 42);
assert_eq!(manager.worker_id(), 42);
assert_eq!(manager.local_count(), 0);
assert_eq!(manager.remote_count(), 0);
}
#[test]
fn test_register_local() {
let agent = make_test_agent("test-register");
let mut manager = LayoutRegistry::new(agent.clone(), 100);
let layout = make_test_layout(&agent);
let handle = manager.register_local(layout).unwrap();
assert_eq!(handle.worker_id(), 100);
assert_eq!(handle.layout_id(), 0);
assert_eq!(manager.local_count(), 1);
assert!(manager.is_local(handle));
assert!(!manager.is_remote(handle));
}
#[test]
fn test_register_multiple_locals() {
let agent = make_test_agent("test-multiple");
let mut manager = LayoutRegistry::new(agent.clone(), 1);
let handle1 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle3 = manager.register_local(make_test_layout(&agent)).unwrap();
assert_eq!(handle1.layout_id(), 0);
assert_eq!(handle2.layout_id(), 1);
assert_eq!(handle3.layout_id(), 2);
assert_eq!(manager.local_count(), 3);
}
#[test]
#[ignore] // Requires actual NIXL memory registration
fn test_export_import_roundtrip() {
// Create source manager and register layouts
let source_agent = make_test_agent("source");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 1);
let handle1 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let handle2 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
// Export metadata
let metadata = source_manager.export_metadata().unwrap();
assert!(!metadata.is_empty());
// Create destination manager and import
let dest_agent = make_test_agent("dest");
let mut dest_manager = LayoutRegistry::new(dest_agent, 2);
let imported_handles = dest_manager.import_metadata(metadata).unwrap();
// Verify
assert_eq!(imported_handles.len(), 2);
assert_eq!(dest_manager.remote_count(), 2);
assert!(dest_manager.is_remote(handle1));
assert!(dest_manager.is_remote(handle2));
// Can get layouts
assert!(dest_manager.get_remote(handle1).is_some());
assert!(dest_manager.get_remote(handle2).is_some());
assert!(dest_manager.get_layout(handle1).is_some());
}
#[test]
#[ignore] // Requires actual NIXL memory registration
fn test_import_duplicate_remote_fails() {
let source_agent = make_test_agent("source2");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 10);
source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let metadata = source_manager.export_metadata().unwrap();
let dest_agent = make_test_agent("dest2");
let mut dest_manager = LayoutRegistry::new(dest_agent, 20);
// First import succeeds
let metadata_clone = SerializedLayout::from_bytes(metadata.as_bytes().clone());
dest_manager.import_metadata(metadata).unwrap();
// Second import should fail
let result = dest_manager.import_metadata(metadata_clone);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already loaded"));
}
#[test]
fn test_get_layout_handles() {
let agent = make_test_agent("test-handles");
let mut manager = LayoutRegistry::new(agent.clone(), 5);
let h1 = manager.register_local(make_test_layout(&agent)).unwrap();
let h2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handles = manager.local_handles();
assert_eq!(handles.len(), 2);
assert!(handles.contains(&h1));
assert!(handles.contains(&h2));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Remote layout wrapper reconstructed from imported metadata.
use super::handle::LayoutHandle;
use crate::block_manager::v2::physical::layout::PhysicalLayout;
/// A remote physical layout reconstructed from imported metadata.
///
/// This wraps a `PhysicalLayout` that was deserialized from another worker's
/// exported metadata. The layout's memory regions point to addresses on the
/// remote worker and are used for building NIXL RDMA transfer descriptors.
///
/// This type is cheap to clone as `PhysicalLayout` contains `Arc` internally.
#[derive(Debug, Clone)]
pub struct RemoteLayout {
handle: LayoutHandle,
layout: PhysicalLayout,
}
#[allow(dead_code)]
impl RemoteLayout {
/// Create a new remote layout.
///
/// # Arguments
/// * `handle` - Unique handle for this layout (from remote worker)
/// * `layout` - The reconstructed physical layout
pub fn new(handle: LayoutHandle, layout: PhysicalLayout) -> Self {
Self { handle, layout }
}
/// Get the handle for this layout.
pub fn handle(&self) -> LayoutHandle {
self.handle
}
/// Get a reference to the physical layout.
pub fn layout(&self) -> &PhysicalLayout {
&self.layout
}
/// Get the worker_id from the handle (identifies the remote worker).
pub fn worker_id(&self) -> u64 {
self.handle.worker_id()
}
/// Get the layout_id from the handle.
pub fn layout_id(&self) -> u16 {
self.handle.layout_id()
}
/// Consume this remote layout and return the physical layout.
pub fn into_layout(self) -> PhysicalLayout {
self.layout
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
use crate::block_manager::v2::physical::layout::{
LayoutConfig, LayoutDescriptor, PhysicalLayout,
};
fn make_serialized_layout() -> LayoutDescriptor {
use crate::block_manager::v2::memory::{MemoryDescriptor, StorageKind};
use crate::block_manager::v2::physical::layout::{
BlockFormat, FullyContiguousDetails, LayoutTypeDetails, NixlMetadata,
};
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
LayoutDescriptor {
version: 1,
layout_config: config,
location: StorageKind::System,
nixl_metadata: NixlMetadata::new(
"remote_agent".to_string(),
nixl_sys::MemType::Dram,
0,
),
memory_descriptors: vec![MemoryDescriptor::new(0x1000, required_size)],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
}),
}
}
#[test]
fn test_remote_layout_creation() {
let handle = LayoutHandle::new(999, 42);
let serialized = make_serialized_layout();
let layout = PhysicalLayout::from_descriptor(serialized).unwrap();
let remote = RemoteLayout::new(handle, layout);
assert_eq!(remote.handle(), handle);
assert_eq!(remote.worker_id(), 999);
assert_eq!(remote.layout_id(), 42);
}
#[test]
fn test_remote_layout_into_layout() {
let handle = LayoutHandle::new(100, 200);
let serialized = make_serialized_layout();
let layout = PhysicalLayout::from_descriptor(serialized).unwrap();
let remote = RemoteLayout::new(handle, layout);
let _recovered = remote.into_layout();
// Successfully consumed and returned the layout
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod layout;
pub mod manager;
pub mod transfer;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer capability flags for controlling direct path enablement.
//!
//! By default, the transfer system uses a conservative staging policy where:
//! - Device can only transfer to/from Host
//! - Disk can only transfer to/from Host
//! - Host can transfer to Device, Disk, or Remote
//! - Device ↔ Device is allowed (native CUDA)
//!
//! These capability flags enable optional direct paths that bypass host staging.
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use crate::block_manager::v2::physical::{
layout::LayoutConfig,
transfer::{
PhysicalLayout, TransferOptions, TransportManager, executor::execute_transfer,
nixl_agent::NixlAgent,
},
};
/// Transfer capability flags controlling which direct paths are enabled.
///
/// # Default Policy (Conservative)
///
/// With all flags disabled (default), the system uses host staging:
/// - **Device → Remote**: Device → Host → Remote (2 hops)
/// - **Disk → Remote**: Disk → Host → Remote (2 hops)
/// - **Device ↔ Disk**: Device → Host → Disk (2 hops)
///
/// # Optional Direct Paths
///
/// - `allow_gds`: Enables GPU Direct Storage (Disk ↔ Device without host)
/// - `allow_gpu_rdma`: Enables GPU RDMA (Device → Remote without host)
///
/// # Example
///
/// ```
/// # use dynamo_kvbm::v2::physical::transfer::TransferCapabilities;
/// // Default conservative policy
/// let caps = TransferCapabilities::default();
/// assert!(!caps.allow_gds);
/// assert!(!caps.allow_gpu_rdma);
///
/// // Enable GDS for high-performance disk I/O
/// let caps = TransferCapabilities::default().with_gds(true);
/// ```
static GDS_SUPPORTED: OnceLock<bool> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct TransferCapabilities {
/// Enable GPU Direct Storage (Disk ↔ Device without host staging).
///
/// When enabled:
/// - Disk → Device: Direct transfer (requires GDS support)
/// - Device → Disk: Direct transfer (requires GDS support)
///
/// When disabled (default):
/// - Disk → Device: Disk → Host → Device (2 hops)
/// - Device → Disk: Device → Host → Disk (2 hops)
pub allow_gds: bool,
/// Enable GPU RDMA (Device → Remote without host staging).
///
/// When enabled:
/// - Device → Remote: Direct NIXL transfer
///
/// When disabled (default):
/// - Device → Remote: Device → Host → Remote (2 hops)
///
/// Note: This only affects Device → Remote. Host → Remote is always direct.
pub allow_gpu_rdma: bool,
}
impl TransferCapabilities {
/// Create capabilities with default conservative policy (all direct paths disabled).
pub fn new() -> Self {
Self::default()
}
/// Create capabilities with all direct paths enabled (high performance mode).
pub fn all_enabled() -> Self {
Self {
allow_gds: true,
allow_gpu_rdma: true,
}
}
/// Set the GDS (GPU Direct Storage) capability.
pub fn with_gds(mut self, enabled: bool) -> Self {
self.allow_gds = enabled;
self
}
fn test_gds_transfer(&self) -> anyhow::Result<()> {
let agent = NixlAgent::require_backends("agent", &["GDS_MT"])?;
// Try a little test transfer and see if it works.
let config = LayoutConfig::builder()
.num_blocks(1)
.num_layers(1)
.outer_dim(1)
.page_size(1)
.inner_dim(4096)
.build()?;
let src = PhysicalLayout::builder(agent.clone())
.with_config(config.clone())
.fully_contiguous()
.allocate_device(0)
.build()?;
let dst = PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_disk(None)
.build()?;
let src_blocks = vec![0];
let dst_blocks = vec![0];
let ctx = TransportManager::builder()
.worker_id(0)
.nixl_agent(agent)
.cuda_device_id(0)
.build()?;
execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptions::default(),
ctx.context(),
)?;
Ok(())
}
pub fn with_gds_if_supported(mut self) -> Self {
self.allow_gds = *GDS_SUPPORTED.get_or_init(|| self.test_gds_transfer().is_ok());
self
}
/// Set the GPU RDMA capability.
pub fn with_gpu_rdma(mut self, enabled: bool) -> Self {
self.allow_gpu_rdma = enabled;
self
}
/// Check if a direct path from Device to Disk is allowed.
pub fn allows_device_disk_direct(&self) -> bool {
self.allow_gds
}
/// Check if a direct path from Device to Remote is allowed.
pub fn allows_device_remote_direct(&self) -> bool {
self.allow_gpu_rdma
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_capabilities() {
let caps = TransferCapabilities::default();
assert!(!caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
assert!(!caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
}
#[test]
fn test_all_enabled() {
let caps = TransferCapabilities::all_enabled();
assert!(caps.allow_gds);
assert!(caps.allow_gpu_rdma);
assert!(caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
#[test]
fn test_builder_pattern() {
let caps = TransferCapabilities::new()
.with_gds(true)
.with_gpu_rdma(false);
assert!(caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
}
#[test]
fn test_selective_enablement() {
// Enable only GDS
let caps = TransferCapabilities::new().with_gds(true);
assert!(caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
// Enable only GPU RDMA
let caps = TransferCapabilities::new().with_gpu_rdma(true);
assert!(!caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Block checksum computation for verification.
//!
//! This module provides utilities to compute checksums of blocks for
//! round-trip test verification.
use crate::block_manager::v2::memory::StorageKind;
use super::PhysicalLayout;
use aligned_vec::{AVec, avec};
use anyhow::{Result, anyhow};
use blake3::Hasher;
use std::{
collections::HashMap,
fs::File,
io::{Read, Seek},
mem::ManuallyDrop,
ops::Range,
os::fd::FromRawFd,
};
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind};
pub type BlockChecksum = String;
/// Compute checksums for a list of blocks.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to checksum
///
/// # Returns
/// A map from block ID to its checksum
///
/// # Errors
/// Returns an error if:
/// - Layout is remote (cannot checksum remote memory directly)
/// - Block IDs are out of range
pub fn compute_block_checksums(
layout: &PhysicalLayout,
block_ids: &[usize],
) -> Result<HashMap<usize, BlockChecksum>> {
let mut checksums = HashMap::new();
for &block_id in block_ids {
let checksum = compute_single_block_checksum(layout, block_id, None)?;
checksums.insert(block_id, checksum);
}
Ok(checksums)
}
/// Compute checksums for specific layers in blocks.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to checksum
/// * `layer_range` - Range of layers to include in checksum
///
/// # Returns
/// A map from block ID to its checksum (for the specified layers only)
pub fn compute_layer_checksums(
layout: &PhysicalLayout,
block_ids: &[usize],
layer_range: Range<usize>,
) -> Result<HashMap<usize, BlockChecksum>> {
let config = layout.layout().config();
if layer_range.end > config.num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layer_range,
config.num_layers
));
}
let mut checksums = HashMap::new();
for &block_id in block_ids {
let checksum = compute_single_block_checksum(layout, block_id, Some(layer_range.clone()))?;
checksums.insert(block_id, checksum);
}
Ok(checksums)
}
/// Compute checksum for a single block.
fn compute_single_block_checksum(
layout: &PhysicalLayout,
block_id: usize,
layer_range: Option<Range<usize>>,
) -> Result<String> {
let config = layout.layout().config();
if block_id >= config.num_blocks {
return Err(anyhow!("Block ID {} out of range", block_id));
}
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
let layers = layer_range.unwrap_or(0..num_layers);
// validate layer range
if layers.end > config.num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layers,
config.num_layers
));
}
let mut hasher = Hasher::new();
// Iterate over all layers and outer dimensions
for layer_id in layers {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
match layout.location() {
StorageKind::System | StorageKind::Pinned => {
let slice = unsafe {
std::slice::from_raw_parts(region.addr() as *const u8, region.size())
};
hasher.update(slice);
}
StorageKind::Device(_) => {
let system_region: Vec<u8> = vec![0; region.size()];
unsafe {
cudaMemcpy(
system_region.as_ptr() as *mut std::ffi::c_void,
region.addr() as *const std::ffi::c_void,
region.size(),
cudaMemcpyKind::cudaMemcpyDeviceToHost,
);
}
hasher.update(system_region.as_slice());
}
StorageKind::Disk(fd) => {
let mut system_region: AVec<u8, _> = avec![[4096]| 0; region.size()];
let mut file = ManuallyDrop::new(unsafe { File::from_raw_fd(fd as i32) });
file.seek(std::io::SeekFrom::Start(region.addr() as u64))?;
file.read_exact(&mut system_region)?;
hasher.update(system_region.as_slice());
}
}
}
}
Ok(hasher.finalize().to_string())
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
use crate::block_manager::v2::physical::transfer::{FillPattern, fill_blocks};
#[test]
fn test_checksum_constant_pattern() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap();
let checksums = compute_block_checksums(&physical, &[0, 1]).unwrap();
// Both blocks should have the same checksum values (same pattern)
assert_eq!(checksums[&0], checksums[&1]);
let memory_region = physical.memory_region(0, 0, 0).unwrap();
let slice = unsafe {
std::slice::from_raw_parts(memory_region.addr() as *const u8, memory_region.size())
};
assert!(slice.iter().all(|&b| b == 42));
let mut hasher = Hasher::new();
hasher.update(slice);
let checksum_mr_slice = hasher.finalize().to_string();
let vec = vec![42; memory_region.size()];
let mut hasher = Hasher::new();
hasher.update(&vec);
let checksum_vec = hasher.finalize().to_string();
assert_eq!(checksum_mr_slice, checksum_vec);
}
// #[test]
// fn test_checksum_different_patterns() {
// let (layout, _memory) = create_test_layout(2);
// let physical = PhysicalLayout::new_local(layout, StorageLocation::System);
// // Fill blocks with different patterns
// fill_blocks(&physical, &[0], FillPattern::Constant(42)).unwrap();
// fill_blocks(&physical, &[1], FillPattern::Constant(100)).unwrap();
// let checksums = compute_block_checksums(&physical, &[0, 1]).unwrap();
// // Blocks should have different checksums
// assert_ne!(checksums[&0], checksums[&1]);
// }
// #[test]
// fn test_checksum_matches() {
// let (layout1, _memory1) = create_test_layout(1);
// let (layout2, _memory2) = create_test_layout(1);
// let physical1 = PhysicalLayout::new_local(layout1, StorageLocation::System);
// let physical2 = PhysicalLayout::new_local(layout2, StorageLocation::System);
// // Fill both with same pattern
// fill_blocks(&physical1, &[0], FillPattern::Sequential).unwrap();
// fill_blocks(&physical2, &[0], FillPattern::Sequential).unwrap();
// let checksum1 = compute_block_checksums(&physical1, &[0]).unwrap();
// let checksum2 = compute_block_checksums(&physical2, &[0]).unwrap();
// // Checksums should match (ignoring block_id)
// assert!(checksum1[&0].matches(&checksum2[&0]));
// }
// #[test]
// fn test_layer_checksums() {
// let (layout, _memory) = create_test_layout(1);
// let physical = PhysicalLayout::new_local(layout, StorageLocation::System);
// // Fill entire block
// fill_blocks(&physical, &[0], FillPattern::Sequential).unwrap();
// // Compute checksums for different layer ranges
// let full_checksum = compute_block_checksums(&physical, &[0]).unwrap();
// let layer0_checksum = compute_layer_checksums(&physical, &[0], 0..1).unwrap();
// let layer1_checksum = compute_layer_checksums(&physical, &[0], 1..2).unwrap();
// // Layer checksums should be different from full checksum
// assert_ne!(full_checksum[&0].byte_count, layer0_checksum[&0].byte_count);
// assert_ne!(full_checksum[&0].byte_count, layer1_checksum[&0].byte_count);
// // Layer 0 and Layer 1 should have same byte count (same size)
// assert_eq!(
// layer0_checksum[&0].byte_count,
// layer1_checksum[&0].byte_count
// );
// }
// #[test]
// fn test_checksum_remote_layout_fails() {
// let (layout, _memory) = create_test_layout(1);
// let physical =
// PhysicalLayout::new_remote(layout, StorageLocation::System, "remote".to_string());
// let result = compute_block_checksums(&physical, &[0]);
// assert!(result.is_err());
// assert!(result.unwrap_err().to_string().contains("remote"));
// }
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment