"lib/vscode:/vscode.git/clone" did not exist on "1208f017f96060cdab69668d05605bb50ad8ede7"
Unverified Commit 80256acf authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding outer dimension to isolate k/v blocks (#1126)

parent 7e452a2e
...@@ -46,10 +46,11 @@ pub struct BlockManager { ...@@ -46,10 +46,11 @@ pub struct BlockManager {
#[pymethods] #[pymethods]
impl BlockManager { impl BlockManager {
#[new] #[new]
#[pyo3(signature = (worker_id, num_layer, page_size, inner_dim, dtype=None, host_num_blocks=None, device_num_blocks=None, device_id=0))] #[pyo3(signature = (worker_id, num_layer, outer_dim, page_size, inner_dim, dtype=None, host_num_blocks=None, device_num_blocks=None, device_id=0))]
fn new( fn new(
worker_id: u64, worker_id: u64,
num_layer: usize, num_layer: usize,
outer_dim: usize,
page_size: usize, page_size: usize,
inner_dim: usize, inner_dim: usize,
dtype: Option<String>, dtype: Option<String>,
...@@ -65,6 +66,7 @@ impl BlockManager { ...@@ -65,6 +66,7 @@ impl BlockManager {
); );
let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
.num_layers(num_layer) .num_layers(num_layer)
.outer_dim(outer_dim)
.page_size(page_size) .page_size(page_size)
.inner_dim(inner_dim); .inner_dim(inner_dim);
let mut dtype_ = dynamo_llm::common::dtype::DType::FP16; // Default in block_manager config let mut dtype_ = dynamo_llm::common::dtype::DType::FP16; // Default in block_manager config
......
...@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge ...@@ -26,6 +26,7 @@ pytestmark = pytest.mark.pre_merge
WORKER_ID = 0 WORKER_ID = 0
NUM_LAYER = 5 NUM_LAYER = 5
OUTER_DIM = 2
PAGE_SIZE = 4 PAGE_SIZE = 4
INNER_DIM = 13 INNER_DIM = 13
DTYPE, TORCH_DTYPE = "FP32", torch.float32 DTYPE, TORCH_DTYPE = "FP32", torch.float32
...@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16 ...@@ -34,16 +35,35 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID = 0 DEVICE_ID = 0
@pytest.fixture
def block_manager():
"""Pytest fixture for creating a BlockManager instance."""
return BlockManager(
WORKER_ID,
NUM_LAYER,
OUTER_DIM,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_manager_initialization(): async def test_block_manager_initialization():
# Python should drop the BlockManager instance as soon as it goes out of scope, but # Python should drop the BlockManager instance as soon as it goes out of scope, but
# it may not be garbage collected immediately, depending on the garbage collector. # it may not be garbage collected immediately, depending on the garbage collector.
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM) BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM, DTYPE) BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE)
BlockManager(WORKER_ID, NUM_LAYER, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS) BlockManager(
WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS
)
BlockManager( BlockManager(
WORKER_ID, WORKER_ID,
NUM_LAYER, NUM_LAYER,
OUTER_DIM,
PAGE_SIZE, PAGE_SIZE,
INNER_DIM, INNER_DIM,
DTYPE, DTYPE,
...@@ -52,6 +72,7 @@ async def test_block_manager_initialization(): ...@@ -52,6 +72,7 @@ async def test_block_manager_initialization():
BlockManager( BlockManager(
WORKER_ID, WORKER_ID,
NUM_LAYER, NUM_LAYER,
OUTER_DIM,
PAGE_SIZE, PAGE_SIZE,
INNER_DIM, INNER_DIM,
DTYPE, DTYPE,
...@@ -61,6 +82,7 @@ async def test_block_manager_initialization(): ...@@ -61,6 +82,7 @@ async def test_block_manager_initialization():
BlockManager( BlockManager(
WORKER_ID, WORKER_ID,
NUM_LAYER, NUM_LAYER,
OUTER_DIM,
PAGE_SIZE, PAGE_SIZE,
INNER_DIM, INNER_DIM,
DTYPE, DTYPE,
...@@ -70,6 +92,7 @@ async def test_block_manager_initialization(): ...@@ -70,6 +92,7 @@ async def test_block_manager_initialization():
BlockManager( BlockManager(
WORKER_ID, WORKER_ID,
NUM_LAYER, NUM_LAYER,
OUTER_DIM,
PAGE_SIZE, PAGE_SIZE,
INNER_DIM, INNER_DIM,
DTYPE, DTYPE,
...@@ -80,17 +103,7 @@ async def test_block_manager_initialization(): ...@@ -80,17 +103,7 @@ async def test_block_manager_initialization():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_cpu_block_access(): async def test_cpu_block_access(block_manager: BlockManager):
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
block_count = 2 block_count = 2
block_list = block_manager.allocate_host_blocks_blocking(block_count) block_list = block_manager.allocate_host_blocks_blocking(block_count)
py_blocks = block_list.to_list() py_blocks = block_list.to_list()
...@@ -117,17 +130,7 @@ async def test_cpu_block_access(): ...@@ -117,17 +130,7 @@ async def test_cpu_block_access():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_gpu_block_access(): async def test_gpu_block_access(block_manager: BlockManager):
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
block_count = 6 block_count = 6
block_list = block_manager.allocate_device_blocks_blocking(block_count) block_list = block_manager.allocate_device_blocks_blocking(block_count)
py_blocks = block_list.to_list() py_blocks = block_list.to_list()
...@@ -154,17 +157,7 @@ async def test_gpu_block_access(): ...@@ -154,17 +157,7 @@ async def test_gpu_block_access():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_list_iteration(): async def test_block_list_iteration(block_manager: BlockManager):
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
block_count = 4 block_count = 4
block_list = block_manager.allocate_host_blocks_blocking(block_count) block_list = block_manager.allocate_host_blocks_blocking(block_count)
# Test __len__() # Test __len__()
...@@ -192,17 +185,7 @@ async def test_block_list_iteration(): ...@@ -192,17 +185,7 @@ async def test_block_list_iteration():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_copy_g1_g2(): async def test_block_copy_g1_g2(block_manager: BlockManager):
block_manager = BlockManager(
WORKER_ID,
NUM_LAYER,
PAGE_SIZE,
INNER_DIM,
DTYPE,
HOST_NUM_BLOCKS,
DEVICE_NUM_BLOCKS,
DEVICE_ID,
)
# Allocate device (G1) and host (G2) block # Allocate device (G1) and host (G2) block
host_block_list = block_manager.allocate_host_blocks_blocking(1) host_block_list = block_manager.allocate_host_blocks_blocking(1)
device_block_list = block_manager.allocate_device_blocks_blocking(1) device_block_list = block_manager.allocate_device_blocks_blocking(1)
...@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2(): ...@@ -243,10 +226,12 @@ async def test_block_copy_g1_g2():
async def main(): async def main():
await test_block_manager_initialization() await test_block_manager_initialization()
await test_cpu_block_access()
await test_gpu_block_access() # todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v)
await test_block_list_iteration() # await test_cpu_block_access()
await test_block_copy_g1_g2() # await test_gpu_block_access()
# await test_block_list_iteration()
# await test_block_copy_g1_g2()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -27,6 +27,9 @@ description = "Dynamo LLM Library" ...@@ -27,6 +27,9 @@ description = "Dynamo LLM Library"
[features] [features]
default = [] default = []
# todo: enable this as default
# default = ["block-manager", "testing-full"]
testing-full = ["testing-cuda", "testing-nixl"] testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"] testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"] testing-nixl = ["dep:nixl-sys"]
......
...@@ -203,6 +203,7 @@ mod tests { ...@@ -203,6 +203,7 @@ mod tests {
.model( .model(
KvManagerModelConfig::builder() KvManagerModelConfig::builder()
.num_layers(3) .num_layers(3)
.outer_dim(2)
.page_size(4) .page_size(4)
.inner_dim(16) .inner_dim(16)
.build() .build()
...@@ -241,6 +242,8 @@ mod tests { ...@@ -241,6 +242,8 @@ mod tests {
let _block_manager = create_reference_block_manager(); let _block_manager = create_reference_block_manager();
} }
// todo: solve the async runtime issue
#[ignore]
#[test] #[test]
fn test_reference_block_manager_blocking() { fn test_reference_block_manager_blocking() {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
......
...@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> { ...@@ -393,11 +393,18 @@ pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns the number of layers in the block /// Returns the number of layers in the block
fn num_layers(&self) -> usize; fn num_layers(&self) -> usize;
/// Returns the number of outer dimensions in the block
fn num_outer_dims(&self) -> usize;
/// Get a read-only view of this block's storage for a layer /// Get a read-only view of this block's storage for a layer
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>>; fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>>;
/// Get a mutable view of this block's storage for a layer /// Get a mutable view of this block's storage for a layer
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>>; fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>>;
/// Get a read-only view of this block's storage /// Get a read-only view of this block's storage
fn block_view(&self) -> BlockResult<view::BlockView<S>>; fn block_view(&self) -> BlockResult<view::BlockView<S>>;
...@@ -451,21 +458,34 @@ where ...@@ -451,21 +458,34 @@ where
self.layout.num_layers() self.layout.num_layers()
} }
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> { fn num_outer_dims(&self) -> usize {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?; self.layout.outer_dim()
unsafe { view::LayerView::new(self, offset as usize, self.layout.memory_region_size()) }
} }
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>> { fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?; let mr = self
unsafe { view::LayerViewMut::new(self, offset as usize, self.layout.memory_region_size()) } .layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerView::new(self, mr.addr(), mr.size()) }
}
fn layer_view_mut(
&mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
let mr = self
.layout
.memory_region(self.block_idx, layer_idx, outer_idx)?;
unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size()) }
} }
fn block_view(&self) -> BlockResult<view::BlockView<S>> { fn block_view(&self) -> BlockResult<view::BlockView<S>> {
if self.is_fully_contiguous() { if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?; let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers(); let offset = mr.addr();
unsafe { view::BlockView::new(self, offset as usize, size) } let size = mr.size() * self.num_layers();
unsafe { view::BlockView::new(self, offset, size) }
} else { } else {
Err(BlockError::InvalidState( Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(), "Block is not fully contiguous".to_string(),
...@@ -475,9 +495,10 @@ where ...@@ -475,9 +495,10 @@ where
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> { fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
if self.is_fully_contiguous() { if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?; let mr = self.layout.memory_region(self.block_idx, 0, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers(); let offset = mr.addr();
unsafe { view::BlockViewMut::new(self, offset as usize, size) } let size = mr.size() * self.num_layers();
unsafe { view::BlockViewMut::new(self, offset, size) }
} else { } else {
Err(BlockError::InvalidState( Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(), "Block is not fully contiguous".to_string(),
...@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB ...@@ -626,12 +647,20 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for MutableB
self.data.num_layers() self.data.num_layers()
} }
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> { fn num_outer_dims(&self) -> usize {
self.data.layer_view(layer_idx) self.data.num_outer_dims()
}
fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.data.layer_view(layer_idx, outer_idx)
} }
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>> { fn layer_view_mut(
self.data.layer_view_mut(layer_idx) &mut self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<S>> {
self.data.layer_view_mut(layer_idx, outer_idx)
} }
fn block_view(&self) -> BlockResult<view::BlockView<S>> { fn block_view(&self) -> BlockResult<view::BlockView<S>> {
...@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl ...@@ -755,11 +784,15 @@ impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataExt<S> for Immutabl
self.block.num_layers() self.block.num_layers()
} }
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> { fn num_outer_dims(&self) -> usize {
self.block.layer_view(layer_idx) self.block.num_outer_dims()
} }
fn layer_view_mut(&mut self, _: usize) -> BlockResult<view::LayerViewMut<S>> { fn layer_view(&self, layer_idx: usize, outer_idx: usize) -> BlockResult<view::LayerView<S>> {
self.block.layer_view(layer_idx, outer_idx)
}
fn layer_view_mut(&mut self, _: usize, _: usize) -> BlockResult<view::LayerViewMut<S>> {
// This should never be called since ImmutableBlock is immutable, // This should never be called since ImmutableBlock is immutable,
// but we need to implement the full trait // but we need to implement the full trait
Err(BlockError::InvalidState( Err(BlockError::InvalidState(
...@@ -946,6 +979,7 @@ pub mod nixl { ...@@ -946,6 +979,7 @@ pub mod nixl {
fn as_layer_descriptor( fn as_layer_descriptor(
&self, &self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>; ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>;
} }
...@@ -961,6 +995,7 @@ pub mod nixl { ...@@ -961,6 +995,7 @@ pub mod nixl {
fn as_layer_descriptor_mut( fn as_layer_descriptor_mut(
&mut self, &mut self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>>; ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>>;
} }
...@@ -974,8 +1009,9 @@ pub mod nixl { ...@@ -974,8 +1009,9 @@ pub mod nixl {
fn as_layer_descriptor( fn as_layer_descriptor(
&self, &self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> { ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
Ok(self.layer_view(layer_idx)?.as_nixl_descriptor()) Ok(self.layer_view(layer_idx, outer_idx)?.as_nixl_descriptor())
} }
} }
...@@ -989,8 +1025,11 @@ pub mod nixl { ...@@ -989,8 +1025,11 @@ pub mod nixl {
fn as_layer_descriptor_mut( fn as_layer_descriptor_mut(
&mut self, &mut self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> { ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
Ok(self.layer_view_mut(layer_idx)?.as_nixl_descriptor_mut()) Ok(self
.layer_view_mut(layer_idx, outer_idx)?
.as_nixl_descriptor_mut())
} }
} }
...@@ -1188,15 +1227,24 @@ pub mod nixl { ...@@ -1188,15 +1227,24 @@ pub mod nixl {
self.data.num_layers() self.data.num_layers()
} }
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<NixlStorage>> { fn num_outer_dims(&self) -> usize {
self.data.layer_view(layer_idx) self.data.num_outer_dims()
}
fn layer_view(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerView<NixlStorage>> {
self.data.layer_view(layer_idx, outer_idx)
} }
fn layer_view_mut( fn layer_view_mut(
&mut self, &mut self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<view::LayerViewMut<NixlStorage>> { ) -> BlockResult<view::LayerViewMut<NixlStorage>> {
self.data.layer_view_mut(layer_idx) self.data.layer_view_mut(layer_idx, outer_idx)
} }
fn block_view(&self) -> BlockResult<view::BlockView<NixlStorage>> { fn block_view(&self) -> BlockResult<view::BlockView<NixlStorage>> {
...@@ -1224,8 +1272,9 @@ pub mod nixl { ...@@ -1224,8 +1272,9 @@ pub mod nixl {
fn as_layer_descriptor( fn as_layer_descriptor(
&self, &self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> { ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
self.data.as_layer_descriptor(layer_idx) self.data.as_layer_descriptor(layer_idx, outer_idx)
} }
} }
...@@ -1244,8 +1293,9 @@ pub mod nixl { ...@@ -1244,8 +1293,9 @@ pub mod nixl {
fn as_layer_descriptor_mut( fn as_layer_descriptor_mut(
&mut self, &mut self,
layer_idx: usize, layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> { ) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
self.data.as_layer_descriptor_mut(layer_idx) self.data.as_layer_descriptor_mut(layer_idx, outer_idx)
} }
} }
...@@ -1733,7 +1783,8 @@ mod tests { ...@@ -1733,7 +1783,8 @@ mod tests {
let config = LayoutConfig::builder() let config = LayoutConfig::builder()
.num_blocks(10) .num_blocks(10)
.num_layers(2) .num_layers(3)
.outer_dim(2)
.page_size(4) .page_size(4)
.inner_dim(13) .inner_dim(13)
.build() .build()
...@@ -1780,6 +1831,7 @@ mod tests { ...@@ -1780,6 +1831,7 @@ mod tests {
let config = LayoutConfig::builder() let config = LayoutConfig::builder()
.num_blocks(10) .num_blocks(10)
.num_layers(2) .num_layers(2)
.outer_dim(1)
.page_size(4) .page_size(4)
.inner_dim(13) .inner_dim(13)
.build() .build()
...@@ -1803,12 +1855,12 @@ mod tests { ...@@ -1803,12 +1855,12 @@ mod tests {
assert_eq!(mutable_block.num_layers(), 2); assert_eq!(mutable_block.num_layers(), 2);
// Test layer_view() // Test layer_view()
let layer_view = mutable_block.layer_view(0).unwrap(); let layer_view = mutable_block.layer_view(0, 0).unwrap();
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view.as_ptr() }.is_null()); assert!(!unsafe { layer_view.as_ptr() }.is_null());
// Test layer_view_mut() // Test layer_view_mut()
let mut layer_view_mut = mutable_block.layer_view_mut(1).unwrap(); let mut layer_view_mut = mutable_block.layer_view_mut(1, 0).unwrap();
assert_eq!(layer_view_mut.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes assert_eq!(layer_view_mut.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view_mut.as_mut_ptr() }.is_null()); assert!(!unsafe { layer_view_mut.as_mut_ptr() }.is_null());
...@@ -1833,6 +1885,7 @@ mod tests { ...@@ -1833,6 +1885,7 @@ mod tests {
let config = LayoutConfig::builder() let config = LayoutConfig::builder()
.num_blocks(10) .num_blocks(10)
.num_layers(2) .num_layers(2)
.outer_dim(1)
.page_size(4) .page_size(4)
.inner_dim(13) .inner_dim(13)
.build() .build()
...@@ -1860,7 +1913,7 @@ mod tests { ...@@ -1860,7 +1913,7 @@ mod tests {
assert_eq!(immutable_block.num_layers(), 2); assert_eq!(immutable_block.num_layers(), 2);
// Test layer_view() // Test layer_view()
let layer_view = immutable_block.layer_view(0).unwrap(); let layer_view = immutable_block.layer_view(0, 0).unwrap();
assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes assert_eq!(layer_view.size(), 4 * 13 * 2); // page_size x inner_dim x dtype_bytes
assert!(!unsafe { layer_view.as_ptr() }.is_null()); assert!(!unsafe { layer_view.as_ptr() }.is_null());
...@@ -1872,7 +1925,7 @@ mod tests { ...@@ -1872,7 +1925,7 @@ mod tests {
// Test that mutable methods return errors // Test that mutable methods return errors
let mut mut_immutable_block = immutable_block; // We need a mutable reference for these tests let mut mut_immutable_block = immutable_block; // We need a mutable reference for these tests
let layer_view_mut_res = mut_immutable_block.layer_view_mut(0); let layer_view_mut_res = mut_immutable_block.layer_view_mut(0, 0);
assert!(layer_view_mut_res.is_err()); assert!(layer_view_mut_res.is_err());
if let Err(BlockError::InvalidState(msg)) = layer_view_mut_res { if let Err(BlockError::InvalidState(msg)) = layer_view_mut_res {
assert!(msg.contains("immutable block")); assert!(msg.contains("immutable block"));
......
...@@ -112,8 +112,9 @@ where ...@@ -112,8 +112,9 @@ where
} }
for layer_idx in layer_range { for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?; for outer_idx in 0..src_data.num_outer_dims() {
let mut dst_view = dst_data.layer_view_mut(layer_idx)?; let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size()); debug_assert_eq!(src_view.size(), dst_view.size());
...@@ -126,6 +127,7 @@ where ...@@ -126,6 +127,7 @@ where
)?; )?;
} }
} }
}
Ok(()) Ok(())
} }
......
...@@ -57,14 +57,16 @@ where ...@@ -57,14 +57,16 @@ where
let dst_data = destinations.block_data_mut(private::PrivateToken); let dst_data = destinations.block_data_mut(private::PrivateToken);
for layer_idx in layer_range { for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?; for outer_idx in 0..src_data.num_outer_dims() {
let mut dst_view = dst_data.layer_view_mut(layer_idx)?; let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size()); debug_assert_eq!(src_view.size(), dst_view.size());
unsafe { unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size()); memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
} }
} }
}
Ok(()) Ok(())
} }
......
...@@ -132,8 +132,9 @@ where ...@@ -132,8 +132,9 @@ where
// } // }
for layer_idx in layer_range { for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?; for outer_idx in 0..src_data.num_outer_dims() {
let mut dst_view = dst_data.layer_view_mut(layer_idx)?; let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size()); debug_assert_eq!(src_view.size(), dst_view.size());
...@@ -154,6 +155,7 @@ where ...@@ -154,6 +155,7 @@ where
)?; )?;
} }
} }
}
let mut xfer_args = OptArgs::new()?; let mut xfer_args = OptArgs::new()?;
......
...@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig { ...@@ -71,6 +71,9 @@ pub struct KvManagerModelConfig {
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub num_layers: usize, pub num_layers: usize,
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub page_size: usize, pub page_size: usize,
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
//! let config = LayoutConfig::builder() //! let config = LayoutConfig::builder()
//! .num_blocks(10) //! .num_blocks(10)
//! .num_layers(4) //! .num_layers(4)
//! .outer_dim(1)
//! .page_size(16) //! .page_size(16)
//! .inner_dim(128) //! .inner_dim(128)
//! .dtype(DType::FP16) //! .dtype(DType::FP16)
...@@ -109,8 +110,12 @@ ...@@ -109,8 +110,12 @@
//! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling //! which extends these layout concepts for NIXL (NVIDIA Interface eXchange Layer), enabling
//! layouts to be registered and serialized for use in distributed environments. //! layouts to be registered and serialized for use in distributed environments.
// todo: coming soon...
// pub mod distributed;
pub mod nixl; pub mod nixl;
use derive_getters::Getters;
use thiserror::Error; use thiserror::Error;
use crate::block_manager::storage::{Storage, StorageAllocator}; use crate::block_manager::storage::{Storage, StorageAllocator};
...@@ -138,6 +143,9 @@ pub enum LayoutError { ...@@ -138,6 +143,9 @@ pub enum LayoutError {
#[error("Invalid layer index: {0}")] #[error("Invalid layer index: {0}")]
InvalidLayerIndex(usize), InvalidLayerIndex(usize),
#[error("Invalid outer index: {0}")]
InvalidOuterIndex(usize),
#[error("Operation failed: {0}")] #[error("Operation failed: {0}")]
OperationFailed(String), OperationFailed(String),
...@@ -165,10 +173,18 @@ pub enum LayoutType { ...@@ -165,10 +173,18 @@ pub enum LayoutType {
// Null, // Null,
} }
/// Local Memory Region
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct LocalMemoryRegion {
#[getter(copy)]
addr: usize,
#[getter(copy)]
size: usize,
}
/// Core trait for block layouts /// Core trait for block layouts
pub trait BlockLayout: pub trait BlockLayout: BlockLayoutConfig + Send + Sync + std::fmt::Debug {
BlockLayoutConfig + BlockLayoutLookup + Send + Sync + std::fmt::Debug
{
/// The type of storage this layout uses /// The type of storage this layout uses
type StorageType: Storage; type StorageType: Storage;
...@@ -180,6 +196,21 @@ pub trait BlockLayout: ...@@ -180,6 +196,21 @@ pub trait BlockLayout:
/// Storage type for the layout /// Storage type for the layout
fn storage_type(&self) -> StorageType; fn storage_type(&self) -> StorageType;
/// Get the memory region for a specific page [page_size, inner_dim]
///
/// # Arguments
///
/// * `block_idx` - The index of the block
/// * `layer_idx` - The index of the layer
/// * `outer_idx` - The index of the outer dimension, e.g. if
///
fn memory_region(
&self,
block_idx: usize,
layer_idx: usize,
outer_idx: usize,
) -> Result<LocalMemoryRegion, LayoutError>;
} }
/// Configuration for block layouts /// Configuration for block layouts
...@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug { ...@@ -193,6 +224,12 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
/// Returns the number of layers per block /// Returns the number of layers per block
fn num_layers(&self) -> usize; fn num_layers(&self) -> usize;
/// Returns the number of outer dimensions per block
/// In some cases, K and V might be indexed separately, so in that example one might have 2 outer dimensions
/// For MLA, this is 1.
/// The location of the outer dimension in the shape of the tensor layout is defined by the layout type.
fn outer_dim(&self) -> usize;
/// Returns the size of each block in bytes /// Returns the size of each block in bytes
fn page_size(&self) -> usize; fn page_size(&self) -> usize;
...@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug { ...@@ -200,15 +237,6 @@ pub trait BlockLayoutConfig: std::fmt::Debug {
fn inner_dim(&self) -> usize; fn inner_dim(&self) -> usize;
} }
/// Trait for looking up memory regions in a block layout
pub trait BlockLayoutLookup {
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError>;
/// Get the memory region for a specific page [page_size, inner_dim]
fn memory_region_size(&self) -> usize;
}
/// Configuration for block layouts /// Configuration for block layouts
#[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize)] #[derive(Debug, Clone, Builder, Validate, Serialize, Deserialize)]
pub struct LayoutConfig { pub struct LayoutConfig {
...@@ -220,6 +248,10 @@ pub struct LayoutConfig { ...@@ -220,6 +248,10 @@ pub struct LayoutConfig {
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub num_layers: usize, pub num_layers: usize,
/// Number of outer dimensions
#[validate(range(min = 1, max = 2))]
pub outer_dim: usize,
/// Page size /// Page size
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub page_size: usize, pub page_size: usize,
...@@ -268,10 +300,24 @@ fn align_up(value: usize, alignment: usize) -> usize { ...@@ -268,10 +300,24 @@ fn align_up(value: usize, alignment: usize) -> usize {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct FullyContiguousConfig { pub(crate) struct FullyContiguousConfig {
inner: LayoutConfig, inner: LayoutConfig,
/// Minimum contiguous memory region size
/// Inner dimension * page size * dtype size
memory_region_size: usize, memory_region_size: usize,
/// Stride between outer dimensions
outer_dim_stride_in_bytes: usize,
/// Stride between layers
layer_stride_in_bytes: usize, layer_stride_in_bytes: usize,
/// Natural block stride
natural_block_stride: usize, natural_block_stride: usize,
/// Block stride in bytes
block_stride_in_bytes: usize, // Aligned if necessary block_stride_in_bytes: usize, // Aligned if necessary
/// Size of the layout data itself (post base offset)
layout_data_bytes: usize, // Size of the layout data itself (post base offset) layout_data_bytes: usize, // Size of the layout data itself (post base offset)
} }
...@@ -284,7 +330,8 @@ impl FullyContiguousConfig { ...@@ -284,7 +330,8 @@ impl FullyContiguousConfig {
let alignment = config.alignment; let alignment = config.alignment;
let memory_region_size = config.page_size * config.inner_dim * config.dtype.size_in_bytes(); let memory_region_size = config.page_size * config.inner_dim * config.dtype.size_in_bytes();
let layer_stride_in_bytes = memory_region_size; let outer_dim_stride_in_bytes = memory_region_size;
let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim;
let natural_block_stride = config.num_layers * layer_stride_in_bytes; let natural_block_stride = config.num_layers * layer_stride_in_bytes;
let block_stride_in_bytes = if alignment > 1 { let block_stride_in_bytes = if alignment > 1 {
...@@ -299,6 +346,7 @@ impl FullyContiguousConfig { ...@@ -299,6 +346,7 @@ impl FullyContiguousConfig {
Ok(Self { Ok(Self {
inner: config, inner: config,
memory_region_size, memory_region_size,
outer_dim_stride_in_bytes,
layer_stride_in_bytes, layer_stride_in_bytes,
natural_block_stride, natural_block_stride,
block_stride_in_bytes, block_stride_in_bytes,
...@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig { ...@@ -331,6 +379,10 @@ impl BlockLayoutConfig for FullyContiguousConfig {
self.inner.num_layers self.inner.num_layers
} }
fn outer_dim(&self) -> usize {
self.inner.outer_dim
}
fn page_size(&self) -> usize { fn page_size(&self) -> usize {
self.inner.page_size self.inner.page_size
} }
...@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> { ...@@ -508,6 +560,39 @@ impl<S: Storage> BlockLayout for FullyContiguous<S> {
fn storage_type(&self) -> StorageType { fn storage_type(&self) -> StorageType {
self.storage_type.clone() self.storage_type.clone()
} }
fn memory_region(
&self,
block_idx: usize,
layer_idx: usize,
outer_idx: usize,
) -> Result<LocalMemoryRegion, LayoutError> {
if block_idx >= self.num_blocks() {
return Err(LayoutError::InvalidBlockIndex(block_idx));
}
if layer_idx >= self.num_layers() {
return Err(LayoutError::InvalidLayerIndex(layer_idx));
}
if outer_idx >= self.outer_dim() {
return Err(LayoutError::InvalidOuterIndex(outer_idx));
}
// Start from the aligned base address
let aligned_start_addr = self.storage.addr() as usize + self.base_offset;
// Calculate offset relative to the aligned start using stored config
let block_offset = block_idx * self.config.block_stride_in_bytes;
let layer_offset = layer_idx * self.config.layer_stride_in_bytes;
let outer_offset = outer_idx * self.config.outer_dim_stride_in_bytes;
let final_addr = aligned_start_addr + block_offset + layer_offset + outer_offset;
Ok(LocalMemoryRegion {
addr: final_addr,
size: self.config.memory_region_size,
})
}
} }
impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> { impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
...@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> { ...@@ -523,6 +608,10 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
self.config.inner.num_layers self.config.inner.num_layers
} }
fn outer_dim(&self) -> usize {
self.config.inner.outer_dim
}
fn page_size(&self) -> usize { fn page_size(&self) -> usize {
self.config.inner.page_size self.config.inner.page_size
} }
...@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> { ...@@ -532,33 +621,6 @@ impl<S: Storage> BlockLayoutConfig for FullyContiguous<S> {
} }
} }
impl<S: Storage> BlockLayoutLookup for FullyContiguous<S> {
fn memory_region_addr(&self, block_idx: usize, layer_idx: usize) -> Result<u64, LayoutError> {
if block_idx >= self.num_blocks() {
return Err(LayoutError::InvalidBlockIndex(block_idx));
}
if layer_idx >= self.num_layers() {
return Err(LayoutError::InvalidLayerIndex(layer_idx));
}
// Start from the aligned base address
let aligned_start_addr = self.storage.addr() + self.base_offset as u64;
// Calculate offset relative to the aligned start using stored config
let block_offset = block_idx * self.config.block_stride_in_bytes;
let layer_offset = layer_idx * self.config.layer_stride_in_bytes;
let final_addr = aligned_start_addr + block_offset as u64 + layer_offset as u64;
Ok(final_addr)
}
fn memory_region_size(&self) -> usize {
// Access via stored dims
self.config.memory_region_size
}
}
#[allow(missing_docs)] #[allow(missing_docs)]
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
...@@ -570,6 +632,7 @@ pub mod tests { ...@@ -570,6 +632,7 @@ pub mod tests {
const NUM_BLOCKS: usize = 7; const NUM_BLOCKS: usize = 7;
const NUM_LAYERS: usize = 5; const NUM_LAYERS: usize = 5;
const OUTER_DIM: usize = 2;
const PAGE_SIZE: usize = 4; const PAGE_SIZE: usize = 4;
const INNER_DIM: usize = 13; const INNER_DIM: usize = 13;
const DTYPE: DType = DType::FP32; // Example dtype const DTYPE: DType = DType::FP32; // Example dtype
...@@ -592,6 +655,7 @@ pub mod tests { ...@@ -592,6 +655,7 @@ pub mod tests {
let config = LayoutConfig { let config = LayoutConfig {
num_blocks: NUM_BLOCKS, num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS, num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE, page_size: PAGE_SIZE,
inner_dim: INNER_DIM, inner_dim: INNER_DIM,
alignment: alignment.unwrap_or(1), alignment: alignment.unwrap_or(1),
...@@ -606,6 +670,7 @@ pub mod tests { ...@@ -606,6 +670,7 @@ pub mod tests {
let config = LayoutConfig::builder() let config = LayoutConfig::builder()
.num_blocks(NUM_BLOCKS) .num_blocks(NUM_BLOCKS)
.num_layers(NUM_LAYERS) .num_layers(NUM_LAYERS)
.outer_dim(OUTER_DIM)
.page_size(PAGE_SIZE) .page_size(PAGE_SIZE)
.inner_dim(INNER_DIM) .inner_dim(INNER_DIM)
.alignment(3) .alignment(3)
...@@ -632,6 +697,7 @@ pub mod tests { ...@@ -632,6 +697,7 @@ pub mod tests {
let config = LayoutConfig { let config = LayoutConfig {
num_blocks: NUM_BLOCKS, num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS, num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE, page_size: PAGE_SIZE,
inner_dim: INNER_DIM, inner_dim: INNER_DIM,
alignment: 1, alignment: 1,
...@@ -656,17 +722,11 @@ pub mod tests { ...@@ -656,17 +722,11 @@ pub mod tests {
assert_eq!(layout.num_blocks(), NUM_BLOCKS); assert_eq!(layout.num_blocks(), NUM_BLOCKS);
assert_eq!(layout.num_layers(), NUM_LAYERS); assert_eq!(layout.num_layers(), NUM_LAYERS);
assert_eq!(layout.outer_dim(), OUTER_DIM);
assert_eq!(layout.page_size(), PAGE_SIZE); assert_eq!(layout.page_size(), PAGE_SIZE);
assert_eq!(layout.inner_dim(), INNER_DIM); assert_eq!(layout.inner_dim(), INNER_DIM);
} }
#[test]
fn test_fc_memory_region_size() {
let layout = setup_layout(None).expect("Layout setup failed");
let expected_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(layout.memory_region_size(), expected_region_size);
}
#[test] #[test]
fn test_fc_offset_calculation() { fn test_fc_offset_calculation() {
let layout = setup_layout(None).expect("Layout setup failed"); let layout = setup_layout(None).expect("Layout setup failed");
...@@ -680,7 +740,7 @@ pub mod tests { ...@@ -680,7 +740,7 @@ pub mod tests {
let expected_offset_0_0 = let expected_offset_0_0 =
calculate_expected_offset(base_addr, 0, 0, block_stride, layer_stride); calculate_expected_offset(base_addr, 0, 0, block_stride, layer_stride);
assert_eq!( assert_eq!(
layout.memory_region_addr(0, 0).unwrap(), layout.memory_region(0, 0, 0).unwrap().addr as u64,
expected_offset_0_0 expected_offset_0_0
); );
...@@ -689,7 +749,7 @@ pub mod tests { ...@@ -689,7 +749,7 @@ pub mod tests {
let expected_offset_0_last = let expected_offset_0_last =
calculate_expected_offset(base_addr, 0, last_layer_idx, block_stride, layer_stride); calculate_expected_offset(base_addr, 0, last_layer_idx, block_stride, layer_stride);
assert_eq!( assert_eq!(
layout.memory_region_addr(0, last_layer_idx).unwrap(), layout.memory_region(0, last_layer_idx, 0).unwrap().addr as u64,
expected_offset_0_last expected_offset_0_last
); );
...@@ -698,7 +758,7 @@ pub mod tests { ...@@ -698,7 +758,7 @@ pub mod tests {
let expected_offset_last_0 = let expected_offset_last_0 =
calculate_expected_offset(base_addr, last_block_idx, 0, block_stride, layer_stride); calculate_expected_offset(base_addr, last_block_idx, 0, block_stride, layer_stride);
assert_eq!( assert_eq!(
layout.memory_region_addr(last_block_idx, 0).unwrap(), layout.memory_region(last_block_idx, 0, 0).unwrap().addr as u64,
expected_offset_last_0 expected_offset_last_0
); );
...@@ -712,8 +772,9 @@ pub mod tests { ...@@ -712,8 +772,9 @@ pub mod tests {
); );
assert_eq!( assert_eq!(
layout layout
.memory_region_addr(last_block_idx, last_layer_idx) .memory_region(last_block_idx, last_layer_idx, 0)
.unwrap(), .unwrap()
.addr as u64,
expected_offset_last_last expected_offset_last_last
); );
...@@ -729,8 +790,9 @@ pub mod tests { ...@@ -729,8 +790,9 @@ pub mod tests {
); );
assert_eq!( assert_eq!(
layout layout
.memory_region_addr(mid_block_idx, mid_layer_idx) .memory_region(mid_block_idx, mid_layer_idx, 0)
.unwrap(), .unwrap()
.addr as u64,
expected_offset_mid_mid expected_offset_mid_mid
); );
} }
...@@ -738,7 +800,7 @@ pub mod tests { ...@@ -738,7 +800,7 @@ pub mod tests {
#[test] #[test]
fn test_fc_invalid_block_index() { fn test_fc_invalid_block_index() {
let layout = setup_layout(None).expect("Layout setup failed"); let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(NUM_BLOCKS, 0); // Index == num_blocks (out of bounds) let result = layout.memory_region(NUM_BLOCKS, 0, 0); // Index == num_blocks (out of bounds)
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!( assert!(matches!(
result.err().unwrap(), result.err().unwrap(),
...@@ -749,7 +811,7 @@ pub mod tests { ...@@ -749,7 +811,7 @@ pub mod tests {
#[test] #[test]
fn test_fc_invalid_layer_index() { fn test_fc_invalid_layer_index() {
let layout = setup_layout(None).expect("Layout setup failed"); let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region_addr(0, NUM_LAYERS); // Index == num_layers (out of bounds) let result = layout.memory_region(0, NUM_LAYERS, 0); // Index == num_layers (out of bounds)
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!( assert!(matches!(
result.err().unwrap(), result.err().unwrap(),
...@@ -757,12 +819,24 @@ pub mod tests { ...@@ -757,12 +819,24 @@ pub mod tests {
)); ));
} }
#[test]
fn test_fc_invalid_outer_index() {
let layout = setup_layout(None).expect("Layout setup failed");
let result = layout.memory_region(0, 0, OUTER_DIM); // Index == num_outer_dims (out of bounds)
assert!(result.is_err());
assert!(matches!(
result.err().unwrap(),
LayoutError::InvalidOuterIndex(OUTER_DIM)
));
}
#[test] #[test]
fn test_fc_allocation_system() { fn test_fc_allocation_system() {
init_logging(); init_logging();
let config = LayoutConfig { let config = LayoutConfig {
num_blocks: NUM_BLOCKS, num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS, num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE, page_size: PAGE_SIZE,
inner_dim: INNER_DIM, inner_dim: INNER_DIM,
alignment: 1, alignment: 1,
...@@ -788,7 +862,7 @@ pub mod tests { ...@@ -788,7 +862,7 @@ pub mod tests {
assert_eq!( assert_eq!(
layout.storage.size(), layout.storage.size(),
NUM_BLOCKS * NUM_LAYERS * PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes() NUM_BLOCKS * NUM_LAYERS * OUTER_DIM * PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes()
); );
} }
...@@ -800,6 +874,7 @@ pub mod tests { ...@@ -800,6 +874,7 @@ pub mod tests {
let config = LayoutConfig { let config = LayoutConfig {
num_blocks: NUM_BLOCKS, num_blocks: NUM_BLOCKS,
num_layers: NUM_LAYERS, num_layers: NUM_LAYERS,
outer_dim: OUTER_DIM,
page_size: PAGE_SIZE, page_size: PAGE_SIZE,
inner_dim: INNER_DIM, inner_dim: INNER_DIM,
alignment: ALIGNMENT, alignment: ALIGNMENT,
...@@ -810,11 +885,11 @@ pub mod tests { ...@@ -810,11 +885,11 @@ pub mod tests {
let memory_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes(); let memory_region_size = PAGE_SIZE * INNER_DIM * DTYPE.size_in_bytes();
assert_eq!(memory_region_size, 208); assert_eq!(memory_region_size, 208);
let natural_block_stride = NUM_LAYERS * memory_region_size; let natural_block_stride = OUTER_DIM * NUM_LAYERS * memory_region_size;
assert_eq!(natural_block_stride, 1040); assert_eq!(natural_block_stride, 2080);
let aligned_block_stride = align_up(natural_block_stride, ALIGNMENT); let aligned_block_stride = align_up(natural_block_stride, ALIGNMENT);
assert_eq!(aligned_block_stride, 1280); assert_eq!(aligned_block_stride, 2304);
// Calculate the expected *allocated* size (data + initial padding) // Calculate the expected *allocated* size (data + initial padding)
let fc_config = FullyContiguousConfig::new(config.clone()).unwrap(); let fc_config = FullyContiguousConfig::new(config.clone()).unwrap();
...@@ -844,40 +919,40 @@ pub mod tests { ...@@ -844,40 +919,40 @@ pub mod tests {
// Check alignment of block starts // Check alignment of block starts
let addr_block_0 = layout let addr_block_0 = layout
.memory_region_addr(0, 0) .memory_region(0, 0, 0)
.expect("Failed to get addr block 0"); .expect("Failed to get addr block 0");
let addr_block_1 = layout let addr_block_1 = layout
.memory_region_addr(1, 0) .memory_region(1, 0, 0)
.expect("Failed to get addr block 1"); .expect("Failed to get addr block 1");
let addr_block_2 = layout let addr_block_2 = layout
.memory_region_addr(2, 0) .memory_region(2, 0, 0)
.expect("Failed to get addr block 2"); .expect("Failed to get addr block 2");
// All blocks should now be aligned due to base_offset adjustment // All blocks should now be aligned due to base_offset adjustment
assert_eq!( assert_eq!(
addr_block_0 % ALIGNMENT as u64, addr_block_0.addr as u64 % ALIGNMENT as u64,
0, 0,
"Block 0 start address is not aligned" "Block 0 start address is not aligned"
); );
assert_eq!( assert_eq!(
addr_block_1 % ALIGNMENT as u64, addr_block_1.addr as u64 % ALIGNMENT as u64,
0, 0,
"Block 1 start address is not aligned" "Block 1 start address is not aligned"
); );
assert_eq!( assert_eq!(
addr_block_2 % ALIGNMENT as u64, addr_block_2.addr as u64 % ALIGNMENT as u64,
0, 0,
"Block 2 start address is not aligned" "Block 2 start address is not aligned"
); );
// Verify the difference matches the aligned stride // Verify the difference matches the aligned stride
assert_eq!( assert_eq!(
addr_block_1 - addr_block_0, addr_block_1.addr as u64 - addr_block_0.addr as u64,
aligned_block_stride as u64, aligned_block_stride as u64,
"Stride between block 0 and 1 mismatch" "Stride between block 0 and 1 mismatch"
); );
assert_eq!( assert_eq!(
addr_block_2 - addr_block_1, addr_block_2.addr as u64 - addr_block_1.addr as u64,
aligned_block_stride as u64, aligned_block_stride as u64,
"Stride between block 1 and 2 mismatch" "Stride between block 1 and 2 mismatch"
); );
......
...@@ -332,6 +332,7 @@ mod tests { ...@@ -332,6 +332,7 @@ mod tests {
let config = LayoutConfig::builder() let config = LayoutConfig::builder()
.num_blocks(10) .num_blocks(10)
.num_layers(2) .num_layers(2)
.outer_dim(2)
.page_size(4) .page_size(4)
.inner_dim(13) .inner_dim(13)
.build() .build()
......
...@@ -506,6 +506,7 @@ mod tests { ...@@ -506,6 +506,7 @@ mod tests {
let mut config = LayoutConfig { let mut config = LayoutConfig {
num_blocks: device_blocks, num_blocks: device_blocks,
num_layers: 8, num_layers: 8,
outer_dim: 1,
page_size: BLOCK_SIZE, page_size: BLOCK_SIZE,
inner_dim: 1024, inner_dim: 1024,
alignment: 1, alignment: 1,
......
...@@ -582,6 +582,7 @@ pub(crate) mod tests { ...@@ -582,6 +582,7 @@ pub(crate) mod tests {
let config = LayoutConfigBuilder::default() let config = LayoutConfigBuilder::default()
.num_blocks(num_blocks) .num_blocks(num_blocks)
.num_layers(61) .num_layers(61)
.outer_dim(1)
.page_size(16) .page_size(16)
.inner_dim(576) .inner_dim(576)
.build() .build()
......
...@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -115,6 +115,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
layout_builder layout_builder
.num_layers(model.num_layers) .num_layers(model.num_layers)
.outer_dim(model.outer_dim)
.page_size(model.page_size) .page_size(model.page_size)
.inner_dim(model.inner_dim) .inner_dim(model.inner_dim)
.dtype(model.dtype); .dtype(model.dtype);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment