Unverified Commit 976bb70a authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: add KVBM memory management enhancements (DIS-1311) (#5532)

parent 57bdfea9
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tensor abstraction built on top of MemoryDescriptor.
//!
//! A tensor is memory with shape, stride, and element size metadata.
//! The underlying memory could be externally owned, self-owned, or a view.
use super::nixl::{self, NixlDescriptor};
use super::{MemoryDescriptor, StorageKind};
use std::any::Any;
use std::sync::Arc;
/// A tensor is memory with shape, stride, and element size metadata.
///
/// This trait extends [`MemoryDescriptor`] with tensor-specific metadata.
/// The underlying memory could be externally owned, self-owned, or a view.
///
/// # Shape and Stride
///
/// - `shape()` returns the number of elements in each dimension
/// - `stride()` returns the number of elements to skip when incrementing each dimension
/// - `element_size()` returns the number of bytes per element
///
/// For a contiguous tensor with shape `[2, 3, 4]`:
/// - stride would be `[12, 4, 1]` (row-major/C order)
/// - total elements = 2 * 3 * 4 = 24
/// - total bytes = 24 * element_size()
pub trait TensorDescriptor: MemoryDescriptor {
/// Shape of the tensor (number of elements per dimension).
fn shape(&self) -> &[usize];
/// Stride of the tensor (elements to skip per dimension).
///
/// `stride[i]` indicates how many elements to skip when incrementing dimension `i`.
fn stride(&self) -> &[usize];
/// Number of bytes per element.
fn element_size(&self) -> usize;
}
// =============================================================================
// Helper methods for TensorDescriptor
// =============================================================================
/// Extension trait providing helper methods for tensor descriptors.
pub trait TensorDescriptorExt: TensorDescriptor {
/// Total number of elements in the tensor (product of shape).
fn numel(&self) -> usize {
self.shape().iter().product()
}
/// Number of dimensions (rank).
fn ndim(&self) -> usize {
self.shape().len()
}
/// Check if tensor is contiguous in memory (row-major/C order).
///
/// A tensor is contiguous if its strides follow the pattern where
/// the last dimension has stride 1, and each preceding dimension
/// has stride equal to the product of all following dimensions.
fn is_contiguous(&self) -> bool {
let shape = self.shape();
let stride = self.stride();
if shape.is_empty() {
return true;
}
let mut expected_stride = 1;
for i in (0..shape.len()).rev() {
if stride[i] != expected_stride {
return false;
}
expected_stride *= shape[i];
}
true
}
/// Compute the contiguous stride for the current shape.
///
/// Returns the stride that would make this tensor contiguous
/// (row-major/C order).
fn contiguous_stride(&self) -> Vec<usize> {
let shape = self.shape();
if shape.is_empty() {
return vec![];
}
let mut stride = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
stride[i] = stride[i + 1] * shape[i + 1];
}
stride
}
/// Returns the CUDA device ID if the tensor is on a CUDA device.
fn cuda_device_id(&self) -> Option<usize> {
match self.storage_kind() {
StorageKind::Device(idx) => Some(idx as usize),
_ => None,
}
}
}
// Blanket impl for all TensorDescriptor types
impl<T: TensorDescriptor + ?Sized> TensorDescriptorExt for T {}
// =============================================================================
// Arc<dyn TensorDescriptor> support for NixlRegisterExt
// =============================================================================
impl nixl::NixlCompatible for Arc<dyn TensorDescriptor> {
fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
let storage = self.storage_kind();
let (mem_type, device_id) = match storage {
StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
StorageKind::System => (nixl::MemType::Dram, 0),
StorageKind::Pinned => (nixl::MemType::Dram, 0),
StorageKind::Disk(fd) => (nixl::MemType::File, fd),
};
(self.addr() as *const u8, self.size(), mem_type, device_id)
}
}
impl MemoryDescriptor for Arc<dyn TensorDescriptor> {
fn addr(&self) -> usize {
(**self).addr()
}
fn size(&self) -> usize {
(**self).size()
}
fn storage_kind(&self) -> StorageKind {
(**self).storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
impl TensorDescriptor for Arc<dyn TensorDescriptor> {
fn shape(&self) -> &[usize] {
(**self).shape()
}
fn stride(&self) -> &[usize] {
(**self).stride()
}
fn element_size(&self) -> usize {
(**self).element_size()
}
}
// =============================================================================
// Arc<dyn TensorDescriptor + Send + Sync> support
// =============================================================================
impl nixl::NixlCompatible for Arc<dyn TensorDescriptor + Send + Sync> {
fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
let storage = self.storage_kind();
let (mem_type, device_id) = match storage {
StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
StorageKind::System => (nixl::MemType::Dram, 0),
StorageKind::Pinned => (nixl::MemType::Dram, 0),
StorageKind::Disk(fd) => (nixl::MemType::File, fd),
};
(self.addr() as *const u8, self.size(), mem_type, device_id)
}
}
impl MemoryDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
fn addr(&self) -> usize {
(**self).addr()
}
fn size(&self) -> usize {
(**self).size()
}
fn storage_kind(&self) -> StorageKind {
(**self).storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
impl TensorDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
fn shape(&self) -> &[usize] {
(**self).shape()
}
fn stride(&self) -> &[usize] {
(**self).stride()
}
fn element_size(&self) -> usize {
(**self).element_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Simple test tensor for unit tests
#[derive(Debug)]
struct TestTensor {
addr: usize,
size: usize,
shape: Vec<usize>,
stride: Vec<usize>,
element_size: usize,
}
impl MemoryDescriptor for TestTensor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
impl TensorDescriptor for TestTensor {
fn shape(&self) -> &[usize] {
&self.shape
}
fn stride(&self) -> &[usize] {
&self.stride
}
fn element_size(&self) -> usize {
self.element_size
}
}
#[test]
fn test_numel() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4, // 24 elements * 4 bytes
shape: vec![2, 3, 4],
stride: vec![12, 4, 1],
element_size: 4,
};
assert_eq!(tensor.numel(), 24);
}
#[test]
fn test_ndim() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4,
shape: vec![2, 3, 4],
stride: vec![12, 4, 1],
element_size: 4,
};
assert_eq!(tensor.ndim(), 3);
}
#[test]
fn test_is_contiguous_true() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4,
shape: vec![2, 3, 4],
stride: vec![12, 4, 1], // Contiguous stride
element_size: 4,
};
assert!(tensor.is_contiguous());
}
#[test]
fn test_is_contiguous_false() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4,
shape: vec![2, 3, 4],
stride: vec![24, 4, 1], // Non-contiguous (gap between first dim)
element_size: 4,
};
assert!(!tensor.is_contiguous());
}
#[test]
fn test_contiguous_stride() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4,
shape: vec![2, 3, 4],
stride: vec![24, 4, 1], // Non-contiguous
element_size: 4,
};
assert_eq!(tensor.contiguous_stride(), vec![12, 4, 1]);
}
#[test]
fn test_empty_tensor() {
let tensor = TestTensor {
addr: 0x1000,
size: 0,
shape: vec![],
stride: vec![],
element_size: 4,
};
assert_eq!(tensor.numel(), 1); // Empty product is 1
assert_eq!(tensor.ndim(), 0);
assert!(tensor.is_contiguous());
}
#[test]
fn test_1d_tensor_contiguous() {
let tensor = TestTensor {
addr: 0x1000,
size: 10 * 4,
shape: vec![10],
stride: vec![1],
element_size: 4,
};
assert_eq!(tensor.numel(), 10);
assert_eq!(tensor.ndim(), 1);
assert!(tensor.is_contiguous());
assert_eq!(tensor.contiguous_stride(), vec![1]);
}
#[test]
fn test_1d_tensor_non_contiguous() {
let tensor = TestTensor {
addr: 0x1000,
size: 10 * 4,
shape: vec![10],
stride: vec![2], // Strided access (every other element)
element_size: 4,
};
assert!(!tensor.is_contiguous());
}
#[test]
fn test_2d_tensor() {
let tensor = TestTensor {
addr: 0x1000,
size: 6 * 4,
shape: vec![2, 3],
stride: vec![3, 1],
element_size: 4,
};
assert_eq!(tensor.numel(), 6);
assert_eq!(tensor.ndim(), 2);
assert!(tensor.is_contiguous());
}
#[test]
fn test_high_dimensional_tensor() {
// 5D tensor: [2, 3, 4, 5, 6]
let shape = vec![2, 3, 4, 5, 6];
// Contiguous stride: [360, 120, 30, 6, 1]
let stride = vec![360, 120, 30, 6, 1];
let numel: usize = shape.iter().product();
let tensor = TestTensor {
addr: 0x1000,
size: numel * 4,
shape,
stride,
element_size: 4,
};
assert_eq!(tensor.numel(), 720);
assert_eq!(tensor.ndim(), 5);
assert!(tensor.is_contiguous());
assert_eq!(tensor.contiguous_stride(), vec![360, 120, 30, 6, 1]);
}
#[test]
fn test_tensor_with_size_1_dimensions() {
// Shape with singleton dimensions: [1, 3, 1, 4]
let tensor = TestTensor {
addr: 0x1000,
size: 12 * 4,
shape: vec![1, 3, 1, 4],
stride: vec![12, 4, 4, 1], // Contiguous for this shape
element_size: 4,
};
assert_eq!(tensor.numel(), 12);
assert_eq!(tensor.ndim(), 4);
assert!(tensor.is_contiguous());
}
#[test]
fn test_contiguous_stride_empty() {
let tensor = TestTensor {
addr: 0x1000,
size: 0,
shape: vec![],
stride: vec![],
element_size: 4,
};
assert!(tensor.contiguous_stride().is_empty());
}
#[test]
fn test_contiguous_stride_1d() {
let tensor = TestTensor {
addr: 0x1000,
size: 5 * 4,
shape: vec![5],
stride: vec![1],
element_size: 4,
};
assert_eq!(tensor.contiguous_stride(), vec![1]);
}
#[test]
fn test_cuda_device_id_system() {
let tensor = TestTensor {
addr: 0x1000,
size: 100,
shape: vec![10],
stride: vec![1],
element_size: 4,
};
assert_eq!(tensor.cuda_device_id(), None);
}
/// Test tensor that reports Device storage kind
#[derive(Debug)]
struct DeviceTensor {
addr: usize,
size: usize,
shape: Vec<usize>,
stride: Vec<usize>,
element_size: usize,
device_id: u32,
}
impl MemoryDescriptor for DeviceTensor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Device(self.device_id)
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
impl TensorDescriptor for DeviceTensor {
fn shape(&self) -> &[usize] {
&self.shape
}
fn stride(&self) -> &[usize] {
&self.stride
}
fn element_size(&self) -> usize {
self.element_size
}
}
#[test]
fn test_cuda_device_id_device() {
let tensor = DeviceTensor {
addr: 0x1000,
size: 100,
shape: vec![10],
stride: vec![1],
element_size: 4,
device_id: 2,
};
assert_eq!(tensor.cuda_device_id(), Some(2));
}
#[test]
fn test_arc_tensor_descriptor() {
let tensor = TestTensor {
addr: 0x1000,
size: 24 * 4,
shape: vec![2, 3, 4],
stride: vec![12, 4, 1],
element_size: 4,
};
let arc: Arc<dyn TensorDescriptor> = Arc::new(tensor);
assert_eq!(arc.addr(), 0x1000);
assert_eq!(arc.size(), 24 * 4);
assert_eq!(arc.shape(), &[2, 3, 4]);
assert_eq!(arc.stride(), &[12, 4, 1]);
assert_eq!(arc.element_size(), 4);
assert_eq!(arc.storage_kind(), StorageKind::System);
assert!(arc.nixl_descriptor().is_none());
}
#[test]
fn test_arc_tensor_send_sync() {
// TestTensor doesn't impl Send+Sync, so we need a type that does
struct SendSyncTensor {
addr: usize,
size: usize,
shape: Vec<usize>,
stride: Vec<usize>,
element_size: usize,
}
unsafe impl Send for SendSyncTensor {}
unsafe impl Sync for SendSyncTensor {}
impl std::fmt::Debug for SendSyncTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SendSyncTensor").finish()
}
}
impl MemoryDescriptor for SendSyncTensor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
impl TensorDescriptor for SendSyncTensor {
fn shape(&self) -> &[usize] {
&self.shape
}
fn stride(&self) -> &[usize] {
&self.stride
}
fn element_size(&self) -> usize {
self.element_size
}
}
let tensor = SendSyncTensor {
addr: 0x2000,
size: 100,
shape: vec![10],
stride: vec![1],
element_size: 4,
};
let arc: Arc<dyn TensorDescriptor + Send + Sync> = Arc::new(tensor);
assert_eq!(arc.addr(), 0x2000);
assert_eq!(arc.size(), 100);
assert_eq!(arc.shape(), &[10]);
assert_eq!(arc.stride(), &[1]);
assert_eq!(arc.element_size(), 4);
}
#[test]
fn test_tensor_shape_stride_element_size() {
let tensor = TestTensor {
addr: 0x1000,
size: 48,
shape: vec![3, 4],
stride: vec![4, 1],
element_size: 4,
};
assert_eq!(tensor.shape(), &[3, 4]);
assert_eq!(tensor.stride(), &[4, 1]);
assert_eq!(tensor.element_size(), 4);
}
#[test]
fn test_tensor_numel_single_element() {
let tensor = TestTensor {
addr: 0x1000,
size: 4,
shape: vec![1, 1, 1],
stride: vec![1, 1, 1],
element_size: 4,
};
assert_eq!(tensor.numel(), 1);
}
}
......@@ -7,13 +7,13 @@ use super::*;
/// Helper function to validate NIXL descriptor consistency.
///
/// For any MemoryDescription that returns Some from nixl_descriptor(),
/// For any MemoryDescriptor that returns Some from nixl_descriptor(),
/// this validates that the descriptor's addr and size match the memory region's addr and size.
///
/// # Panics
/// Panics if descriptor values don't match memory region values.
#[allow(dead_code)]
fn validate_nixl_descriptor<M: MemoryDescription>(memory: &M) {
fn validate_nixl_descriptor<M: MemoryDescriptor>(memory: &M) {
if let Some(desc) = memory.nixl_descriptor() {
assert_eq!(
desc.addr as usize,
......@@ -32,6 +32,186 @@ fn validate_nixl_descriptor<M: MemoryDescription>(memory: &M) {
}
}
// ========== StorageKind tests ==========
#[test]
fn test_storage_kind_cuda_device_index_device() {
let kind = StorageKind::Device(3);
assert_eq!(kind.cuda_device_index(), Some(3));
}
#[test]
fn test_storage_kind_cuda_device_index_system() {
let kind = StorageKind::System;
assert_eq!(kind.cuda_device_index(), None);
}
#[test]
fn test_storage_kind_cuda_device_index_pinned() {
let kind = StorageKind::Pinned;
assert_eq!(kind.cuda_device_index(), None);
}
#[test]
fn test_storage_kind_cuda_device_index_disk() {
let kind = StorageKind::Disk(123);
assert_eq!(kind.cuda_device_index(), None);
}
#[test]
fn test_storage_kind_is_cuda() {
assert!(StorageKind::Device(0).is_cuda());
assert!(!StorageKind::System.is_cuda());
assert!(!StorageKind::Pinned.is_cuda());
assert!(!StorageKind::Disk(1).is_cuda());
}
#[test]
fn test_storage_kind_is_system() {
assert!(StorageKind::System.is_system());
assert!(!StorageKind::Device(0).is_system());
assert!(!StorageKind::Pinned.is_system());
assert!(!StorageKind::Disk(1).is_system());
}
#[test]
fn test_storage_kind_is_pinned() {
assert!(StorageKind::Pinned.is_pinned());
assert!(!StorageKind::System.is_pinned());
assert!(!StorageKind::Device(0).is_pinned());
assert!(!StorageKind::Disk(1).is_pinned());
}
#[test]
fn test_storage_kind_is_disk() {
assert!(StorageKind::Disk(1).is_disk());
assert!(!StorageKind::System.is_disk());
assert!(!StorageKind::Pinned.is_disk());
assert!(!StorageKind::Device(0).is_disk());
}
// ========== Buffer tests ==========
#[test]
fn test_buffer_new() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = Buffer::new(storage);
assert_eq!(buffer.size(), 1024);
assert_eq!(buffer.storage_kind(), StorageKind::System);
}
#[test]
fn test_buffer_from_arc() {
use std::sync::Arc;
let storage = SystemStorage::new(2048).unwrap();
let arc: Arc<dyn MemoryDescriptor> = Arc::new(storage);
let buffer = Buffer::from_arc(arc);
assert_eq!(buffer.size(), 2048);
}
#[test]
fn test_buffer_from_impl() {
use std::sync::Arc;
let storage = SystemStorage::new(512).unwrap();
let arc: Arc<dyn MemoryDescriptor> = Arc::new(storage);
let buffer: Buffer = arc.into();
assert_eq!(buffer.size(), 512);
}
#[test]
fn test_buffer_deref() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = Buffer::new(storage);
// Deref allows calling MemoryDescriptor methods directly
let size = buffer.size();
assert_eq!(size, 1024);
}
#[test]
fn test_buffer_debug() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = Buffer::new(storage);
let debug_str = format!("{:?}", buffer);
assert!(debug_str.contains("Buffer"));
assert!(debug_str.contains("size"));
assert!(debug_str.contains("addr"));
}
#[test]
fn test_buffer_clone() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = Buffer::new(storage);
let cloned = buffer.clone();
assert_eq!(buffer.addr(), cloned.addr());
assert_eq!(buffer.size(), cloned.size());
}
// ========== MemoryRegion tests ==========
#[test]
fn test_memory_region_new() {
let region = MemoryRegion::new(0x1000, 4096);
assert_eq!(region.addr, 0x1000);
assert_eq!(region.size, 4096);
}
#[test]
fn test_memory_region_accessors() {
let region = MemoryRegion::new(0x2000, 8192);
assert_eq!(region.addr(), 0x2000);
assert_eq!(region.size(), 8192);
}
#[test]
fn test_memory_region_zero_address() {
let region = MemoryRegion::new(0, 1024);
assert_eq!(region.addr(), 0);
assert_eq!(region.size(), 1024);
}
#[test]
fn test_memory_region_zero_size() {
let region = MemoryRegion::new(0x1000, 0);
assert_eq!(region.addr(), 0x1000);
assert_eq!(region.size(), 0);
}
#[test]
fn test_memory_region_clone() {
let region = MemoryRegion::new(0x3000, 2048);
let cloned = region;
assert_eq!(region.addr(), cloned.addr());
assert_eq!(region.size(), cloned.size());
}
#[test]
fn test_memory_region_eq() {
let region1 = MemoryRegion::new(0x1000, 4096);
let region2 = MemoryRegion::new(0x1000, 4096);
let region3 = MemoryRegion::new(0x2000, 4096);
assert_eq!(region1, region2);
assert_ne!(region1, region3);
}
#[test]
fn test_memory_region_debug() {
let region = MemoryRegion::new(0x1000, 4096);
let debug_str = format!("{:?}", region);
assert!(debug_str.contains("MemoryRegion"));
}
// ========== create_buffer helper tests ==========
#[test]
fn test_create_buffer_helper() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = create_buffer(storage);
assert_eq!(buffer.size(), 1024);
assert_eq!(buffer.storage_kind(), StorageKind::System);
}
// ========== Original tests ==========
#[test]
fn test_system_storage() {
let storage = SystemStorage::new(1024).unwrap();
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
impl TorchDevice {
pub fn is_cuda(&self) -> bool {
matches!(self, TorchDevice::Cuda(_))
}
pub fn cuda_device_index(&self) -> Option<usize> {
match self {
TorchDevice::Cuda(index) => Some(*index),
TorchDevice::Other(_) => None,
}
}
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment