Unverified Commit 3998fdcb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KVBM V2 Initial Migration (#3861)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent e64d2f09
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for the storage-next module.
use super::*;
#[test]
fn test_system_storage() {
let storage = SystemStorage::new(1024).unwrap();
assert_eq!(storage.size(), 1024);
assert_eq!(storage.storage_kind(), StorageKind::System);
assert!(storage.addr() != 0);
// Test that we can create multiple allocations
let storage2 = SystemStorage::new(2048).unwrap();
assert_eq!(storage2.size(), 2048);
assert_ne!(storage.addr(), storage2.addr());
}
#[test]
fn test_system_storage_zero_size() {
let result = SystemStorage::new(0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_disk_storage_temp() {
let storage = DiskStorage::new(4096).unwrap();
assert_eq!(storage.size(), 4096);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
// Disk storage is file-backed, so addr() returns 0 (no memory address)
assert_eq!(storage.addr(), 0);
assert!(storage.path().exists());
}
#[test]
fn test_disk_storage_at_path() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.bin");
let storage = DiskStorage::new_at(&path, 8192).unwrap();
assert_eq!(storage.size(), 8192);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
assert!(path.exists());
}
#[test]
fn test_type_erasure() {
let storage = SystemStorage::new(1024).unwrap();
let erased: OwnedMemoryRegion = erase_storage(storage);
assert_eq!(erased.size(), 1024);
assert_eq!(erased.storage_kind(), StorageKind::System);
}
#[test]
fn test_memory_descriptor() {
let desc = MemoryDescriptor::new(0x1000, 4096);
assert_eq!(desc.addr, 0x1000);
assert_eq!(desc.size, 4096);
}
#[cfg(feature = "testing-cuda")]
mod cuda_tests {
use super::*;
#[test]
fn test_pinned_storage() {
let storage = PinnedStorage::new(2048).unwrap();
assert_eq!(storage.size(), 2048);
assert_eq!(storage.storage_kind(), StorageKind::Pinned);
assert!(storage.addr() != 0);
}
#[test]
fn test_pinned_storage_zero_size() {
let storage = PinnedStorage::new(0);
assert!(storage.is_err());
assert!(matches!(
storage.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_device_storage() {
let storage = DeviceStorage::new(4096, 0).unwrap();
assert_eq!(storage.size(), 4096);
assert_eq!(storage.storage_kind(), StorageKind::Device(0));
assert!(storage.addr() != 0);
assert_eq!(storage.device_id(), 0);
}
#[test]
fn test_device_storage_zero_size() {
let result = DeviceStorage::new(0, 0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
}
// Tests for NIXL registration would require a real NIXL agent,
// so we'll skip those for now. In practice, you'd mock the agent
// or use integration tests.
#[cfg(feature = "testing-nixl")]
mod nixl_tests {
use super::super::registered::register_with_nixl;
use super::*;
use nixl_sys::Agent as NixlAgent;
// These tests would require a mock NIXL agent or real NIXL setup
// Placeholder for now
#[test]
fn test_nixl_registration() {
let pinned = PinnedStorage::new(2048).unwrap();
let agent = NixlAgent::new("test_agent").unwrap();
let registered = register_with_nixl(pinned, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
impl TorchDevice {
pub fn is_cuda(&self) -> bool {
matches!(self, TorchDevice::Cuda(_))
}
pub fn cuda_device_index(&self) -> Option<usize> {
match self {
TorchDevice::Cuda(index) => Some(*index),
TorchDevice::Other(_) => None,
}
}
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
use super::InnerShape;
/// Configuration for block layouts
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize, PartialEq, Eq)]
pub struct LayoutConfig {
/// Number of blocks
#[validate(range(min = 1))]
pub num_blocks: usize,
/// Number of layers
#[validate(range(min = 1))]
pub num_layers: usize,
/// Number of outer dimensions
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
/// Page size
#[validate(range(min = 1))]
pub page_size: usize,
/// Inner dimension
#[validate(range(min = 1))]
pub inner_dim: usize,
/// Alignment
#[validate(custom(function = "validate_power_of_2"))]
#[builder(default = "1")]
pub alignment: usize,
/// Data type
#[validate(custom(function = "validate_dtype_width_bytes"))]
#[builder(default = "2")]
pub dtype_width_bytes: usize,
/// Inner shape format (NHD, HND, or Unknown)
#[builder(default = "InnerShape::Unknown")]
pub inner_shape: InnerShape,
}
impl LayoutConfig {
/// Builder for LayoutConfig
pub fn builder() -> LayoutConfigBuilder {
LayoutConfigBuilder::default()
}
pub fn required_bytes(&self) -> usize {
self.num_blocks
.saturating_mul(self.num_layers)
.saturating_mul(self.outer_dim)
.saturating_mul(self.page_size)
.saturating_mul(self.inner_dim)
.saturating_mul(self.dtype_width_bytes)
}
}
/// The first two dimensions of the tensor, `shape[0]` and `shape[1]`, one of those corresponds to the
/// block dimension, while the other corresponds to the outer dimension.
///
/// The outer dimension is typically:
/// - 1: MLA or K and V stored together,
/// - 2: K and V stored separately,
///
/// The block dimension tell us the number of blocks.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockDimension {
/// The block dimension is the first dimension of the tensor, `[n_blocks, outer_dim, inner_dim]`
BlockIsFirstDim,
/// The block dimension is the second dimension of the tensor, `[outer_dim, n_blocks, inner_dim]`
/// This is a replacement for v1's `outer_contiguous` is true.
BlockIsSecondDim,
}
/// Validation function for Option<usize> to check if it's Some(power_of_2).
pub fn validate_power_of_2(alignment: usize) -> Result<(), ValidationError> {
if !alignment.is_power_of_two() {
// Return validation error if alignment is not a power of 2
return Err(validator::ValidationError::new(
"alignment_must_be_power_of_2",
));
}
// Passes validation if alignment is a power of 2
Ok(())
}
pub fn validate_dtype_width_bytes(dtype_width_bytes: usize) -> Result<(), ValidationError> {
if !dtype_width_bytes.is_power_of_two() || !(2..=8).contains(&dtype_width_bytes) {
return Err(validator::ValidationError::new(
"dtype_width_bytes_must_be_power_of_two_and_less_than_8_bytes",
));
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Fully contiguous layout implementation.
//!
//! This layout stores all blocks in a single contiguous memory allocation
//! with the shape: [num_blocks, num_layers, outer_dim, page_size, inner_dim].
use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{BlockFormat, FullyContiguousDetails, LayoutTypeDetails};
use super::{Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion};
/// Fully contiguous layout where all blocks are in a single allocation.
#[derive(Debug)]
pub struct FullyContiguousLayout {
config: LayoutConfig,
/// Base address of the allocation
base_addr: usize,
/// Stride between blocks in bytes
block_stride: usize,
/// Stride between layers in bytes
layer_stride: usize,
/// Stride between outer dimensions in bytes
outer_stride: usize,
/// Size of each memory region (page) in bytes
region_size: usize,
/// Owned memory region backing this layout
memory: Arc<dyn MemoryRegion>,
/// Format of blocks in memory
block_format: BlockFormat,
}
impl FullyContiguousLayout {
/// Create a new fully contiguous layout.
///
/// # Arguments
/// * `config` - Layout configuration
/// * `memory` - Owned memory region that backs this layout
///
/// # Returns
/// A new FullyContiguousLayout instance
pub fn new(config: LayoutConfig, memory: Arc<dyn MemoryRegion>) -> Result<Self> {
config.validate()?;
let base_addr = memory.addr();
// Calculate strides
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let outer_stride = region_size;
let layer_stride = outer_stride * config.outer_dim;
let block_stride = layer_stride * config.num_layers;
// Validate that the memory region is large enough
let required_size = block_stride * config.num_blocks;
if memory.size() < required_size {
return Err(anyhow!(
"Memory region too small for layout. Required: {} bytes, got: {} bytes",
required_size,
memory.size()
));
}
Ok(Self {
config,
base_addr,
block_stride,
layer_stride,
outer_stride,
region_size,
memory,
block_format: BlockFormat::default(),
})
}
/// Create a new fully contiguous layout with a specific block format.
///
/// # Arguments
/// * `config` - Layout configuration
/// * `memory` - Owned memory region that backs this layout
/// * `block_format` - Format of blocks in memory
///
/// # Returns
/// A new FullyContiguousLayout instance
pub(crate) fn new_with_format(
config: LayoutConfig,
memory: Arc<dyn MemoryRegion>,
block_format: BlockFormat,
) -> Result<Self> {
let mut layout = Self::new(config, memory)?;
layout.block_format = block_format;
Ok(layout)
}
/// Get the block format.
pub fn block_format(&self) -> BlockFormat {
self.block_format
}
/// Calculate the address of a specific memory region.
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
Ok(self.base_addr
+ block_id * self.block_stride
+ layer_id * self.layer_stride
+ outer_id * self.outer_stride)
}
/// Get mutable reference to the memory Arc for NIXL registration.
pub fn memory_arc_mut(&mut self) -> &mut Arc<dyn MemoryRegion> {
&mut self.memory
}
}
impl Layout for FullyContiguousLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
std::slice::from_ref(&self.memory)
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
// Single contiguous allocation
vec![self.block_stride * self.config.num_blocks]
}
fn is_fully_contiguous(&self) -> bool {
true
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: self.block_format,
})
}
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_fully_contiguous_layout_creation() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_bytes = config.required_bytes();
assert_eq!(required_bytes, 10 * 4 * 2 * 16 * 128 * 2);
let memory = MockMemory::new(0x1000, required_bytes);
let layout = FullyContiguousLayout::new(config, memory).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.required_bytes();
let memory = MockMemory::new(0x1000, required_size);
let layout = FullyContiguousLayout::new(config.clone(), memory).unwrap();
// Test accessing specific memory regions
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
// Block 0, Layer 0, Outer 0
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
// Block 0, Layer 0, Outer 1
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
// Block 0, Layer 1, Outer 0
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + 2 * region_size);
assert_eq!(region.size, region_size);
// Block 1, Layer 0, Outer 0
let region = layout.memory_region(1, 0, 0).unwrap();
assert_eq!(
region.addr,
0x1000 + (config.outer_dim * config.num_layers * region_size)
);
assert_eq!(region.size, region_size);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests comparing v1 and v2 layout implementations.
//!
//! These tests validate that the new v2 layout system produces identical
//! memory region **addresses** as the proven v1 implementation.
//!
//! **Note on Size Differences**: V1's `memory_region()` returns `layer_stride` as the
//! size (covering all outer dimensions), while V2 returns `outer_stride` (single page).
//! This is an intentional API difference - V2 provides more granular access.
//! Therefore, these tests only compare addresses, not sizes.
#![cfg(test)]
use anyhow::Result;
use std::{any::Any, sync::Arc};
use crate::block_manager::{
layout::{
BlockDimension, BlockLayout, BlockLayoutConfig, GenericBlockLayout, LayoutConfig,
LayoutType,
tests::{setup_layer_separate_layout, setup_layout},
},
storage::{Storage, tests::NullDeviceStorage},
v2::storage::StorageKind,
};
use super::{
FullyContiguousLayout, LayerSeparateLayout, Layout, LayoutConfig as V2LayoutConfig,
MemoryRegion,
};
// Test constants matching v1 tests
const NUM_BLOCKS: usize = 7;
const NUM_LAYERS: usize = 5;
const OUTER_DIM: usize = 2;
const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 13;
const DTYPE_WIDTH_BYTES: usize = 4;
/// Wrapper to make v1 NullDeviceStorage compatible with v2 MemoryRegion trait.
#[derive(Debug)]
struct V1StorageWrapper {
storage: NullDeviceStorage,
}
impl MemoryRegion for V1StorageWrapper {
fn addr(&self) -> usize {
self.storage.addr() as usize
}
fn size(&self) -> usize {
self.storage.size()
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
}
/// Create v1 layout configuration
fn create_v1_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.alignment(1)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()
.unwrap()
}
/// Create v2 layout configuration (equivalent to v1)
fn create_v2_config() -> V2LayoutConfig {
create_v1_config()
}
#[test]
fn test_v1_v2_fully_contiguous_equivalence() -> Result<()> {
// Create v1 layout
let v1_layout = setup_layout(None)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let required_size =
NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={}",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={}",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify metadata
assert_eq!(v1_layout.num_blocks(), v2_layout.num_blocks());
assert_eq!(v1_layout.num_layers(), v2_layout.num_layers());
assert_eq!(v1_layout.outer_dim(), v2_layout.outer_dim());
assert_eq!(v1_layout.page_size(), v2_layout.page_size());
assert_eq!(v1_layout.inner_dim(), v2_layout.inner_dim());
Ok(())
}
#[test]
fn test_v1_v2_layer_separate_block_contiguous_equivalence() -> Result<()> {
// Create v1 layout (block contiguous = !outer_contiguous)
let v1_layout = setup_layer_separate_layout(None, BlockDimension::BlockIsFirstDim)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let per_layer_size = NUM_BLOCKS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..NUM_LAYERS)
.map(|_| {
Arc::new(V1StorageWrapper {
storage: NullDeviceStorage::new(per_layer_size as u64),
}) as Arc<dyn MemoryRegion>
})
.collect();
let v2_layout = LayerSeparateLayout::new(v2_config, memory, BlockDimension::BlockIsFirstDim)?;
// Verify metadata
assert_eq!(v1_layout.num_blocks(), v2_layout.num_blocks());
assert_eq!(v1_layout.num_layers(), v2_layout.num_layers());
assert_eq!(v1_layout.outer_dim(), v2_layout.outer_dim());
assert_eq!(v1_layout.page_size(), v2_layout.page_size());
assert_eq!(v1_layout.inner_dim(), v2_layout.inner_dim());
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={} (block_contiguous)",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={} (block_contiguous)",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify layout type
assert!(!v2_layout.is_fully_contiguous());
assert_eq!(
v1_layout.layout_type(),
LayoutType::LayerSeparate {
block_dim: BlockDimension::BlockIsFirstDim,
}
);
Ok(())
}
#[test]
fn test_v1_v2_layer_separate_outer_contiguous_equivalence() -> Result<()> {
// Create v1 layout (outer contiguous)
let v1_layout = setup_layer_separate_layout(None, BlockDimension::BlockIsSecondDim)?;
// Create v2 layout with same configuration
let v2_config = create_v2_config();
let per_layer_size = NUM_BLOCKS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..NUM_LAYERS)
.map(|_| {
Arc::new(V1StorageWrapper {
storage: NullDeviceStorage::new(per_layer_size as u64),
}) as Arc<dyn MemoryRegion>
})
.collect();
let v2_layout = LayerSeparateLayout::new(v2_config, memory, BlockDimension::BlockIsSecondDim)?;
// Compare all memory regions
for block_id in 0..NUM_BLOCKS {
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, layer_id, outer_id)?;
assert_eq!(
v1_region.addr(),
v2_region.addr,
"Address mismatch at block={}, layer={}, outer={} (outer_contiguous)",
block_id,
layer_id,
outer_id
);
assert_eq!(
v1_region.size(),
v2_region.size,
"Size mismatch at block={}, layer={}, outer={} (outer_contiguous)",
block_id,
layer_id,
outer_id
);
}
}
}
// Verify layout type
assert!(!v2_layout.is_fully_contiguous());
assert_eq!(
v1_layout.layout_type(),
LayoutType::LayerSeparate {
block_dim: BlockDimension::BlockIsSecondDim,
}
);
Ok(())
}
#[test]
fn test_v1_v2_stride_calculations() -> Result<()> {
// Test with a specific pattern to verify stride calculations
let _v1_layout = setup_layout(None)?;
let v2_config = create_v2_config();
let required_size =
NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Calculate expected strides
let region_size = PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let outer_stride = region_size;
let layer_stride = outer_stride * OUTER_DIM;
let block_stride = layer_stride * NUM_LAYERS;
// Test stride consistency across blocks
for block_id in 0..NUM_BLOCKS - 1 {
let region_b0 = v2_layout.memory_region(block_id, 0, 0)?;
let region_b1 = v2_layout.memory_region(block_id + 1, 0, 0)?;
assert_eq!(
region_b1.addr - region_b0.addr,
block_stride,
"Block stride mismatch between blocks {} and {}",
block_id,
block_id + 1
);
}
// Test stride consistency across layers
for layer_id in 0..NUM_LAYERS - 1 {
let region_l0 = v2_layout.memory_region(0, layer_id, 0)?;
let region_l1 = v2_layout.memory_region(0, layer_id + 1, 0)?;
assert_eq!(
region_l1.addr - region_l0.addr,
layer_stride,
"Layer stride mismatch between layers {} and {}",
layer_id,
layer_id + 1
);
}
// Test stride consistency across outer dimensions
for outer_id in 0..OUTER_DIM - 1 {
let region_o0 = v2_layout.memory_region(0, 0, outer_id)?;
let region_o1 = v2_layout.memory_region(0, 0, outer_id + 1)?;
assert_eq!(
region_o1.addr - region_o0.addr,
outer_stride,
"Outer stride mismatch between outer dims {} and {}",
outer_id,
outer_id + 1
);
}
Ok(())
}
#[test]
fn test_v1_v2_edge_case_single_block() -> Result<()> {
// Test with minimal configuration: single block
let v1_config = LayoutConfig::builder()
.num_blocks(1)
.num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()
.unwrap();
let v1_layout = crate::block_manager::layout::FullyContiguous::allocate(
v1_config.clone(),
&crate::block_manager::storage::tests::NullDeviceAllocator,
)?;
let v2_config = v1_config.clone();
let required_size = 1 * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare the single block across all layers and outer dims
for layer_id in 0..NUM_LAYERS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(0, layer_id, outer_id)?;
let v2_region = v2_layout.memory_region(0, layer_id, outer_id)?;
assert_eq!(v1_region.addr(), v2_region.addr);
assert_eq!(v1_region.size(), v2_region.size);
}
}
Ok(())
}
#[test]
fn test_v1_v2_edge_case_single_layer() -> Result<()> {
// Test with minimal configuration: single layer
let v1_config = LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(1)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.dtype_width_bytes(DTYPE_WIDTH_BYTES)
.build()?;
let v1_layout = crate::block_manager::layout::FullyContiguous::allocate(
v1_config.clone(),
&crate::block_manager::storage::tests::NullDeviceAllocator,
)?;
let v2_config = v1_config.clone();
let required_size = NUM_BLOCKS * 1 * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE_WIDTH_BYTES;
let v1_storage = NullDeviceStorage::new(required_size as u64);
let memory = Arc::new(V1StorageWrapper {
storage: v1_storage,
}) as Arc<dyn MemoryRegion>;
let v2_layout = FullyContiguousLayout::new(v2_config, memory)?;
// Compare the single layer across all blocks and outer dims
for block_id in 0..NUM_BLOCKS {
for outer_id in 0..OUTER_DIM {
let v1_region = v1_layout.memory_region(block_id, 0, outer_id)?;
let v2_region = v2_layout.memory_region(block_id, 0, outer_id)?;
assert_eq!(v1_region.addr(), v2_region.addr);
assert_eq!(v1_region.size(), v2_region.size);
}
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Layer-separate layout implementation.
//!
//! This layout stores each layer in its own allocation, which is the typical
//! vLLM layout. Each layer can be either block-contiguous or outer-contiguous:
//! - Block-contiguous: [num_blocks, outer_dim, page_size, inner_dim]
//! - Outer-contiguous: [outer_dim, num_blocks, page_size, inner_dim]
use anyhow::{Result, anyhow};
use std::sync::Arc;
use validator::Validate;
use super::serialize::{LayerSeparateDetails, LayoutTypeDetails};
use super::{
BlockDimension, Layout, LayoutConfig, MemoryDescriptor, MemoryRegion, OwnedMemoryRegion,
};
/// Layer-separate layout where each layer has its own allocation.
#[derive(Debug)]
pub struct LayerSeparateLayout {
config: LayoutConfig,
/// Base addresses for each layer
layer_base_addrs: Vec<usize>,
/// Whether the outer dimension is contiguous (vs block dimensionl
block_dim: BlockDimension,
/// Stride between blocks in bytes
block_stride: usize,
/// Stride between outer dimensions in bytes
outer_stride: usize,
/// Size of each memory region (page) in bytes
region_size: usize,
/// Owned memory regions backing this layout (one per layer)
memory_regions: Vec<Arc<dyn MemoryRegion>>,
}
impl LayerSeparateLayout {
/// Create a new layer-separate layout.
///
/// # Arguments
/// - `config` - Layout configuration
/// - `memory` - Vector of owned memory regions (one per layer)
/// - `outer_contiguous` - If true, outer dimension is contiguous with the inner dimension, i.e. (num_blocks, outer_dim, ...);
/// if false, block dimension is contiguous with the inner dimension, i.e. (outer_dim, num_blocks, ...).
///
/// # Returns
/// A new LayerSeparateLayout instance
pub fn new(
config: LayoutConfig,
memory: Vec<Arc<dyn MemoryRegion>>,
block_dim: BlockDimension,
) -> Result<Self> {
config.validate()?;
if memory.len() != config.num_layers {
return Err(anyhow!(
"Memory region count ({}) must match num_layers ({})",
memory.len(),
config.num_layers
));
}
// Calculate strides
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let (block_stride, outer_stride) = if block_dim == BlockDimension::BlockIsSecondDim {
// Layout: [outer_dim, num_blocks, page_size, inner_dim]
let block_stride = region_size;
let outer_stride = block_stride * config.num_blocks;
(block_stride, outer_stride)
} else {
// Layout: [num_blocks, outer_dim, page_size, inner_dim]
let outer_stride = region_size;
let block_stride = outer_stride * config.outer_dim;
(block_stride, outer_stride)
};
// Extract base addresses and validate sizes
let mut layer_base_addrs = Vec::with_capacity(config.num_layers);
let required_size = config.num_blocks * config.outer_dim * region_size;
for (i, mem) in memory.iter().enumerate() {
if mem.size() < required_size {
return Err(anyhow!(
"Memory region {} too small for layout. Required: {} bytes, got: {} bytes",
i,
required_size,
mem.size()
));
}
layer_base_addrs.push(mem.addr());
}
Ok(Self {
config,
layer_base_addrs,
block_dim,
block_stride,
outer_stride,
region_size,
memory_regions: memory,
})
}
/// Calculate the address of a specific memory region.
fn calculate_address(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<usize> {
if block_id >= self.config.num_blocks {
return Err(anyhow!(
"Block ID {} out of range (max: {})",
block_id,
self.config.num_blocks
));
}
if layer_id >= self.config.num_layers {
return Err(anyhow!(
"Layer ID {} out of range (max: {})",
layer_id,
self.config.num_layers
));
}
if outer_id >= self.config.outer_dim {
return Err(anyhow!(
"Outer ID {} out of range (max: {})",
outer_id,
self.config.outer_dim
));
}
let base_addr = self.layer_base_addrs[layer_id];
let offset = block_id * self.block_stride + outer_id * self.outer_stride;
Ok(base_addr + offset)
}
pub fn block_dim(&self) -> BlockDimension {
self.block_dim
}
/// Get mutable reference to the memory regions for NIXL registration.
pub fn memory_regions_mut(&mut self) -> &mut [Arc<dyn MemoryRegion>] {
&mut self.memory_regions
}
}
impl Layout for LayerSeparateLayout {
fn config(&self) -> &LayoutConfig {
&self.config
}
fn memory_regions(&self) -> &[OwnedMemoryRegion] {
&self.memory_regions
}
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
let addr = self.calculate_address(block_id, layer_id, outer_id)?;
Ok(MemoryDescriptor::new(addr, self.region_size))
}
fn required_allocations(&self) -> Vec<usize> {
// One allocation per layer
let per_layer_size = self.config.num_blocks * self.config.outer_dim * self.region_size;
vec![per_layer_size; self.config.num_layers]
}
fn is_fully_contiguous(&self) -> bool {
false
}
fn num_blocks(&self) -> usize {
self.config.num_blocks
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn outer_dim(&self) -> usize {
self.config.outer_dim
}
fn page_size(&self) -> usize {
self.config.page_size
}
fn inner_dim(&self) -> usize {
self.config.inner_dim
}
fn dtype_width_bytes(&self) -> usize {
self.config.dtype_width_bytes
}
fn serialization_details(&self) -> LayoutTypeDetails {
LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: self.block_dim,
})
}
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_layer_separate_block_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
assert_eq!(layout.required_allocations().len(), 4);
}
#[test]
fn test_layer_separate_outer_contiguous() {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 10 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..4)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsSecondDim).unwrap();
assert_eq!(layout.num_blocks(), 10);
assert!(!layout.is_fully_contiguous());
}
#[test]
fn test_memory_region() {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap();
let per_layer_size = 2 * 2 * 16 * 128 * 2;
let memory: Vec<Arc<dyn MemoryRegion>> = (0..2)
.map(|i| {
MockMemory::new(0x1000 + i * per_layer_size, per_layer_size)
as Arc<dyn MemoryRegion>
})
.collect();
let layout =
LayerSeparateLayout::new(config, memory, BlockDimension::BlockIsFirstDim).unwrap();
// Test accessing specific memory regions
let region_size = 16 * 128 * 2;
// Block 0, Layer 0, Outer 0 - should be at layer 0's base address
let region = layout.memory_region(0, 0, 0).unwrap();
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, region_size);
// Block 0, Layer 1, Outer 0 - should be at layer 1's base address
let region = layout.memory_region(0, 1, 0).unwrap();
assert_eq!(region.addr, 0x1000 + per_layer_size);
assert_eq!(region.size, region_size);
// Block 0, Layer 0, Outer 1 - should be offset within layer 0
let region = layout.memory_region(0, 0, 1).unwrap();
assert_eq!(region.addr, 0x1000 + region_size);
assert_eq!(region.size, region_size);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Decoupled layout system for block management.
//!
//! This module provides a simplified layout abstraction that:
//! - Maps block IDs to physical memory regions (address + size)
//! - Decouples memory regions from storage type information
//! - Specifies allocation requirements without performing allocation
//! - Uses trait objects for memory ownership
pub(crate) mod builder;
mod config;
mod fully_contiguous;
mod layer_separate;
mod physical;
mod serialize;
mod validation;
#[cfg(test)]
pub(super) mod tests;
// #[cfg(test)]
// mod integration_tests;
pub use builder::{LayoutKind, PhysicalLayoutBuilder};
pub use config::{BlockDimension, LayoutConfig};
pub use fully_contiguous::FullyContiguousLayout;
pub use layer_separate::LayerSeparateLayout;
pub use physical::{NixlMetadata, PhysicalLayout};
pub use serialize::{
BlockFormat, FullyContiguousDetails, LayerSeparateDetails, LayoutDescriptor, LayoutTypeDetails,
};
pub use validation::{TensorFormat, validate_tensor_shapes, validate_tensor_strides};
// mod registration;
// pub use registration::{RegisteredLayout, RegisteredStorageMetadata, RegistrationManager};
use anyhow::Result;
use serde::{Deserialize, Serialize};
pub use crate::block_manager::v2::memory::{MemoryDescriptor, MemoryRegion, OwnedMemoryRegion};
/// Core layout trait for mapping block IDs to memory regions.
///
/// Layouts specify how KV cache blocks are organized in memory without
/// performing allocation themselves. They provide:
/// - Memory region lookup for specific blocks
/// - Allocation requirements for external allocators
/// - Metadata about block organization
pub trait Layout: Send + Sync + std::fmt::Debug {
/// Get the configuration for this layout.
fn config(&self) -> &LayoutConfig;
/// Get the root memory regions backing this layout.
///
/// These regions correspond to the concrete allocations that store the layout's data.
/// Implementations that derive memory procedurally can return an empty slice.
fn memory_regions(&self) -> &[OwnedMemoryRegion];
/// Get memory regions for a specific block_id, layer_id, outer_id.
///
/// Returns a [MemoryRegion] for the continuous region specified by the given block_id,
/// layer_id, outer_id.
///
/// # Arguments
/// * `block_id` - The ID of the block to query (0..num_blocks)
/// * `layer_id` - The ID of the layer to query (0..num_layers)
/// * `outer_id` - The ID of the outer dimension to query (0..outer_dim)
fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor>;
/// Get the allocation requirements for this layout.
///
/// Returns a vector of allocation sizes needed to back this layout.
/// For fully contiguous layouts, this will be a single size.
/// For layer-separate layouts, this will contain one size per layer.
///
/// # Returns
/// Vector of allocation sizes in bytes.
fn required_allocations(&self) -> Vec<usize>;
/// Check if this layout uses fully contiguous memory.
///
/// Fully contiguous layouts have all blocks in a single allocation,
/// which enables certain optimizations.
fn is_fully_contiguous(&self) -> bool;
/// Get the total number of blocks in this layout.
fn num_blocks(&self) -> usize;
/// Get the number of layers per block.
fn num_layers(&self) -> usize;
/// Get the outer dimension size.
///
/// In typical KV cache layouts, this is often 2 (for K and V),
/// but can be 1 for architectures like MLA.
fn outer_dim(&self) -> usize;
/// Get the page size (often corresponds to block size in tokens).
fn page_size(&self) -> usize;
/// Get the inner dimension size.
///
/// This is typically the hidden size divided by tensor parallel size.
fn inner_dim(&self) -> usize;
/// Get the data type width in bytes.
fn dtype_width_bytes(&self) -> usize;
/// Get serialization details for this layout type.
///
/// This provides the layout-type-specific information needed to serialize
/// and reconstruct the layout on a remote node.
fn serialization_details(&self) -> serialize::LayoutTypeDetails;
}
/// Inner shape format for tensor layout
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum InnerShape {
/// Unknown shape - fallback when we can't determine the format
Unknown,
/// NHD format: [block_size, num_heads, head_dim]
/// Common for attention layers where N=tokens, H=heads, D=dimension
NHD,
/// HND format: [num_heads, block_size, head_dim]
/// Alternative layout with heads first
HND,
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Physical layout types that combine abstract layouts with storage location metadata.
use super::{
FullyContiguousLayout, LayerSeparateLayout, Layout, MemoryDescriptor,
builder::{PhysicalLayoutBuilder, PhysicalLayoutBuilderDefault},
serialize::{LayoutDescriptor, LayoutTypeDetails},
};
use crate::block_manager::v2::memory::{MemoryRegion, StorageKind};
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::sync::Arc;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
/// Runtime representation of a layout with its physical storage location.
///
/// A `PhysicalLayout` wraps an abstract [`Layout`] with information about where
/// its memory physically resides (GPU, host, disk) and whether it's local or remote.
/// This enables the transfer system to select appropriate copy strategies and build
/// NIXL transfer descriptors.
#[derive(Debug, Clone)]
pub struct PhysicalLayout {
/// The abstract layout defining memory organization
layout: Arc<dyn Layout>,
/// Physical storage location (System, Device, Pinned, Disk)
location: StorageKind,
/// NIXL registration metadata
nixl_metadata: NixlMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NixlMetadata {
agent_name: String,
mem_type: nixl_sys::MemType,
device_id: u64,
}
impl NixlMetadata {
pub fn new(agent_name: String, mem_type: nixl_sys::MemType, device_id: u64) -> Self {
Self {
agent_name,
mem_type,
device_id,
}
}
pub fn agent_name(&self) -> &str {
&self.agent_name
}
pub fn mem_type(&self) -> nixl_sys::MemType {
self.mem_type
}
pub fn device_id(&self) -> u64 {
self.device_id
}
}
impl PhysicalLayout {
/// Create a typed builder that enforces NIXL registration.
pub fn builder(agent: NixlAgent) -> PhysicalLayoutBuilderDefault {
PhysicalLayoutBuilder::new(agent)
}
/// Create a new local physical layout.
///
/// # Arguments
/// * `layout` - The abstract layout to wrap
/// * `location` - Where the layout's memory resides
pub(crate) fn new_local(
layout: Arc<dyn Layout>,
location: StorageKind,
nixl_metadata: NixlMetadata,
) -> Self {
Self {
layout,
location,
nixl_metadata,
}
}
// /// Create a new remote physical layout from a descriptor.
// ///
// /// # Arguments
// /// * `layout` - The abstract layout to wrap
// /// * `location` - Where the layout's memory resides (on remote node)
// /// * `remote_agent` - Name of the NIXL agent on the remote node
// pub fn new_remote(
// layout: Arc<dyn Layout>,
// location: StorageKind,
// remote_agent: String,
// ) -> Self {
// let metadata = NixlMetadata::new(
// remote_agent.clone(),
// location.to_nixl_mem_type(),
// location.device_id(),
// );
// let registrations = vec![RegisteredStorageMetadata::new(
// metadata.agent_name().to_string(),
// location,
// )];
// Self {
// layout,
// location,
// locality: Locality::Remote(remote_agent),
// nixl_metadata: Some(metadata),
// registered: registrations,
// }
// }
/// Get the underlying layout.
pub fn layout(&self) -> &Arc<dyn Layout> {
&self.layout
}
/// Get the storage location.
pub fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata.
pub fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get a memory region with location information.
///
/// # Arguments
/// * `block_id` - Block identifier
/// * `layer_id` - Layer identifier
/// * `outer_id` - Outer dimension identifier
pub fn memory_region(
&self,
block_id: usize,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryDescriptor> {
self.layout.memory_region(block_id, layer_id, outer_id)
}
/// Serialize this physical layout for transmission to remote nodes.
///
/// This converts the runtime `PhysicalLayout` into a `LayoutDescriptor` that
/// contains all information needed to reconstruct the layout on a remote node,
/// including layout configuration, memory descriptors, NIXL metadata, and
/// layout-type-specific details.
///
/// # Returns
/// A serializable representation of this layout
pub fn to_descriptor(&self) -> Result<LayoutDescriptor> {
// Extract memory descriptors
let memory_descriptors = self
.layout
.memory_regions()
.iter()
.map(|region| MemoryDescriptor {
addr: region.addr(),
size: region.size(),
})
.collect();
// Get layout type details from the layout itself
let layout_type_details = self.layout.serialization_details();
Ok(LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: self.layout.config().clone(),
location: self.location,
nixl_metadata: self.nixl_metadata.clone(),
memory_descriptors,
layout_type_details,
})
}
/// Reconstruct a physical layout from serialized data received from a remote node.
///
/// This creates a new `PhysicalLayout` from a `LayoutDescriptor`. The reconstructed
/// layout will have memory descriptors that point to the remote node's memory,
/// allowing NIXL to build RDMA descriptors for remote access.
///
/// # Arguments
/// * `serialized` - Serialized layout data from a remote node
///
/// # Returns
/// A new `PhysicalLayout` representing the remote layout
///
/// # Note
/// The memory regions in the reconstructed layout are not valid for local access;
/// they represent remote memory addresses and are used to build NIXL transfer descriptors.
pub fn from_descriptor(serialized: LayoutDescriptor) -> Result<Self> {
// Validate version
if serialized.version > LayoutDescriptor::CURRENT_VERSION {
return Err(anyhow!(
"Unsupported serialization version: {}. Maximum supported: {}",
serialized.version,
LayoutDescriptor::CURRENT_VERSION
));
}
// Create remote memory regions from descriptors
let remote_regions: Vec<Arc<dyn MemoryRegion>> = serialized
.memory_descriptors
.iter()
.map(|desc| {
Arc::new(RemoteMemoryDescriptor {
addr: desc.addr,
size: desc.size,
storage_kind: serialized.location,
}) as Arc<dyn MemoryRegion>
})
.collect();
// Reconstruct the layout based on type
let layout: Arc<dyn Layout> = match serialized.layout_type_details {
LayoutTypeDetails::FullyContiguous(details) => {
if remote_regions.len() != 1 {
return Err(anyhow!(
"FullyContiguous layout requires exactly 1 memory region, got {}",
remote_regions.len()
));
}
let layout = FullyContiguousLayout::new_with_format(
serialized.layout_config.clone(),
remote_regions[0].clone(),
details.block_format,
)?;
Arc::new(layout)
}
LayoutTypeDetails::LayerSeparate(details) => {
if remote_regions.len() != serialized.layout_config.num_layers {
return Err(anyhow!(
"LayerSeparate layout requires {} memory regions (one per layer), got {}",
serialized.layout_config.num_layers,
remote_regions.len()
));
}
let layout = LayerSeparateLayout::new(
serialized.layout_config.clone(),
remote_regions,
details.block_dim,
)?;
Arc::new(layout)
}
};
Ok(Self {
layout,
location: serialized.location,
nixl_metadata: serialized.nixl_metadata,
})
}
}
/// A memory region that represents remote memory addresses.
///
/// This type is used when reconstructing layouts from serialized data.
/// The addresses are not valid for local access but can be used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug)]
struct RemoteMemoryDescriptor {
addr: usize,
size: usize,
storage_kind: StorageKind,
}
impl MemoryRegion for RemoteMemoryDescriptor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.storage_kind
}
fn as_any(&self) -> &dyn Any {
self
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Serialization types for physical layouts.
//!
//! This module provides types for serializing and deserializing physical layouts
//! so they can be transmitted to remote nodes and reconstructed there for RDMA operations.
use super::physical::NixlMetadata;
use super::{BlockDimension, LayoutConfig};
use crate::block_manager::v2::memory::{MemoryDescriptor, StorageKind};
use anyhow::Result;
use serde::{Deserialize, Serialize};
/// Format of blocks in a fully contiguous layout.
///
/// This enum describes how the blocks are organized and formatted in memory.
/// Currently only `Operational` is supported, but future variants may include
/// different compression schemes or memory layouts.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockFormat {
/// Standard operational format - blocks are stored in their normal, uncompressed form.
Operational,
}
impl Default for BlockFormat {
fn default() -> Self {
Self::Operational
}
}
/// Details specific to fully contiguous layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullyContiguousDetails {
/// Format of the blocks in memory
pub block_format: BlockFormat,
}
/// Details specific to layer-separate layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerSeparateDetails {
/// Block dimension ordering (block-first or block-second)
pub block_dim: BlockDimension,
}
/// Layout-type-specific details.
///
/// This enum captures the information that differs between layout types
/// and is needed to reconstruct the layout on a remote node.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayoutTypeDetails {
/// Fully contiguous layout details
FullyContiguous(FullyContiguousDetails),
/// Layer-separate layout details
LayerSeparate(LayerSeparateDetails),
}
/// Serializable representation of a physical layout.
///
/// This structure contains all information needed to reconstruct a layout
/// on a remote node, including:
/// - Layout configuration (dimensions, sizes, etc.)
/// - Storage location and NIXL metadata
/// - Memory descriptors for all regions
/// - Layout-type-specific details
///
/// The serialized form can be transmitted over the network and used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayoutDescriptor {
/// Serialization format version (for future compatibility)
pub version: u32,
/// Layout configuration
pub layout_config: LayoutConfig,
/// Storage location
pub location: StorageKind,
/// NIXL metadata from the source node
pub nixl_metadata: NixlMetadata,
/// Memory descriptors for all regions backing this layout
pub memory_descriptors: Vec<MemoryDescriptor>,
/// Layout-type-specific details
pub layout_type_details: LayoutTypeDetails,
}
impl LayoutDescriptor {
/// Current serialization version
pub const CURRENT_VERSION: u32 = 1;
/// Serialize this layout to a JSON string.
///
/// # Returns
/// JSON string representation of the layout
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON: {}", e))
}
/// Serialize this layout to JSON bytes.
///
/// # Returns
/// UTF-8 encoded JSON bytes
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
serde_json::to_vec(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON bytes: {}", e))
}
/// Deserialize a layout from a JSON string.
///
/// # Arguments
/// * `json` - JSON string representation
///
/// # Returns
/// Deserialized layout
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON: {}", e))
}
/// Deserialize a layout from JSON bytes.
///
/// # Arguments
/// * `bytes` - UTF-8 encoded JSON bytes
///
/// # Returns
/// Deserialized layout
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON bytes: {}", e))
}
/// Get the layout configuration.
pub fn layout_config(&self) -> &LayoutConfig {
&self.layout_config
}
/// Get the storage location.
pub fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata from the source node.
pub fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get the memory descriptors.
pub fn memory_descriptors(&self) -> &[MemoryDescriptor] {
&self.memory_descriptors
}
/// Get the layout type details.
pub fn layout_type_details(&self) -> &LayoutTypeDetails {
&self.layout_type_details
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
#[test]
fn test_block_format_default() {
assert_eq!(BlockFormat::default(), BlockFormat::Operational);
}
#[test]
fn test_serialized_layout_json_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), nixl_sys::MemType::Dram, 0),
memory_descriptors: vec![MemoryDescriptor::new(0x1000, 4096)],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
}),
};
// Test to_json/from_json
let json = layout.to_json().unwrap();
let deserialized = LayoutDescriptor::from_json(&json).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.layout_config, layout.layout_config);
assert_eq!(deserialized.location, layout.location);
assert_eq!(
deserialized.nixl_metadata.agent_name(),
layout.nixl_metadata.agent_name()
);
assert_eq!(deserialized.memory_descriptors.len(), 1);
}
#[test]
fn test_serialized_layout_json_bytes_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), nixl_sys::MemType::Vram, 5),
memory_descriptors: vec![
MemoryDescriptor::new(0x1000, 2048),
MemoryDescriptor::new(0x2000, 2048),
],
layout_type_details: LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsFirstDim,
}),
};
// Test to_json_bytes/from_json_bytes
let bytes = layout.to_json_bytes().unwrap();
let deserialized = LayoutDescriptor::from_json_bytes(&bytes).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.nixl_metadata.device_id(), 5);
assert_eq!(deserialized.memory_descriptors.len(), 2);
}
#[test]
fn test_fully_contiguous_details_serialization() {
let details = LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::FullyContiguous(d) => {
assert_eq!(d.block_format, BlockFormat::Operational);
}
_ => panic!("Expected FullyContiguous variant"),
}
}
#[test]
fn test_layer_separate_details_serialization() {
let details = LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsSecondDim,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::LayerSeparate(d) => {
assert_eq!(d.block_dim, BlockDimension::BlockIsSecondDim);
}
_ => panic!("Expected LayerSeparate variant"),
}
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tensor validation utilities for layout creation.
use anyhow::{Result, anyhow};
use std::sync::Arc;
use crate::block_manager::v2::memory::TorchTensor;
/// Format of tensor layout (for future TP translation).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorFormat {
/// NHD format: [N, H, D] where N=block_size, H=heads, D=hidden
NHD,
/// HND format: [H, N, D] where H=heads, N=block_size, D=hidden
HND,
/// Unknown or ambiguous format
Unknown,
}
/// Validate tensor strides and detect format.
///
/// This function checks that tensor strides are monotonically decreasing,
/// which ensures tensor-contiguous layout. The stride validation is flexible
/// at the inner dimension boundary to accommodate different layouts.
///
/// Additionally, it attempts to detect whether the layout is NHD or HND format,
/// which is important for future tensor parallel (TP) translation.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The detected tensor format (NHD, HND, or Unknown)
pub fn validate_tensor_strides(tensors: &[Arc<dyn TorchTensor>]) -> Result<TensorFormat> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let mut format = TensorFormat::Unknown;
for tensor in tensors {
let stride = tensor.stride();
let shape = tensor.shape();
if stride.len() < 2 {
return Err(anyhow!(
"Tensor must have at least 2 dimensions, got stride: {:?}",
stride
));
}
// Check monotonic decreasing stride
// Note: We're flexible at the combined inner dimension boundary as per requirements
let mut prev_stride = usize::MAX;
for (i, &current_stride) in stride.iter().enumerate() {
if current_stride > prev_stride {
return Err(anyhow!(
"Tensor strides must be monotonically decreasing (until inner dimension). \
Got stride: {:?} at position {}",
stride,
i
));
}
prev_stride = current_stride;
}
// Attempt to detect NHD vs HND format based on shape and stride patterns
// This is a heuristic and may need refinement based on actual usage
if shape.len() >= 3 {
// If the first dimension stride is smaller than the second, likely HND
// If the first dimension stride is larger than the second, likely NHD
if stride[0] < stride[1] {
format = TensorFormat::HND;
} else if stride[0] > stride[1] {
format = TensorFormat::NHD;
}
}
}
Ok(format)
}
/// Validate that all tensors have consistent shapes.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The common shape shared by all tensors
pub fn validate_tensor_shapes(tensors: &[Arc<dyn TorchTensor>]) -> Result<Vec<usize>> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(anyhow!(
"All tensors must have the same shape. Expected {:?}, got {:?}",
first_shape,
tensor.shape()
));
}
}
Ok(first_shape)
}
#[allow(dead_code)]
pub fn determine_compressed_shape(shape: &[usize]) -> usize {
shape.iter().product()
}
#[cfg(test)]
mod tests {
// Note: These tests would require mock TorchTensor implementations
// which we can add if needed for testing infrastructure
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod layout;
pub mod manager;
pub mod transfer;
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment