Unverified Commit 976bb70a authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: add KVBM memory management enhancements (DIS-1311) (#5532)

parent 57bdfea9
...@@ -2117,6 +2117,7 @@ dependencies = [ ...@@ -2117,6 +2117,7 @@ dependencies = [
"nixl-sys", "nixl-sys",
"offset-allocator", "offset-allocator",
"serde", "serde",
"serde_json",
"tempfile", "tempfile",
"thiserror 2.0.18", "thiserror 2.0.18",
"tracing", "tracing",
......
...@@ -37,4 +37,5 @@ nix = { version = "0.30", features = ["fs"] } ...@@ -37,4 +37,5 @@ nix = { version = "0.30", features = ["fs"] }
offset-allocator = "0.2" offset-allocator = "0.2"
[dev-dependencies] [dev-dependencies]
serde_json = { workspace = true }
tempfile = "3" tempfile = "3"
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
//! Storage actions. //! Storage actions.
use super::{MemoryDescription, StorageError}; use super::{MemoryDescriptor, StorageError};
/// Extension trait for storage types that support memory setting operations /// Extension trait for storage types that support memory setting operations
pub trait Memset: MemoryDescription { pub trait Memset: MemoryDescriptor {
/// Sets a region of memory to a specific value /// Sets a region of memory to a specific value
/// ///
/// # Arguments /// # Arguments
...@@ -22,7 +22,7 @@ pub trait Memset: MemoryDescription { ...@@ -22,7 +22,7 @@ pub trait Memset: MemoryDescription {
} }
/// Extension trait for storage types that support slicing operations /// Extension trait for storage types that support slicing operations
pub trait Slice: MemoryDescription + 'static { pub trait Slice: MemoryDescriptor + 'static {
/// Returns an immutable byte slice view of the entire storage region /// Returns an immutable byte slice view of the entire storage region
/// ///
/// # Safety /// # Safety
...@@ -133,7 +133,8 @@ pub trait Slice: MemoryDescription + 'static { ...@@ -133,7 +133,8 @@ pub trait Slice: MemoryDescription + 'static {
} }
} }
pub trait SliceMut: MemoryDescription + 'static { /// Extension trait for storage types that support mutable slicing operations.
pub trait SliceMut: MemoryDescriptor + 'static {
/// Returns a mutable byte slice view of the entire storage region /// Returns a mutable byte slice view of the entire storage region
/// ///
/// # Safety /// # Safety
...@@ -239,3 +240,234 @@ pub trait SliceMut: MemoryDescription + 'static { ...@@ -239,3 +240,234 @@ pub trait SliceMut: MemoryDescription + 'static {
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) }) Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::SystemStorage;
// Helper to create a test storage
fn create_storage(size: usize) -> SystemStorage {
SystemStorage::new(size).expect("allocation failed")
}
// ========== Memset tests ==========
#[test]
fn test_memset_full_region() {
let mut storage = create_storage(1024);
storage
.memset(0xAB, 0, 1024)
.expect("memset should succeed");
let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
assert!(slice.iter().all(|&b| b == 0xAB));
}
#[test]
fn test_memset_partial_region() {
let mut storage = create_storage(1024);
// First fill with 0x00
storage
.memset(0x00, 0, 1024)
.expect("memset should succeed");
// Then fill middle region with 0xFF
storage
.memset(0xFF, 100, 200)
.expect("memset should succeed");
let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
// Check before region
assert!(slice[..100].iter().all(|&b| b == 0x00));
// Check filled region
assert!(slice[100..300].iter().all(|&b| b == 0xFF));
// Check after region
assert!(slice[300..].iter().all(|&b| b == 0x00));
}
#[test]
fn test_memset_at_end() {
let mut storage = create_storage(1024);
// Fill the last 100 bytes
storage
.memset(0x42, 924, 100)
.expect("memset should succeed");
let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
assert!(slice[924..].iter().all(|&b| b == 0x42));
}
#[test]
fn test_memset_zero_size() {
let mut storage = create_storage(1024);
// Zero-size memset should succeed (no-op)
storage
.memset(0xFF, 500, 0)
.expect("zero-size memset should succeed");
}
#[test]
fn test_memset_out_of_bounds() {
let mut storage = create_storage(1024);
// Try to write beyond the storage
let result = storage.memset(0xFF, 900, 200);
assert!(result.is_err());
}
#[test]
fn test_memset_offset_overflow() {
let mut storage = create_storage(1024);
// offset + size would overflow
let result = storage.memset(0xFF, usize::MAX, 1);
assert!(result.is_err());
}
// ========== Slice tests ==========
#[test]
fn test_as_slice_full() {
let mut storage = create_storage(1024);
storage
.memset(0xCD, 0, 1024)
.expect("memset should succeed");
let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
assert_eq!(slice.len(), 1024);
assert!(slice.iter().all(|&b| b == 0xCD));
}
#[test]
fn test_slice_partial() {
let mut storage = create_storage(1024);
storage
.memset(0x00, 0, 1024)
.expect("memset should succeed");
storage
.memset(0xAA, 100, 50)
.expect("memset should succeed");
let partial = storage.slice(100, 50).expect("slice should succeed");
assert_eq!(partial.len(), 50);
assert!(partial.iter().all(|&b| b == 0xAA));
}
#[test]
fn test_slice_at_start() {
let storage = create_storage(1024);
let slice = storage.slice(0, 100).expect("slice should succeed");
assert_eq!(slice.len(), 100);
}
#[test]
fn test_slice_at_end() {
let storage = create_storage(1024);
let slice = storage.slice(924, 100).expect("slice should succeed");
assert_eq!(slice.len(), 100);
}
#[test]
fn test_slice_zero_length() {
let storage = create_storage(1024);
let slice = storage
.slice(500, 0)
.expect("zero-length slice should succeed");
assert!(slice.is_empty());
}
#[test]
fn test_slice_out_of_bounds() {
let storage = create_storage(1024);
let result = storage.slice(900, 200);
assert!(result.is_err());
}
#[test]
fn test_slice_offset_overflow() {
let storage = create_storage(1024);
// offset + len would overflow when using saturating_add
let result = storage.slice(usize::MAX, 1);
assert!(result.is_err());
}
// ========== Typed slice tests ==========
#[test]
fn test_as_slice_typed_u32() {
let mut storage = create_storage(1024);
// Fill with known pattern
storage
.memset(0x00, 0, 1024)
.expect("memset should succeed");
let typed: &[u32] = storage
.as_slice_typed()
.expect("typed slice should succeed");
assert_eq!(typed.len(), 256); // 1024 / 4
assert!(typed.iter().all(|&v| v == 0));
}
#[test]
fn test_as_slice_typed_u64() {
let storage = create_storage(1024);
let typed: &[u64] = storage
.as_slice_typed()
.expect("typed slice should succeed");
assert_eq!(typed.len(), 128); // 1024 / 8
}
#[test]
fn test_slice_typed_partial() {
let mut storage = create_storage(1024);
storage
.memset(0x00, 0, 1024)
.expect("memset should succeed");
// Slice 10 u32 elements starting at offset 0
let typed: &[u32] = storage
.slice_typed(0, 10)
.expect("typed slice should succeed");
assert_eq!(typed.len(), 10);
}
#[test]
fn test_slice_typed_with_offset() {
let storage = create_storage(1024);
// Slice starting at offset 64 (aligned for u64)
let typed: &[u64] = storage
.slice_typed(64, 5)
.expect("typed slice should succeed");
assert_eq!(typed.len(), 5);
}
#[test]
fn test_as_slice_typed_zst_error() {
let storage = create_storage(1024);
// Zero-sized types should fail
let result: Result<&[()], _> = storage.as_slice_typed();
assert!(result.is_err());
}
#[test]
fn test_as_slice_typed_size_not_multiple() {
// Create storage with size not divisible by 4
let storage = create_storage(1023);
let result: Result<&[u32], _> = storage.as_slice_typed();
assert!(result.is_err());
}
#[test]
fn test_slice_typed_length_overflow() {
let storage = create_storage(1024);
// len * size_of::<u64>() would overflow
let result: Result<&[u64], _> = storage.slice_typed(0, usize::MAX);
assert!(result.is_err());
}
#[test]
fn test_slice_typed_out_of_bounds() {
let storage = create_storage(1024);
// Request more elements than available
let result: Result<&[u64], _> = storage.slice_typed(0, 200);
assert!(result.is_err());
}
}
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
//! # Arena Allocator //! # Arena Allocator
//! //!
//! This module provides an arena allocator for generally heap-like allocations. //! This module provides an arena allocator for generally heap-like allocations.
//! An [`ArenaAllocator`] can be created by taking ownership of a [`MemoryDescription`] instance. //! An [`ArenaAllocator`] can be created by taking ownership of a [`MemoryDescriptor`] instance.
//! //!
//! The [`ArenaAllocator`] allocates memory contiguous regions using the [`offset_allocator`] crate, //! The [`ArenaAllocator`] allocates memory contiguous regions using the [`offset_allocator`] crate,
//! which builds on [Sebastian Aaltonen's ArenaAllocator](https://github.com/sebbbi/ArenaAllocator) //! which builds on [Sebastian Aaltonen's ArenaAllocator](https://github.com/sebbbi/ArenaAllocator)
use crate::StorageKind; use crate::StorageKind;
use super::{MemoryDescription, StorageError}; use super::{MemoryDescriptor, StorageError};
use offset_allocator::{Allocation, Allocator}; use offset_allocator::{Allocation, Allocator};
use std::{ use std::{
any::Any, any::Any,
...@@ -20,6 +20,7 @@ use std::{ ...@@ -20,6 +20,7 @@ use std::{
/// Errors specific to arena allocation. /// Errors specific to arena allocation.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum ArenaError { pub enum ArenaError {
#[error("Page size must be a power of 2")] #[error("Page size must be a power of 2")]
PageSizeNotAligned, PageSizeNotAligned,
...@@ -34,20 +35,20 @@ pub enum ArenaError { ...@@ -34,20 +35,20 @@ pub enum ArenaError {
StorageError(#[from] StorageError), StorageError(#[from] StorageError),
} }
/// Arena allocator backed by an instance of a [`MemoryDescription`] object. /// Arena allocator backed by an instance of a [`MemoryDescriptor`] object.
/// ///
/// This struct wraps an [`Allocator`] from the [`offset_allocator`] crate, /// This struct wraps an [`Allocator`] from the [`offset_allocator`] crate,
/// and provides methods for allocating memory from the storage. /// and provides methods for allocating memory from the storage.
/// ///
/// The allocator is thread-safe, and the storage is shared between the allocator and the buffers. /// The allocator is thread-safe, and the storage is shared between the allocator and the buffers.
#[derive(Clone)] #[derive(Clone)]
pub struct ArenaAllocator<S: MemoryDescription> { pub struct ArenaAllocator<S: MemoryDescriptor> {
storage: Arc<S>, storage: Arc<S>,
allocator: Arc<Mutex<Allocator>>, allocator: Arc<Mutex<Allocator>>,
page_size: u64, page_size: u64,
} }
impl<S: MemoryDescription> std::fmt::Debug for ArenaAllocator<S> { impl<S: MemoryDescriptor> std::fmt::Debug for ArenaAllocator<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!( write!(
f, f,
...@@ -62,18 +63,24 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaAllocator<S> { ...@@ -62,18 +63,24 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaAllocator<S> {
/// This struct wraps an [`Allocation`] from the [`offset_allocator`] crate, /// This struct wraps an [`Allocation`] from the [`offset_allocator`] crate,
/// and provides methods for interacting with the allocated memory. /// and provides methods for interacting with the allocated memory.
/// ///
/// The buffer is backed by a [`MemoryDescription`] object, and the allocation is freed when the buffer is dropped. /// The buffer is backed by a [`MemoryDescriptor`] object, and the allocation is freed when the buffer is dropped.
pub struct ArenaBuffer<S: MemoryDescription> { pub struct ArenaBuffer<S: MemoryDescriptor> {
/// Byte offset from the start of the backing storage.
offset: usize, offset: usize,
/// Absolute memory address of this buffer.
address: usize, address: usize,
/// User-requested allocation size in bytes.
requested_size: usize, requested_size: usize,
/// Shared reference to the backing storage.
storage: Arc<S>, storage: Arc<S>,
/// Internal allocation handle from the offset allocator.
allocation: Allocation, allocation: Allocation,
/// Shared reference to the allocator for freeing on drop.
allocator: Arc<Mutex<Allocator>>, allocator: Arc<Mutex<Allocator>>,
} }
impl<S: MemoryDescription> ArenaAllocator<S> { impl<S: MemoryDescriptor> ArenaAllocator<S> {
/// Create a new [`ArenaAllocator`] from a [`MemoryDescription`] object and a page size. /// Create a new [`ArenaAllocator`] from a [`MemoryDescriptor`] object and a page size.
/// ///
/// The page size must be a power of two. /// The page size must be a power of two.
/// ///
...@@ -107,7 +114,11 @@ impl<S: MemoryDescription> ArenaAllocator<S> { ...@@ -107,7 +114,11 @@ impl<S: MemoryDescription> ArenaAllocator<S> {
}) })
} }
/// Allocate a new [`ArenaBuffer`] from the allocator. /// Allocates a new [`ArenaBuffer`] of the given size from this allocator.
///
/// The actual allocation may consume more pages than strictly needed due to
/// page-size rounding. Returns [`ArenaError::AllocationFailed`] if there are
/// not enough contiguous pages available.
pub fn allocate(&self, size: usize) -> std::result::Result<ArenaBuffer<S>, ArenaError> { pub fn allocate(&self, size: usize) -> std::result::Result<ArenaBuffer<S>, ArenaError> {
let size = size as u64; let size = size as u64;
let pages = size.div_ceil(self.page_size); let pages = size.div_ceil(self.page_size);
...@@ -135,7 +146,7 @@ impl<S: MemoryDescription> ArenaAllocator<S> { ...@@ -135,7 +146,7 @@ impl<S: MemoryDescription> ArenaAllocator<S> {
} }
} }
impl<S: MemoryDescription> std::fmt::Debug for ArenaBuffer<S> { impl<S: MemoryDescriptor> std::fmt::Debug for ArenaBuffer<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!( write!(
f, f,
...@@ -148,7 +159,7 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaBuffer<S> { ...@@ -148,7 +159,7 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaBuffer<S> {
} }
} }
impl<S: MemoryDescription + 'static> MemoryDescription for ArenaBuffer<S> { impl<S: MemoryDescriptor + 'static> MemoryDescriptor for ArenaBuffer<S> {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.address self.address
} }
...@@ -177,7 +188,7 @@ use super::nixl::{NixlCompatible, NixlDescriptor, RegisteredView}; ...@@ -177,7 +188,7 @@ use super::nixl::{NixlCompatible, NixlDescriptor, RegisteredView};
impl<S> ArenaBuffer<S> impl<S> ArenaBuffer<S>
where where
S: MemoryDescription + NixlCompatible, S: MemoryDescriptor + NixlCompatible,
{ {
/// Create a NIXL descriptor for this buffer with the correct offset and size. /// Create a NIXL descriptor for this buffer with the correct offset and size.
/// ///
...@@ -200,7 +211,7 @@ where ...@@ -200,7 +211,7 @@ where
impl<S> ArenaBuffer<S> impl<S> ArenaBuffer<S>
where where
S: MemoryDescription + RegisteredView, S: MemoryDescriptor + RegisteredView,
{ {
/// Get the agent name from registered storage. /// Get the agent name from registered storage.
/// ///
...@@ -223,7 +234,7 @@ where ...@@ -223,7 +234,7 @@ where
} }
} }
impl<S: MemoryDescription> Drop for ArenaBuffer<S> { impl<S: MemoryDescriptor> Drop for ArenaBuffer<S> {
fn drop(&mut self) { fn drop(&mut self) {
self.allocator.lock().unwrap().free(self.allocation); self.allocator.lock().unwrap().free(self.allocation);
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//! CUDA device memory storage. //! CUDA device memory storage.
use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor}; use super::{MemoryDescriptor, Result, StorageError, StorageKind, nixl::NixlDescriptor};
use cudarc::driver::CudaContext; use cudarc::driver::CudaContext;
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -26,9 +26,13 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> { ...@@ -26,9 +26,13 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
/// CUDA device memory allocated via cudaMalloc. /// CUDA device memory allocated via cudaMalloc.
#[derive(Debug)] #[derive(Debug)]
pub struct DeviceStorage { pub struct DeviceStorage {
/// CUDA context used for allocation and deallocation.
ctx: Arc<CudaContext>, ctx: Arc<CudaContext>,
/// Device pointer to the allocated memory.
ptr: u64, ptr: u64,
/// CUDA device ID where memory is allocated.
device_id: u32, device_id: u32,
/// Size of the allocation in bytes.
len: usize, len: usize,
} }
...@@ -84,7 +88,7 @@ impl Drop for DeviceStorage { ...@@ -84,7 +88,7 @@ impl Drop for DeviceStorage {
} }
} }
impl MemoryDescription for DeviceStorage { impl MemoryDescriptor for DeviceStorage {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.device_ptr() as usize self.device_ptr() as usize
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//! Disk-backed memory storage using memory-mapped files. //! Disk-backed memory storage using memory-mapped files.
use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor}; use super::{MemoryDescriptor, Result, StorageError, StorageKind, nixl::NixlDescriptor};
use std::any::Any; use std::any::Any;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
...@@ -16,15 +16,21 @@ use std::os::fd::BorrowedFd; ...@@ -16,15 +16,21 @@ use std::os::fd::BorrowedFd;
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR"; const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/"; const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
/// Disk-backed storage using memory-mapped files with O_DIRECT support.
#[derive(Debug)] #[derive(Debug)]
pub struct DiskStorage { pub struct DiskStorage {
/// File descriptor for the backing file.
fd: u64, fd: u64,
/// Path to the backing file.
path: PathBuf, path: PathBuf,
/// Size of the storage in bytes.
size: usize, size: usize,
/// Whether the file has been unlinked from the filesystem.
unlinked: bool, unlinked: bool,
} }
impl DiskStorage { impl DiskStorage {
/// Creates a new disk storage of the given size in the default cache directory.
pub fn new(size: usize) -> Result<Self> { pub fn new(size: usize) -> Result<Self> {
// We need to open our file with some special flags that aren't supported by the tempfile crate. // We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags. // Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
...@@ -36,6 +42,7 @@ impl DiskStorage { ...@@ -36,6 +42,7 @@ impl DiskStorage {
Self::new_at(file_path, size) Self::new_at(file_path, size)
} }
/// Creates a new disk storage at the specified path with the given size.
pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> { pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
if len == 0 { if len == 0 {
return Err(StorageError::AllocationFailed( return Err(StorageError::AllocationFailed(
...@@ -140,15 +147,17 @@ impl DiskStorage { ...@@ -140,15 +147,17 @@ impl DiskStorage {
}) })
} }
/// Returns the file descriptor of the backing file.
pub fn fd(&self) -> u64 { pub fn fd(&self) -> u64 {
self.fd self.fd
} }
/// Returns the path to the backing file.
pub fn path(&self) -> &Path { pub fn path(&self) -> &Path {
self.path.as_path() self.path.as_path()
} }
/// Unlink our temp file. /// Unlinks the backing file from the filesystem.
/// This means that when this process terminates, the file will be automatically deleted by the OS. /// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked. /// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL. /// To get around this, we unlink the file only after we've registered it with NIXL.
...@@ -163,6 +172,7 @@ impl DiskStorage { ...@@ -163,6 +172,7 @@ impl DiskStorage {
Ok(()) Ok(())
} }
/// Returns whether the backing file has been unlinked from the filesystem.
pub fn unlinked(&self) -> bool { pub fn unlinked(&self) -> bool {
self.unlinked self.unlinked
} }
...@@ -177,7 +187,7 @@ impl Drop for DiskStorage { ...@@ -177,7 +187,7 @@ impl Drop for DiskStorage {
} }
} }
impl MemoryDescription for DiskStorage { impl MemoryDescriptor for DiskStorage {
fn addr(&self) -> usize { fn addr(&self) -> usize {
0 0
} }
...@@ -345,7 +355,7 @@ impl super::nixl::NixlCompatible for DiskStorage { ...@@ -345,7 +355,7 @@ impl super::nixl::NixlCompatible for DiskStorage {
// } // }
// } // }
// impl MemoryDescription for MemMappedFileStorage { // impl MemoryDescriptor for MemMappedFileStorage {
// fn addr(&self) -> usize { // fn addr(&self) -> usize {
// self.mmap.as_ptr() as usize // self.mmap.as_ptr() as usize
// } // }
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! External memory wrapper for memory allocated by external frameworks.
//!
//! This module provides `ExternalDeviceMemory` for wrapping pointers to GPU
//! memory allocated by external frameworks (e.g., vLLM's KV cache). This type
//! does NOT own the memory - ownership remains with the external framework.
//!
//! The primary use case is registering external GPU memory with NIXL for RDMA
//! transfers without copying.
use crate::nixl::{MemType, NixlCompatible, NixlDescriptor};
use crate::{MemoryDescriptor, StorageKind};
use std::any::Any;
use std::fmt;
/// Wrapper for externally-allocated device (GPU) memory.
///
/// This type wraps a raw pointer to GPU memory that is owned by an external
/// framework (like vLLM). It provides the necessary traits for NIXL registration
/// without taking ownership of the underlying memory.
///
/// # Safety
///
/// This type relies on the caller to guarantee that:
/// - The pointer points to valid GPU memory on the specified device
/// - The memory remains valid for the lifetime of this wrapper
/// - The memory size is exactly as specified
/// - The external framework doesn't free the memory while this wrapper exists
///
/// # Example
///
/// ```ignore
/// // vLLM allocates KV cache tensors
/// let tensor_ptr = tensor.data_ptr();
/// let tensor_size = tensor.size_bytes();
/// let device_id = tensor.device.index;
///
/// // Wrap without taking ownership
/// let external = unsafe {
/// ExternalDeviceMemory::new(tensor_ptr as *const u8, tensor_size, device_id as u64)
/// };
///
/// // Register with NIXL for RDMA
/// let registered = register_with_nixl(external, &agent, None)?;
/// ```
pub struct ExternalDeviceMemory {
/// Raw pointer to externally-allocated GPU memory.
ptr: *const u8,
/// Size of the memory region in bytes.
size: usize,
/// CUDA device ID where this memory resides.
device_id: u64,
}
// Safety: The external framework (e.g., vLLM) guarantees the memory remains valid
// for the lifetime of the KV cache. The pointer is only used for NIXL registration
// and transfer operations which are synchronized by the framework.
unsafe impl Send for ExternalDeviceMemory {}
unsafe impl Sync for ExternalDeviceMemory {}
impl ExternalDeviceMemory {
/// Create a wrapper for external device memory.
///
/// # Safety
///
/// Caller must ensure:
/// - `ptr` points to valid GPU memory on CUDA device `device_id`
/// - The memory remains valid for the lifetime of this wrapper
/// - The memory size is exactly `size` bytes
/// - The external framework doesn't free the memory while this wrapper exists
#[inline]
pub unsafe fn new(ptr: *const u8, size: usize, device_id: u64) -> Self {
Self {
ptr,
size,
device_id,
}
}
/// Get the raw pointer to the external memory.
#[inline]
pub fn as_ptr(&self) -> *const u8 {
self.ptr
}
/// Get the CUDA device ID where this memory resides.
#[inline]
pub fn device_id(&self) -> u64 {
self.device_id
}
}
impl fmt::Debug for ExternalDeviceMemory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExternalDeviceMemory")
.field("ptr", &format_args!("{:p}", self.ptr))
.field("size", &self.size)
.field("device_id", &self.device_id)
.finish()
}
}
impl MemoryDescriptor for ExternalDeviceMemory {
#[inline]
fn addr(&self) -> usize {
self.ptr as usize
}
#[inline]
fn size(&self) -> usize {
self.size
}
#[inline]
fn storage_kind(&self) -> StorageKind {
StorageKind::Device(self.device_id as u32)
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
// External memory doesn't have a pre-existing NIXL descriptor
// It will be registered and get one via NixlRegistered wrapper
None
}
}
impl NixlCompatible for ExternalDeviceMemory {
fn nixl_params(&self) -> (*const u8, usize, MemType, u64) {
(self.ptr, self.size, MemType::Vram, self.device_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_external_device_memory_traits() {
// Create with a dummy pointer (not actually valid GPU memory)
let ptr = 0x1000 as *const u8;
let size = 1024;
let device_id = 0;
let external = unsafe { ExternalDeviceMemory::new(ptr, size, device_id) };
// Check MemoryDescriptor
assert_eq!(external.addr(), 0x1000);
assert_eq!(external.size(), 1024);
assert_eq!(external.storage_kind(), StorageKind::Device(0));
assert!(external.nixl_descriptor().is_none());
// Check NixlCompatible
let (p, s, mem_type, dev) = external.nixl_params();
assert_eq!(p as usize, 0x1000);
assert_eq!(s, 1024);
assert_eq!(mem_type, MemType::Vram);
assert_eq!(dev, 0);
}
}
...@@ -4,24 +4,34 @@ ...@@ -4,24 +4,34 @@
//! Clean, minimal storage API for v2 block manager. //! Clean, minimal storage API for v2 block manager.
//! //!
//! This module provides a simplified storage abstraction with: //! This module provides a simplified storage abstraction with:
//! - Single trait for type erasure (`MemoryDescription`) //! - Single trait for type erasure (`MemoryDescriptor`)
//! - Concrete storage types (no trait implementations required) //! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper //! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory) //! - RAII with proper drop ordering (registration handle drops before memory)
#![deny(missing_docs)]
pub mod actions; pub mod actions;
pub mod arena; pub mod arena;
pub mod nixl; pub mod nixl;
pub mod numa;
/// Offset-based buffer views into underlying storage.
pub mod offset; pub mod offset;
/// CUDA memory pool utilities.
pub mod pool; pub mod pool;
/// Common imports for working with memory types.
pub mod prelude; pub mod prelude;
mod device; mod device;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
mod disk; mod disk;
mod external;
mod pinned; mod pinned;
mod system; mod system;
mod torch; mod tensor;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
...@@ -30,9 +40,13 @@ pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError}; ...@@ -30,9 +40,13 @@ pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
pub use device::DeviceStorage; pub use device::DeviceStorage;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub use disk::DiskStorage; pub use disk::DiskStorage;
pub use external::ExternalDeviceMemory;
pub use numa::{NumaNode, is_numa_enabled};
pub use offset::OffsetBuffer;
pub use pinned::PinnedStorage; pub use pinned::PinnedStorage;
pub use pool::{CudaMemPool, CudaMemPoolBuilder};
pub use system::SystemStorage; pub use system::SystemStorage;
pub use torch::{TorchDevice, TorchTensor}; pub use tensor::{TensorDescriptor, TensorDescriptorExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::Any; use std::any::Any;
...@@ -43,8 +57,30 @@ use thiserror::Error; ...@@ -43,8 +57,30 @@ use thiserror::Error;
/// Result type for storage operations. /// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>; pub type Result<T> = std::result::Result<T, StorageError>;
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescriptor>`.
pub trait MemoryDescriptor: Send + Sync + fmt::Debug {
/// Base address of the memory region.
fn addr(&self) -> usize;
/// Size of the memory region in bytes.
fn size(&self) -> usize;
/// Type of storage backing this region.
fn storage_kind(&self) -> StorageKind;
/// Enable downcasting to concrete type.
fn as_any(&self) -> &dyn Any;
/// Get the NIXL descriptor for this memory region.
fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
}
/// Errors that can occur during storage operations. /// Errors that can occur during storage operations.
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[allow(missing_docs)]
pub enum StorageError { pub enum StorageError {
#[error("allocation failed: {0}")] #[error("allocation failed: {0}")]
AllocationFailed(String), AllocationFailed(String),
...@@ -87,32 +123,51 @@ pub enum StorageKind { ...@@ -87,32 +123,51 @@ pub enum StorageKind {
Disk(u64), Disk(u64),
} }
/// Core trait for memory regions that can be type-erased. impl StorageKind {
/// /// Returns the CUDA device index if this is device memory.
/// This is the only trait in the storage API. Concrete storage types pub fn cuda_device_index(&self) -> Option<u32> {
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescription>`. match self {
pub trait MemoryDescription: Send + Sync + fmt::Debug { StorageKind::Device(idx) => Some(*idx),
/// Base address of the memory region. _ => None,
fn addr(&self) -> usize; }
}
/// Size of the memory region in bytes. /// Returns true if this is CUDA device memory.
fn size(&self) -> usize; pub fn is_cuda(&self) -> bool {
matches!(self, StorageKind::Device(_))
}
/// Type of storage backing this region. /// Returns true if this is system memory (malloc).
fn storage_kind(&self) -> StorageKind; pub fn is_system(&self) -> bool {
matches!(self, StorageKind::System)
}
/// Enable downcasting to concrete type. /// Returns true if this is CUDA pinned host memory.
fn as_any(&self) -> &dyn Any; pub fn is_pinned(&self) -> bool {
matches!(self, StorageKind::Pinned)
}
/// Get the NIXL descriptor for this memory region. /// Returns true if this is disk-backed memory.
fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>; pub fn is_disk(&self) -> bool {
matches!(self, StorageKind::Disk(_))
}
} }
/// Type-erased memory region for use in layouts. /// Type-erased memory region for use in layouts.
#[derive(Clone)] #[derive(Clone)]
pub struct Buffer(Arc<dyn MemoryDescription>); pub struct Buffer(Arc<dyn MemoryDescriptor>);
impl Buffer {
/// Wraps a concrete storage type into a type-erased [`Buffer`].
///
/// This is the primary way to create a `Buffer` from any type that
/// implements [`MemoryDescriptor`].
pub fn new<S: MemoryDescriptor + 'static>(memory: S) -> Self {
Buffer(Arc::new(memory))
}
}
impl MemoryDescription for Buffer { impl MemoryDescriptor for Buffer {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.0.addr() self.0.addr()
} }
...@@ -131,7 +186,7 @@ impl MemoryDescription for Buffer { ...@@ -131,7 +186,7 @@ impl MemoryDescription for Buffer {
} }
impl std::ops::Deref for Buffer { impl std::ops::Deref for Buffer {
type Target = dyn MemoryDescription; type Target = dyn MemoryDescriptor;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.0.as_ref() self.0.as_ref()
...@@ -149,10 +204,31 @@ impl std::fmt::Debug for Buffer { ...@@ -149,10 +204,31 @@ impl std::fmt::Debug for Buffer {
} }
/// Helper function to convert concrete storage to type-erased form. /// Helper function to convert concrete storage to type-erased form.
pub fn create_buffer<S: MemoryDescription + 'static>(memory: S) -> Buffer { pub fn create_buffer<S: MemoryDescriptor + 'static>(memory: S) -> Buffer {
Buffer(Arc::new(memory)) Buffer(Arc::new(memory))
} }
impl Buffer {
/// Create a Buffer from an existing Arc<dyn MemoryDescriptor>.
pub fn from_arc(arc: Arc<dyn MemoryDescriptor>) -> Self {
Buffer(arc)
}
}
// From implementations for ergonomic Buffer creation
impl From<Arc<dyn MemoryDescriptor>> for Buffer {
fn from(arc: Arc<dyn MemoryDescriptor>) -> Self {
Buffer::from_arc(arc)
}
}
impl From<Arc<dyn nixl::NixlMemory + Send + Sync>> for Buffer {
fn from(arc: Arc<dyn nixl::NixlMemory + Send + Sync>) -> Self {
// Arc<dyn NixlMemory> implements MemoryDescriptor, so we can wrap it
Buffer::new(arc)
}
}
/// An unowned contiguous chunk of memory, not storage specific. /// An unowned contiguous chunk of memory, not storage specific.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryRegion { pub struct MemoryRegion {
...@@ -164,17 +240,62 @@ pub struct MemoryRegion { ...@@ -164,17 +240,62 @@ pub struct MemoryRegion {
} }
impl MemoryRegion { impl MemoryRegion {
/// Creates a new memory region with the given base address and size.
pub fn new(addr: usize, size: usize) -> Self { pub fn new(addr: usize, size: usize) -> Self {
Self { addr, size } Self { addr, size }
} }
/// Returns the base address of this memory region.
#[inline] #[inline]
pub fn addr(&self) -> usize { pub fn addr(&self) -> usize {
self.addr self.addr
} }
/// Returns the size of this memory region in bytes.
#[inline] #[inline]
pub fn size(&self) -> usize { pub fn size(&self) -> usize {
self.size self.size
} }
/// Get a slice view of this memory region.
///
/// # Safety
/// This is unsafe because:
/// - The caller must ensure the memory region is valid and properly initialized
/// - The caller must ensure no mutable references exist to this memory
/// - The caller must ensure the memory remains valid for the lifetime of the slice
#[cfg(feature = "unsafe-slices")]
pub unsafe fn as_slice(&self) -> Result<&[u8]> {
if self.size == 0 {
return Ok(&[]);
}
// SAFETY: Caller guarantees memory is valid
unsafe {
Ok(std::slice::from_raw_parts(
self.addr as *const u8,
self.size,
))
}
}
/// Get a mutable slice view of this memory region.
///
/// # Safety
/// This is unsafe because:
/// - The caller must ensure the memory region is valid and properly initialized
/// - The caller must ensure no other references (mutable or immutable) exist to this memory
/// - The caller must ensure the memory remains valid for the lifetime of the slice
#[cfg(feature = "unsafe-slices")]
pub unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8]> {
if self.size == 0 {
return Ok(&mut []);
}
// SAFETY: Caller guarantees memory is valid and exclusively accessible
unsafe {
Ok(std::slice::from_raw_parts_mut(
self.addr as *mut u8,
self.size,
))
}
}
} }
...@@ -6,14 +6,17 @@ ...@@ -6,14 +6,17 @@
mod agent; mod agent;
mod config; mod config;
use super::{MemoryDescription, StorageKind}; use super::{MemoryDescriptor, StorageKind};
use std::any::Any; use std::any::Any;
use std::fmt; use std::fmt;
use std::sync::Arc;
pub use agent::NixlAgent; pub use agent::NixlAgent;
pub use config::NixlBackendConfig; pub use config::NixlBackendConfig;
pub use nixl_sys::{MemType, OptArgs, RegistrationHandle}; pub use nixl_sys::{
Agent, MemType, NotificationMap, OptArgs, RegistrationHandle, XferDescList, XferOp, XferRequest,
};
pub use serde::{Deserialize, Serialize}; pub use serde::{Deserialize, Serialize};
/// Trait for storage types that can be registered with NIXL. /// Trait for storage types that can be registered with NIXL.
...@@ -24,12 +27,29 @@ pub trait NixlCompatible { ...@@ -24,12 +27,29 @@ pub trait NixlCompatible {
fn nixl_params(&self) -> (*const u8, usize, MemType, u64); fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
} }
/// Combined trait for memory that can be registered with NIXL.
///
/// This supertrait enables type erasure via `Arc<dyn NixlMemory>`.
/// Any type implementing both `MemoryDescriptor` and `NixlCompatible`
/// automatically implements this trait via the blanket implementation.
pub trait NixlMemory: MemoryDescriptor + NixlCompatible {}
// Blanket impl - any type with both traits automatically implements NixlMemory
impl<T: MemoryDescriptor + NixlCompatible + ?Sized> NixlMemory for T {}
/// NIXL descriptor containing registration information. /// NIXL descriptor containing registration information.
///
/// This struct holds the information needed to describe a memory region
/// to NIXL for transfer operations.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NixlDescriptor { pub struct NixlDescriptor {
/// Base address of the memory region.
pub addr: u64, pub addr: u64,
/// Size of the memory region in bytes.
pub size: usize, pub size: usize,
/// Type of memory (host, device, etc.).
pub mem_type: MemType, pub mem_type: MemType,
/// Device identifier (GPU index for device memory, 0 for host memory).
pub device_id: u64, pub device_id: u64,
} }
...@@ -91,7 +111,7 @@ impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> { ...@@ -91,7 +111,7 @@ impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
} }
} }
impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for NixlRegistered<S> { impl<S: MemoryDescriptor + NixlCompatible + 'static> MemoryDescriptor for NixlRegistered<S> {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.storage.addr() self.storage.addr()
} }
...@@ -113,7 +133,7 @@ impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for Nixl ...@@ -113,7 +133,7 @@ impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for Nixl
} }
} }
impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S> { impl<S: MemoryDescriptor + NixlCompatible> RegisteredView for NixlRegistered<S> {
fn agent_name(&self) -> &str { fn agent_name(&self) -> &str {
&self.agent_name &self.agent_name
} }
...@@ -129,7 +149,7 @@ impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S> ...@@ -129,7 +149,7 @@ impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S>
} }
} }
impl<S: MemoryDescription + NixlCompatible> NixlRegistered<S> { impl<S: MemoryDescriptor + NixlCompatible> NixlRegistered<S> {
/// Get a reference to the underlying storage. /// Get a reference to the underlying storage.
pub fn storage(&self) -> &S { pub fn storage(&self) -> &S {
&self.storage &self.storage
...@@ -179,8 +199,42 @@ pub fn register_with_nixl<S>( ...@@ -179,8 +199,42 @@ pub fn register_with_nixl<S>(
opt: Option<&OptArgs>, opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S> ) -> std::result::Result<NixlRegistered<S>, S>
where where
S: MemoryDescription + NixlCompatible, S: MemoryDescriptor + NixlCompatible,
{ {
// let storage_kind = storage.storage_kind();
// // Determine if registration is needed based on storage type and available backends
// let should_register = match storage_kind {
// StorageKind::System | StorageKind::Pinned => {
// // System/Pinned memory needs UCX for remote transfers
// agent.has_backend("UCX") || agent.has_backend("POSIX")
// }
// StorageKind::Device(_) => {
// // Device memory needs UCX for remote transfers OR GDS for direct disk transfers
// agent.has_backend("UCX") || agent.has_backend("GDS_MT")
// }
// StorageKind::Disk(_) => {
// // Disk storage needs POSIX for regular I/O OR GDS for GPU direct I/O
// agent.has_backend("POSIX") || agent.has_backend("GDS_MT")
// } // StorageKind::Object(_) => {
// // // Object storage is always registered via NIXL's OBJ plugin
// // agent.has_backend("OBJ")
// // }
// };
// this is not true for our future object storage. so let's rethink this.
// for object, if there is no device_id or device_id is 0, then we need to register
// alternatively, the object storage holds it's own internal metadata but does not
// expose as a nixl descriptor, thus ObjectStorag will by default like all other storage
// types have a None for nixl_descriptor(), and we will use the internal
if storage.nixl_descriptor().is_some() {
return Ok(NixlRegistered {
storage,
handle: None,
agent_name: agent.name().to_string(),
});
}
// Get NIXL parameters // Get NIXL parameters
let (ptr, size, mem_type, device_id) = storage.nixl_params(); let (ptr, size, mem_type, device_id) = storage.nixl_params();
...@@ -201,3 +255,69 @@ where ...@@ -201,3 +255,69 @@ where
Err(_) => Err(storage), Err(_) => Err(storage),
} }
} }
// =============================================================================
// Arc<dyn NixlMemory> support
// =============================================================================
impl NixlCompatible for Arc<dyn NixlMemory + Send + Sync> {
fn nixl_params(&self) -> (*const u8, usize, MemType, u64) {
(**self).nixl_params()
}
}
impl MemoryDescriptor for Arc<dyn NixlMemory + Send + Sync> {
fn addr(&self) -> usize {
(**self).addr()
}
fn size(&self) -> usize {
(**self).size()
}
fn storage_kind(&self) -> StorageKind {
(**self).storage_kind()
}
fn as_any(&self) -> &dyn Any {
(**self).as_any()
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
(**self).nixl_descriptor()
}
}
// =============================================================================
// Extension trait for ergonomic API
// =============================================================================
/// Extension trait providing ergonomic `.register()` method for NIXL registration.
///
/// This trait is automatically implemented for all types that implement both
/// `MemoryDescriptor` and `NixlCompatible`. Import this trait to use the
/// method syntax:
///
///
pub trait NixlRegisterExt: MemoryDescriptor + NixlCompatible + Sized {
/// Get this memory as NIXL-registered.
///
/// This operation is idempotent - it's a no-op if the memory is already registered.
///
/// # Arguments
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper on success, or the original storage on failure.
fn register(
self,
agent: &NixlAgent,
opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<Self>, Self> {
register_with_nixl(self, agent, opt)
}
}
// Blanket impl for all compatible types
impl<T: MemoryDescriptor + NixlCompatible + Sized> NixlRegisterExt for T {}
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
use anyhow::Result; use anyhow::Result;
use nixl_sys::Agent; use nixl_sys::Agent;
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use crate::nixl::NixlBackendConfig;
/// A NIXL agent wrapper that tracks which backends were successfully initialized. /// A NIXL agent wrapper that tracks which backends were successfully initialized.
/// ///
...@@ -40,26 +42,64 @@ impl NixlAgent { ...@@ -40,26 +42,64 @@ impl NixlAgent {
}) })
} }
/// Add a backend to the agent. /// Creates a new agent configured with backends from the given config.
///
/// This method iterates over all backends in the config and initializes them
/// with their associated parameters. If a backend has custom parameters defined
/// in the config, those are used; otherwise, default plugin parameters are used.
pub fn from_nixl_backend_config(name: &str, config: NixlBackendConfig) -> Result<Self> {
let mut agent = Self::new(name)?;
for (backend, params) in config.iter() {
agent.add_backend_with_params(backend, params)?;
}
Ok(agent)
}
/// Add a backend to the agent with default parameters.
pub fn add_backend(&mut self, backend: &str) -> Result<()> { pub fn add_backend(&mut self, backend: &str) -> Result<()> {
if self.available_backends.contains(&backend.to_uppercase()) { self.add_backend_with_params(backend, &HashMap::new())
}
/// Add a backend to the agent with optional custom parameters.
///
/// If `custom_params` is non-empty, those parameters are used instead of
/// the plugin defaults. If empty, default parameters from the plugin are used.
///
/// # Errors
/// Returns an error if custom parameters are provided (not yet supported until nixl_sys 0.9).
pub fn add_backend_with_params(
&mut self,
backend: &str,
custom_params: &HashMap<String, String>,
) -> Result<()> {
let backend_upper = backend.to_uppercase();
if self.available_backends.contains(&backend_upper) {
return Ok(()); return Ok(());
} }
let backend_upper = backend.to_uppercase();
match self.agent.get_plugin_params(&backend_upper) { // TODO(DIS-1310): Custom params require nixl_sys 0.9+ which adds nixl_capi_params_add
Ok((_, params)) => match self.agent.create_backend(&backend_upper, &params) { if !custom_params.is_empty() {
Ok(_) => { anyhow::bail!(
self.available_backends.insert(backend_upper); "Custom NIXL backend parameters for {} are not yet supported. \
} This feature requires nixl_sys 0.9+. Params provided: {:?}",
Err(e) => { backend_upper,
anyhow::bail!("Failed to create nixl backend: {}", e); custom_params.keys().collect::<Vec<_>>()
} );
}, }
Err(_) => {
anyhow::bail!("No {} plugin found", backend_upper); // Get default params from plugin
let (_, params) = match self.agent.get_plugin_params(&backend_upper) {
Ok(result) => result,
Err(_) => anyhow::bail!("No {} plugin found", backend_upper),
};
match self.agent.create_backend(&backend_upper, &params) {
Ok(_) => {
self.available_backends.insert(backend_upper);
Ok(())
} }
Err(e) => anyhow::bail!("Failed to create nixl backend: {}", e),
} }
Ok(())
} }
/// Create a NIXL agent requiring ALL specified backends to be available. /// Create a NIXL agent requiring ALL specified backends to be available.
...@@ -200,4 +240,59 @@ mod tests { ...@@ -200,4 +240,59 @@ mod tests {
let result = NixlAgent::with_backends("test_strict_fail", &["UCX", "DUDE"]); let result = NixlAgent::with_backends("test_strict_fail", &["UCX", "DUDE"]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test]
fn test_add_backend_with_empty_params() {
let mut agent = NixlAgent::new("test_empty_params").expect("Failed to create agent");
// Empty params should work (uses plugin defaults)
let result = agent.add_backend_with_params("UCX", &HashMap::new());
assert!(result.is_ok());
assert!(agent.has_backend("UCX"));
}
#[test]
fn test_add_backend_with_custom_params_fails() {
let mut agent = NixlAgent::new("test_custom_params").expect("Failed to create agent");
// Custom params should fail until nixl_sys 0.9
let mut params = HashMap::new();
params.insert("some_key".to_string(), "some_value".to_string());
let result = agent.add_backend_with_params("UCX", &params);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("not yet supported"));
assert!(err_msg.contains("nixl_sys 0.9"));
assert!(err_msg.contains("some_key"));
}
#[test]
fn test_from_nixl_backend_config_with_custom_params_fails() {
// Config with custom params should fail
let mut params = HashMap::new();
params.insert("threads".to_string(), "4".to_string());
let config = NixlBackendConfig::default().with_backend_params("UCX", params);
let result = NixlAgent::from_nixl_backend_config("test_config_params", config);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("not yet supported"));
assert!(err_msg.contains("threads"));
}
#[test]
fn test_from_nixl_backend_config_with_empty_params() {
// Config with empty params should work
let config = NixlBackendConfig::default().with_backend("UCX");
let result = NixlAgent::from_nixl_backend_config("test_config_empty", config);
assert!(result.is_ok());
let agent = result.unwrap();
assert!(agent.has_backend("UCX"));
}
} }
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
//! NIXL backend configuration with Figment support. //! NIXL backend configuration with Figment support.
//! //!
//! This module provides configuration extraction for NIXL backends from //! This module provides configuration extraction for NIXL backends from
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>_<key>=<value>` //! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>=<value>`
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use std::collections::HashSet; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use dynamo_config::parse_bool; use dynamo_config::parse_bool;
...@@ -19,16 +20,40 @@ use dynamo_config::parse_bool; ...@@ -19,16 +20,40 @@ use dynamo_config::parse_bool;
/// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive) /// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive)
/// - Invalid values (e.g., "maybe", "random") will cause an error /// - Invalid values (e.g., "maybe", "random") will cause an error
/// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error /// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error
#[derive(Debug, Clone, Default)] ///
/// # Data Structure
///
/// Uses a single HashMap where:
/// - Key presence = backend is enabled
/// - Value (inner HashMap) = backend-specific parameters (empty = defaults)
///
/// # TOML Example
///
/// ```toml
/// [backends.UCX]
/// # UCX with default params (empty map)
///
/// [backends.GDS]
/// threads = "4"
/// buffer_size = "1048576"
/// ```
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NixlBackendConfig { pub struct NixlBackendConfig {
/// Set of enabled backends (just backend names, no custom params yet) /// Map of backend name (uppercase) -> optional parameters.
backends: HashSet<String>, ///
/// If a backend is present in the map, it's enabled.
/// The inner HashMap contains optional override parameters.
/// An empty inner map means use default parameters.
#[serde(default)]
backends: HashMap<String, HashMap<String, String>>,
} }
impl NixlBackendConfig { impl NixlBackendConfig {
/// Create a new empty configuration. /// Creates a new configuration with the given backends.
pub fn new() -> Self { ///
Self::default() /// For an empty configuration with no backends, use [`Default::default()`].
pub fn new(backends: HashMap<String, HashMap<String, String>>) -> Self {
Self { backends }
} }
/// Create configuration from environment variables. /// Create configuration from environment variables.
...@@ -40,7 +65,7 @@ impl NixlBackendConfig { ...@@ -40,7 +65,7 @@ impl NixlBackendConfig {
/// - Custom parameters are detected (not yet supported) /// - Custom parameters are detected (not yet supported)
/// - Invalid boolean values are provided (must be truthy or falsey) /// - Invalid boolean values are provided (must be truthy or falsey)
pub fn from_env() -> Result<Self> { pub fn from_env() -> Result<Self> {
let mut backends = HashSet::new(); let mut backends = HashMap::new();
// Extract all environment variables that match our pattern // Extract all environment variables that match our pattern
for (key, value) in std::env::vars() { for (key, value) in std::env::vars() {
...@@ -59,7 +84,7 @@ impl NixlBackendConfig { ...@@ -59,7 +84,7 @@ impl NixlBackendConfig {
let backend_name = remainder.to_uppercase(); let backend_name = remainder.to_uppercase();
match parse_bool(&value) { match parse_bool(&value) {
Ok(true) => { Ok(true) => {
backends.insert(backend_name); backends.insert(backend_name, HashMap::new());
} }
Ok(false) => { Ok(false) => {
// Explicitly disabled, don't add to backends // Explicitly disabled, don't add to backends
...@@ -70,39 +95,59 @@ impl NixlBackendConfig { ...@@ -70,39 +95,59 @@ impl NixlBackendConfig {
} }
} }
// Default to UCX if no backends specified
if backends.is_empty() {
backends.insert("UCX".to_string());
}
Ok(Self { backends }) Ok(Self { backends })
} }
/// Add a backend to the configuration. /// Add a backend with default parameters.
/// /// Backend name is normalized to uppercase.
/// Backend names will be converted to uppercase for consistency.
pub fn with_backend(mut self, backend: impl Into<String>) -> Self { pub fn with_backend(mut self, backend: impl Into<String>) -> Self {
self.backends.insert(backend.into().to_uppercase()); self.backends
.insert(backend.into().to_uppercase(), HashMap::new());
self self
} }
/// Get the set of enabled backends. /// Add a backend with custom parameters.
pub fn backends(&self) -> &HashSet<String> { /// Backend name is normalized to uppercase.
&self.backends pub fn with_backend_params(
mut self,
backend: impl Into<String>,
params: HashMap<String, String>,
) -> Self {
self.backends.insert(backend.into().to_uppercase(), params);
self
}
/// Get the list of enabled backend names (uppercase).
pub fn backends(&self) -> Vec<String> {
self.backends.keys().cloned().collect()
}
/// Get parameters for a specific backend.
/// Backend name is normalized to uppercase for lookup.
///
/// Returns None if the backend is not enabled.
pub fn backend_params(&self, backend: &str) -> Option<&HashMap<String, String>> {
self.backends.get(&backend.to_uppercase())
} }
/// Check if a specific backend is enabled. /// Check if a specific backend is enabled.
pub fn has_backend(&self, backend: &str) -> bool { pub fn has_backend(&self, backend: &str) -> bool {
self.backends.contains(&backend.to_uppercase()) self.backends.contains_key(&backend.to_uppercase())
} }
/// Merge another configuration into this one. /// Merge another configuration into this one.
/// ///
/// Backends from the other configuration will be added to this one. /// Backends from the other configuration will be added to this one.
/// If both have the same backend, params from `other` take precedence.
pub fn merge(mut self, other: NixlBackendConfig) -> Self { pub fn merge(mut self, other: NixlBackendConfig) -> Self {
self.backends.extend(other.backends); self.backends.extend(other.backends);
self self
} }
/// Iterate over all enabled backends and their parameters.
pub fn iter(&self) -> impl Iterator<Item = (&String, &HashMap<String, String>)> {
self.backends.iter()
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -111,13 +156,19 @@ mod tests { ...@@ -111,13 +156,19 @@ mod tests {
#[test] #[test]
fn test_new_config_is_empty() { fn test_new_config_is_empty() {
let config = NixlBackendConfig::new(); let config = NixlBackendConfig::default();
assert!(config.backends().is_empty()); assert_eq!(config.backends().len(), 0);
}
#[test]
fn test_default_is_empty() {
let config = NixlBackendConfig::default();
assert!(config.backends().is_empty()); // default() has no backends
} }
#[test] #[test]
fn test_with_backend() { fn test_with_backend() {
let config = NixlBackendConfig::new() let config = NixlBackendConfig::default()
.with_backend("ucx") .with_backend("ucx")
.with_backend("gds_mt"); .with_backend("gds_mt");
...@@ -128,10 +179,30 @@ mod tests { ...@@ -128,10 +179,30 @@ mod tests {
assert!(!config.has_backend("other")); assert!(!config.has_backend("other"));
} }
#[test]
fn test_with_backend_params() {
let mut params = HashMap::new();
params.insert("threads".to_string(), "4".to_string());
params.insert("buffer_size".to_string(), "1048576".to_string());
let config = NixlBackendConfig::default()
.with_backend("UCX")
.with_backend_params("GDS", params);
// UCX should have empty params
let ucx_params = config.backend_params("UCX").unwrap();
assert!(ucx_params.is_empty());
// GDS should have custom params
let gds_params = config.backend_params("GDS").unwrap();
assert_eq!(gds_params.get("threads"), Some(&"4".to_string()));
assert_eq!(gds_params.get("buffer_size"), Some(&"1048576".to_string()));
}
#[test] #[test]
fn test_merge_configs() { fn test_merge_configs() {
let config1 = NixlBackendConfig::new().with_backend("ucx"); let config1 = NixlBackendConfig::default().with_backend("ucx");
let config2 = NixlBackendConfig::new().with_backend("gds"); let config2 = NixlBackendConfig::default().with_backend("gds");
let merged = config1.merge(config2); let merged = config1.merge(config2);
...@@ -141,7 +212,7 @@ mod tests { ...@@ -141,7 +212,7 @@ mod tests {
#[test] #[test]
fn test_backend_name_case_insensitive() { fn test_backend_name_case_insensitive() {
let config = NixlBackendConfig::new() let config = NixlBackendConfig::default()
.with_backend("ucx") .with_backend("ucx")
.with_backend("Gds_mt") .with_backend("Gds_mt")
.with_backend("OTHER"); .with_backend("OTHER");
...@@ -154,6 +225,19 @@ mod tests { ...@@ -154,6 +225,19 @@ mod tests {
assert!(config.has_backend("other")); assert!(config.has_backend("other"));
} }
#[test]
fn test_iter() {
let mut params = HashMap::new();
params.insert("key".to_string(), "value".to_string());
let config = NixlBackendConfig::default()
.with_backend("UCX")
.with_backend_params("GDS", params);
let items: Vec<_> = config.iter().collect();
assert_eq!(items.len(), 2);
}
// Note: Testing from_env() would require setting environment variables, // Note: Testing from_env() would require setting environment variables,
// which is challenging in unit tests. This is better tested with integration tests. // which is challenging in unit tests. This is better tested with integration tests.
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA-aware memory allocation utilities.
//!
//! This module provides utilities for NUMA-aware memory allocation, which is critical
//! for optimal performance on multi-socket systems with GPUs. Memory allocated on the
//! NUMA node closest to the target GPU has significantly lower access latency.
//!
//! ## Architecture
//!
//! - [`NumaNode`]: Represents a NUMA node ID
//! - [`topology`]: Reads CPU-to-NUMA mapping from `/sys/devices/system/node`
//! - [`worker_pool`]: Dedicated worker threads pinned to specific NUMA nodes
//!
//! ## Usage
//!
//! NUMA optimization is opt-in via environment variable:
//! ```bash
//! export DYN_KVBM_ENABLE_NUMA=1
//! ```
//!
//! When enabled, pinned memory allocations are routed through NUMA workers
//! that are pinned to the target GPU's NUMA node, ensuring first-touch policy
//! places pages on the correct node.
pub mod topology;
pub mod worker_pool;
use nix::libc;
use serde::{Deserialize, Serialize};
use std::{mem, process::Command};
/// Check if NUMA optimization is enabled via environment variable
///
/// Set `DYN_KVBM_ENABLE_NUMA=1` to enable NUMA-aware allocation.
/// Default: disabled (opt-in)
pub fn is_numa_enabled() -> bool {
std::env::var("DYN_KVBM_ENABLE_NUMA")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
/// Represents a NUMA node identifier.
///
/// NUMA nodes are typically numbered 0, 1, 2, etc. corresponding to physical
/// CPU sockets. Use [`NumaNode::UNKNOWN`] when the node cannot be determined.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NumaNode(pub u32);
impl NumaNode {
/// Sentinel value for unknown NUMA node.
pub const UNKNOWN: NumaNode = NumaNode(u32::MAX);
/// Returns true if this represents an unknown NUMA node.
pub fn is_unknown(&self) -> bool {
self.0 == u32::MAX
}
}
impl std::fmt::Display for NumaNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_unknown() {
write!(f, "UNKNOWN")
} else {
write!(f, "NumaNode({})", self.0)
}
}
}
/// Get the current CPU's NUMA node.
///
/// Uses the Linux `getcpu` syscall to determine which NUMA node the current CPU belongs to.
/// Returns [`NumaNode::UNKNOWN`] if the syscall fails.
pub fn get_current_cpu_numa_node() -> NumaNode {
unsafe {
let mut cpu: libc::c_uint = 0;
let mut node: libc::c_uint = 0;
// getcpu syscall: int getcpu(unsigned *cpu, unsigned *node, struct getcpu_cache *tcache);
let result = libc::syscall(
libc::SYS_getcpu,
&mut cpu,
&mut node,
std::ptr::null_mut::<libc::c_void>(),
);
if result == 0 {
NumaNode(node)
} else {
NumaNode::UNKNOWN
}
}
}
/// Get NUMA node for a GPU device.
///
/// For GPU memory, the NUMA affinity depends on which PCIe bus the GPU is attached to.
/// This is queried via nvidia-smi. Falls back to a heuristic (device_id % 2) if nvidia-smi
/// is unavailable.
///
/// # Arguments
/// * `device_id` - CUDA device index (0, 1, 2, ...)
///
/// # Returns
/// The NUMA node closest to the specified GPU, or a heuristic fallback.
pub fn get_device_numa_node(device_id: u32) -> NumaNode {
// Use nvidia-smi topo to get NUMA ID of nearest CPU
// This directly returns the NUMA node
let output = match Command::new("nvidia-smi")
.args([
"topo",
"--get-numa-id-of-nearby-cpu",
"-i",
&device_id.to_string(),
])
.output()
{
Ok(out) if out.status.success() => out,
_ => {
tracing::warn!("nvidia-smi failed for GPU {}, using heuristic", device_id);
return NumaNode(device_id % 2);
}
};
if let Ok(stdout) = std::str::from_utf8(&output.stdout)
&& let Some(line) = stdout.lines().next()
&& let Some(numa_str) = line.split(':').nth(1)
&& let Ok(node) = numa_str.trim().parse::<u32>()
{
tracing::trace!("GPU {} on NUMA node {}", device_id, node);
return NumaNode(node);
}
tracing::warn!("Failed to get NUMA node for GPU {}", device_id);
NumaNode::UNKNOWN
}
/// Pin the current thread to a specific NUMA node's CPUs.
///
/// This sets the CPU affinity for the calling thread to only run on CPUs
/// belonging to the specified NUMA node. This is critical for ensuring
/// that memory allocations follow the first-touch policy on the correct node.
///
/// # Arguments
/// * `node` - The NUMA node to pin the thread to
///
/// # Errors
/// Returns an error if:
/// - NUMA topology cannot be read
/// - No CPUs are found for the specified node
/// - The `sched_setaffinity` syscall fails
pub fn pin_thread_to_numa_node(node: NumaNode) -> Result<(), String> {
let topology =
topology::get_numa_topology().map_err(|e| format!("Can not get NUMA topology: {}", e))?;
let cpus = topology
.cpus_for_node(node.0)
.ok_or_else(|| format!("No CPUs found for NUMA node {}", node.0))?;
if cpus.is_empty() {
return Err(format!("No CPUs found for NUMA node {}", node.0));
}
unsafe {
let mut cpu_set: libc::cpu_set_t = mem::zeroed();
for cpu in cpus {
libc::CPU_SET(*cpu, &mut cpu_set);
}
let result = libc::sched_setaffinity(
0, // current thread
mem::size_of::<libc::cpu_set_t>(),
&cpu_set,
);
if result != 0 {
let err = std::io::Error::last_os_error();
return Err(format!("Failed to set CPU affinity: {}", err));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numa_node_equality() {
let node0a = NumaNode(0);
let node0b = NumaNode(0);
let node1 = NumaNode(1);
assert_eq!(node0a, node0b);
assert_ne!(node0a, node1);
}
#[test]
fn test_numa_node_unknown() {
let unknown = NumaNode::UNKNOWN;
assert!(unknown.is_unknown());
assert_eq!(unknown.0, u32::MAX);
let valid = NumaNode(0);
assert!(!valid.is_unknown());
}
#[test]
fn test_numa_node_display() {
assert_eq!(format!("{}", NumaNode(0)), "NumaNode(0)");
assert_eq!(format!("{}", NumaNode(7)), "NumaNode(7)");
assert_eq!(format!("{}", NumaNode::UNKNOWN), "UNKNOWN");
}
#[test]
fn test_numa_node_serialization() {
// Verify NumaNode can be serialized (important for benchmarking)
let node = NumaNode(1);
let json = serde_json::to_string(&node).unwrap();
let deserialized: NumaNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, deserialized);
}
#[test]
fn test_get_current_cpu_numa_node() {
// Should either return a valid node or UNKNOWN
let node = get_current_cpu_numa_node();
// If not unknown, should be a reasonable NUMA node number (< 8 on most systems)
if !node.is_unknown() {
assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
}
}
#[test]
fn test_get_device_numa_node_valid_gpu() {
// Test GPU 0 detection
let node = get_device_numa_node(0);
// Should return either a valid node (0-7) or use heuristic (gpu_id % 2)
// On dual-socket systems, GPU 0 typically on node 0 or 1
println!("GPU 0 detected on NUMA node: {}", node.0);
}
#[test]
fn test_numa_node_hash() {
// Verify NumaNode can be used as a HashMap key
use std::collections::HashMap;
let mut map = HashMap::new();
map.insert(NumaNode(0), "node0");
map.insert(NumaNode(1), "node1");
assert_eq!(map.get(&NumaNode(0)), Some(&"node0"));
assert_eq!(map.get(&NumaNode(1)), Some(&"node1"));
assert_eq!(map.get(&NumaNode(2)), None);
}
#[test]
fn test_numa_node_copy_clone() {
// Verify NumaNode is Copy and Clone
let node1 = NumaNode(5);
let node2 = node1; // Copy
let node3 = node1; // Clone
assert_eq!(node1, node2);
assert_eq!(node1, node3);
assert_eq!(node2, node3);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA topology detection
//!
//! This module provides utilities to read the actual CPU-to-NUMA mapping from the system,
//! replacing heuristic assumptions with real topology data.
use std::collections::HashMap;
use std::fs;
/// Global cached topology
static TOPOLOGY: std::sync::OnceLock<Result<NumaTopology, String>> = std::sync::OnceLock::new();
/// Represents the CPU topology for NUMA nodes.
///
/// This struct provides bidirectional lookup between NUMA nodes and CPUs,
/// read from the Linux sysfs interface at `/sys/devices/system/node/`.
pub struct NumaTopology {
/// Maps NUMA node ID -> list of CPU IDs
node_to_cpus: HashMap<u32, Vec<usize>>,
/// Maps CPU ID -> NUMA node ID
cpu_to_node: HashMap<usize, u32>,
}
impl NumaTopology {
/// Read NUMA topology from sysfs.
///
/// Parses `/sys/devices/system/node/node*/cpulist` to build the CPU-to-NUMA mapping.
pub fn from_sysfs() -> Result<Self, String> {
let mut node_to_cpus: HashMap<u32, Vec<usize>> = HashMap::new();
let mut cpu_to_node: HashMap<usize, u32> = HashMap::new();
let node_dir = std::path::Path::new("/sys/devices/system/node");
if !node_dir.exists() {
return Err("Node directory not found".to_string());
}
let entries =
fs::read_dir(node_dir).map_err(|e| format!("Failed to read node directory: {}", e))?;
for entry in entries.flatten() {
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
// Only process "nodeN" directories
if !name.starts_with("node") {
continue;
}
// Extract node number
let node_id: u32 = name[4..]
.parse()
.map_err(|_| format!("Invalid node dir: {}", name))?;
// Read cpulist file
let cpulist_path = path.join("cpulist");
if !cpulist_path.exists() {
continue;
}
let cpulist = fs::read_to_string(&cpulist_path)
.map_err(|e| format!("Failed to read {}: {}", cpulist_path.display(), e))?;
let cpus = parse_cpulist(cpulist.trim())?;
// Populate both maps
for cpu in &cpus {
cpu_to_node.insert(*cpu, node_id);
}
node_to_cpus.insert(node_id, cpus);
}
if node_to_cpus.is_empty() {
return Err("No NUMA nodes found".to_string());
}
Ok(Self {
node_to_cpus,
cpu_to_node,
})
}
/// Returns all CPU IDs belonging to the given NUMA node.
///
/// Returns `None` if the node ID is not in the topology.
pub fn cpus_for_node(&self, node_id: u32) -> Option<&[usize]> {
self.node_to_cpus.get(&node_id).map(|v| v.as_slice())
}
/// Returns the NUMA node ID that contains the given CPU.
///
/// Returns `None` if the CPU ID is not in the topology.
pub fn node_for_cpu(&self, cpu_id: usize) -> Option<u32> {
self.cpu_to_node.get(&cpu_id).copied()
}
/// Returns the number of NUMA nodes in the system.
pub fn num_nodes(&self) -> usize {
self.node_to_cpus.len()
}
/// Returns `true` if this is a single-node (non-NUMA) system.
pub fn is_single_node(&self) -> bool {
self.num_nodes() == 1
}
}
/// Parse Linux cpulist format.
///
/// # Examples
/// - `"0-15"` -> `[0,1,2,...,15]`
/// - `"0,4,8"` -> `[0,4,8]`
/// - `"0-3,8-11"` -> `[0,1,2,3,8,9,10,11]`
fn parse_cpulist(cpulist: &str) -> Result<Vec<usize>, String> {
let mut cpus = Vec::new();
for part in cpulist.split(',') {
if part.contains('-') {
// Range: "0-15"
let range: Vec<&str> = part.split('-').collect();
if range.len() != 2 {
return Err(format!("Invalid range: {}", part));
}
let start: usize = range[0]
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", range[0]))?;
let end: usize = range[1]
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", range[1]))?;
for cpu in start..=end {
cpus.push(cpu);
}
} else {
// Single CPU
let cpu: usize = part
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", part))?;
cpus.push(cpu);
}
}
cpus.sort_unstable();
cpus.dedup();
Ok(cpus)
}
/// Get the global NUMA topology (cached after first call).
///
/// Returns an error if NUMA topology cannot be read from sysfs. This indicates either:
/// - System doesn't support NUMA
/// - `/sys` is not mounted (e.g., restricted container)
/// - Kernel NUMA support is disabled
///
/// Callers should handle errors gracefully by disabling NUMA optimizations.
pub fn get_numa_topology() -> Result<&'static NumaTopology, &'static str> {
TOPOLOGY
.get_or_init(NumaTopology::from_sysfs)
.as_ref()
.map_err(|e| {
tracing::warn!("NUMA topology unavailable: {}", e);
"NUMA topology unavailable"
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cpulist_range() {
let cpus = parse_cpulist("0-3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3]);
}
#[test]
fn test_parse_cpulist_list() {
let cpus = parse_cpulist("0,4,8").unwrap();
assert_eq!(cpus, vec![0, 4, 8]);
}
#[test]
fn test_parse_cpulist_mixed() {
let cpus = parse_cpulist("0-2,8,16-17").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 8, 16, 17]);
}
#[test]
fn test_parse_cpulist_ht() {
// Hyperthreading: 0-15,32-47 (physical cores 0-15, HT siblings 32-47)
let cpus = parse_cpulist("0-15,32-47").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 0);
assert_eq!(cpus[15], 15);
assert_eq!(cpus[16], 32);
assert_eq!(cpus[31], 47);
}
#[test]
fn test_parse_cpulist_real_numa_system() {
// Real dual-socket system with hyperthreading (discovered pattern)
// Node 0: CPUs 0-15, 128-143
let cpus = parse_cpulist("0-15,128-143").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 0);
assert_eq!(cpus[15], 15);
assert_eq!(cpus[16], 128);
assert_eq!(cpus[31], 143);
// Node 1: CPUs 16-31, 144-159
let cpus = parse_cpulist("16-31,144-159").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 16);
assert_eq!(cpus[15], 31);
assert_eq!(cpus[16], 144);
assert_eq!(cpus[31], 159);
}
#[test]
fn test_parse_cpulist_out_of_order() {
// Test that parser handles out-of-order input (seen in some systems)
let cpus = parse_cpulist("4,2,0,1,3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3, 4]); // Should be sorted
}
#[test]
fn test_parse_cpulist_duplicates() {
// Test deduplication (in case kernel reports duplicates)
let cpus = parse_cpulist("0-2,1-3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3]); // Should remove duplicates
}
#[test]
fn test_parse_cpulist_empty() {
// Edge case: empty cpulist
let result = parse_cpulist("");
assert!(result.is_err() || result.unwrap().is_empty());
}
#[test]
fn test_parse_cpulist_single_cpu() {
// Single CPU node (uncommon but valid)
let cpus = parse_cpulist("5").unwrap();
assert_eq!(cpus, vec![5]);
}
#[test]
fn test_topology_bidirectional_lookup() {
// Test that node->cpu and cpu->node mappings are consistent
let mut node_to_cpus = std::collections::HashMap::new();
let mut cpu_to_node = std::collections::HashMap::new();
node_to_cpus.insert(0, vec![0, 1, 2, 3]);
node_to_cpus.insert(1, vec![4, 5, 6, 7]);
for (node, cpus) in &node_to_cpus {
for cpu in cpus {
cpu_to_node.insert(*cpu, *node);
}
}
let topology = NumaTopology {
node_to_cpus,
cpu_to_node,
};
// Verify forward lookup (node -> cpus)
assert_eq!(topology.cpus_for_node(0), Some(&[0, 1, 2, 3][..]));
assert_eq!(topology.cpus_for_node(1), Some(&[4, 5, 6, 7][..]));
// Verify reverse lookup (cpu -> node)
assert_eq!(topology.node_for_cpu(0), Some(0));
assert_eq!(topology.node_for_cpu(3), Some(0));
assert_eq!(topology.node_for_cpu(4), Some(1));
assert_eq!(topology.node_for_cpu(7), Some(1));
// Verify unknown CPU
assert_eq!(topology.node_for_cpu(999), None);
}
}
This diff is collapsed.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{ use super::{
Any, Buffer, MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor, Any, Buffer, MemoryDescriptor, Result, StorageError, StorageKind, nixl::NixlDescriptor,
}; };
/// An [`OffsetBuffer`] is a new [`Buffer`]-like object that represents a sub-region (still contiguous) /// An [`OffsetBuffer`] is a new [`Buffer`]-like object that represents a sub-region (still contiguous)
...@@ -40,6 +40,26 @@ impl OffsetBuffer { ...@@ -40,6 +40,26 @@ impl OffsetBuffer {
Ok(Self { base, offset, size }) Ok(Self { base, offset, size })
} }
/// Creates an offset buffer from an absolute address within the base region.
pub fn from_inner_address(base: Buffer, address: usize, size: usize) -> Result<Self> {
// Use checked arithmetic to prevent overflow
let end = address
.checked_add(size)
.ok_or_else(|| StorageError::Unsupported("address + size overflow".into()))?;
let base_end = base
.addr()
.checked_add(base.size())
.ok_or_else(|| StorageError::Unsupported("base address + size overflow".into()))?;
// Verify address is within the base region
if address < base.addr() || end > base_end {
return Err(StorageError::Unsupported("address out of bounds".into()));
}
let offset = address - base.addr();
Self::new(base, offset, size)
}
/// Get the offset relative to the base mapping. /// Get the offset relative to the base mapping.
pub fn offset(&self) -> usize { pub fn offset(&self) -> usize {
self.offset self.offset
...@@ -51,7 +71,7 @@ impl OffsetBuffer { ...@@ -51,7 +71,7 @@ impl OffsetBuffer {
} }
} }
impl MemoryDescription for OffsetBuffer { impl MemoryDescriptor for OffsetBuffer {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.base.addr() + self.offset self.base.addr() + self.offset
} }
...@@ -75,3 +95,189 @@ impl MemoryDescription for OffsetBuffer { ...@@ -75,3 +95,189 @@ impl MemoryDescription for OffsetBuffer {
Some(descriptor) Some(descriptor)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::SystemStorage;
fn create_test_buffer(size: usize) -> Buffer {
Buffer::new(SystemStorage::new(size).expect("allocation failed"))
}
#[test]
fn test_offset_buffer_new_valid() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
assert_eq!(offset_buf.offset(), 100);
assert_eq!(offset_buf.size(), 200);
}
#[test]
fn test_offset_buffer_new_zero_offset() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base.clone(), 0, 1024).expect("should succeed");
assert_eq!(offset_buf.offset(), 0);
assert_eq!(offset_buf.size(), 1024);
assert_eq!(offset_buf.addr(), base.addr());
}
#[test]
fn test_offset_buffer_new_at_end() {
let base = create_test_buffer(1024);
// Offset at exact end with zero size should succeed
let offset_buf = OffsetBuffer::new(base, 1024, 0).expect("should succeed");
assert_eq!(offset_buf.offset(), 1024);
assert_eq!(offset_buf.size(), 0);
}
#[test]
fn test_offset_buffer_new_invalid_offset() {
let base = create_test_buffer(1024);
// Offset beyond bounds
let result = OffsetBuffer::new(base, 1025, 0);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_new_invalid_size() {
let base = create_test_buffer(1024);
// Size exceeds remaining space
let result = OffsetBuffer::new(base, 100, 1000);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_new_size_overflow() {
let base = create_test_buffer(1024);
// offset + size would overflow usize
let result = OffsetBuffer::new(base, usize::MAX, 1);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_from_inner_address_valid() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
let offset_buf =
OffsetBuffer::from_inner_address(base, base_addr + 100, 200).expect("should succeed");
assert_eq!(offset_buf.offset(), 100);
assert_eq!(offset_buf.size(), 200);
}
#[test]
fn test_offset_buffer_from_inner_address_at_start() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
let offset_buf = OffsetBuffer::from_inner_address(base.clone(), base_addr, 1024)
.expect("should succeed");
assert_eq!(offset_buf.offset(), 0);
assert_eq!(offset_buf.addr(), base.addr());
}
#[test]
fn test_offset_buffer_from_inner_address_overflow() {
let base = create_test_buffer(1024);
// address + size would overflow
let result = OffsetBuffer::from_inner_address(base, usize::MAX, 1);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_from_inner_address_out_of_bounds_before() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
// Address before base region
let result = OffsetBuffer::from_inner_address(base, base_addr.saturating_sub(1), 100);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_from_inner_address_out_of_bounds_after() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
// End address beyond base region
let result = OffsetBuffer::from_inner_address(base, base_addr + 900, 200);
assert!(result.is_err());
}
#[test]
fn test_offset_buffer_accessors() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
let offset_buf = OffsetBuffer::new(base, 256, 512).expect("should succeed");
assert_eq!(offset_buf.offset(), 256);
assert_eq!(offset_buf.base().addr(), base_addr);
assert_eq!(offset_buf.base().size(), 1024);
}
#[test]
fn test_offset_buffer_memory_descriptor_addr() {
let base = create_test_buffer(1024);
let base_addr = base.addr();
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
// addr() should return base_addr + offset
assert_eq!(offset_buf.addr(), base_addr + 100);
}
#[test]
fn test_offset_buffer_memory_descriptor_size() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
assert_eq!(offset_buf.size(), 200);
}
#[test]
fn test_offset_buffer_memory_descriptor_storage_kind() {
let base = create_test_buffer(1024);
let base_kind = base.storage_kind();
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
// storage_kind should match the base
assert_eq!(offset_buf.storage_kind(), base_kind);
}
#[test]
fn test_offset_buffer_as_any() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
// Should be able to downcast to OffsetBuffer
let any_ref = offset_buf.as_any();
assert!(any_ref.downcast_ref::<OffsetBuffer>().is_some());
}
#[test]
fn test_offset_buffer_clone() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
let cloned = offset_buf.clone();
assert_eq!(offset_buf.addr(), cloned.addr());
assert_eq!(offset_buf.size(), cloned.size());
assert_eq!(offset_buf.offset(), cloned.offset());
}
#[test]
fn test_offset_buffer_debug() {
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
let debug_str = format!("{:?}", offset_buf);
assert!(debug_str.contains("OffsetBuffer"));
assert!(debug_str.contains("offset"));
assert!(debug_str.contains("size"));
}
#[test]
fn test_offset_buffer_nixl_descriptor_none() {
// SystemStorage doesn't have a NIXL descriptor
let base = create_test_buffer(1024);
let offset_buf = OffsetBuffer::new(base, 100, 200).expect("should succeed");
// Should return None since base has no NIXL descriptor
assert!(offset_buf.nixl_descriptor().is_none());
}
}
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//! CUDA pinned host memory storage. //! CUDA pinned host memory storage.
use super::{MemoryDescription, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor}; use super::{MemoryDescriptor, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
use cudarc::driver::CudaContext; use cudarc::driver::CudaContext;
use cudarc::driver::sys; use cudarc::driver::sys;
use std::any::Any; use std::any::Any;
...@@ -27,8 +27,11 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> { ...@@ -27,8 +27,11 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
/// CUDA pinned host memory allocated via cudaHostAlloc. /// CUDA pinned host memory allocated via cudaHostAlloc.
#[derive(Debug)] #[derive(Debug)]
pub struct PinnedStorage { pub struct PinnedStorage {
/// Host pointer to the pinned memory.
ptr: usize, ptr: usize,
/// Size of the allocation in bytes.
len: usize, len: usize,
/// CUDA context used for allocation and deallocation.
ctx: Arc<CudaContext>, ctx: Arc<CudaContext>,
} }
...@@ -38,29 +41,73 @@ unsafe impl Sync for PinnedStorage {} ...@@ -38,29 +41,73 @@ unsafe impl Sync for PinnedStorage {}
impl PinnedStorage { impl PinnedStorage {
/// Allocate new pinned memory of the given size. /// Allocate new pinned memory of the given size.
/// ///
/// This is a convenience method that calls `new_for_device(len, None)`.
///
/// # Arguments /// # Arguments
/// * `len` - Size in bytes to allocate /// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device to associate with the allocation
pub fn new(len: usize) -> Result<Self> { pub fn new(len: usize) -> Result<Self> {
Self::new_for_device(len, None)
}
/// Allocate pinned memory, optionally NUMA-aware for a specific GPU.
///
/// When `device_id` is `Some`, the allocation is performed on a worker thread
/// pinned to the GPU's NUMA node, ensuring optimal memory placement via
/// first-touch policy, However, NUMA is only used if enabled via the
/// `DYN_KVBM_ENABLE_NUMA=1` environment variable.
///
/// When `device_id` is `None`, a direct allocation is performed on device 0.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - If Some, use NUMA-aware allocation on the GPU's NUMA node
///
/// # Errors
/// Returns an error if:
/// - `len` is 0
/// - CUDA context creation fails
/// - Memory allocation fails
pub fn new_for_device(len: usize, device_id: Option<u32>) -> Result<Self> {
use super::numa;
if len == 0 { if len == 0 {
return Err(StorageError::AllocationFailed( return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(), "zero-sized allocations are not supported".into(),
)); ));
} }
let ctx = cuda_context(0)?; let gpu_id = device_id.unwrap_or(0);
let ptr = unsafe { let ctx = cuda_context(gpu_id)?;
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = match device_id {
let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_WRITECOMBINED) Some(gpu_id) if numa::is_numa_enabled() => {
.map_err(StorageError::Cuda)?; // NUMA-aware allocation via worker pool
tracing::debug!(
let ptr = ptr as *mut u8; "Using NUMA-aware allocation for {} bytes on GPU {}",
assert!(!ptr.is_null(), "Failed to allocate pinned memory"); len,
assert!(ptr.is_aligned(), "Pinned memory is not aligned"); gpu_id
assert!(len < isize::MAX as usize); );
numa::worker_pool::NumaWorkerPool::global()
ptr as usize .allocate_pinned_for_gpu(len, gpu_id)
.map_err(StorageError::AllocationFailed)? as usize
}
_ => {
// Direct allocation (no NUMA or device_id not specified)
unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr =
cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_DEVICEMAP)
.map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(len < isize::MAX as usize);
ptr as usize
}
}
}; };
Ok(Self { ptr, len, ctx }) Ok(Self { ptr, len, ctx })
...@@ -97,7 +144,7 @@ impl Drop for PinnedStorage { ...@@ -97,7 +144,7 @@ impl Drop for PinnedStorage {
} }
} }
impl MemoryDescription for PinnedStorage { impl MemoryDescriptor for PinnedStorage {
fn addr(&self) -> usize { fn addr(&self) -> usize {
unsafe { self.as_ptr() as usize } unsafe { self.as_ptr() as usize }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! CUDA memory pool for efficient device memory allocation in hot paths. //! CUDA memory pool for efficient device memory allocation in hot paths.
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
//! This module provides a safe wrapper around CUDA's memory pool APIs, enabling //! This module provides a safe wrapper around CUDA's memory pool APIs, enabling
//! fast async allocations that avoid the overhead of cudaMalloc/cudaFree per call. //! fast async allocations that avoid the overhead of cudaMalloc/cudaFree per call.
//! Memory is returned to the pool on free and reused for subsequent allocations. //! Memory is returned to the pool on free and reused for subsequent allocations.
//!
//! # Thread Safety
//!
//! [`CudaMemPool`] uses internal locking to serialize host-side calls to the CUDA
//! driver. This is required because `cuMemAllocFromPoolAsync` is not host-thread
//! reentrant. The GPU-side operations remain stream-ordered and asynchronous.
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use cudarc::driver::sys::{ use cudarc::driver::sys::{
......
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Memory pool for efficient device memory allocation in hot paths. //! Memory pool for efficient device memory allocation in hot paths.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub use super::MemoryDescription; pub use super::MemoryDescriptor;
pub use super::nixl::{NixlCompatible, NixlMemory, NixlRegisterExt, RegisteredView};
pub use super::nixl::{NixlCompatible, RegisteredView}; pub use super::tensor::{TensorDescriptor, TensorDescriptorExt};
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//! System memory storage backed by malloc. //! System memory storage backed by malloc.
use super::{MemoryDescription, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor}; use super::{MemoryDescriptor, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
use std::any::Any; use std::any::Any;
use std::ptr::NonNull; use std::ptr::NonNull;
...@@ -82,7 +82,7 @@ impl Drop for SystemStorage { ...@@ -82,7 +82,7 @@ impl Drop for SystemStorage {
} }
} }
impl MemoryDescription for SystemStorage { impl MemoryDescriptor for SystemStorage {
fn addr(&self) -> usize { fn addr(&self) -> usize {
self.ptr.as_ptr() as usize self.ptr.as_ptr() as usize
} }
...@@ -139,3 +139,130 @@ impl actions::Slice for SystemStorage { ...@@ -139,3 +139,130 @@ impl actions::Slice for SystemStorage {
Ok(unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }) Ok(unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) })
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::actions::{Memset, Slice};
#[test]
fn test_system_storage_new() {
let storage = SystemStorage::new(1024).expect("allocation should succeed");
assert_eq!(storage.size(), 1024);
assert!(storage.addr() != 0);
}
#[test]
fn test_system_storage_zero_size_fails() {
let result = SystemStorage::new(0);
assert!(result.is_err());
}
#[test]
fn test_system_storage_storage_kind() {
let storage = SystemStorage::new(1024).unwrap();
assert_eq!(storage.storage_kind(), StorageKind::System);
}
#[test]
fn test_system_storage_as_any() {
let storage = SystemStorage::new(1024).unwrap();
let any = storage.as_any();
assert!(any.downcast_ref::<SystemStorage>().is_some());
}
#[test]
fn test_system_storage_nixl_descriptor() {
let storage = SystemStorage::new(1024).unwrap();
// Unregistered storage has no NIXL descriptor
assert!(storage.nixl_descriptor().is_none());
}
#[test]
fn test_system_storage_as_ptr() {
let storage = SystemStorage::new(1024).unwrap();
unsafe {
let ptr = storage.as_ptr();
assert!(!ptr.is_null());
assert_eq!(ptr as usize, storage.addr());
}
}
#[test]
fn test_system_storage_as_mut_ptr() {
let mut storage = SystemStorage::new(1024).unwrap();
unsafe {
let ptr = storage.as_mut_ptr();
assert!(!ptr.is_null());
assert_eq!(ptr as usize, storage.addr());
// Write and read back to verify the pointer works
*ptr = 0xAB;
assert_eq!(*ptr, 0xAB);
}
}
#[test]
fn test_system_storage_zero_initialized() {
let storage = SystemStorage::new(1024).unwrap();
unsafe {
let slice = storage.as_slice().unwrap();
// Memory should be zero-initialized
assert!(slice.iter().all(|&b| b == 0));
}
}
#[test]
fn test_system_storage_memset_and_read() {
let mut storage = SystemStorage::new(1024).unwrap();
storage.memset(0xCD, 0, 1024).unwrap();
unsafe {
let slice = storage.as_slice().unwrap();
assert!(slice.iter().all(|&b| b == 0xCD));
}
}
#[test]
fn test_system_storage_multiple_allocations_independent() {
let storage1 = SystemStorage::new(512).unwrap();
let storage2 = SystemStorage::new(512).unwrap();
// Different allocations should have different addresses
assert_ne!(storage1.addr(), storage2.addr());
}
#[test]
fn test_system_storage_alignment() {
let storage = SystemStorage::new(1024).unwrap();
// posix_memalign allocates with 4096-byte alignment
assert!(storage.addr().is_multiple_of(4096));
}
#[test]
fn test_system_storage_nixl_compatible() {
use crate::nixl::NixlCompatible;
let storage = SystemStorage::new(2048).unwrap();
let (ptr, size, mem_type, device_id) = storage.nixl_params();
assert_eq!(ptr as usize, storage.addr());
assert_eq!(size, 2048);
assert_eq!(mem_type, nixl_sys::MemType::Dram);
assert_eq!(device_id, 0);
}
#[test]
fn test_system_storage_large_allocation() {
// Allocate 1MB to test larger sizes
let storage = SystemStorage::new(1024 * 1024).unwrap();
assert_eq!(storage.size(), 1024 * 1024);
}
#[test]
fn test_system_storage_debug() {
let storage = SystemStorage::new(1024).unwrap();
let debug_str = format!("{:?}", storage);
assert!(debug_str.contains("SystemStorage"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment