Unverified Commit 86892a05 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

refactor: dedup memory module with dynamo-memory; re-exports preserve current t… (#6345)


Co-authored-by: default avatarRyan Olson <rolson@nvidia.com>
Co-authored-by: default avatarOlga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
parent 8245633a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod topology; //! Re-export NUMA utilities from dynamo-memory.
pub mod worker_pool; pub use dynamo_memory::numa::*;
use nix::libc;
use serde::{Deserialize, Serialize};
use std::{mem, process::Command};
/// Check if NUMA optimization is enabled via environment variable
///
/// Set `DYN_KVBM_ENABLE_NUMA=1` to enable NUMA-aware allocation.
/// Default: disabled (opt-in)
pub fn is_numa_enabled() -> bool {
std::env::var("DYN_KVBM_ENABLE_NUMA")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NumaNode(pub u32);
impl NumaNode {
pub const UNKNOWN: NumaNode = NumaNode(u32::MAX);
pub fn is_unknown(&self) -> bool {
self.0 == u32::MAX
}
}
impl std::fmt::Display for NumaNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_unknown() {
write!(f, "UNKNOWN")
} else {
write!(f, "NumaNode({})", self.0)
}
}
}
/// Get the current CPU's NUMA node
///
/// Uses the Linux `getcpu` syscall to determine which NUMA node the current CPU belongs to.
/// Returns `NumaNode::UNKNOWN` if the syscall fails.
pub fn get_current_cpu_numa_node() -> NumaNode {
unsafe {
let mut cpu: libc::c_uint = 0;
let mut node: libc::c_uint = 0;
// getcpu syscall: int getcpu(unsigned *cpu, unsigned *node, struct getcpu_cache *tcache);
let result = libc::syscall(
libc::SYS_getcpu,
&mut cpu,
&mut node,
std::ptr::null_mut::<libc::c_void>(),
);
if result == 0 {
NumaNode(node)
} else {
NumaNode::UNKNOWN
}
}
}
/// Get NUMA node for device (GPU) memory
///
/// For GPU memory, the NUMA affinity depends on which PCIe bus the GPU is attached to.
/// This can be queried via nvidia-smi.
pub fn get_device_numa_node(device_id: u32) -> NumaNode {
// Use nvidia-smi topo to get NUMA ID of nearest CPU
// This directly returns the NUMA node
let output = match Command::new("nvidia-smi")
.args([
"topo",
"--get-numa-id-of-nearby-cpu",
"-i",
&device_id.to_string(),
])
.output()
{
Ok(out) if out.status.success() => out,
_ => {
tracing::warn!("nvidia-smi failed for GPU {}, using heuristic", device_id);
return NumaNode(device_id % 2);
}
};
if let Ok(stdout) = std::str::from_utf8(&output.stdout)
&& let Some(line) = stdout.lines().next()
&& let Some(numa_str) = line.split(':').nth(1)
&& let Ok(node) = numa_str.trim().parse::<u32>()
{
tracing::trace!("GPU {} on NUMA node {}", device_id, node);
return NumaNode(node);
}
tracing::warn!("Failed to get NUMA node for GPU {}", device_id);
NumaNode::UNKNOWN
}
/// Pin the current thread to a specific NUMA node's CPUs
///
/// This sets the CPU affinity for the calling thread to only run on CPUs
/// belonging to the specified NUMA node. This is critical for ensuring
/// that memory allocations follow the first-touch policy on the correct node.
pub fn pin_thread_to_numa_node(node: NumaNode) -> Result<(), String> {
let topology =
topology::get_numa_topology().map_err(|e| format!("Can not get NUMA topology: {}", e))?;
let cpus = topology
.cpus_for_node(node.0)
.ok_or_else(|| format!("No CPUs found for NUMA node {}", node.0))?;
if cpus.is_empty() {
return Err(format!("No CPUs found for NUMA node {}", node.0));
}
unsafe {
let mut cpu_set: libc::cpu_set_t = mem::zeroed();
for cpu in cpus {
libc::CPU_SET(*cpu, &mut cpu_set);
}
let result = libc::sched_setaffinity(
0, // current thread
mem::size_of::<libc::cpu_set_t>(),
&cpu_set,
);
if result != 0 {
let err = std::io::Error::last_os_error();
return Err(format!("Failed to set CPU affinity: {}", err));
}
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// ── NumaNode tests ──────────────────────────────────────────────────
#[test] #[test]
fn test_numa_node_equality() { fn test_numa_node_equality() {
let node0a = NumaNode(0); let node0a = NumaNode(0);
...@@ -170,37 +39,14 @@ mod tests { ...@@ -170,37 +39,14 @@ mod tests {
#[test] #[test]
fn test_numa_node_serialization() { fn test_numa_node_serialization() {
// Verify NumaNode can be serialized (important for benchmarking)
let node = NumaNode(1); let node = NumaNode(1);
let json = serde_json::to_string(&node).unwrap(); let json = serde_json::to_string(&node).unwrap();
let deserialized: NumaNode = serde_json::from_str(&json).unwrap(); let deserialized: NumaNode = serde_json::from_str(&json).unwrap();
assert_eq!(node, deserialized); assert_eq!(node, deserialized);
} }
#[test]
fn test_get_current_cpu_numa_node() {
// Should either return a valid node or UNKNOWN
let node = get_current_cpu_numa_node();
// If not unknown, should be a reasonable NUMA node number (< 8 on most systems)
if !node.is_unknown() {
assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
}
}
#[test]
fn test_get_device_numa_node_valid_gpu() {
// Test GPU 0 detection
let node = get_device_numa_node(0);
// Should return either a valid node (0-7) or use heuristic (gpu_id % 2)
// On dual-socket systems, GPU 0 typically on node 0 or 1
println!("GPU 0 detected on NUMA node: {}", node.0);
}
#[test] #[test]
fn test_numa_node_hash() { fn test_numa_node_hash() {
// Verify NumaNode can be used as a HashMap key
use std::collections::HashMap; use std::collections::HashMap;
let mut map = HashMap::new(); let mut map = HashMap::new();
...@@ -214,7 +60,6 @@ mod tests { ...@@ -214,7 +60,6 @@ mod tests {
#[test] #[test]
fn test_numa_node_copy_clone() { fn test_numa_node_copy_clone() {
// Verify NumaNode is Copy and Clone
let node1 = NumaNode(5); let node1 = NumaNode(5);
let node2 = node1; // Copy let node2 = node1; // Copy
let node3 = node1; // Clone let node3 = node1; // Clone
...@@ -223,4 +68,111 @@ mod tests { ...@@ -223,4 +68,111 @@ mod tests {
assert_eq!(node1, node3); assert_eq!(node1, node3);
assert_eq!(node2, node3); assert_eq!(node2, node3);
} }
// ── System detection tests ──────────────────────────────────────────
#[test]
fn test_get_current_cpu_numa_node() {
let node = get_current_cpu_numa_node();
if !node.is_unknown() {
assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
}
}
#[test]
fn test_get_device_numa_node_valid_gpu() {
let node = get_device_numa_node(0);
println!("GPU 0 detected on NUMA node: {}", node.0);
}
// ── Worker pool tests ───────────────────────────────────────────────
//
// NumaWorker and NumaWorkerPool::new() are private in dynamo-memory,
// so these tests go through the public NumaWorkerPool::global() API.
/// Check if CUDA is available for testing
fn is_cuda_available() -> bool {
if std::process::Command::new("nvidia-smi")
.arg("--query-gpu=count")
.arg("--format=csv,noheader")
.output()
.is_err()
{
return false;
}
crate::block_manager::storage::cuda::Cuda::device_or_create(0).is_ok()
}
#[test]
fn test_worker_pool_singleton() {
let pool1 = worker_pool::NumaWorkerPool::global();
let pool2 = worker_pool::NumaWorkerPool::global();
assert!(std::ptr::eq(pool1, pool2));
}
#[test]
fn test_worker_pool_allocate() {
if !is_cuda_available() {
eprintln!("Skipping test_worker_pool_allocate: CUDA not available");
return;
}
let pool = worker_pool::NumaWorkerPool::global();
unsafe {
let ptr = pool.allocate_pinned_for_gpu(8192, 0).unwrap();
assert!(!ptr.is_null());
cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
}
}
#[test]
fn test_worker_pool_reuse() {
if !is_cuda_available() {
eprintln!("Skipping test_worker_pool_reuse: CUDA not available");
return;
}
let pool = worker_pool::NumaWorkerPool::global();
unsafe {
let ptr1 = pool.allocate_pinned_for_gpu(1024, 0).unwrap();
let ptr2 = pool.allocate_pinned_for_gpu(1024, 0).unwrap();
assert!(!ptr1.is_null());
assert!(!ptr2.is_null());
assert_ne!(ptr1, ptr2);
cudarc::driver::result::free_host(ptr1 as *mut std::ffi::c_void).unwrap();
cudarc::driver::result::free_host(ptr2 as *mut std::ffi::c_void).unwrap();
}
}
#[test]
fn test_zero_size_allocation() {
if !is_cuda_available() {
eprintln!("Skipping test_zero_size_allocation: CUDA not available");
return;
}
let pool = worker_pool::NumaWorkerPool::global();
let result = pool.allocate_pinned_for_gpu(0, 0);
assert!(result.is_err());
assert!(result.unwrap_err().contains("zero"));
}
#[test]
fn test_pinned_allocation_api() {
let pool = worker_pool::NumaWorkerPool::global();
unsafe {
if let Ok(ptr) = pool.allocate_pinned_for_gpu(1024, 0) {
assert!(!ptr.is_null());
cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
}
}
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA topology detection
//!
//! This module provides utilities to read the actual CPU-to-NUMA mapping from the system,
//! replacing heuristic assumptions with real topology data.
use std::collections::HashMap;
use std::fs;
/// Global cached topology
static TOPOLOGY: std::sync::OnceLock<Result<NumaTopology, String>> = std::sync::OnceLock::new();
/// Represents the CPU topology for NUMA nodes
pub struct NumaTopology {
/// Maps NUMA node ID -> list of CPU IDs
node_to_cpus: HashMap<u32, Vec<usize>>,
/// Maps CPU ID -> NUMA node ID
cpu_to_node: HashMap<usize, u32>,
}
impl NumaTopology {
/// Read NUMA topology from sysfs
pub fn from_sysfs() -> Result<Self, String> {
let mut node_to_cpus: HashMap<u32, Vec<usize>> = HashMap::new();
let mut cpu_to_node: HashMap<usize, u32> = HashMap::new();
// TODO: Read /sys/devices/system/node directory
let node_dir = std::path::Path::new("/sys/devices/system/node");
if !node_dir.exists() {
return Err("Node directory not found".to_string());
}
let entries =
fs::read_dir(node_dir).map_err(|e| format!("Failed to read node directory: {}", e))?;
for entry in entries.flatten() {
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
// Only process "nodeN" directories
if !name.starts_with("node") {
continue;
}
// Extract node number
let node_id: u32 = name[4..]
.parse()
.map_err(|_| format!("Invalid node dir: {}", name))?;
// Read cpulist file
let cpulist_path = path.join("cpulist");
if !cpulist_path.exists() {
continue;
}
let cpulist = fs::read_to_string(&cpulist_path)
.map_err(|e| format!("Failed to read {}: {}", cpulist_path.display(), e))?;
let cpus = parse_cpulist(cpulist.trim())?;
// Populate both maps
for cpu in &cpus {
cpu_to_node.insert(*cpu, node_id);
}
node_to_cpus.insert(node_id, cpus);
}
if node_to_cpus.is_empty() {
return Err("No NUMA nodes found".to_string());
}
Ok(Self {
node_to_cpus,
cpu_to_node,
})
}
/// Get all CPUs for a NUMA node
pub fn cpus_for_node(&self, node_id: u32) -> Option<&[usize]> {
self.node_to_cpus.get(&node_id).map(|v| v.as_slice())
}
/// Get NUMA node for a CPU
pub fn node_for_cpu(&self, cpu_id: usize) -> Option<u32> {
self.cpu_to_node.get(&cpu_id).copied()
}
/// Get number of NUMA nodes
pub fn num_nodes(&self) -> usize {
self.node_to_cpus.len()
}
/// Check if single-node system
pub fn is_single_node(&self) -> bool {
self.num_nodes() == 1
}
}
/// Parse Linux cpulist format
/// Examples:
/// "0-15" -> [0,1,2,...,15]
/// "0,4,8" -> [0,4,8]
/// "0-3,8-11" -> [0,1,2,3,8,9,10,11]
fn parse_cpulist(cpulist: &str) -> Result<Vec<usize>, String> {
let mut cpus = Vec::new();
for part in cpulist.split(',') {
if part.contains('-') {
// Range: "0-15"
let range: Vec<&str> = part.split('-').collect();
if range.len() != 2 {
return Err(format!("Invalid range: {}", part));
}
let start: usize = range[0]
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", range[0]))?;
let end: usize = range[1]
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", range[1]))?;
for cpu in start..=end {
cpus.push(cpu);
}
} else {
// Single CPU
let cpu: usize = part
.parse()
.map_err(|_| format!("Invalid CPU ID: {}", part))?;
cpus.push(cpu);
}
}
cpus.sort_unstable();
cpus.dedup();
Ok(cpus)
}
/// Get the global NUMA topology (cached after first call)
///
/// Returns an error if NUMA topology cannot be read from sysfs. This indicates either:
/// - System doesn't support NUMA
/// - `/sys` is not mounted (e.g., restricted container)
/// - Kernel NUMA support is disabled
///
/// Callers should handle errors gracefully by disabling NUMA optimizations.
pub fn get_numa_topology() -> Result<&'static NumaTopology, &'static str> {
TOPOLOGY
.get_or_init(NumaTopology::from_sysfs)
.as_ref()
.map_err(|e| {
tracing::warn!("NUMA topology unavailable: {}", e);
"NUMA topology unavailable"
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_cpulist_range() {
let cpus = parse_cpulist("0-3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3]);
}
#[test]
fn test_parse_cpulist_list() {
let cpus = parse_cpulist("0,4,8").unwrap();
assert_eq!(cpus, vec![0, 4, 8]);
}
#[test]
fn test_parse_cpulist_mixed() {
let cpus = parse_cpulist("0-2,8,16-17").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 8, 16, 17]);
}
#[test]
fn test_parse_cpulist_ht() {
// Hyperthreading: 0-15,32-47 (physical cores 0-15, HT siblings 32-47)
let cpus = parse_cpulist("0-15,32-47").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 0);
assert_eq!(cpus[15], 15);
assert_eq!(cpus[16], 32);
assert_eq!(cpus[31], 47);
}
#[test]
fn test_parse_cpulist_real_numa_system() {
// Real dual-socket system with hyperthreading (discovered pattern)
// Node 0: CPUs 0-15, 128-143
let cpus = parse_cpulist("0-15,128-143").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 0);
assert_eq!(cpus[15], 15);
assert_eq!(cpus[16], 128);
assert_eq!(cpus[31], 143);
// Node 1: CPUs 16-31, 144-159
let cpus = parse_cpulist("16-31,144-159").unwrap();
assert_eq!(cpus.len(), 32);
assert_eq!(cpus[0], 16);
assert_eq!(cpus[15], 31);
assert_eq!(cpus[16], 144);
assert_eq!(cpus[31], 159);
}
#[test]
fn test_parse_cpulist_out_of_order() {
// Test that parser handles out-of-order input (seen in some systems)
let cpus = parse_cpulist("4,2,0,1,3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3, 4]); // Should be sorted
}
#[test]
fn test_parse_cpulist_duplicates() {
// Test deduplication (in case kernel reports duplicates)
let cpus = parse_cpulist("0-2,1-3").unwrap();
assert_eq!(cpus, vec![0, 1, 2, 3]); // Should remove duplicates
}
#[test]
fn test_parse_cpulist_empty() {
// Edge case: empty cpulist
let result = parse_cpulist("");
assert!(result.is_err() || result.unwrap().is_empty());
}
#[test]
fn test_parse_cpulist_single_cpu() {
// Single CPU node (uncommon but valid)
let cpus = parse_cpulist("5").unwrap();
assert_eq!(cpus, vec![5]);
}
#[test]
fn test_topology_bidirectional_lookup() {
// Test that node->cpu and cpu->node mappings are consistent
let mut node_to_cpus = std::collections::HashMap::new();
let mut cpu_to_node = std::collections::HashMap::new();
node_to_cpus.insert(0, vec![0, 1, 2, 3]);
node_to_cpus.insert(1, vec![4, 5, 6, 7]);
for (node, cpus) in &node_to_cpus {
for cpu in cpus {
cpu_to_node.insert(*cpu, *node);
}
}
let topology = NumaTopology {
node_to_cpus,
cpu_to_node,
};
// Verify forward lookup (node -> cpus)
assert_eq!(topology.cpus_for_node(0), Some(&[0, 1, 2, 3][..]));
assert_eq!(topology.cpus_for_node(1), Some(&[4, 5, 6, 7][..]));
// Verify reverse lookup (cpu -> node)
assert_eq!(topology.node_for_cpu(0), Some(0));
assert_eq!(topology.node_for_cpu(3), Some(0));
assert_eq!(topology.node_for_cpu(4), Some(1));
assert_eq!(topology.node_for_cpu(7), Some(1));
// Verify unknown CPU
assert_eq!(topology.node_for_cpu(999), None);
}
}
...@@ -74,7 +74,7 @@ use std::{ ...@@ -74,7 +74,7 @@ use std::{
sync::{Arc, Mutex, OnceLock}, sync::{Arc, Mutex, OnceLock},
}; };
use cudarc::driver::{CudaContext, sys}; use cudarc::driver::CudaContext;
use crate::block_manager::numa_allocator; use crate::block_manager::numa_allocator;
...@@ -91,7 +91,12 @@ use crate::block_manager::numa_allocator; ...@@ -91,7 +91,12 @@ use crate::block_manager::numa_allocator;
unsafe fn malloc_host_prefer_writecombined(size: usize) -> Result<*mut u8, StorageError> { unsafe fn malloc_host_prefer_writecombined(size: usize) -> Result<*mut u8, StorageError> {
// First, try write-combined allocation (optimal for PCIe systems) // First, try write-combined allocation (optimal for PCIe systems)
// SAFETY: Caller guarantees a valid CUDA context is bound to the current thread // SAFETY: Caller guarantees a valid CUDA context is bound to the current thread
match unsafe { cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED) } { match unsafe {
cudarc::driver::result::malloc_host(
size,
cudarc::driver::sys::CU_MEMHOSTALLOC_WRITECOMBINED,
)
} {
Ok(ptr) => Ok(ptr as *mut u8), Ok(ptr) => Ok(ptr as *mut u8),
Err(_) => { Err(_) => {
// Write-combined not supported (e.g., Grace Hopper/Blackwell), // Write-combined not supported (e.g., Grace Hopper/Blackwell),
...@@ -205,7 +210,7 @@ impl PinnedStorage { ...@@ -205,7 +210,7 @@ impl PinnedStorage {
unsafe { unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?; ctx.bind_to_thread().map_err(StorageError::Cuda)?;
// Try NUMA-aware allocation if enabled, otherwise use direct allocation // Try NUMA-aware allocation if enabled, otherwise use direct allocation.
let ptr = if numa_allocator::is_numa_enabled() { let ptr = if numa_allocator::is_numa_enabled() {
let device_id = ctx.cu_device() as u32; let device_id = ctx.cu_device() as u32;
match numa_allocator::worker_pool::NumaWorkerPool::global() match numa_allocator::worker_pool::NumaWorkerPool::global()
...@@ -627,14 +632,8 @@ mod tests { ...@@ -627,14 +632,8 @@ mod tests {
} }
} }
/// Test PinnedStorage::new with NUMA disabled (the direct allocation path) /// Test PinnedStorage::new with NUMA disabled (the direct allocation path).
///
/// This test confirms that when `DYN_KVBM_ENABLE_NUMA` is not set,
/// PinnedStorage::new uses the direct malloc_host_prefer_writecombined path
/// (lines 222-224 in the source).
#[test] #[test]
// `remove_var` is not thread-safe, so we need to run this test in a serial context
// #[serial]
fn test_pinned_storage_new_without_numa() { fn test_pinned_storage_new_without_numa() {
// Verify NUMA is actually disabled for this test // Verify NUMA is actually disabled for this test
assert!( assert!(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Storage actions.
use super::{MemoryRegion, StorageError};
/// Extension trait for storage types that support memory setting operations
pub trait Memset: MemoryRegion {
/// Sets a region of memory to a specific value
///
/// # Arguments
/// * `value` - The value to set (will be truncated to u8)
/// * `offset` - Offset in bytes from the start of the storage
/// * `size` - Number of bytes to set
///
/// # Safety
/// The caller must ensure:
/// - offset + size <= self.size()
/// - No other references exist to the memory region being set
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
}
/// Extension trait for storage types that support slicing operations
pub trait Slice {
/// Returns an immutable byte slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
fn as_slice(&self) -> Result<&[u8], StorageError>;
/// Returns an immutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
fn slice(&self, offset: usize, len: usize) -> Result<&[u8], StorageError> {
let slice = self.as_slice()?;
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed immutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid and initialized
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn as_slice_typed<T>(&self) -> Result<&[T], StorageError> {
let bytes = self.as_slice()?;
let ptr = bytes.as_ptr() as *const T;
let len = bytes.len() / std::mem::size_of::<T>();
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % std::mem::size_of::<T>() != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
std::mem::size_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
/// Returns a typed immutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn slice_typed<T>(&self, offset: usize, len: usize) -> Result<&[T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice(offset, byte_len)?;
let ptr = bytes.as_ptr() as *const T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
}
pub trait SliceMut {
/// Returns a mutable byte slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn as_slice_mut(&mut self) -> Result<&mut [u8], StorageError>;
/// Returns a mutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_mut(&mut self, offset: usize, len: usize) -> Result<&mut [u8], StorageError> {
let slice = self.as_slice_mut()?;
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get_mut(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed mutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No other references (mutable or immutable) exist to this memory region
fn as_slice_typed_mut<T>(&mut self) -> Result<&mut [T], StorageError> {
let bytes = self.as_slice_mut()?;
let ptr = bytes.as_mut_ptr() as *mut T;
let len = bytes.len() / std::mem::size_of::<T>();
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % std::mem::size_of::<T>() != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
std::mem::size_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
/// Returns a typed mutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_typed_mut<T>(&mut self, offset: usize, len: usize) -> Result<&mut [T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice_mut(offset, byte_len)?;
let ptr = bytes.as_mut_ptr() as *mut T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA device memory storage.
use super::{MemoryRegion, Result, StorageError, StorageKind};
use cudarc::driver::CudaContext;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA device memory allocated via cudaMalloc.
#[derive(Debug)]
pub struct DeviceStorage {
ctx: Arc<CudaContext>,
ptr: u64,
device_id: u32,
len: usize,
}
unsafe impl Send for DeviceStorage {}
unsafe impl Sync for DeviceStorage {}
impl DeviceStorage {
/// Allocate new device memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device on which to allocate
pub fn new(len: usize, device_id: u32) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(device_id)?;
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
Ok(Self {
ctx,
ptr,
device_id,
len,
})
}
/// Get the device pointer value.
pub fn device_ptr(&self) -> u64 {
self.ptr
}
/// Get the CUDA device ID this memory is allocated on.
pub fn device_id(&self) -> u32 {
self.device_id
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
};
}
}
impl MemoryRegion for DeviceStorage {
fn addr(&self) -> usize {
self.device_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Device(self.device_id)
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for DeviceStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(
self.ptr as *const u8,
self.len,
nixl_sys::MemType::Vram,
self.device_id as u64,
)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Disk-backed memory storage using memory-mapped files.
use super::{MemoryRegion, Result, StorageError, StorageKind};
use std::any::Any;
use std::path::{Path, PathBuf};
use core::ffi::c_char;
use nix::fcntl::{FallocateFlags, fallocate};
use nix::unistd::unlink;
use std::ffi::CString;
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
#[derive(Debug)]
pub struct DiskStorage {
fd: u64,
path: PathBuf,
size: usize,
unlinked: bool,
}
impl DiskStorage {
pub fn new(size: usize) -> Result<Self> {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let specified_dir =
std::env::var(DISK_CACHE_KEY).unwrap_or_else(|_| DEFAULT_DISK_CACHE_DIR.to_string());
let file_path = Path::new(&specified_dir).join("dynamo-kvbm-disk-cache-XXXXXX");
Self::new_at(file_path, size)
}
pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let file_path = path.as_ref().to_path_buf();
if !file_path.exists() {
std::fs::create_dir_all(file_path.parent().unwrap()).unwrap();
}
tracing::debug!("Allocating disk cache file at {}", file_path.display());
let path_str = file_path.to_str().unwrap();
let is_template = path_str.contains("XXXXXX");
let (raw_fd, actual_path) = if is_template {
// Template path - use mkostemp to generate unique filename
let template = CString::new(path_str).unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut c_char,
nix::libc::O_RDWR | nix::libc::O_DIRECT,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"mkostemp failed: {}",
std::io::Error::last_os_error()
)));
}
// Extract the actual path created by mkostemp
let actual = PathBuf::from(
CString::from_vec_with_nul(template_bytes)
.unwrap()
.to_str()
.unwrap(),
);
(fd, actual)
} else {
// Specific path - use open with O_CREAT
let path_cstr = CString::new(path_str).unwrap();
let fd = unsafe {
nix::libc::open(
path_cstr.as_ptr(),
nix::libc::O_CREAT | nix::libc::O_RDWR | nix::libc::O_DIRECT,
0o644,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"open failed: {}",
std::io::Error::last_os_error()
)));
}
(fd, file_path)
};
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(raw_fd, FallocateFlags::empty(), 0, len as i64).map_err(|e| {
StorageError::AllocationFailed(format!("Failed to allocate temp file: {}", e))
})?;
Ok(Self {
fd: raw_fd as u64,
path: actual_path,
size: len,
unlinked: false,
})
}
pub fn fd(&self) -> u64 {
self.fd
}
pub fn path(&self) -> &Path {
self.path.as_path()
}
/// Unlink our temp file.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
pub fn unlink(&mut self) -> Result<()> {
if self.unlinked {
return Ok(());
}
unlink(self.path.as_path())
.map_err(|e| StorageError::AllocationFailed(format!("Failed to unlink file: {}", e)))?;
self.unlinked = true;
Ok(())
}
pub fn unlinked(&self) -> bool {
self.unlinked
}
}
impl Drop for DiskStorage {
fn drop(&mut self) {
let _ = self.unlink();
}
}
impl MemoryRegion for DiskStorage {
fn addr(&self) -> usize {
0
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Disk(self.fd)
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for DiskStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
#[cfg(unix)]
{
// Use file descriptor as device_id for MemType::File
(
std::ptr::null(),
self.size,
nixl_sys::MemType::File,
self.fd,
)
}
#[cfg(not(unix))]
{
// On non-Unix systems, we can't get the file descriptor easily
// Return device_id as 0 - registration will fail on these systems
(
self.mmap.as_ptr(),
self.mmap.len(),
nixl_sys::MemType::File,
0,
)
}
}
}
// mod mmap {
// use super::*;
// #[cfg(unix)]
// use std::os::unix::io::AsRawFd;
// use memmap2::{MmapMut, MmapOptions};
// use std::fs::{File, OpenOptions};
// use tempfile::NamedTempFile;
// /// Disk-backed storage using memory-mapped files.
// #[derive(Debug)]
// pub struct MemMappedFileStorage {
// _file: File, // Keep file alive for the lifetime of the mmap
// mmap: MmapMut,
// path: PathBuf,
// #[cfg(unix)]
// fd: i32,
// }
// unsafe impl Send for MemMappedFileStorage {}
// unsafe impl Sync for MemMappedFileStorage {}
// impl MemMappedFileStorage {
// /// Create new disk storage with a temporary file.
// pub fn new_temp(len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// // Create temporary file
// let temp_file = NamedTempFile::new()?;
// let path = temp_file.path().to_path_buf();
// let file = temp_file.into_file();
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Create new disk storage with a specific file path.
// pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// let path = path.as_ref().to_path_buf();
// // Create or open file
// let file = OpenOptions::new()
// .read(true)
// .write(true)
// .create(true)
// .open(&path)?;
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Get the path to the backing file.
// pub fn path(&self) -> &Path {
// &self.path
// }
// /// Get the file descriptor (Unix only).
// #[cfg(unix)]
// pub fn fd(&self) -> i32 {
// self.fd
// }
// /// Get a pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped.
// pub unsafe fn as_ptr(&self) -> *const u8 {
// self.mmap.as_ptr()
// }
// /// Get a mutable pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped
// /// and that there are no other references to this memory.
// pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
// self.mmap.as_mut_ptr()
// }
// }
// impl MemoryRegion for MemMappedFileStorage {
// fn addr(&self) -> usize {
// self.mmap.as_ptr() as usize
// }
// fn size(&self) -> usize {
// self.mmap.len()
// }
// fn storage_kind(&self) -> StorageKind {
// StorageKind::Disk
// }
// fn as_any(&self) -> &dyn Any {
// self
// }
// }
// // Support for NIXL registration
// impl super::super::registered::NixlCompatible for MemMappedFileStorage {
// fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
// #[cfg(unix)]
// {
// // Use file descriptor as device_id for MemType::File
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// self.fd as u64,
// )
// }
// #[cfg(not(unix))]
// {
// // On non-Unix systems, we can't get the file descriptor easily
// // Return device_id as 0 - registration will fail on these systems
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// 0,
// )
// }
// }
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Clean, minimal storage API for v2 block manager. //! Re-export memory types from dynamo-memory with backwards-compatible names.
//! //!
//! This module provides a simplified storage abstraction with: //! This module previously contained its own storage implementations, which have
//! - Single trait for type erasure (`MemoryRegion`) //! been consolidated into the `dynamo-memory` crate. Types are re-exported here
//! - Concrete storage types (no trait implementations required) //! with compatibility aliases to preserve the existing API.
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)
pub mod actions;
mod device;
mod disk;
mod pinned;
mod registered;
mod system;
mod torch;
#[cfg(test)] // Re-export actions module from dynamo-memory
mod tests; pub use dynamo_memory::actions;
pub use device::DeviceStorage; // Keep local torch types (unique to block_manager)
pub use disk::DiskStorage; mod torch;
pub use pinned::PinnedStorage;
pub use registered::{
NixlCompatible, NixlDescriptor, NixlRegistered, RegisteredView, register_with_nixl,
};
pub use system::SystemStorage;
pub use torch::{TorchDevice, TorchTensor}; pub use torch::{TorchDevice, TorchTensor};
use serde::{Deserialize, Serialize}; // Keep tests
use std::any::Any; #[cfg(test)]
use std::fmt; mod tests;
use std::sync::Arc;
use thiserror::Error;
/// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>;
/// Errors that can occur during storage operations.
#[derive(Debug, Error)]
pub enum StorageError {
#[error("allocation failed: {0}")]
AllocationFailed(String),
#[error("registration failed: {0}")]
RegistrationFailed(String),
#[error("operation failed: {0}")]
OperationFailed(String),
#[error("unsupported operation: {0}")]
Unsupported(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
// #[cfg(feature = "cuda")]
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")]
Nixl(#[from] nixl_sys::NixlError),
}
/// Storage type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageKind {
/// System memory (malloc)
System,
/// CUDA pinned host memory // === Core types with name aliases ===
// #[cfg(feature = "cuda")]
Pinned,
/// CUDA device memory with device ID /// The core trait (was `MemoryRegion` here, now `MemoryDescriptor` in dynamo-memory).
// #[cfg(feature = "cuda")] pub use dynamo_memory::MemoryDescriptor as MemoryRegion;
Device(u32),
/// Disk-backed memory (mmap) /// The simple descriptor struct (was `MemoryDescriptor` here, now `MemoryRegion` in dynamo-memory).
Disk(u64), pub use dynamo_memory::MemoryRegion as MemoryDescriptor;
}
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryRegion>`.
pub trait MemoryRegion: Send + Sync + fmt::Debug {
/// Base address of the memory region.
fn addr(&self) -> usize;
/// Size of the memory region in bytes. // === Storage types (same names) ===
fn size(&self) -> usize; pub use dynamo_memory::{
DeviceStorage, DiskStorage, PinnedStorage, StorageError, StorageKind, SystemStorage,
};
/// Type of storage backing this region. // === NIXL types ===
fn storage_kind(&self) -> StorageKind; pub use dynamo_memory::nixl::{
NixlCompatible, NixlDescriptor, NixlRegistered, RegisteredView, register_with_nixl,
};
/// Enable downcasting to concrete type. // === Compatibility aliases ===
fn as_any(&self) -> &dyn Any;
/// Get the NIXL descriptor for this memory region. /// Result type for storage operations.
fn nixl_descriptor(&self) -> Option<NixlDescriptor> { pub type Result<T> = std::result::Result<T, StorageError>;
None
}
}
/// Type-erased memory region for use in layouts. /// Type-erased memory region for use in layouts.
pub type OwnedMemoryRegion = Arc<dyn MemoryRegion>; pub type OwnedMemoryRegion = std::sync::Arc<dyn MemoryRegion>;
/// Helper function to convert concrete storage to type-erased form. /// Helper function to convert concrete storage to type-erased form.
pub fn erase_storage<S: MemoryRegion + 'static>(storage: S) -> OwnedMemoryRegion { pub fn erase_storage<S: MemoryRegion + 'static>(storage: S) -> OwnedMemoryRegion {
Arc::new(storage) std::sync::Arc::new(storage)
} }
/// Simple memory region descriptor. /// An offset view into an existing memory region.
///
/// This wraps an `OwnedMemoryRegion` with an offset and length to represent
/// a sub-region of the original allocation.
#[derive(Debug)] #[derive(Debug)]
pub struct OffsetMemoryRegion { pub struct OffsetMemoryRegion {
base: OwnedMemoryRegion, base: OwnedMemoryRegion,
...@@ -172,35 +110,11 @@ impl MemoryRegion for OffsetMemoryRegion { ...@@ -172,35 +110,11 @@ impl MemoryRegion for OffsetMemoryRegion {
self.base.storage_kind() self.base.storage_kind()
} }
fn as_any(&self) -> &dyn Any { fn as_any(&self) -> &dyn std::any::Any {
self self
} }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryDescriptor {
pub addr: usize,
pub size: usize,
}
impl MemoryDescriptor {
pub fn new(addr: usize, size: usize) -> Self {
Self { addr, size }
}
#[inline]
pub fn addr(&self) -> usize {
self.addr
}
#[inline] fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
pub fn size(&self) -> usize { None
self.size
}
}
impl actions::Slice for MemoryDescriptor {
fn as_slice(&self) -> Result<&[u8]> {
Ok(unsafe { std::slice::from_raw_parts(self.addr as *const u8, self.size) })
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA pinned host memory storage.
use super::{MemoryRegion, Result, StorageError, StorageKind, actions};
use cudarc::driver::CudaContext;
use cudarc::driver::sys;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA pinned host memory allocated via cudaHostAlloc.
#[derive(Debug)]
pub struct PinnedStorage {
ptr: usize,
len: usize,
ctx: Arc<CudaContext>,
}
unsafe impl Send for PinnedStorage {}
unsafe impl Sync for PinnedStorage {}
impl PinnedStorage {
/// Allocate new pinned memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device to associate with the allocation
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(0)?;
let ptr = unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(len < isize::MAX as usize);
ptr as usize
};
Ok(Self { ptr, len, ctx })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
tracing::debug!("failed to free pinned memory: {e}");
}
};
}
}
impl MemoryRegion for PinnedStorage {
fn addr(&self) -> usize {
unsafe { self.as_ptr() as usize }
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Pinned
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for PinnedStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
let ptr = unsafe { self.as_ptr() };
(ptr, self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = (self.ptr as *mut u8).add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL registration wrapper for storage types.
use super::{MemoryRegion, StorageKind};
use nixl_sys::{Agent as NixlAgent, MemType, OptArgs, RegistrationHandle};
use std::any::Any;
use std::fmt;
/// Trait for storage types that can be registered with NIXL.
pub trait NixlCompatible {
/// Get parameters needed for NIXL registration.
///
/// Returns (ptr, size, mem_type, device_id)
fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
}
/// NIXL descriptor containing registration information.
#[derive(Debug, Clone)]
pub struct NixlDescriptor {
pub addr: u64,
pub size: usize,
pub mem_type: MemType,
pub device_id: u64,
}
impl nixl_sys::MemoryRegion for NixlDescriptor {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl nixl_sys::NixlDescriptor for NixlDescriptor {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
/// View trait for accessing registration information without unwrapping.
pub trait RegisteredView {
/// Get the name of the NIXL agent that registered this memory.
fn agent_name(&self) -> &str;
/// Get the NIXL descriptor for this registered memory.
fn descriptor(&self) -> NixlDescriptor;
}
/// Wrapper for storage that has been registered with NIXL.
///
/// This wrapper ensures proper drop order: the registration handle is
/// dropped before the storage, ensuring deregistration happens before
/// the memory is freed.
pub struct NixlRegistered<S: NixlCompatible> {
storage: S,
handle: Option<RegistrationHandle>,
agent_name: String,
}
impl<S: NixlCompatible> Drop for NixlRegistered<S> {
fn drop(&mut self) {
// Explicitly drop the registration handle first
drop(self.handle.take());
// Storage drops naturally after
}
}
impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NixlRegistered")
.field("storage", &self.storage)
.field("agent_name", &self.agent_name)
.field("handle", &self.handle.is_some())
.finish()
}
}
impl<S: MemoryRegion + NixlCompatible + 'static> MemoryRegion for NixlRegistered<S> {
fn addr(&self) -> usize {
self.storage.addr()
}
fn size(&self) -> usize {
self.storage.size()
}
fn storage_kind(&self) -> StorageKind {
self.storage.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor())
}
}
impl<S: MemoryRegion + NixlCompatible> RegisteredView for NixlRegistered<S> {
fn agent_name(&self) -> &str {
&self.agent_name
}
fn descriptor(&self) -> NixlDescriptor {
let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
}
}
}
impl<S: MemoryRegion + NixlCompatible> NixlRegistered<S> {
/// Get a reference to the underlying storage.
pub fn storage(&self) -> &S {
&self.storage
}
/// Get a mutable reference to the underlying storage.
pub fn storage_mut(&mut self) -> &mut S {
&mut self.storage
}
/// Check if the registration handle is still valid.
pub fn is_registered(&self) -> bool {
self.handle.is_some()
}
/// Consume this wrapper and return the underlying storage.
///
/// This will deregister the storage from NIXL.
pub fn into_storage(mut self) -> S {
// Manually drop the handle first
self.handle = None;
// Now we can move out the storage
// We need to use mem::forget to prevent Drop from running
let storage = std::mem::replace(&mut self.storage, unsafe { std::mem::zeroed() });
std::mem::forget(self);
storage
}
}
/// Register storage with a NIXL agent.
///
/// This consumes the storage and returns a `NixlRegistered` wrapper that
/// manages the registration lifetime. The registration handle will be
/// automatically dropped when the wrapper is dropped, ensuring proper
/// cleanup order.
///
/// # Arguments
/// * `storage` - The storage to register (consumed)
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper containing the storage and registration handle.
pub fn register_with_nixl<S>(
storage: S,
agent: &NixlAgent,
opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S>
where
S: MemoryRegion + NixlCompatible,
{
// Get NIXL parameters
let (ptr, size, mem_type, device_id) = storage.nixl_params();
// Create a NIXL descriptor for registration
let descriptor = NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
};
match agent.register_memory(&descriptor, opt) {
Ok(handle) => Ok(NixlRegistered {
storage,
handle: Some(handle),
agent_name: agent.name().to_string(),
}),
Err(_) => Err(storage),
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! System memory storage backed by malloc.
use super::{MemoryRegion, Result, StorageError, StorageKind, actions};
use std::any::Any;
use std::ptr::NonNull;
use nix::libc;
/// System memory allocated via malloc.
#[derive(Debug)]
pub struct SystemStorage {
ptr: NonNull<u8>,
len: usize,
}
unsafe impl Send for SystemStorage {}
unsafe impl Sync for SystemStorage {}
impl SystemStorage {
/// Allocate new system memory of the given size.
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let mut ptr: *mut libc::c_void = std::ptr::null_mut();
// We need 4KB alignment here for NIXL disk transfers to work.
// The O_DIRECT flag is required for GDS.
// However, a limitation of this flag is that all operations involving disk
// (both read and write) must be page-aligned.
// Pinned memory is already page-aligned, so we only need to align system memory.
// TODO(jthomson04): Is page size always 4KB?
// SAFETY: malloc returns suitably aligned memory or null on failure.
let result = unsafe { libc::posix_memalign(&mut ptr, 4096, len) };
if result != 0 {
return Err(StorageError::AllocationFailed(format!(
"posix_memalign failed for size {}",
len
)));
}
let ptr = NonNull::new(ptr as *mut u8).ok_or_else(|| {
StorageError::AllocationFailed(format!("malloc failed for size {}", len))
})?;
// Zero-initialize the memory
unsafe {
std::ptr::write_bytes(ptr.as_ptr(), 0, len);
}
Ok(Self { ptr, len })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl Drop for SystemStorage {
fn drop(&mut self) {
// SAFETY: pointer was allocated by malloc.
unsafe {
libc::free(self.ptr.as_ptr() as *mut libc::c_void);
}
}
}
impl MemoryRegion for SystemStorage {
fn addr(&self) -> usize {
self.ptr.as_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for SystemStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(self.ptr.as_ptr(), self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for SystemStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.ptr.as_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
impl actions::Slice for SystemStorage {
fn as_slice(&self) -> Result<&[u8]> {
Ok(unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Tests for the storage-next module. //! Tests for the storage module (re-exported from dynamo-memory).
use super::*; use super::*;
...@@ -107,18 +107,11 @@ mod cuda_tests { ...@@ -107,18 +107,11 @@ mod cuda_tests {
} }
} }
// Tests for NIXL registration would require a real NIXL agent,
// so we'll skip those for now. In practice, you'd mock the agent
// or use integration tests.
#[cfg(all(feature = "testing-nixl", feature = "testing-cuda"))] #[cfg(all(feature = "testing-nixl", feature = "testing-cuda"))]
mod nixl_tests { mod nixl_tests {
use super::super::registered::register_with_nixl;
use super::*; use super::*;
use nixl_sys::Agent as NixlAgent; use nixl_sys::Agent as NixlAgent;
// These tests would require a mock NIXL agent or real NIXL setup
// Placeholder for now
#[test] #[test]
fn test_nixl_registration() { fn test_nixl_registration() {
let pinned = PinnedStorage::new(2048).unwrap(); let pinned = PinnedStorage::new(2048).unwrap();
......
...@@ -287,4 +287,8 @@ impl MemoryRegion for RemoteMemoryDescriptor { ...@@ -287,4 +287,8 @@ impl MemoryRegion for RemoteMemoryDescriptor {
fn as_any(&self) -> &dyn Any { fn as_any(&self) -> &dyn Any {
self self
} }
fn nixl_descriptor(&self) -> Option<crate::block_manager::v2::memory::NixlDescriptor> {
None
}
} }
...@@ -42,6 +42,9 @@ impl MemoryRegion for MockMemory { ...@@ -42,6 +42,9 @@ impl MemoryRegion for MockMemory {
fn as_any(&self) -> &dyn Any { fn as_any(&self) -> &dyn Any {
self self
} }
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
} }
/// Mock memory region for testing serialization /// Mock memory region for testing serialization
......
...@@ -195,7 +195,16 @@ fn fill_memory_region( ...@@ -195,7 +195,16 @@ fn fill_memory_region(
mod tests { mod tests {
use super::super::tests::*; use super::super::tests::*;
use super::*; use super::*;
use crate::block_manager::v2::memory::actions::Slice;
/// Get a byte slice from a MemoryDescriptor.
///
/// # Safety
/// The memory region must be valid and no mutable references may exist.
unsafe fn descriptor_as_slice(
desc: &crate::block_manager::v2::memory::MemoryDescriptor,
) -> &[u8] {
unsafe { std::slice::from_raw_parts(desc.addr as *const u8, desc.size) }
}
#[test] #[test]
fn test_fill_blocks_constant() { fn test_fill_blocks_constant() {
...@@ -208,15 +217,9 @@ mod tests { ...@@ -208,15 +217,9 @@ mod tests {
fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap(); fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap();
// Verify all bytes are set to 42 // Verify all bytes are set to 42
assert!( let mr = physical.memory_region(0, 0, 0).unwrap();
physical let mr_slice = unsafe { descriptor_as_slice(&mr) };
.memory_region(0, 0, 0) assert!(mr_slice.iter().all(|&b| b == 42));
.unwrap()
.as_slice()
.unwrap()
.iter()
.all(|&b| b == 42)
);
} }
#[test] #[test]
...@@ -230,7 +233,7 @@ mod tests { ...@@ -230,7 +233,7 @@ mod tests {
fill_blocks(&physical, &[0, 1], FillPattern::Sequential).unwrap(); fill_blocks(&physical, &[0, 1], FillPattern::Sequential).unwrap();
let mr = physical.memory_region(0, 0, 0).unwrap(); let mr = physical.memory_region(0, 0, 0).unwrap();
let mr_slice = mr.as_slice().unwrap(); let mr_slice = unsafe { descriptor_as_slice(&mr) };
// Verify pattern is applied (spot check a few bytes) // Verify pattern is applied (spot check a few bytes)
let first_byte = mr_slice[0]; let first_byte = mr_slice[0];
...@@ -239,7 +242,7 @@ mod tests { ...@@ -239,7 +242,7 @@ mod tests {
assert_eq!(second_byte, first_byte.wrapping_add(1)); assert_eq!(second_byte, first_byte.wrapping_add(1));
let mr = physical.memory_region(1, 1, 0).unwrap(); let mr = physical.memory_region(1, 1, 0).unwrap();
let mr_slice = mr.as_slice().unwrap(); let mr_slice = unsafe { descriptor_as_slice(&mr) };
let first_byte = mr_slice[0]; let first_byte = mr_slice[0];
let second_byte = mr_slice[1]; let second_byte = mr_slice[1];
...@@ -261,10 +264,14 @@ mod tests { ...@@ -261,10 +264,14 @@ mod tests {
fill_layers(&physical, &[1], 0..1, FillPattern::Constant(100)).unwrap(); fill_layers(&physical, &[1], 0..1, FillPattern::Constant(100)).unwrap();
fill_layers(&physical, &[1], 1..2, FillPattern::Constant(101)).unwrap(); fill_layers(&physical, &[1], 1..2, FillPattern::Constant(101)).unwrap();
let mr_00 = physical.memory_region(0, 0, 0).unwrap().as_slice().unwrap()[0]; let desc = physical.memory_region(0, 0, 0).unwrap();
let mr_01 = physical.memory_region(0, 1, 0).unwrap().as_slice().unwrap()[0]; let mr_00 = unsafe { descriptor_as_slice(&desc) }[0];
let mr_10 = physical.memory_region(1, 0, 0).unwrap().as_slice().unwrap()[0]; let desc = physical.memory_region(0, 1, 0).unwrap();
let mr_11 = physical.memory_region(1, 1, 0).unwrap().as_slice().unwrap()[0]; let mr_01 = unsafe { descriptor_as_slice(&desc) }[0];
let desc = physical.memory_region(1, 0, 0).unwrap();
let mr_10 = unsafe { descriptor_as_slice(&desc) }[0];
let desc = physical.memory_region(1, 1, 0).unwrap();
let mr_11 = unsafe { descriptor_as_slice(&desc) }[0];
assert_eq!(mr_00, 0); assert_eq!(mr_00, 0);
assert_eq!(mr_01, 1); assert_eq!(mr_01, 1);
assert_eq!(mr_10, 100); assert_eq!(mr_10, 100);
......
...@@ -195,7 +195,7 @@ impl<S: MemoryDescriptor + NixlCompatible> NixlRegistered<S> { ...@@ -195,7 +195,7 @@ impl<S: MemoryDescriptor + NixlCompatible> NixlRegistered<S> {
/// A `NixlRegistered` wrapper containing the storage and registration handle. /// A `NixlRegistered` wrapper containing the storage and registration handle.
pub fn register_with_nixl<S>( pub fn register_with_nixl<S>(
storage: S, storage: S,
agent: &NixlAgent, agent: &Agent,
opt: Option<&OptArgs>, opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S> ) -> std::result::Result<NixlRegistered<S>, S>
where where
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment