Unverified Commit 80256acf authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding outer dimension to isolate k/v blocks (#1126)

parent 7e452a2e
......@@ -46,10 +46,11 @@ pub struct BlockManager {
#[pymethods]
impl BlockManager {
#[new]
#[pyo3(signature = (worker_id, num_layer, page_size, inner_dim, dtype=None, host_num_blocks=None, device_num_blocks=None, device_id=0))]
#[pyo3(signature = (worker_id, num_layer, outer_dim, page_size, inner_dim, dtype=None, host_num_blocks=None, device_num_blocks=None, device_id=0))]
fn new(
worker_id: u64,
num_layer: usize,
outer_dim: usize,
page_size: usize,
inner_dim: usize,
dtype: Option<String>,
......@@ -65,6 +66,7 @@ impl BlockManager {
);
let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
.num_layers(num_layer)
.outer_dim(outer_dim)
.page_size(page_size)
.inner_dim(inner_dim);
let mut dtype_ = dynamo_llm::common::dtype::DType::FP16; // Default in block_manager config
......
......@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge
WORKER_ID = 0
NUM_LAYER = 5
OUTER_DIM = 2
PAGE_SIZE = 4
INNER_DIM = 13
DTYPE, TORCH_DTYPE = "FP32", torch.float32
......@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID = 0
@pytest.fixture
def block_manager():
"""Pytest fixture for creating a BlockManager instance."""
return BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_manager_initialization():
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# it may not be garbage collected immediately, depending on the garbage collector.
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM)
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM, DTYPE)
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS)
BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE)
BlockManager(
WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS
)
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
......@@ -52,6 +72,7 @@ async def test_block_manager_initialization():
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
......@@ -61,6 +82,7 @@ async def test_block_manager_initialization():
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
......@@ -70,6 +92,7 @@ async def test_block_manager_initialization():
BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
......@@ -80,17 +103,7 @@ async def test_block_manager_initialization():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_cpu_block_access():
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
async def test_cpu_block_access(block_manager: BlockManager):
block_count = 2
block_list = block_manager.allocate_host_blocks_blocking(block_count)
py_blocks = block_list.to_list()
......@@ -117,17 +130,7 @@ async def test_cpu_block_access():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_gpu_block_access():
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
async def test_gpu_block_access(block_manager: BlockManager):
block_count = 6
block_list = block_manager.allocate_device_blocks_blocking(block_count)
py_blocks = block_list.to_list()
......@@ -154,17 +157,7 @@ async def test_gpu_block_access():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_list_iteration():
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
async def test_block_list_iteration(block_manager: BlockManager):
block_count = 4
block_list = block_manager.allocate_host_blocks_blocking(block_count)
# Test __len__()
......@@ -192,17 +185,7 @@ async def test_block_list_iteration():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_copy_g1_g2():
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
async def test_block_copy_g1_g2(block_manager: BlockManager):
# Allocate device (G1) and host (G2) block
host_block_list = block_manager.allocate_host_blocks_blocking(1)
device_block_list = block_manager.allocate_device_blocks_blocking(1)
......@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2():
async def main():
await test_block_manager_initialization()
await test_cpu_block_access()
await test_gpu_block_access()
await test_block_list_iteration()
await test_block_copy_g1_g2()
# todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v)
# await test_cpu_block_access()
# await test_gpu_block_access()
# await test_block_list_iteration()
# await test_block_copy_g1_g2()
if __name__ == "__main__":
......
......@@ -27,6 +27,9 @@ description = "Dynamo LLM Library"
[features]
default = []
# todo: enable this as default
# default = ["block-manager", "testing-full"]
testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
......
......@@ -203,6 +203,7 @@ mod tests {
.model(
KvManagerModelConfig::builder()
.num_layers(3)
.outer_dim(2)
.page_size(4)
.inner_dim(16)
.build()
......@@ -241,6 +242,8 @@ mod tests {
let _block_manager = create_reference_block_manager();
}
// todo: solve the async runtime issue
#[ignore]
#[test]
fn test_reference_block_manager_blocking() {
dynamo_runtime::logging::init();
......
......@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns the number of layers in the block
fn num_layers(&self) -> usize;
/// Returns the number of outer dimensions in the block
fn num_outer_dims(&self) -> usize;
/// Get a read-only view of this block's storage for a layer
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>>;
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>>;
/// Get a mutable view of this block's storage for a layer
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>>;
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>>;
/// Get a read-only view of this block's storage
fn block_view(&self) -> BlockResult<view::BlockView<S>>;
......@@ -451,21 +458,34 @@ where
self.layout.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?;
unsafe { view::LayerView::new(self, offset as usize, self.layout.memory_region_size()) }
fn num_outer_dims(&self) -> usize {
self.layout.outer_dim()
}
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerView::new(self, mr.addr(), mr.size()) }
}
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>> {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?;
unsafe { view::LayerViewMut::new(self, offset as usize, self.layout.memory_region_size()) }
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size()) }
}
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers();
unsafe { view::BlockView::new(self, offset as usize, size) }
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
unsafe { view::BlockView::new(self, offset, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
......@@ -475,9 +495,10 @@ where
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers();
unsafe { view::BlockViewMut::new(self, offset as usize, size) }
let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let offset = mr.addr();
let size = mr.size() * self.num_layers();
unsafe { view::BlockViewMut::new(self, offset, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
......@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB
self.data.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.data.layer_view(layer_idx)
fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims()
}
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.data.layer_view(layer_idx, outer_idx)
}
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>> {
self.data.layer_view_mut(layer_idx)
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
self.data.layer_view_mut(layer_idx, outer_idx)
}
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
......@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl
self.block.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.block.layer_view(layer_idx)
fn num_outer_dims(&self) -> usize {
self.block.num_outer_dims()
}
fn layer_view_mut(&mut self, _: usize) -> BlockResult<view::LayerViewMut<S>> {
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.block.layer_view(layer_idx, outer_idx)
}
fn layer_view_mut(&mut self, _: usize, _: usize) -> BlockResult<view::LayerViewMut<S>> {
// This should never be called since ImmutableBlock is immutable,
// but we need to implement the full trait
Err(BlockError::InvalidState(
......@@ -946,6 +979,7 @@ pub mod nixl {
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>;
}
......@@ -961,6 +995,7 @@ pub mod nixl {
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>>;
}
......@@ -974,8 +1009,9 @@ pub mod nixl {
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
Ok(self.layer_view(layer_idx)?.as_nixl_descriptor())
Ok(self.layer_view(layer_idx, outer_idx)?.as_nixl_descriptor())
}
}
......@@ -989,8 +1025,11 @@ pub mod nixl {
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
Ok(self.layer_view_mut(layer_idx)?.as_nixl_descriptor_mut())
Ok(self
.layer_view_mut(layer_idx, outer_idx)?
.as_nixl_descriptor_mut())
}
}
......@@ -1188,15 +1227,24 @@ pub mod nixl {
self.data.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<NixlStorage>> {
self.data.layer_view(layer_idx)
fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims()
}
fn layer_view(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerView<NixlStorage>> {
self.data.layer_view(layer_idx, outer_idx)
}
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<NixlStorage>> {
self.data.layer_view_mut(layer_idx)
self.data.layer_view_mut(layer_idx, outer_idx)
}
fn block_view(&self) -> BlockResult<view::BlockView<NixlStorage>> {
......@@ -1224,8 +1272,9 @@ pub mod nixl {
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
self.data.as_layer_descriptor(layer_idx)
self.data.as_layer_descriptor(layer_idx, outer_idx)
}
}
......@@ -1244,8 +1293,9 @@ pub mod nixl {
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
self.data.as_layer_descriptor_mut(layer_idx)
self.data.as_layer_descriptor_mut(layer_idx, outer_idx)
}
}
......@@ -1733,7 +1783,8 @@ mod tests {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.num_layers(3)
.outer_dim(2)
.page_size(4)
.inner_dim(13)
.build()
......@@ -1780,6 +1831,7 @@ mod tests {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.outer_dim(1)
.page_size(4)
.inner_dim(13)
.build()
......@@ -1803,12 +1855,12 @@ mod tests {
assert_eq!(mutable_block.num_layers(), 2);
// Test layer_view()
let layer_view = mutable_block.layer_view(0).unwrap();
let layer_view = mutable_block.layer_view(0, 0).unwrap();
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view.as_ptr() }.is_null());
// Test layer_view_mut()
let mut layer_view_mut = mutable_block.layer_view_mut(1).unwrap();
let mut layer_view_mut = mutable_block.layer_view_mut(1, 0).unwrap();
assert_eq!(layer_view_mut.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view_mut.as_mut_ptr() }.is_null());
......@@ -1833,6 +1885,7 @@ mod tests {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.outer_dim(1)
.page_size(4)
.inner_dim(13)
.build()
......@@ -1860,7 +1913,7 @@ mod tests {
assert_eq!(immutable_block.num_layers(), 2);
// Test layer_view()
let layer_view = immutable_block.layer_view(0).unwrap();
let layer_view = immutable_block.layer_view(0, 0).unwrap();
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view.as_ptr() }.is_null());
......@@ -1872,7 +1925,7 @@ mod tests {
// Test that mutable methods return errors
let mut mut_immutable_block = immutable_block; // We need a mutable reference for these tests
let layer_view_mut_res = mut_immutable_block.layer_view_mut(0);
let layer_view_mut_res = mut_immutable_block.layer_view_mut(0, 0);
assert!(layer_view_mut_res.is_err());
if let Err(BlockError::InvalidState(msg)) = layer_view_mut_res {
assert!(msg.contains("immutable block"));
......
......@@ -112,18 +112,20 @@ where
}
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
for outer_idx in 0..src_data.num_outer_dims() {
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
}
}
}
Ok(())
......
......@@ -57,12 +57,14 @@ where
let dst_data = destinations.block_data_mut(private::PrivateToken);
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
for outer_idx in 0..src_data.num_outer_dims() {
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
}
}
}
Ok(())
......
......@@ -132,26 +132,28 @@ where
// }
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
for outer_idx in 0..src_data.num_outer_dims() {
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
}
......
......@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig {
#[validate(range(min = 1))]
pub num_layers: usize,
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
#[validate(range(min = 1))]
pub page_size: usize,
......
......@@ -84,6 +84,7 @@
//! let config = LayoutConfig::builder()
//! .num_blocks(10)
//! .num_layers(4)
//! .outer_dim(1)
//! .page_size(16)
//! .inner_dim(128)
//! .dtype(DType::FP16)
......@@ -109,8 +110,12 @@
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! layouts to be registered and serialized for use in distributed environments.
// todo: coming soon...
// pub mod distributed;
pub mod nixl;
use derive_getters::Getters;
use thiserror::Error;
use crate::block_manager::storage::{Storage, StorageAllocator};
......@@ -138,6 +143,9 @@ pub enum LayoutError {
#[error("Invalid layer index: {0}")]
InvalidLayerIndex(usize),
#[error("Invalid outer index: {0}")]
InvalidOuterIndex(usize),
#[error("Operation failed: {0}")]
OperationFailed(String),
......@@ -165,10 +173,18 @@ pub enum LayoutType {
// Null,
}
/// Local Memory Region
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct LocalMemoryRegion {
#[getter(copy)]
addr: usize,
#[getter(copy)]
size: usize,
}
/// Core trait for block layouts
pub trait BlockLayout:
BlockLayoutConfig + BlockLayoutLookup + Send + Sync + std::fmt::Debug
{
pub trait BlockLayout: BlockLayoutConfig + Send + Sync + std::fmt::Debug {
/// The type of storage this layout uses
type StorageType: Storage;
......@@ -180,6 +196,21 @@ pub trait BlockLayout:
/// Storage type for the layout
fn storage_type(&self) -> StorageType;
/// Get the memory region for a specific page [page_size, inner_dim]
///
/// # Arguments
///
/// * `block_idx` - The index of the block
/// * `layer_idx` - The index of the layer
/// * `outer_idx` - The index of the outer dimension, e.g. if
///
fn memory_region(
&self,
block_idx: usize,
layer_idx: usize,
outer_idx: usize,
) -> Result<LocalMemoryRegion, LayoutError>;
}
/// Configuration for block layouts
......@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
/// Returns the number of layers per block
fn num_layers(&self) -> usize;
/// Returns the number of outer dimensions per block
/// In some cases, K and V might be indexed separately, so in that example one might have 2 outer dimensions
/// For MLA, this is 1.
/// The location of the outer dimension in the shape of the tensor layout is defined by the layout type.
fn outer_dim(&self) -> usize;
/// Returns the size of each block in bytes
fn page_size(&self) -> usize;
......@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
fn inner_dim(&self) -> usize;
}
/// Trait for looking up memory regions in a block layout
pub trait BlockLayoutLookup {
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError>;
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_size(&self) -> usize;
}
/// Configuration for block layouts
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize)]
pub struct LayoutConfig {
......@@ -220,6 +248,10 @@ pub struct LayoutConfig {
#[validate(range(min = 1))]
pub num_layers: usize,
/// Number of outer dimensions
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
/// Page size
#[validate(range(min = 1))]
pub page_size: usize,
......@@ -268,11 +300,25 @@ fn align_up(value: usize, alignment: usize) -> usize {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct FullyContiguousConfig {
inner: LayoutConfig,
/// Minimum contiguous memory region size
/// Inner dimension * page size * dtype size
memory_region_size: usize,
/// Stride between outer dimensions
outer_dim_stride_in_bytes: usize,
/// Stride between layers
layer_stride_in_bytes: usize,
/// Natural block stride
natural_block_stride: usize,
/// Block stride in bytes
block_stride_in_bytes: usize, // Aligned if necessary
layout_data_bytes: usize, // Size of the layout data itself (post base offset)
/// Size of the layout data itself (post base offset)
layout_data_bytes: usize, // Size of the layout data itself (post base offset)
}
impl FullyContiguousConfig {
......@@ -284,7 +330,8 @@ impl FullyContiguousConfig {
let alignment = config.alignment;
let memory_region_size = config.page_size * config.inner_dim * config.dtype.size_in_bytes();
let layer_stride_in_bytes = memory_region_size;
let outer_dim_stride_in_bytes = memory_region_size;
let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim;
let natural_block_stride = config.num_layers * layer_stride_in_bytes;
let block_stride_in_bytes = if alignment > 1 {
......@@ -299,6 +346,7 @@ impl FullyContiguousConfig {
Ok(Self {
inner: config,
memory_region_size,
outer_dim_stride_in_bytes,
layer_stride_in_bytes,
natural_block_stride,
block_stride_in_bytes,
......@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig {
self.inner.num_layers
}
fn outer_dim(&self) -> usize {
self.inner.outer_dim
}
fn page_size(&self) -> usize {
self.inner.page_size
}
......@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> {
fn storage_type(&self) -> StorageType {
self.storage_type.clone()
}
fn memory_region(
&self,
block_idx: usize,
layer_idx: usize,
outer_idx: usize,
) -> Result<LocalMemoryRegion, LayoutError> {
if block_idx >= self.num_blocks() {
return Err(LayoutError::InvalidBlockIndex(block_idx));
}
if layer_idx >= self.num_layers() {
return Err(LayoutError::InvalidLayerIndex(layer_idx));
}
if outer_idx >= self.outer_dim() {
return Err(LayoutError::InvalidOuterIndex(outer_idx));
}
// Start from the aligned base address
let aligned_start_addr = self.storage.addr() as usize + self.base_offset;
// Calculate offset relative to the aligned start using stored config
let block_offset = block_idx * self.config.block_stride_in_bytes;
let layer_offset = layer_idx * self.config.layer_stride_in_bytes;
let outer_offset = outer_idx * self.config.outer_dim_stride_in_bytes;
let final_addr = aligned_start_addr + block_offset + layer_offset + outer_offset;
Ok(LocalMemoryRegion {
addr: final_addr,
size: self.config.memory_region_size,
})
}
}
impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
......@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
self.config.inner.num_layers
}
fn outer_dim(&self) -> usize {
self.config.inner.outer_dim
}
fn page_size(&self) -> usize {
self.config.inner.page_size
}
......@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
}
}
impl<S: Storage> BlockLayoutLookup for FullyContiguous<S> {
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError> {
if block_idx >= self.num_blocks() {
return Err(LayoutError::InvalidBlockIndex(block_idx));
}
if layer_idx >= self.num_layers() {
return Err(LayoutError::InvalidLayerIndex(layer_idx));
}
// Start from the aligned base address
let aligned_start_addr = self.storage.addr() + self.base_offset as u64;
// Calculate offset relative to the aligned start using stored config
let block_offset = block_idx * self.config.block_stride_in_bytes;
let layer_offset = layer_idx * self.config.layer_stride_in_bytes;
let final_addr = aligned_start_addr + block_offset as u64 + layer_offset as u64;
Ok(final_addr)
}
fn memory_region_size(&self) -> usize {
// Access via stored dims
self.config.memory_region_size
}
}
#[allow(missing_docs)]
#[cfg(test)]
pub mod tests {
......@@ -570,6 +632,7 @@ pub mod tests {
const NUM_BLOCKS: usize = 7;
const NUM_LAYERS: usize = 5;
const OUTER_DIM: usize = 2;
const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 13;
const DTYPE: DType = DType::FP32; // Example dtype
......@@ -592,6 +655,7 @@ pub mod tests {
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: alignment.unwrap_or(1),
......@@ -606,6 +670,7 @@ pub mod tests {
let config = LayoutConfig::builder()
.num_blocks(NUM_BLOCKS)
.num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE)
.inner_dim(INNER_DIM)
.alignment(3)
......@@ -632,6 +697,7 @@ pub mod tests {
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: 1,
......@@ -656,17 +722,11 @@ pub mod tests {
assert_eq!(layout.num_blocks(), NUM_BLOCKS);
assert_eq!(layout.num_layers(), NUM_LAYERS);
assert_eq!(layout.outer_dim(), OUTER_DIM);
assert_eq!(layout.page_size(), PAGE_SIZE);
assert_eq!(layout.inner_dim(), INNER_DIM);
}
#[test]
fn test_fc_memory_region_size() {
let layout = setup_layout(None).expect("Layout setup failed");
let expected_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(layout.memory_region_size(), expected_region_size);
}
#[test]
fn test_fc_offset_calculation() {
let layout = setup_layout(None).expect("Layout setup failed");
......@@ -680,7 +740,7 @@ pub mod tests {
let expected_offset_0_0 =
calculate_expected_offset(base_addr, 0, 0, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(0, 0).unwrap(),
layout.memory_region(0, 0, 0).unwrap().addr as u64,
expected_offset_0_0
);
......@@ -689,7 +749,7 @@ pub mod tests {
let expected_offset_0_last =
calculate_expected_offset(base_addr, 0, last_layer_idx, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(0, last_layer_idx).unwrap(),
layout.memory_region(0, last_layer_idx, 0).unwrap().addr as u64,
expected_offset_0_last
);
......@@ -698,7 +758,7 @@ pub mod tests {
let expected_offset_last_0 =
calculate_expected_offset(base_addr, last_block_idx, 0, block_stride, layer_stride);
assert_eq!(
layout.memory_region_addr(last_block_idx, 0).unwrap(),
layout.memory_region(last_block_idx, 0, 0).unwrap().addr as u64,
expected_offset_last_0
);
......@@ -712,8 +772,9 @@ pub mod tests {
);
assert_eq!(
layout
.memory_region_addr(last_block_idx, last_layer_idx)
.unwrap(),
.memory_region(last_block_idx, last_layer_idx, 0)
.unwrap()
.addr as u64,
expected_offset_last_last
);
......@@ -729,8 +790,9 @@ pub mod tests {
);
assert_eq!(
layout
.memory_region_addr(mid_block_idx, mid_layer_idx)
.unwrap(),
.memory_region(mid_block_idx, mid_layer_idx, 0)
.unwrap()
.addr as u64,
expected_offset_mid_mid
);
}
......@@ -738,7 +800,7 @@ pub mod tests {
#[test]
fn test_fc_invalid_block_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(NUM_BLOCKS, 0); // Index == num_blocks (out of bounds)
let result = layout.memory_region(NUM_BLOCKS, 0, 0); // Index == num_blocks (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
......@@ -749,7 +811,7 @@ pub mod tests {
#[test]
fn test_fc_invalid_layer_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(0, NUM_LAYERS); // Index == num_layers (out of bounds)
let result = layout.memory_region(0, NUM_LAYERS, 0); // Index == num_layers (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
......@@ -757,12 +819,24 @@ pub mod tests {
));
}
#[test]
fn test_fc_invalid_outer_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region(0, 0, OUTER_DIM); // Index == num_outer_dims (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
LayoutError::InvalidOuterIndex(OUTER_DIM)
));
}
#[test]
fn test_fc_allocation_system() {
init_logging();
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: 1,
......@@ -788,7 +862,7 @@ pub mod tests {
assert_eq!(
layout.storage.size(),
NUM_BLOCKS * NUM_LAYERS * PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes()
NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes()
);
}
......@@ -800,6 +874,7 @@ pub mod tests {
let config = LayoutConfig {
num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE,
inner_dim: INNER_DIM,
alignment: ALIGNMENT,
......@@ -810,11 +885,11 @@ pub mod tests {
let memory_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(memory_region_size, 208);
let natural_block_stride = NUM_LAYERS * memory_region_size;
assert_eq!(natural_block_stride, 1040);
let natural_block_stride = OUTER_DIM * NUM_LAYERS * memory_region_size;
assert_eq!(natural_block_stride, 2080);
let aligned_block_stride = align_up(natural_block_stride, ALIGNMENT);
assert_eq!(aligned_block_stride, 1280);
assert_eq!(aligned_block_stride, 2304);
// Calculate the expected *allocated* size (data + initial padding)
let fc_config = FullyContiguousConfig::new(config.clone()).unwrap();
......@@ -844,40 +919,40 @@ pub mod tests {
// Check alignment of block starts
let addr_block_0 = layout
.memory_region_addr(0, 0)
.memory_region(0, 0, 0)
.expect("Failed to get addr block 0");
let addr_block_1 = layout
.memory_region_addr(1, 0)
.memory_region(1, 0, 0)
.expect("Failed to get addr block 1");
let addr_block_2 = layout
.memory_region_addr(2, 0)
.memory_region(2, 0, 0)
.expect("Failed to get addr block 2");
// All blocks should now be aligned due to base_offset adjustment
assert_eq!(
addr_block_0 % ALIGNMENT as u64,
addr_block_0.addr as u64 % ALIGNMENT as u64,
0,
"Block 0 start address is not aligned"
);
assert_eq!(
addr_block_1 % ALIGNMENT as u64,
addr_block_1.addr as u64 % ALIGNMENT as u64,
0,
"Block 1 start address is not aligned"
);
assert_eq!(
addr_block_2 % ALIGNMENT as u64,
addr_block_2.addr as u64 % ALIGNMENT as u64,
0,
"Block 2 start address is not aligned"
);
// Verify the difference matches the aligned stride
assert_eq!(
addr_block_1 - addr_block_0,
addr_block_1.addr as u64 - addr_block_0.addr as u64,
aligned_block_stride as u64,
"Stride between block 0 and 1 mismatch"
);
assert_eq!(
addr_block_2 - addr_block_1,
addr_block_2.addr as u64 - addr_block_1.addr as u64,
aligned_block_stride as u64,
"Stride between block 1 and 2 mismatch"
);
......
......@@ -332,6 +332,7 @@ mod tests {
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(13)
.build()
......
......@@ -506,6 +506,7 @@ mod tests {
let mut config = LayoutConfig {
num_blocks: device_blocks,
num_layers: 8,
outer_dim: 1,
page_size: BLOCK_SIZE,
inner_dim: 1024,
alignment: 1,
......
......@@ -582,6 +582,7 @@ pub(crate) mod tests {
let config = LayoutConfigBuilder::default()
.num_blocks(num_blocks)
.num_layers(61)
.outer_dim(1)
.page_size(16)
.inner_dim(576)
.build()
......
......@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
layout_builder
.num_layers(model.num_layers)
.outer_dim(model.outer_dim)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype(model.dtype);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment