Unverified Commit d2e3b66e authored by Olga Andreeva's avatar Olga Andreeva Committed by GitHub
Browse files

feat: Transition to FullyContiguous Host and Disk layouts (#3090)


Signed-off-by: default avatarOlga Andreeva <oandreeva@nvidia.com>
Signed-off-by: default avatarOlga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
Co-authored-by: default avataroandreeva-nv <oandreeva-nv@nvidia.com>
parent a5e1d45e
......@@ -343,4 +343,4 @@ RUN uv pip install maturin[patchelf] && \
uv pip install --no-deps -e .
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
\ No newline at end of file
CMD []
......@@ -9,4 +9,4 @@ mod worker;
pub use leader::KvbmLeader;
pub use utils::get_barrier_id_prefix;
pub use worker::{KvbmWorker, VllmTensor};
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};
......@@ -10,8 +10,44 @@ use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
KvbmWorkerConfig,
};
use llm_rs::block_manager::layout::LayoutType;
use llm_rs::block_manager::storage::torch::{TorchDevice, TorchTensor};
/// A wrapper around a layout type.
/// This is used to convert between the Python and Rust layout types.
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq, Eq)]
pub enum PyLayoutType {
FullyContiguous,
LayerSeparate,
}
#[pymethods]
impl PyLayoutType {
/// String representation of the layout type
fn __str__(&self) -> &'static str {
match self {
PyLayoutType::FullyContiguous => "FullyContiguous",
PyLayoutType::LayerSeparate => "LayerSeparate",
}
}
/// Representation for debugging
fn __repr__(&self) -> String {
format!("PyLayoutType.{}", self.__str__())
}
}
impl From<PyLayoutType> for LayoutType {
fn from(py_layout: PyLayoutType) -> Self {
match py_layout {
PyLayoutType::FullyContiguous => LayoutType::FullyContiguous,
// Layout (outer_contiguous vs block_contiguous) is auto-detected from tensor shapes
PyLayoutType::LayerSeparate => LayoutType::layer_separate_auto_default(),
}
}
}
/// A wrapper around a Torch tensor.
/// We hold onto the py object to ensure it doesn't get GCed.
#[derive(Clone, Debug)]
......@@ -107,7 +143,7 @@ impl KvbmWorker {
#[pymethods]
impl KvbmWorker {
#[new]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false))]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
fn new(
num_device_blocks: usize,
page_size: usize,
......@@ -116,6 +152,9 @@ impl KvbmWorker {
dtype_width_bytes: usize,
drt: Option<DistributedRuntime>,
layout_blocking: bool,
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
) -> PyResult<Self> {
let py_drt = drt.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided")
......@@ -142,6 +181,21 @@ impl KvbmWorker {
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(barrier_id_prefix)
.device_layout_type(
device_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.host_layout_type(
host_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.disk_layout_type(
disk_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.build()
.map_err(to_pyerr)?;
......
......@@ -19,6 +19,7 @@ use crate::{
use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
use dynamo_llm::block_manager::layout::LayoutType;
use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics;
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::DistributedRuntime;
......@@ -144,7 +145,9 @@ impl Worker for KvConnectorWorker {
.tensors(kv_cache_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.is_fully_contiguous_layout(true)
.device_layout_type(LayoutType::FullyContiguous)
.host_layout_type(LayoutType::FullyContiguous)
.disk_layout_type(LayoutType::FullyContiguous)
.barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone()))
.build()?;
......
......@@ -18,8 +18,10 @@ use crate::{
};
use dynamo_runtime::metrics::prometheus_names::kvbm_connector;
use crate::llm::block_manager::distributed::PyLayoutType;
use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
use dynamo_llm::block_manager::layout::LayoutType;
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
......@@ -33,6 +35,9 @@ pub trait Worker: Send + Sync {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Arc<VllmTensor>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
) -> anyhow::Result<()>;
fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> anyhow::Result<()>;
......@@ -133,6 +138,9 @@ impl Worker for KvConnectorWorker {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Arc<VllmTensor>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
) -> anyhow::Result<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
......@@ -147,9 +155,16 @@ impl Worker for KvConnectorWorker {
// Process kv_caches in layer execution order (already sorted by layer index)
let mut vllm_tensors = Vec::new();
let mut first_tensor_shape: Option<Vec<usize>> = None;
for (layer_name, vllm_tensor) in kv_caches {
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
// Capture the shape of the first tensor for layout detection
if first_tensor_shape.is_none() {
first_tensor_shape = Some(vllm_tensor.shape());
}
// Store for later lookup by name
self.kv_cache_layers.push((layer_name, vllm_tensor.clone()));
......@@ -159,6 +174,35 @@ impl Worker for KvConnectorWorker {
self.layer_events = raw_event_handles;
// Auto-detect device layout type if not explicitly provided
let detected_device_layout_type = match device_layout_type {
Some(layout) => layout,
None => {
if let Some(ref shape) = first_tensor_shape {
match LayoutType::layer_separate_auto(shape, num_device_blocks) {
Ok(detected) => {
tracing::info!(
"Auto-detected device layout from tensor shape: {:?}",
detected
);
detected
}
Err(e) => {
tracing::warn!(
"Failed to auto-detect layout from shape {:?}: {}. Using default.",
shape,
e
);
LayoutType::layer_separate_auto_default()
}
}
} else {
tracing::warn!("No tensors available for layout detection. Using default.");
LayoutType::layer_separate_auto_default()
}
}
};
let config = KvbmWorkerConfig::builder()
.drt(self.drt.clone())
.num_device_blocks(num_device_blocks)
......@@ -168,6 +212,9 @@ impl Worker for KvConnectorWorker {
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone()))
.device_layout_type(detected_device_layout_type)
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
.disk_layout_type(disk_layout_type.unwrap_or(LayoutType::FullyContiguous))
.build()?;
let worker = self.drt.runtime().primary().block_on(async move {
......@@ -416,6 +463,7 @@ impl PyKvConnectorWorker {
Ok(Self { connector_worker })
}
#[pyo3(signature = (num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches, raw_event_handles, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
pub fn register_kv_caches(
&mut self,
num_device_blocks: usize,
......@@ -424,6 +472,9 @@ impl PyKvConnectorWorker {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Py<PyAny>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
) -> PyResult<()> {
// Convert Python tensors to Rust VllmTensor objects
let mut rust_kv_caches = Vec::new();
......@@ -440,6 +491,9 @@ impl PyKvConnectorWorker {
dtype_width_bytes,
rust_kv_caches,
raw_event_handles,
device_layout_type.map(|py_layout| py_layout.into()),
host_layout_type.map(|py_layout| py_layout.into()),
disk_layout_type.map(|py_layout| py_layout.into()),
)
.map_err(to_pyerr)
}
......
......@@ -683,4 +683,455 @@ mod tests {
assert!(slice.iter().all(|&x| x == 42));
}
}
// ============================================================================
// CUDA TRANSFER TESTS FOR LAYOUT COMPATIBILITY
// ============================================================================
mod layout_transfer_tests {
use super::*;
use crate::block_manager::layout::{
FullyContiguous, GenericBlockLayout, LayerSeparate, LayoutConfig,
};
const TEST_NUM_BLOCKS: usize = 4;
const TEST_NUM_LAYERS: usize = 3;
const TEST_OUTER_DIM: usize = 2;
const TEST_PAGE_SIZE: usize = 8;
const TEST_INNER_DIM: usize = 16;
const TEST_DTYPE_WIDTH_BYTES: usize = 2;
fn create_test_config() -> LayoutConfig {
LayoutConfig {
num_blocks: TEST_NUM_BLOCKS,
num_layers: TEST_NUM_LAYERS,
outer_dim: TEST_OUTER_DIM,
page_size: TEST_PAGE_SIZE,
inner_dim: TEST_INNER_DIM,
alignment: 256, // GPU-friendly alignment
dtype_width_bytes: TEST_DTYPE_WIDTH_BYTES,
}
}
/// Test H2D transfers between FullyContiguous host and LayerSeparate device layouts
#[test]
fn test_h2d_fc_host_to_ls_device() {
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
let stream = ctx.new_stream().unwrap();
let config = create_test_config();
// Create FullyContiguous host layout
let host_layout = FullyContiguous::allocate(config.clone(), &pinned_allocator).unwrap();
// Create LayerSeparate device layout
let device_layout = LayerSeparate::allocate(config, &device_allocator, true).unwrap();
// Test data transfer for each memory region
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let host_region = host_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let device_region = device_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
// Verify regions have same size
assert_eq!(
host_region.size(),
device_region.size(),
"Region size mismatch at ({}, {}, {})",
block_idx,
layer_idx,
outer_idx
);
// Create test pattern
let pattern =
((block_idx as u8) << 4) | ((layer_idx as u8) << 2) | (outer_idx as u8);
// Fill host memory with pattern
unsafe {
let host_slice = std::slice::from_raw_parts_mut(
host_region.addr() as *mut u8,
host_region.size(),
);
host_slice.fill(pattern);
}
// Transfer H2D
unsafe {
cuda_memcpy_h2d(
host_region.addr() as *const u8,
device_region.addr() as *mut u8,
host_region.size(),
stream.as_ref(),
)
.unwrap();
}
}
}
}
stream.synchronize().unwrap();
// Verify transfers by copying back and checking patterns
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let host_region = host_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let device_region = device_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let expected_pattern =
((block_idx as u8) << 4) | ((layer_idx as u8) << 2) | (outer_idx as u8);
// Create temporary verification buffer
let mut verify_buffer =
pinned_allocator.allocate(host_region.size()).unwrap();
// Copy back from device
unsafe {
cuda_memcpy_d2h(
device_region.addr() as *const u8,
verify_buffer.as_mut_ptr(),
host_region.size(),
stream.as_ref(),
)
.unwrap();
}
stream.synchronize().unwrap();
// Verify pattern
unsafe {
let verify_slice = std::slice::from_raw_parts(
verify_buffer.as_ptr(),
host_region.size(),
);
assert!(
verify_slice.iter().all(|&x| x == expected_pattern),
"Pattern mismatch at ({}, {}, {}) - expected {}, got {:?}",
block_idx,
layer_idx,
outer_idx,
expected_pattern,
&verify_slice[0..std::cmp::min(8, verify_slice.len())]
);
}
}
}
}
}
/// Test D2H transfers from LayerSeparate device to FullyContiguous host
#[test]
fn test_d2h_ls_device_to_fc_host() {
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
let stream = ctx.new_stream().unwrap();
let config = create_test_config();
// Create LayerSeparate device layout (block contiguous)
let device_layout =
LayerSeparate::allocate(config.clone(), &device_allocator, false).unwrap();
// Create FullyContiguous host layout
let host_layout = FullyContiguous::allocate(config, &pinned_allocator).unwrap();
// Initialize device memory with patterns using a temporary host buffer
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let device_region = device_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let pattern = ((block_idx as u8) << 4)
| ((layer_idx as u8) << 2)
| (outer_idx as u8)
| 0x80;
// Create temp buffer with pattern
let mut temp_buffer =
pinned_allocator.allocate(device_region.size()).unwrap();
unsafe {
let temp_slice = std::slice::from_raw_parts_mut(
temp_buffer.as_mut_ptr(),
device_region.size(),
);
temp_slice.fill(pattern);
}
// Copy pattern to device
unsafe {
cuda_memcpy_h2d(
temp_buffer.as_ptr(),
device_region.addr() as *mut u8,
device_region.size(),
stream.as_ref(),
)
.unwrap();
}
}
}
}
stream.synchronize().unwrap();
// Clear host layout
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let host_region = host_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
unsafe {
let host_slice = std::slice::from_raw_parts_mut(
host_region.addr() as *mut u8,
host_region.size(),
);
host_slice.fill(0);
}
}
}
}
// Transfer D2H
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let device_region = device_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let host_region = host_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
unsafe {
cuda_memcpy_d2h(
device_region.addr() as *const u8,
host_region.addr() as *mut u8,
device_region.size(),
stream.as_ref(),
)
.unwrap();
}
}
}
}
stream.synchronize().unwrap();
// Verify patterns in host layout
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let host_region = host_layout
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
let expected_pattern = ((block_idx as u8) << 4)
| ((layer_idx as u8) << 2)
| (outer_idx as u8)
| 0x80;
unsafe {
let host_slice = std::slice::from_raw_parts(
host_region.addr() as *const u8,
host_region.size(),
);
assert!(
host_slice.iter().all(|&x| x == expected_pattern),
"Pattern mismatch at ({}, {}, {}) - expected {}, got {:?}",
block_idx,
layer_idx,
outer_idx,
expected_pattern,
&host_slice[0..std::cmp::min(8, host_slice.len())]
);
}
}
}
}
}
/// Test bidirectional transfers with layout compatibility verification
#[test]
fn test_bidirectional_layout_transfers() {
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
let stream = ctx.new_stream().unwrap();
let config = create_test_config();
// Create both layout types
let host_fc = FullyContiguous::allocate(config.clone(), &pinned_allocator).unwrap();
let device_ls_outer =
LayerSeparate::allocate(config.clone(), &device_allocator, true).unwrap();
let device_ls_block =
LayerSeparate::allocate(config, &device_allocator, false).unwrap();
// Test round-trip: Host FC -> Device LS (outer) -> Device LS (block) -> Host FC
for block_idx in 0..TEST_NUM_BLOCKS {
for layer_idx in 0..TEST_NUM_LAYERS {
for outer_idx in 0..TEST_OUTER_DIM {
let original_pattern = ((block_idx as u8) << 4)
| ((layer_idx as u8) << 2)
| (outer_idx as u8)
| 0x40;
// Step 1: Initialize host FC with pattern
let host_region = host_fc
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
unsafe {
let host_slice = std::slice::from_raw_parts_mut(
host_region.addr() as *mut u8,
host_region.size(),
);
host_slice.fill(original_pattern);
}
// Step 2: Transfer to device LS outer
let device_outer_region = device_ls_outer
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
unsafe {
cuda_memcpy_h2d(
host_region.addr() as *const u8,
device_outer_region.addr() as *mut u8,
host_region.size(),
stream.as_ref(),
)
.unwrap();
}
// Step 3: Transfer between device layouts (D2D)
let device_block_region = device_ls_block
.memory_region(block_idx, layer_idx, outer_idx)
.unwrap();
unsafe {
cuda_memcpy_d2d(
device_outer_region.addr() as *const u8,
device_block_region.addr() as *mut u8,
device_outer_region.size(),
stream.as_ref(),
)
.unwrap();
}
stream.synchronize().unwrap();
// Step 4: Clear host and transfer back
unsafe {
let host_slice = std::slice::from_raw_parts_mut(
host_region.addr() as *mut u8,
host_region.size(),
);
host_slice.fill(0);
}
unsafe {
cuda_memcpy_d2h(
device_block_region.addr() as *const u8,
host_region.addr() as *mut u8,
device_block_region.size(),
stream.as_ref(),
)
.unwrap();
}
stream.synchronize().unwrap();
// Step 5: Verify pattern survived the round trip
unsafe {
let host_slice = std::slice::from_raw_parts(
host_region.addr() as *const u8,
host_region.size(),
);
assert!(
host_slice.iter().all(|&x| x == original_pattern),
"Round-trip pattern mismatch at ({}, {}, {}) - expected {}, got {:?}",
block_idx,
layer_idx,
outer_idx,
original_pattern,
&host_slice[0..std::cmp::min(8, host_slice.len())]
);
}
}
}
}
}
/// Test transfer performance and alignment impact
#[test]
fn test_layout_transfer_alignment_performance() {
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
let stream = ctx.new_stream().unwrap();
// Test different alignments
for alignment in [1, 64, 256, 512] {
let config = LayoutConfig {
num_blocks: 2,
num_layers: 2,
outer_dim: 1,
page_size: 1024,
inner_dim: 256,
alignment,
dtype_width_bytes: 4,
};
let host_layout =
FullyContiguous::allocate(config.clone(), &pinned_allocator).unwrap();
let device_layout = FullyContiguous::allocate(config, &device_allocator).unwrap();
// Measure transfer time (basic timing)
let start = std::time::Instant::now();
for block_idx in 0..2 {
for layer_idx in 0..2 {
let host_region =
host_layout.memory_region(block_idx, layer_idx, 0).unwrap();
let device_region = device_layout
.memory_region(block_idx, layer_idx, 0)
.unwrap();
unsafe {
cuda_memcpy_h2d(
host_region.addr() as *const u8,
device_region.addr() as *mut u8,
host_region.size(),
stream.as_ref(),
)
.unwrap();
}
}
}
stream.synchronize().unwrap();
let duration = start.elapsed();
// Verify alignment was applied correctly
let region = host_layout.memory_region(0, 0, 0).unwrap();
if alignment > 1 {
assert_eq!(
region.addr() % alignment,
0,
"Memory not aligned to {} bytes",
alignment
);
}
println!("Transfer with alignment {} took {:?}", alignment, duration);
}
}
}
}
......@@ -106,8 +106,14 @@ pub struct KvbmWorkerConfig {
#[builder(default = "2")]
dtype_width_bytes: usize,
#[builder(default = false)]
is_fully_contiguous_layout: bool,
#[builder(default = "LayoutType::FullyContiguous")]
device_layout_type: LayoutType,
#[builder(default = "LayoutType::FullyContiguous")]
host_layout_type: LayoutType,
#[builder(default = "LayoutType::FullyContiguous")]
disk_layout_type: LayoutType,
#[builder(default = "String::from(\"kvbm\")")]
barrier_id_prefix: String,
......@@ -161,53 +167,51 @@ impl KvbmWorker {
)));
}
let (layout_type, num_layers, outer_dim, inner_dim) = if !config.is_fully_contiguous_layout
{
let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks {
(false, shape[1])
} else if shape[1] >= config.num_device_blocks {
(true, shape[0])
} else {
return Err(anyhow::anyhow!(format!(
"Unsupported kv cache layout. Got shape: {:?}",
shape
)));
};
let num_layers = device_tensors.len();
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
device_tensors.len(),
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::LayerSeparate { outer_contiguous },
num_layers,
outer_dim,
inner_dim,
)
} else {
let num_layers = shape[1];
let outer_dim = shape[2];
let inner_dim = shape[3..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::FullyContiguous,
num_layers,
outer_dim,
inner_dim,
)
let (layout_type, num_layers, outer_dim, inner_dim) = match config.device_layout_type {
LayoutType::FullyContiguous => {
let num_layers = shape[1];
let outer_dim = shape[2];
let inner_dim = shape[3..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::FullyContiguous,
num_layers,
outer_dim,
inner_dim,
)
}
LayoutType::LayerSeparate { outer_contiguous } => {
// Use the already-detected layout type from config (no re-detection needed)
let layout_type = config.device_layout_type;
// Extract outer_dim based on the provided outer_contiguous value
let outer_dim = if outer_contiguous {
shape[0] // Outer contiguous: [outer_dim, n_blocks, ...]
} else {
shape[1] // Block contiguous: [n_blocks, outer_dim, ...]
};
let num_layers = device_tensors.len();
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
outer_contiguous,
config.page_size,
inner_dim
);
(layout_type, num_layers, outer_dim, inner_dim)
}
};
let bytes_per_block =
......@@ -556,7 +560,7 @@ impl KvbmWorker {
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
mut layout_builder: LayoutConfigBuilder,
leader_data: KvbmLeaderData,
layout_type: LayoutType,
_layout_type: LayoutType,
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
handler_tx: oneshot::Sender<BlockTransferHandler>,
......@@ -606,7 +610,7 @@ impl KvbmWorker {
let host_layout = layout_builder
.num_blocks(leader_data.num_host_blocks)
.build()?
.allocate_layout(layout_type, host_allocator)?;
.allocate_layout(config.host_layout_type, host_allocator)?;
Some(Self::make_layout::<_, BasicMetadata>(
host_layout,
......@@ -623,7 +627,7 @@ impl KvbmWorker {
let disk_layout = layout_builder
.num_blocks(leader_data.num_disk_blocks)
.build()?
.allocate_layout(layout_type, disk_allocator)?;
.allocate_layout(config.disk_layout_type, disk_allocator)?;
Some(Self::make_layout::<_, BasicMetadata>(
disk_layout,
......
This diff is collapsed.
......@@ -44,21 +44,63 @@ pub struct LayoutVerificationStats {
pub successful_verifications: usize,
}
/// A utility for verifying the consistency and correctness of memory layout implementations.
///
/// This verifier systematically checks all memory regions within a layout to ensure:
/// - Memory addresses are calculated correctly
/// - Memory region sizes match expected values
/// - Layout configuration is internally consistent
///
/// The verifier maintains statistics about verification results and can identify
/// critical mismatches that indicate layout implementation errors.
#[derive(Debug)]
#[allow(dead_code)]
pub struct WorkerLayoutVerifier {
stats: LayoutVerificationStats,
}
impl Default for WorkerLayoutVerifier {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
impl WorkerLayoutVerifier {
// Constructor: Start with clean slate
/// Creates a new layout verifier with clean statistics.
///
/// The verifier starts with zero counts for all verification metrics
/// and is ready to verify layout consistency.
pub fn new() -> Self {
Self {
stats: LayoutVerificationStats::default(),
}
}
/// Verifies the consistency of all memory regions in a layout.
///
/// This is the main orchestrator method that systematically checks every memory region
/// in the layout to ensure consistency. It resets the internal statistics and then
/// iterates through all valid combinations of block, layer, and outer dimension indices.
///
/// # Arguments
///
/// * `layout` - The layout to verify
///
/// # Returns
///
/// A vector of verification results for each memory region, or an error if
/// verification fails for any region.
///
/// # Example
///
/// ```rust,ignore
/// let mut verifier = WorkerLayoutVerifier::new();
/// let results = verifier.verify_layout_consistency(&layout)?;
/// if verifier.has_critical_mismatches() {
/// // Handle verification failures
/// }
/// ```
pub fn verify_layout_consistency<L: GenericBlockLayout>(
&mut self,
layout: &L,
......@@ -85,6 +127,22 @@ impl WorkerLayoutVerifier {
Ok(results)
}
/// Verifies a specific memory region within a layout.
///
/// This method checks a single memory region identified by the provided indices
/// and compares the actual memory address and size against expected values.
///
/// # Arguments
///
/// * `layout` - The layout containing the memory region to verify
/// * `block_idx` - The block index (must be < layout.num_blocks())
/// * `layer_idx` - The layer index (must be < layout.num_layers())
/// * `outer_idx` - The outer dimension index (must be < layout.outer_dim())
///
/// # Returns
///
/// A verification result containing the comparison between expected and actual
/// values, or an error if the indices are invalid or layout access fails.
pub fn verify_memory_region<L: GenericBlockLayout>(
&mut self,
layout: &L,
......@@ -125,6 +183,15 @@ impl WorkerLayoutVerifier {
}
}
/// Checks if any critical mismatches were found during verification.
///
/// Critical mismatches are currently defined as size mismatches, which indicate
/// that the layout is calculating memory region sizes incorrectly. This is
/// considered more critical than address mismatches as it affects memory safety.
///
/// # Returns
///
/// `true` if any memory regions had size mismatches, `false` otherwise.
pub fn has_critical_mismatches(&self) -> bool {
self.stats.size_mismatches > 0
}
......@@ -144,7 +211,12 @@ pub fn validate_power_of_2(alignment: usize) -> Result<(), ValidationError> {
/// Helper to align a value up to the nearest multiple of alignment.
/// Alignment must be a power of 2.
#[inline(always)]
pub fn align_up(value: usize, alignment: usize) -> usize {
debug_assert!(
alignment.is_power_of_two(),
"Alignment must be a power of 2"
);
(value + alignment - 1) & !(alignment - 1)
}
......@@ -191,6 +263,7 @@ pub fn validate_storage<S: Storage, C: BlockLayoutConfig>(
Ok(base_offset)
}
/// Validate that the provided indices are within bounds for the given layout configuration
pub fn validate_indices<C: BlockLayoutConfig>(
config: &C,
block_idx: usize,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment