"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "501ef021c247eb2e82ec84843028fdcdfb157252"
Unverified Commit 3659c82e authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm v2 memory crate (#4012)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent bc02088e
...@@ -2176,6 +2176,13 @@ dependencies = [ ...@@ -2176,6 +2176,13 @@ dependencies = [
"syn 2.0.106", "syn 2.0.106",
] ]
[[package]]
name = "dynamo-config"
version = "0.6.1"
dependencies = [
"anyhow",
]
[[package]] [[package]]
name = "dynamo-engine-mistralrs" name = "dynamo-engine-mistralrs"
version = "0.6.1" version = "0.6.1"
...@@ -2291,6 +2298,23 @@ dependencies = [ ...@@ -2291,6 +2298,23 @@ dependencies = [
"zeromq", "zeromq",
] ]
[[package]]
name = "dynamo-memory"
version = "0.6.1"
dependencies = [
"anyhow",
"cudarc 0.17.3",
"dynamo-config",
"libc",
"nix 0.30.1",
"nixl-sys",
"offset-allocator",
"serde",
"tempfile",
"thiserror 2.0.16",
"tracing",
]
[[package]] [[package]]
name = "dynamo-parsers" name = "dynamo-parsers"
version = "0.6.1" version = "0.6.1"
...@@ -5277,6 +5301,18 @@ dependencies = [ ...@@ -5277,6 +5301,18 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "nix"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6"
dependencies = [
"bitflags 2.9.4",
"cfg-if 1.0.3",
"cfg_aliases",
"libc",
]
[[package]] [[package]]
name = "nixl-sys" name = "nixl-sys"
version = "0.7.0" version = "0.7.0"
......
...@@ -6,13 +6,16 @@ members = [ ...@@ -6,13 +6,16 @@ members = [
"launch/dynamo-run", "launch/dynamo-run",
"lib/llm", "lib/llm",
"lib/runtime", "lib/runtime",
"lib/config",
"lib/tokens", "lib/tokens",
"lib/memory",
"lib/async-openai", "lib/async-openai",
"lib/parsers", "lib/parsers",
"lib/bindings/c", "lib/bindings/c",
"lib/bindings/python/codegen", "lib/bindings/python/codegen",
"lib/engines/*", "lib/engines/*",
"lib/kvbm", "lib/kvbm",
"lib/config",
] ]
# Exclude certain packages that are slow to build and we don't ship as flagship # Exclude certain packages that are slow to build and we don't ship as flagship
# features from default build, but keep them in workspace for convenience. # features from default build, but keep them in workspace for convenience.
...@@ -21,6 +24,7 @@ members = [ ...@@ -21,6 +24,7 @@ members = [
default-members = [ default-members = [
"lib/llm", "lib/llm",
"lib/runtime", "lib/runtime",
"lib/config",
"lib/tokens", "lib/tokens",
"lib/async-openai", "lib/async-openai",
"lib/parsers", "lib/parsers",
...@@ -42,6 +46,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"] ...@@ -42,6 +46,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"]
# Local crates # Local crates
dynamo-runtime = { path = "lib/runtime", version = "0.6.1" } dynamo-runtime = { path = "lib/runtime", version = "0.6.1" }
dynamo-llm = { path = "lib/llm", version = "0.6.1" } dynamo-llm = { path = "lib/llm", version = "0.6.1" }
dynamo-config = { path = "lib/config", version = "0.6.1" }
dynamo-tokens = { path = "lib/tokens", version = "0.6.1" } dynamo-tokens = { path = "lib/tokens", version = "0.6.1" }
dynamo-async-openai = { path = "lib/async-openai", version = "0.6.1", features = [ dynamo-async-openai = { path = "lib/async-openai", version = "0.6.1", features = [
"byot", "byot",
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-config"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
anyhow = { workspace = true }
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Configuration utilities and trait re-exports.
//!
//! This module provides utility functions for parsing configuration values
//! and re-exports the core configuration traits from the integrations module.
// ===== Environment Variable Utilities =====
/// Check if a string is truthy.
///
/// This will be used to evaluate environment variables or any other subjective
/// configuration parameters that can be set by the user that should be evaluated
/// as a boolean value.
///
/// Truthy values: "1", "true", "on", "yes" (case-insensitive)
///
/// Returns `false` for invalid values. Use [`parse_bool`] if you need to error on invalid values.
pub fn is_truthy(val: &str) -> bool {
matches!(val.to_lowercase().as_str(), "1" | "true" | "on" | "yes")
}
/// Check if a string is falsey.
///
/// This will be used to evaluate environment variables or any other subjective
/// configuration parameters that can be set by the user that should be evaluated
/// as a boolean value (opposite of is_truthy).
///
/// Falsey values: "0", "false", "off", "no" (case-insensitive)
///
/// Returns `false` for invalid values. Use [`parse_bool`] if you need to error on invalid values.
pub fn is_falsey(val: &str) -> bool {
matches!(val.to_lowercase().as_str(), "0" | "false" | "off" | "no")
}
/// Parse a string as a boolean value, returning an error if invalid.
///
/// This function strictly validates that the input is a valid boolean representation.
///
/// # Arguments
/// * `val` - The string value to parse
///
/// # Returns
/// * `Ok(true)` - For truthy values: "1", "true", "on", "yes" (case-insensitive)
/// * `Ok(false)` - For falsey values: "0", "false", "off", "no" (case-insensitive)
/// * `Err(_)` - For any other value
///
/// # Example
/// ```ignore
/// assert_eq!(parse_bool("true")?, true);
/// assert_eq!(parse_bool("0")?, false);
/// assert!(parse_bool("maybe").is_err());
/// ```
pub fn parse_bool(val: &str) -> anyhow::Result<bool> {
if is_truthy(val) {
Ok(true)
} else if is_falsey(val) {
Ok(false)
} else {
anyhow::bail!(
"Invalid boolean value: '{}'. Expected one of: true/false, 1/0, on/off, yes/no",
val
)
}
}
/// Check if an environment variable is truthy.
///
/// Returns `false` if the environment variable is not set or is invalid.
/// Use [`env_parse_bool`] if you need to distinguish between unset, valid, and invalid values.
pub fn env_is_truthy(env: &str) -> bool {
match std::env::var(env) {
Ok(val) => is_truthy(val.as_str()),
Err(_) => false,
}
}
/// Check if an environment variable is falsey.
///
/// Returns `false` if the environment variable is not set or is invalid.
/// Use [`env_parse_bool`] if you need to distinguish between unset, valid, and invalid values.
pub fn env_is_falsey(env: &str) -> bool {
match std::env::var(env) {
Ok(val) => is_falsey(val.as_str()),
Err(_) => false,
}
}
/// Parse an environment variable as a boolean, returning an error if invalid.
///
/// # Arguments
/// * `env` - The environment variable name
///
/// # Returns
/// * `Ok(Some(true))` - If the env var is set to a truthy value
/// * `Ok(Some(false))` - If the env var is set to a falsey value
/// * `Ok(None)` - If the env var is not set
/// * `Err(_)` - If the env var is set to an invalid value
///
/// # Example
/// ```ignore
/// match env_parse_bool("MY_FLAG")? {
/// Some(true) => println!("enabled"),
/// Some(false) => println!("disabled"),
/// None => println!("not configured"),
/// }
/// ```
pub fn env_parse_bool(env: &str) -> anyhow::Result<Option<bool>> {
match std::env::var(env) {
Ok(val) => parse_bool(&val).map(Some),
Err(std::env::VarError::NotPresent) => Ok(None),
Err(e) => anyhow::bail!("Failed to read environment variable {}: {}", env, e),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_truthy() {
assert!(is_truthy("1"));
assert!(is_truthy("true"));
assert!(is_truthy("True"));
assert!(is_truthy("TRUE"));
assert!(is_truthy("on"));
assert!(is_truthy("ON"));
assert!(is_truthy("yes"));
assert!(is_truthy("YES"));
assert!(!is_truthy("0"));
assert!(!is_truthy("false"));
assert!(!is_truthy("off"));
assert!(!is_truthy("no"));
assert!(!is_truthy(""));
assert!(!is_truthy("random"));
}
#[test]
fn test_is_falsey() {
assert!(is_falsey("0"));
assert!(is_falsey("false"));
assert!(is_falsey("False"));
assert!(is_falsey("FALSE"));
assert!(is_falsey("off"));
assert!(is_falsey("OFF"));
assert!(is_falsey("no"));
assert!(is_falsey("NO"));
assert!(!is_falsey("1"));
assert!(!is_falsey("true"));
assert!(!is_falsey("on"));
assert!(!is_falsey("yes"));
assert!(!is_falsey(""));
assert!(!is_falsey("random"));
}
#[test]
fn test_env_is_truthy_not_set() {
// Test with a variable that definitely doesn't exist
assert!(!env_is_truthy("DEFINITELY_NOT_SET_VAR_12345"));
}
#[test]
fn test_env_is_falsey_not_set() {
// Test with a variable that definitely doesn't exist
assert!(!env_is_falsey("DEFINITELY_NOT_SET_VAR_12345"));
}
#[test]
fn test_parse_bool() {
// Truthy values
assert!(parse_bool("1").unwrap());
assert!(parse_bool("true").unwrap());
assert!(parse_bool("TRUE").unwrap());
assert!(parse_bool("on").unwrap());
assert!(parse_bool("yes").unwrap());
// Falsey values
assert!(!parse_bool("0").unwrap());
assert!(!parse_bool("false").unwrap());
assert!(!parse_bool("FALSE").unwrap());
assert!(!parse_bool("off").unwrap());
assert!(!parse_bool("no").unwrap());
// Invalid values
assert!(parse_bool("").is_err());
assert!(parse_bool("maybe").is_err());
assert!(parse_bool("2").is_err());
assert!(parse_bool("random").is_err());
}
#[test]
fn test_env_parse_bool_not_set() {
// Test with a variable that definitely doesn't exist
assert_eq!(
env_parse_bool("DEFINITELY_NOT_SET_VAR_12345").unwrap(),
None
);
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-memory"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
[features]
default = ["testing-all"]
# feature to enable unsafe slices of memory descriptors
# for advanced testing in other crates
unsafe-slices = []
# test features for hardware-specific tests
testing-cuda = []
testing-nixl = []
testing-all = ["testing-cuda", "testing-nixl"]
[dependencies]
dynamo-config = { workspace = true }
anyhow = { workspace = true }
cudarc = { workspace = true }
nixl-sys = { version = "0.7" }
serde = { workspace = true}
thiserror = { workspace = true }
tracing = { workspace = true }
libc = { version = "0.2" }
nix = { version = "0.30", features = ["fs"] }
offset-allocator = "0.2"
[dev-dependencies]
tempfile = "3"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Storage actions.
use super::{MemoryDescription, StorageError};
/// Extension trait for storage types that support memory setting operations
pub trait Memset: MemoryDescription {
/// Sets a region of memory to a specific value
///
/// # Arguments
/// * `value` - The value to set (will be truncated to u8)
/// * `offset` - Offset in bytes from the start of the storage
/// * `size` - Number of bytes to set
///
/// # Safety
/// The caller must ensure:
/// - offset + size <= self.size()
/// - No other references exist to the memory region being set
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
}
/// Extension trait for storage types that support slicing operations
pub trait Slice: MemoryDescription + 'static {
/// Returns an immutable byte slice view of the entire storage region
///
/// # Safety
/// This is an unsafe method. The caller must ensure:
/// - The memory region remains valid for the lifetime of the returned slice
/// - The memory region is properly initialized
/// - No concurrent mutable access occurs while the slice is in use
/// - The memory backing this storage remains valid (implementors with owned
/// memory satisfy this, but care must be taken with unowned memory regions)
unsafe fn as_slice(&self) -> Result<&[u8], StorageError>;
/// Returns an immutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
fn slice(&self, offset: usize, len: usize) -> Result<&[u8], StorageError> {
// SAFETY: Caller guarantees memory validity per trait's safety contract
let slice = unsafe { self.as_slice()? };
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed immutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid and initialized
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn as_slice_typed<T: Sized>(&self) -> Result<&[T], StorageError> {
// SAFETY: Caller guarantees memory validity per trait's safety contract
let bytes = unsafe { self.as_slice()? };
let ptr = bytes.as_ptr() as *const T;
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Err(StorageError::Unsupported(
"zero-sized types are not supported".into(),
));
}
let len = bytes.len() / elem_size;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % elem_size != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
elem_size
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
/// Returns a typed immutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn slice_typed<T: Sized>(&self, offset: usize, len: usize) -> Result<&[T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice(offset, byte_len)?;
let ptr = bytes.as_ptr() as *const T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
}
pub trait SliceMut: MemoryDescription + 'static {
/// Returns a mutable byte slice view of the entire storage region
///
/// # Safety
/// This is an unsafe method. The caller must ensure:
/// - The memory region remains valid for the lifetime of the returned slice
/// - The memory region is valid and accessible
/// - No other references (mutable or immutable) exist to this memory region
/// - The memory backing this storage remains valid (implementors with owned
/// memory satisfy this, but care must be taken with unowned memory regions)
unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8], StorageError>;
/// Returns a mutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_mut(&mut self, offset: usize, len: usize) -> Result<&mut [u8], StorageError> {
// SAFETY: Caller guarantees memory validity per trait's safety contract
let slice = unsafe { self.as_slice_mut()? };
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get_mut(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed mutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No other references (mutable or immutable) exist to this memory region
fn as_slice_typed_mut<T: Sized>(&mut self) -> Result<&mut [T], StorageError> {
// SAFETY: Caller guarantees memory validity per trait's safety contract
let bytes = unsafe { self.as_slice_mut()? };
let ptr = bytes.as_mut_ptr() as *mut T;
let len = bytes.len() / std::mem::size_of::<T>();
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % std::mem::size_of::<T>() != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
std::mem::size_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
/// Returns a typed mutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_typed_mut<T: Sized>(
&mut self,
offset: usize,
len: usize,
) -> Result<&mut [T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice_mut(offset, byte_len)?;
let ptr = bytes.as_mut_ptr() as *mut T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! # Arena Allocator
//!
//! This module provides an arena allocator for generally heap-like allocations.
//! An [`ArenaAllocator`] can be created by taking ownership of a [`MemoryDescription`] instance.
//!
//! The [`ArenaAllocator`] allocates memory contiguous regions using the [`offset_allocator`] crate,
//! which builds on [Sebastian Aaltonen's ArenaAllocator](https://github.com/sebbbi/ArenaAllocator)
use crate::StorageKind;
use super::{MemoryDescription, StorageError};
use offset_allocator::{Allocation, Allocator};
use std::{
any::Any,
sync::{Arc, Mutex},
};
/// Errors specific to arena allocation.
#[derive(Debug, thiserror::Error)]
pub enum ArenaError {
#[error("Page size must be a power of 2")]
PageSizeNotAligned,
#[error("Allocation failed")]
AllocationFailed,
#[error("Failed to convert pages to u32")]
PagesNotConvertible,
#[error("Storage error: {0}")]
StorageError(#[from] StorageError),
}
/// Arena allocator backed by an instance of a [`MemoryDescription`] object.
///
/// This struct wraps an [`Allocator`] from the [`offset_allocator`] crate,
/// and provides methods for allocating memory from the storage.
///
/// The allocator is thread-safe, and the storage is shared between the allocator and the buffers.
#[derive(Clone)]
pub struct ArenaAllocator<S: MemoryDescription> {
storage: Arc<S>,
allocator: Arc<Mutex<Allocator>>,
page_size: u64,
}
impl<S: MemoryDescription> std::fmt::Debug for ArenaAllocator<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ArenaAllocator {{ storage: {:?}, page_size: {} }}",
self.storage, self.page_size
)
}
}
/// A buffer allocated from an [`ArenaAllocator`].
///
/// This struct wraps an [`Allocation`] from the [`offset_allocator`] crate,
/// and provides methods for interacting with the allocated memory.
///
/// The buffer is backed by a [`MemoryDescription`] object, and the allocation is freed when the buffer is dropped.
pub struct ArenaBuffer<S: MemoryDescription> {
offset: usize,
address: usize,
requested_size: usize,
storage: Arc<S>,
allocation: Allocation,
allocator: Arc<Mutex<Allocator>>,
}
impl<S: MemoryDescription> ArenaAllocator<S> {
/// Create a new [`ArenaAllocator`] from a [`MemoryDescription`] object and a page size.
///
/// The page size must be a power of two.
///
/// The allocator will divide the storage into pages and allocations will consist of a set of contiguous
/// pages whose aggregate size is greater than or equal to the requested size.
///
/// The allocator is thread-safe, and the storage is shared between the allocator and the buffers.
pub fn new(storage: S, page_size: usize) -> std::result::Result<Self, ArenaError> {
let storage = Arc::new(storage);
if !page_size.is_power_of_two() {
return Err(ArenaError::PageSizeNotAligned);
}
// divide storage into pages,
// round down such that all pages are fully and any remaining bytes are discarded
let pages = storage.size() / page_size;
let allocator = Allocator::new(
pages
.try_into()
.map_err(|_| ArenaError::PagesNotConvertible)?,
);
let allocator = Arc::new(Mutex::new(allocator));
Ok(Self {
storage,
allocator,
page_size: page_size as u64,
})
}
/// Allocate a new [`ArenaBuffer`] from the allocator.
pub fn allocate(&self, size: usize) -> std::result::Result<ArenaBuffer<S>, ArenaError> {
let size = size as u64;
let pages = size.div_ceil(self.page_size);
let allocation = self
.allocator
.lock()
.unwrap()
.allocate(pages.try_into().map_err(|_| ArenaError::AllocationFailed)?)
.ok_or(ArenaError::AllocationFailed)?;
let offset = allocation.offset as u64 * self.page_size;
let address = self.storage.addr() + offset as usize;
debug_assert!(address + size as usize <= self.storage.addr() + self.storage.size());
Ok(ArenaBuffer {
offset: offset as usize,
address,
requested_size: size as usize,
allocation,
storage: self.storage.clone(),
allocator: self.allocator.clone(),
})
}
}
impl<S: MemoryDescription> std::fmt::Debug for ArenaBuffer<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ArenaBuffer {{ addr: {}, size: {}, kind: {:?}, allocator: {:p} }}",
self.address,
self.requested_size,
self.storage.storage_kind(),
Arc::as_ptr(&self.storage)
)
}
}
impl<S: MemoryDescription + 'static> MemoryDescription for ArenaBuffer<S> {
fn addr(&self) -> usize {
self.address
}
fn size(&self) -> usize {
self.requested_size
}
fn storage_kind(&self) -> StorageKind {
self.storage.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
if let Some(mut descriptor) = self.storage.nixl_descriptor() {
descriptor.addr = self.addr() as u64;
descriptor.size = self.size();
Some(descriptor)
} else {
None
}
}
}
// NIXL integration helpers
use super::nixl::{NixlCompatible, NixlDescriptor, RegisteredView};
impl<S> ArenaBuffer<S>
where
S: MemoryDescription + NixlCompatible,
{
/// Create a NIXL descriptor for this buffer with the correct offset and size.
///
/// This can be used when the base storage implements NixlCompatible to create
/// a descriptor that points to just this buffer's region.
pub fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
let (base_ptr, _base_size, mem_type, device_id) = self.storage.nixl_params();
// Calculate the offset pointer
let buffer_ptr = unsafe { base_ptr.add(self.offset) };
Some(NixlDescriptor {
addr: buffer_ptr as u64,
size: self.requested_size,
mem_type,
device_id,
})
}
}
impl<S> ArenaBuffer<S>
where
S: MemoryDescription + RegisteredView,
{
/// Get the agent name from registered storage.
///
/// This is a convenience method when using ArenaAllocator with NixlRegistered<T> storage.
pub fn agent_name(&self) -> &str {
self.storage.agent_name()
}
/// Get a NIXL descriptor that includes registration information.
pub fn registered_descriptor(&self) -> NixlDescriptor {
let base_descriptor = self.storage.descriptor();
// Create a new descriptor with adjusted address and size for this buffer
NixlDescriptor {
addr: base_descriptor.addr + self.offset as u64,
size: self.requested_size,
mem_type: base_descriptor.mem_type,
device_id: base_descriptor.device_id,
}
}
}
impl<S: MemoryDescription> Drop for ArenaBuffer<S> {
fn drop(&mut self) {
self.allocator.lock().unwrap().free(self.allocation);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SystemStorage;
const PAGE_SIZE: usize = 4096;
const PAGE_COUNT: usize = 10;
const TOTAL_STORAGE_SIZE: usize = PAGE_SIZE * PAGE_COUNT;
fn create_allocator() -> ArenaAllocator<SystemStorage> {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
ArenaAllocator::new(storage, PAGE_SIZE).unwrap()
}
#[test]
/// Tests successful creation of an `ArenaAllocator` with valid page size.
/// Verifies that `ArenaAllocator::new` returns `Ok`.
fn test_arena_allocator_new_success() {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
let allocator_result = ArenaAllocator::new(storage, PAGE_SIZE);
assert!(allocator_result.is_ok());
}
#[test]
/// Tests `ArenaAllocator` creation with an invalid page size (not a power of 2).
/// Verifies that `ArenaAllocator::new` returns an `ArenaError::PageSizeNotAligned` error.
fn test_arena_allocator_new_invalid_page_size() {
let storage = SystemStorage::new(TOTAL_STORAGE_SIZE).unwrap();
let allocator_result = ArenaAllocator::new(storage, PAGE_SIZE + 1);
assert!(allocator_result.is_err());
assert!(matches!(
allocator_result,
Err(ArenaError::PageSizeNotAligned)
));
}
#[test]
/// Tests allocation of a single buffer that is a multiple of the page size.
/// Verifies that the allocation is successful, the buffer has the correct size,
/// and its address is the start of the storage area (as it's the first allocation).
fn test_allocate_single_buffer() {
let allocator = create_allocator();
let buffer_size = PAGE_SIZE * 2;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_ok());
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), buffer_size);
assert_eq!(buffer.addr(), allocator.storage.addr()); // First allocation starts at addr
}
#[test]
/// Tests allocation of multiple buffers of varying sizes (multiples of page size).
/// Verifies that allocations are successful, buffers have correct sizes, and their
/// addresses are correctly offset from each other based on previous allocations.
fn test_allocate_multiple_buffers() {
let allocator = create_allocator();
let buffer_size1 = PAGE_SIZE * 2;
let buffer1_result = allocator.allocate(buffer_size1);
assert!(buffer1_result.is_ok());
let buffer1 = buffer1_result.unwrap();
assert_eq!(buffer1.size(), buffer_size1);
assert_eq!(buffer1.addr(), allocator.storage.addr());
let buffer_size2 = PAGE_SIZE * 3;
let buffer2_result = allocator.allocate(buffer_size2);
assert!(buffer2_result.is_ok());
let buffer2 = buffer2_result.unwrap();
assert_eq!(buffer2.size(), buffer_size2);
assert_eq!(buffer2.addr(), allocator.storage.addr() + buffer_size1);
}
#[test]
/// Tests allocation of a single buffer that consumes the entire storage space.
/// Verifies that the allocation is successful and the buffer has the correct size.
fn test_allocate_exact_size() {
let allocator = create_allocator();
let buffer_size = TOTAL_STORAGE_SIZE;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_ok());
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), buffer_size);
}
#[test]
/// Tests an attempt to allocate a buffer larger than the total available storage.
/// Verifies that the allocation fails with `ArenaError::AllocationFailed`.
fn test_allocate_too_large() {
let allocator = create_allocator();
let buffer_size = TOTAL_STORAGE_SIZE + PAGE_SIZE;
let buffer_result = allocator.allocate(buffer_size);
assert!(buffer_result.is_err());
assert!(matches!(buffer_result, Err(ArenaError::AllocationFailed)));
}
#[test]
/// Tests the `Drop` implementation of `ArenaBuffer` for freeing allocated pages.
/// It allocates a buffer, lets it go out of scope (triggering `drop`), and then
/// attempts to reallocate a buffer of the same size. This second allocation should
/// succeed and reuse the initially allocated space, starting at the storage address.
fn test_buffer_drop_and_reallocate() {
let allocator = create_allocator();
// we can not allocate two buffers of `buffer_size` as it will exceed the total storage size
// if the memory is properly returned, then we should be able to reallocate the same size buffer
let buffer_size = PAGE_SIZE * 6;
{
let buffer1 = allocator.allocate(buffer_size).unwrap();
assert_eq!(buffer1.size(), buffer_size);
assert_eq!(buffer1.addr(), allocator.storage.addr());
} // buffer1 is dropped here, freeing its pages
// Try to allocate a new buffer of the same size, it should succeed and reuse the space
let buffer2_result = allocator.allocate(buffer_size);
assert!(buffer2_result.is_ok());
let buffer2 = buffer2_result.unwrap();
assert_eq!(buffer2.size(), buffer_size);
assert_eq!(buffer2.addr(), allocator.storage.addr()); // Should be at the start again
}
#[test]
/// Tests filling the arena with two buffers that together consume all available pages
/// and then attempting one more small allocation, which should fail.
/// Verifies that after the allocator is full, `ArenaError::AllocationFailed` is returned.
fn test_allocate_fill_and_fail() {
let allocator = create_allocator();
let buffer_size_half = TOTAL_STORAGE_SIZE / 2; // Each takes 5 pages
let buffer1 = allocator.allocate(buffer_size_half).unwrap();
assert_eq!(buffer1.size(), buffer_size_half);
let buffer2 = allocator.allocate(buffer_size_half).unwrap();
assert_eq!(buffer2.size(), buffer_size_half);
assert_eq!(buffer2.addr(), allocator.storage.addr() + buffer_size_half);
// Now try to allocate one more page, should fail
let buffer3_result = allocator.allocate(PAGE_SIZE);
assert!(buffer3_result.is_err());
assert!(matches!(buffer3_result, Err(ArenaError::AllocationFailed)));
}
#[test]
/// Tests allocation of a single byte.
/// Verifies that the allocation is successful and the buffer reports its size as 1.
/// The actual page consumption is tested behaviorally in exhaustion tests.
fn test_allocate_non_page_aligned_single_byte() {
let allocator = create_allocator();
let buffer = allocator.allocate(1).unwrap();
assert_eq!(buffer.size(), 1);
// Internal page allocation is behaviorally tested by exhaustion tests
}
#[test]
/// Tests allocation of a size that is one byte less than a full page.
/// Verifies that the allocation is successful and the buffer reports the correct size.
/// The actual page consumption is tested behaviorally in exhaustion tests.
fn test_allocate_non_page_aligned_almost_full_page() {
let allocator = create_allocator();
let buffer = allocator.allocate(PAGE_SIZE - 1).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE - 1);
}
#[test]
/// Tests allocation of a size that is one byte more than a full page.
/// Verifies that the allocation is successful and the buffer reports the correct size.
/// This will consume two pages, which is tested behaviorally in exhaustion tests.
fn test_allocate_non_page_aligned_just_over_one_page() {
let allocator = create_allocator();
let buffer = allocator.allocate(PAGE_SIZE + 1).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE + 1);
}
#[test]
/// Tests a specific scenario of non-page-aligned allocations leading to arena exhaustion.
/// Allocates `(PAGE_COUNT / 2 * PAGE_SIZE) + 1` bytes. This requires `(PAGE_COUNT / 2) + 1` pages.
/// The first allocation should succeed. The second allocation of the same size should fail
/// because not enough pages remain, verifying the page rounding and consumption logic.
fn test_allocate_half_plus_one_byte_twice_exhausts_arena() {
let allocator = create_allocator();
let allocation_size = (PAGE_COUNT / 2 * PAGE_SIZE) + 1;
// This allocation will require (PAGE_COUNT / 2) + 1 pages.
// For PAGE_COUNT = 10, this is 5 * PAGE_SIZE + 1 bytes, requiring 6 pages.
let buffer1_result = allocator.allocate(allocation_size);
assert!(buffer1_result.is_ok(), "First allocation should succeed");
let buffer1 = buffer1_result.unwrap();
assert_eq!(buffer1.size(), allocation_size);
let pages_for_first_alloc = (allocation_size as u64).div_ceil(allocator.page_size);
assert_eq!(pages_for_first_alloc, (PAGE_COUNT / 2 + 1) as u64);
// Second allocation of the same size should fail because we don't have enough pages left.
// Remaining pages = PAGE_COUNT - pages_for_first_alloc
// For PAGE_COUNT = 10, remaining = 10 - 6 = 4 pages.
// We need (PAGE_COUNT / 2 + 1) = 6 pages.
let buffer2_result = allocator.allocate(allocation_size);
assert!(
buffer2_result.is_err(),
"Second allocation should fail due to insufficient pages"
);
assert!(matches!(buffer2_result, Err(ArenaError::AllocationFailed)));
}
#[test]
/// Tests filling the arena with multiple non-page-aligned allocations that each consume more
/// than one page due to rounding (specifically, `PAGE_SIZE + 1` bytes, consuming 2 pages each).
/// After filling the arena based on this consumption, it verifies that a subsequent small
/// allocation fails with `ArenaError::AllocationFailed`.
fn test_fill_with_non_aligned_and_fail() {
let allocator = create_allocator();
// This test verifies that multiple small allocations, each consuming slightly more than one page
// (thus taking two pages from the underlying offset_allocator), correctly fill the arena.
// Let's allocate (PAGE_SIZE + 1) multiple times. Each will take 2 pages.
let single_alloc_size = PAGE_SIZE + 1; // Will take 2 pages
let num_possible_allocs = PAGE_COUNT / 2; // e.g., 10 / 2 = 5 such allocations
let mut allocated_buffers = Vec::with_capacity(num_possible_allocs);
for i in 0..num_possible_allocs {
let buffer_result = allocator.allocate(single_alloc_size);
assert!(buffer_result.is_ok(), "Allocation {} should succeed", i + 1);
let buffer = buffer_result.unwrap();
assert_eq!(buffer.size(), single_alloc_size);
allocated_buffers.push(buffer);
}
// At this point, all pages should be consumed (num_possible_allocs * 2 pages)
// So, allocating even 1 byte should fail.
let final_alloc_result = allocator.allocate(1);
assert!(
final_alloc_result.is_err(),
"Final allocation of 1 byte should fail as arena is full"
);
assert!(matches!(
final_alloc_result,
Err(ArenaError::AllocationFailed)
));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA device memory storage.
use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor};
use cudarc::driver::CudaContext;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA device memory allocated via cudaMalloc.
#[derive(Debug)]
pub struct DeviceStorage {
ctx: Arc<CudaContext>,
ptr: u64,
device_id: u32,
len: usize,
}
unsafe impl Send for DeviceStorage {}
unsafe impl Sync for DeviceStorage {}
impl DeviceStorage {
/// Allocate new device memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device on which to allocate
pub fn new(len: usize, device_id: u32) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(device_id)?;
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
Ok(Self {
ctx,
ptr,
device_id,
len,
})
}
/// Get the device pointer value.
pub fn device_ptr(&self) -> u64 {
self.ptr
}
/// Get the CUDA device ID this memory is allocated on.
pub fn device_id(&self) -> u32 {
self.device_id
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
};
}
}
impl MemoryDescription for DeviceStorage {
fn addr(&self) -> usize {
self.device_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Device(self.device_id)
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
// Support for NIXL registration
impl super::nixl::NixlCompatible for DeviceStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(
self.ptr as *const u8,
self.len,
nixl_sys::MemType::Vram,
self.device_id as u64,
)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Disk-backed memory storage using memory-mapped files.
use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor};
use std::any::Any;
use std::path::{Path, PathBuf};
use core::ffi::c_char;
use nix::fcntl::{FallocateFlags, fallocate};
use nix::unistd::unlink;
use std::ffi::CString;
use std::os::fd::BorrowedFd;
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
#[derive(Debug)]
pub struct DiskStorage {
fd: u64,
path: PathBuf,
size: usize,
unlinked: bool,
}
impl DiskStorage {
pub fn new(size: usize) -> Result<Self> {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let specified_dir =
std::env::var(DISK_CACHE_KEY).unwrap_or_else(|_| DEFAULT_DISK_CACHE_DIR.to_string());
let file_path = Path::new(&specified_dir).join("dynamo-kvbm-disk-cache-XXXXXX");
Self::new_at(file_path, size)
}
pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let file_path = path.as_ref().to_path_buf();
if !file_path.exists() {
let parent = file_path.parent().ok_or_else(|| {
StorageError::AllocationFailed(format!(
"disk cache path {} has no parent directory",
file_path.display()
))
})?;
std::fs::create_dir_all(parent).map_err(|e| {
StorageError::AllocationFailed(format!(
"failed to create disk cache directory {}: {e}",
parent.display()
))
})?;
}
tracing::debug!("Allocating disk cache file at {}", file_path.display());
let path_str = file_path.to_str().ok_or_else(|| {
StorageError::AllocationFailed(format!(
"disk cache path {} is not valid UTF-8",
file_path.display()
))
})?;
let is_template = path_str.contains("XXXXXX");
let (raw_fd, actual_path) = if is_template {
// Template path - use mkostemp to generate unique filename
let template = CString::new(path_str).unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut c_char,
nix::libc::O_RDWR | nix::libc::O_DIRECT,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"mkostemp failed: {}",
std::io::Error::last_os_error()
)));
}
// Extract the actual path created by mkostemp
let actual = PathBuf::from(
CString::from_vec_with_nul(template_bytes)
.unwrap()
.to_str()
.unwrap(),
);
(fd, actual)
} else {
// Specific path - use open with O_CREAT
let path_cstr = CString::new(path_str).unwrap();
let fd = unsafe {
nix::libc::open(
path_cstr.as_ptr(),
nix::libc::O_CREAT | nix::libc::O_RDWR | nix::libc::O_DIRECT,
0o644,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"open failed: {}",
std::io::Error::last_os_error()
)));
}
(fd, file_path)
};
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
unsafe {
fallocate(
BorrowedFd::borrow_raw(raw_fd),
FallocateFlags::empty(),
0,
len as i64,
)
.map_err(|e| {
StorageError::AllocationFailed(format!("Failed to allocate temp file: {}", e))
})?
};
Ok(Self {
fd: raw_fd as u64,
path: actual_path,
size: len,
unlinked: false,
})
}
pub fn fd(&self) -> u64 {
self.fd
}
pub fn path(&self) -> &Path {
self.path.as_path()
}
/// Unlink our temp file.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
pub fn unlink(&mut self) -> Result<()> {
if self.unlinked {
return Ok(());
}
unlink(self.path.as_path())
.map_err(|e| StorageError::AllocationFailed(format!("Failed to unlink file: {}", e)))?;
self.unlinked = true;
Ok(())
}
pub fn unlinked(&self) -> bool {
self.unlinked
}
}
impl Drop for DiskStorage {
fn drop(&mut self) {
let _ = self.unlink();
if let Err(e) = nix::unistd::close(self.fd as std::os::fd::RawFd) {
tracing::debug!("failed to close disk cache fd {}: {e}", self.fd);
}
}
}
impl MemoryDescription for DiskStorage {
fn addr(&self) -> usize {
0
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Disk(self.fd)
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
// Support for NIXL registration
impl super::nixl::NixlCompatible for DiskStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
#[cfg(unix)]
{
// Use file descriptor as device_id for MemType::File
(
std::ptr::null(),
self.size,
nixl_sys::MemType::File,
self.fd,
)
}
#[cfg(not(unix))]
{
// On non-Unix systems, we can't get the file descriptor easily
// Return device_id as 0 - registration will fail on these systems
(
self.mmap.as_ptr(),
self.mmap.len(),
nixl_sys::MemType::File,
0,
)
}
}
}
// mod mmap {
// use super::*;
// #[cfg(unix)]
// use std::os::unix::io::AsRawFd;
// use memmap2::{MmapMut, MmapOptions};
// use std::fs::{File, OpenOptions};
// use tempfile::NamedTempFile;
// /// Disk-backed storage using memory-mapped files.
// #[derive(Debug)]
// pub struct MemMappedFileStorage {
// _file: File, // Keep file alive for the lifetime of the mmap
// mmap: MmapMut,
// path: PathBuf,
// #[cfg(unix)]
// fd: i32,
// }
// unsafe impl Send for MemMappedFileStorage {}
// unsafe impl Sync for MemMappedFileStorage {}
// impl MemMappedFileStorage {
// /// Create new disk storage with a temporary file.
// pub fn new_temp(len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// // Create temporary file
// let temp_file = NamedTempFile::new()?;
// let path = temp_file.path().to_path_buf();
// let file = temp_file.into_file();
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Create new disk storage with a specific file path.
// pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// let path = path.as_ref().to_path_buf();
// // Create or open file
// let file = OpenOptions::new()
// .read(true)
// .write(true)
// .create(true)
// .open(&path)?;
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Get the path to the backing file.
// pub fn path(&self) -> &Path {
// &self.path
// }
// /// Get the file descriptor (Unix only).
// #[cfg(unix)]
// pub fn fd(&self) -> i32 {
// self.fd
// }
// /// Get a pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped.
// pub unsafe fn as_ptr(&self) -> *const u8 {
// self.mmap.as_ptr()
// }
// /// Get a mutable pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped
// /// and that there are no other references to this memory.
// pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
// self.mmap.as_mut_ptr()
// }
// }
// impl MemoryDescription for MemMappedFileStorage {
// fn addr(&self) -> usize {
// self.mmap.as_ptr() as usize
// }
// fn size(&self) -> usize {
// self.mmap.len()
// }
// fn storage_kind(&self) -> StorageKind {
// StorageKind::Disk
// }
// fn as_any(&self) -> &dyn Any {
// self
// }
// }
// // Support for NIXL registration
// impl super::super::registered::NixlCompatible for MemMappedFileStorage {
// fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
// #[cfg(unix)]
// {
// // Use file descriptor as device_id for MemType::File
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// self.fd as u64,
// )
// }
// #[cfg(not(unix))]
// {
// // On non-Unix systems, we can't get the file descriptor easily
// // Return device_id as 0 - registration will fail on these systems
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// 0,
// )
// }
// }
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Clean, minimal storage API for v2 block manager.
//!
//! This module provides a simplified storage abstraction with:
//! - Single trait for type erasure (`MemoryDescription`)
//! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)
pub mod actions;
pub mod arena;
pub mod nixl;
pub mod offset;
pub mod prelude;
mod device;
mod disk;
mod pinned;
mod system;
mod torch;
#[cfg(test)]
mod tests;
pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
pub use device::DeviceStorage;
pub use disk::DiskStorage;
pub use pinned::PinnedStorage;
pub use system::SystemStorage;
pub use torch::{TorchDevice, TorchTensor};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;
/// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>;
/// Errors that can occur during storage operations.
#[derive(Debug, Error)]
pub enum StorageError {
#[error("allocation failed: {0}")]
AllocationFailed(String),
#[error("registration failed: {0}")]
RegistrationFailed(String),
#[error("operation failed: {0}")]
OperationFailed(String),
#[error("unsupported operation: {0}")]
Unsupported(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
// #[cfg(feature = "cuda")]
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")]
Nixl(#[from] nixl_sys::NixlError),
}
/// Storage type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageKind {
/// System memory (malloc)
System,
/// CUDA pinned host memory
// #[cfg(feature = "cuda")]
Pinned,
/// CUDA device memory with device ID
// #[cfg(feature = "cuda")]
Device(u32),
/// Disk-backed memory (mmap)
Disk(u64),
}
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescription>`.
pub trait MemoryDescription: Send + Sync + fmt::Debug {
/// Base address of the memory region.
fn addr(&self) -> usize;
/// Size of the memory region in bytes.
fn size(&self) -> usize;
/// Type of storage backing this region.
fn storage_kind(&self) -> StorageKind;
/// Enable downcasting to concrete type.
fn as_any(&self) -> &dyn Any;
/// Get the NIXL descriptor for this memory region.
fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
}
/// Type-erased memory region for use in layouts.
#[derive(Clone)]
pub struct Buffer(Arc<dyn MemoryDescription>);
impl MemoryDescription for Buffer {
fn addr(&self) -> usize {
self.0.addr()
}
fn size(&self) -> usize {
self.0.size()
}
fn storage_kind(&self) -> StorageKind {
self.0.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self.0.as_any()
}
fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
self.0.nixl_descriptor()
}
}
impl std::ops::Deref for Buffer {
type Target = dyn MemoryDescription;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::fmt::Debug for Buffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Buffer")
.field("addr", &self.addr())
.field("size", &self.size())
.field("kind", &self.storage_kind())
.finish()
}
}
/// Helper function to convert concrete storage to type-erased form.
pub fn create_buffer<S: MemoryDescription + 'static>(memory: S) -> Buffer {
Buffer(Arc::new(memory))
}
/// An unowned contiguous chunk of memory, not storage specific.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryRegion {
/// Start address of the memory region.
pub addr: usize,
/// Size of the memory region in bytes.
pub size: usize,
}
impl MemoryRegion {
pub fn new(addr: usize, size: usize) -> Self {
Self { addr, size }
}
#[inline]
pub fn addr(&self) -> usize {
self.addr
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL registration wrapper for storage types.
mod agent;
mod config;
use super::{MemoryDescription, StorageKind};
use std::any::Any;
use std::fmt;
pub use agent::NixlAgent;
pub use config::NixlBackendConfig;
pub use nixl_sys::{MemType, OptArgs, RegistrationHandle};
/// Trait for storage types that can be registered with NIXL.
pub trait NixlCompatible {
/// Get parameters needed for NIXL registration.
///
/// Returns (ptr, size, mem_type, device_id)
fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
}
/// NIXL descriptor containing registration information.
#[derive(Debug, Clone)]
pub struct NixlDescriptor {
pub addr: u64,
pub size: usize,
pub mem_type: MemType,
pub device_id: u64,
}
impl nixl_sys::MemoryRegion for NixlDescriptor {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl nixl_sys::NixlDescriptor for NixlDescriptor {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
/// View trait for accessing registration information without unwrapping.
pub trait RegisteredView {
/// Get the name of the NIXL agent that registered this memory.
fn agent_name(&self) -> &str;
/// Get the NIXL descriptor for this registered memory.
fn descriptor(&self) -> NixlDescriptor;
}
/// Wrapper for storage that has been registered with NIXL.
///
/// This wrapper ensures proper drop order: the registration handle is
/// dropped before the storage, ensuring deregistration happens before
/// the memory is freed.
pub struct NixlRegistered<S: NixlCompatible> {
storage: S,
handle: Option<RegistrationHandle>,
agent_name: String,
}
impl<S: NixlCompatible> Drop for NixlRegistered<S> {
fn drop(&mut self) {
// Explicitly drop the registration handle first
drop(self.handle.take());
// Storage drops naturally after
}
}
impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NixlRegistered")
.field("storage", &self.storage)
.field("agent_name", &self.agent_name)
.field("handle", &self.handle.is_some())
.finish()
}
}
impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for NixlRegistered<S> {
fn addr(&self) -> usize {
self.storage.addr()
}
fn size(&self) -> usize {
self.storage.size()
}
fn storage_kind(&self) -> StorageKind {
self.storage.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor())
}
}
impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S> {
fn agent_name(&self) -> &str {
&self.agent_name
}
fn descriptor(&self) -> NixlDescriptor {
let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
}
}
}
impl<S: MemoryDescription + NixlCompatible> NixlRegistered<S> {
/// Get a reference to the underlying storage.
pub fn storage(&self) -> &S {
&self.storage
}
/// Get a mutable reference to the underlying storage.
pub fn storage_mut(&mut self) -> &mut S {
&mut self.storage
}
/// Check if the registration handle is still valid.
pub fn is_registered(&self) -> bool {
self.handle.is_some()
}
/// Consume this wrapper and return the underlying storage.
///
/// This will deregister the storage from NIXL.
pub fn into_storage(mut self) -> S {
drop(self.handle.take());
let mut this = std::mem::ManuallyDrop::new(self);
unsafe {
let storage = std::ptr::read(&this.storage);
std::ptr::drop_in_place(&mut this.agent_name);
storage
}
}
}
/// Register storage with a NIXL agent.
///
/// This consumes the storage and returns a `NixlRegistered` wrapper that
/// manages the registration lifetime. The registration handle will be
/// automatically dropped when the wrapper is dropped, ensuring proper
/// cleanup order.
///
/// # Arguments
/// * `storage` - The storage to register (consumed)
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper containing the storage and registration handle.
pub fn register_with_nixl<S>(
storage: S,
agent: &NixlAgent,
opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S>
where
S: MemoryDescription + NixlCompatible,
{
// Get NIXL parameters
let (ptr, size, mem_type, device_id) = storage.nixl_params();
// Create a NIXL descriptor for registration
let descriptor = NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
};
match agent.register_memory(&descriptor, opt) {
Ok(handle) => Ok(NixlRegistered {
storage,
handle: Some(handle),
agent_name: agent.name().to_string(),
}),
Err(_) => Err(storage),
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL agent wrapper and configuration.
//!
//! This module provides:
//! - `NixlAgent`: Wrapper around nixl_sys::Agent that tracks initialized backends
//! - `NixlBackendConfig`: Configuration for NIXL backends from environment variables
use anyhow::Result;
use nixl_sys::Agent;
use std::collections::HashSet;
/// A NIXL agent wrapper that tracks which backends were successfully initialized.
///
/// This wrapper provides:
/// - Runtime validation of backend availability
/// - Clear error messages when operations need unavailable backends
/// - Single source of truth for backend state in tests and production
///
/// # Backend Tracking
///
/// Since `nixl_sys::Agent` doesn't provide a method to query active backends,
/// we track them during initialization. The `available_backends` set is populated
/// based on successful `create_backend()` calls.
#[derive(Clone, Debug)]
pub struct NixlAgent {
agent: Agent,
available_backends: HashSet<String>,
}
impl NixlAgent {
/// Create a NIXL agent without any backends.
pub fn new(name: &str) -> Result<Self> {
let agent = Agent::new(name)?;
Ok(Self {
agent,
available_backends: HashSet::new(),
})
}
/// Add a backend to the agent.
pub fn add_backend(&mut self, backend: &str) -> Result<()> {
if self.available_backends.contains(&backend.to_uppercase()) {
return Ok(());
}
let backend_upper = backend.to_uppercase();
match self.agent.get_plugin_params(&backend_upper) {
Ok((_, params)) => match self.agent.create_backend(&backend_upper, &params) {
Ok(_) => {
self.available_backends.insert(backend_upper);
}
Err(e) => {
anyhow::bail!("Failed to create nixl backend: {}", e);
}
},
Err(_) => {
anyhow::bail!("No {} plugin found", backend_upper);
}
}
Ok(())
}
/// Create a NIXL agent requiring ALL specified backends to be available.
///
/// Unlike `new_with_backends()` which continues if some backends fail, this method
/// will return an error if ANY backend fails to initialize. Use this in production
/// when specific backends are mandatory.
///
/// # Arguments
/// * `name` - Agent name
/// * `backends` - List of backend names that MUST be available
///
/// # Returns
/// A `NixlAgent` with all requested backends initialized.
///
/// # Errors
/// Returns an error if:
/// - Agent creation fails
/// - Any backend fails to initialize
pub fn with_backends(name: &str, backends: &[&str]) -> Result<Self> {
let mut agent = Self::new(name)?;
let mut failed_backends = Vec::new();
for backend in backends {
let backend_upper = backend.to_uppercase();
match agent.add_backend(&backend_upper) {
Ok(_) => {
tracing::debug!("Initialized NIXL backend: {}", backend_upper);
}
Err(e) => {
tracing::error!("Failed to initialize {} backend: {}", backend_upper, e);
failed_backends.push((backend_upper, e.to_string()));
}
}
}
if !failed_backends.is_empty() {
let error_details: Vec<String> = failed_backends
.iter()
.map(|(name, reason)| format!("{}: {}", name, reason))
.collect();
anyhow::bail!(
"Failed to initialize required backends: [{}]",
error_details.join(", ")
);
}
Ok(agent)
}
/// Get a reference to the underlying raw NIXL agent.
pub fn raw_agent(&self) -> &Agent {
&self.agent
}
/// Consume and return the underlying raw NIXL agent.
///
/// **Warning**: Once consumed, backend tracking is lost. Use this only when
/// interfacing with code that requires `nixl_sys::Agent` directly.
pub fn into_raw_agent(self) -> Agent {
self.agent
}
/// Check if a specific backend is available.
pub fn has_backend(&self, backend: &str) -> bool {
self.available_backends.contains(&backend.to_uppercase())
}
/// Get all available backends.
pub fn backends(&self) -> &HashSet<String> {
&self.available_backends
}
/// Require a specific backend, returning an error if unavailable.
///
/// Use this at the start of operations that need specific backends.
///
/// Note: In general, you want to instantiate all your backends before you start registering memory.
/// We may change this to a builder pattern in the future to enforce all backends are instantiated
/// before you start registering memory.
pub fn require_backend(&self, backend: &str) -> Result<()> {
let backend_upper = backend.to_uppercase();
if self.has_backend(&backend_upper) {
Ok(())
} else {
anyhow::bail!(
"Operation requires {} backend, but it was not initialized. Available backends: {:?}",
backend_upper,
self.available_backends
)
}
}
}
// Delegate common methods to the underlying agent
impl std::ops::Deref for NixlAgent {
type Target = Agent;
fn deref(&self) -> &Self::Target {
&self.agent
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_backend_tracking() {
// Try to create agent with UCX
let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
// Should succeed if UCX is available
assert!(agent.has_backend("UCX"));
assert!(agent.has_backend("ucx")); // Case insensitive
}
#[test]
fn test_require_backend() {
let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
// Should succeed for available backend
assert!(agent.require_backend("UCX").is_ok());
// Should fail for unavailable backend
assert!(agent.require_backend("GDS_MT").is_err());
}
#[test]
fn test_require_backends_strict() {
// Should succeed if UCX is available
let agent =
NixlAgent::with_backends("test_strict", &["UCX"]).expect("Failed to require backends");
assert!(agent.has_backend("UCX"));
// Should fail if any backend is missing (GDS likely not available)
let result = NixlAgent::with_backends("test_strict_fail", &["UCX", "DUDE"]);
assert!(result.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL backend configuration with Figment support.
//!
//! This module provides configuration extraction for NIXL backends from
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>_<key>=<value>`
use anyhow::{Result, bail};
use std::collections::HashSet;
use dynamo_config::parse_bool;
/// Configuration for NIXL backends.
///
/// Supports extracting backend configurations from environment variables:
/// - `DYN_KVBM_NIXL_BACKEND_UCX=true` - Enable UCX backend with default params
/// - `DYN_KVBM_NIXL_BACKEND_GDS=false` - Explicitly disable GDS backend
/// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive)
/// - Invalid values (e.g., "maybe", "random") will cause an error
/// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error
#[derive(Debug, Clone, Default)]
pub struct NixlBackendConfig {
/// Set of enabled backends (just backend names, no custom params yet)
backends: HashSet<String>,
}
impl NixlBackendConfig {
/// Create a new empty configuration.
pub fn new() -> Self {
Self::default()
}
/// Create configuration from environment variables.
///
/// Extracts backends from `DYN_KVBM_NIXL_BACKEND_<backend>=<value>` variables.
///
/// # Errors
/// Returns an error if:
/// - Custom parameters are detected (not yet supported)
/// - Invalid boolean values are provided (must be truthy or falsey)
pub fn from_env() -> Result<Self> {
let mut backends = HashSet::new();
// Extract all environment variables that match our pattern
for (key, value) in std::env::vars() {
if let Some(remainder) = key.strip_prefix("DYN_KVBM_NIXL_BACKEND_") {
// Check if there's an underscore (indicating custom params)
if remainder.contains('_') {
bail!(
"Custom NIXL backend parameters are not yet supported. \
Found: {}. Please use only DYN_KVBM_NIXL_BACKEND_<backend>=true \
to enable backends with default parameters.",
key
);
}
// Simple backend enablement (e.g., DYN_KVBM_NIXL_BACKEND_UCX=true)
let backend_name = remainder.to_uppercase();
match parse_bool(&value) {
Ok(true) => {
backends.insert(backend_name);
}
Ok(false) => {
// Explicitly disabled, don't add to backends
continue;
}
Err(e) => bail!("Invalid value for {}: {}", key, e),
}
}
}
// Default to UCX if no backends specified
if backends.is_empty() {
backends.insert("UCX".to_string());
}
Ok(Self { backends })
}
/// Add a backend to the configuration.
///
/// Backend names will be converted to uppercase for consistency.
pub fn with_backend(mut self, backend: impl Into<String>) -> Self {
self.backends.insert(backend.into().to_uppercase());
self
}
/// Get the set of enabled backends.
pub fn backends(&self) -> &HashSet<String> {
&self.backends
}
/// Check if a specific backend is enabled.
pub fn has_backend(&self, backend: &str) -> bool {
self.backends.contains(&backend.to_uppercase())
}
/// Merge another configuration into this one.
///
/// Backends from the other configuration will be added to this one.
pub fn merge(mut self, other: NixlBackendConfig) -> Self {
self.backends.extend(other.backends);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_config_is_empty() {
let config = NixlBackendConfig::new();
assert!(config.backends().is_empty());
}
#[test]
fn test_with_backend() {
let config = NixlBackendConfig::new()
.with_backend("ucx")
.with_backend("gds_mt");
assert!(config.has_backend("ucx"));
assert!(config.has_backend("UCX"));
assert!(config.has_backend("gds_mt"));
assert!(config.has_backend("GDS_MT"));
assert!(!config.has_backend("other"));
}
#[test]
fn test_merge_configs() {
let config1 = NixlBackendConfig::new().with_backend("ucx");
let config2 = NixlBackendConfig::new().with_backend("gds");
let merged = config1.merge(config2);
assert!(merged.has_backend("ucx"));
assert!(merged.has_backend("gds"));
}
#[test]
fn test_backend_name_case_insensitive() {
let config = NixlBackendConfig::new()
.with_backend("ucx")
.with_backend("Gds_mt")
.with_backend("OTHER");
assert!(config.has_backend("UCX"));
assert!(config.has_backend("ucx"));
assert!(config.has_backend("GDS_MT"));
assert!(config.has_backend("gds_mt"));
assert!(config.has_backend("OTHER"));
assert!(config.has_backend("other"));
}
// Note: Testing from_env() would require setting environment variables,
// which is challenging in unit tests. This is better tested with integration tests.
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::{
Any, Buffer, MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor,
};
/// An [`OffsetBuffer`] is a new [`Buffer`]-like object that represents a sub-region (still contiguous)
/// within an existing [`Buffer`].
#[derive(Clone)]
pub struct OffsetBuffer {
base: Buffer,
offset: usize,
size: usize,
}
impl std::fmt::Debug for OffsetBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OffsetBuffer")
.field("base", &self.base)
.field("offset", &self.offset)
.field("size", &self.size)
.finish()
}
}
impl OffsetBuffer {
/// Create a new offset view into an existing memory region.
///
/// Returns an error if the offset and length exceed the bounds of the base region.
pub fn new(base: Buffer, offset: usize, size: usize) -> Result<Self> {
let end = offset
.checked_add(size)
.ok_or_else(|| StorageError::Unsupported("offset overflow".into()))?;
if end > base.size() {
return Err(StorageError::Unsupported(
"offset region exceeds base allocation bounds".into(),
));
}
Ok(Self { base, offset, size })
}
/// Get the offset relative to the base mapping.
pub fn offset(&self) -> usize {
self.offset
}
/// Access the underlying base region.
pub fn base(&self) -> &Buffer {
&self.base
}
}
impl MemoryDescription for OffsetBuffer {
fn addr(&self) -> usize {
self.base.addr() + self.offset
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.base.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
let mut descriptor = self.base.nixl_descriptor()?;
descriptor.addr = self.addr() as u64;
descriptor.size = self.size();
Some(descriptor)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA pinned host memory storage.
use super::{MemoryDescription, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
use cudarc::driver::CudaContext;
use cudarc::driver::sys;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA pinned host memory allocated via cudaHostAlloc.
#[derive(Debug)]
pub struct PinnedStorage {
ptr: usize,
len: usize,
ctx: Arc<CudaContext>,
}
unsafe impl Send for PinnedStorage {}
unsafe impl Sync for PinnedStorage {}
impl PinnedStorage {
/// Allocate new pinned memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device to associate with the allocation
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(0)?;
let ptr = unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(len < isize::MAX as usize);
ptr as usize
};
Ok(Self { ptr, len, ctx })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
tracing::debug!("failed to free pinned memory: {e}");
}
};
}
}
impl MemoryDescription for PinnedStorage {
fn addr(&self) -> usize {
unsafe { self.as_ptr() as usize }
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Pinned
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
// Support for NIXL registration
impl super::nixl::NixlCompatible for PinnedStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
let ptr = unsafe { self.as_ptr() };
(ptr, self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
let end = offset
.checked_add(size)
.ok_or_else(|| StorageError::OperationFailed("memset: offset overflow".into()))?;
if end > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = (self.ptr as *mut u8).add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use super::MemoryDescription;
pub use super::nixl::{NixlCompatible, RegisteredView};
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! System memory storage backed by malloc.
use super::{MemoryDescription, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
use std::any::Any;
use std::ptr::NonNull;
/// System memory allocated via malloc.
#[derive(Debug)]
pub struct SystemStorage {
ptr: NonNull<u8>,
len: usize,
}
unsafe impl Send for SystemStorage {}
unsafe impl Sync for SystemStorage {}
impl SystemStorage {
/// Allocate new system memory of the given size.
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let mut ptr: *mut libc::c_void = std::ptr::null_mut();
// We need 4KB alignment here for NIXL disk transfers to work.
// The O_DIRECT flag is required for GDS.
// However, a limitation of this flag is that all operations involving disk
// (both read and write) must be page-aligned.
// Pinned memory is already page-aligned, so we only need to align system memory.
// TODO(jthomson04): Is page size always 4KB?
// SAFETY: malloc returns suitably aligned memory or null on failure.
let result = unsafe { libc::posix_memalign(&mut ptr, 4096, len) };
if result != 0 {
return Err(StorageError::AllocationFailed(format!(
"posix_memalign failed for size {}",
len
)));
}
let ptr = NonNull::new(ptr as *mut u8).ok_or_else(|| {
StorageError::AllocationFailed(format!("malloc failed for size {}", len))
})?;
// Zero-initialize the memory
unsafe {
std::ptr::write_bytes(ptr.as_ptr(), 0, len);
}
Ok(Self { ptr, len })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl Drop for SystemStorage {
fn drop(&mut self) {
// SAFETY: pointer was allocated by malloc.
unsafe {
libc::free(self.ptr.as_ptr() as *mut libc::c_void);
}
}
}
impl MemoryDescription for SystemStorage {
fn addr(&self) -> usize {
self.ptr.as_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
// Support for NIXL registration
impl super::nixl::NixlCompatible for SystemStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(self.ptr.as_ptr(), self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for SystemStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
let end = offset
.checked_add(size)
.ok_or_else(|| StorageError::OperationFailed("memset: offset overflow".into()))?;
if end > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.ptr.as_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
impl actions::Slice for SystemStorage {
unsafe fn as_slice(&self) -> Result<&[u8]> {
// SAFETY: SystemStorage owns the memory allocated via the global allocator.
// The memory remains valid as long as this SystemStorage instance exists.
// The ptr is guaranteed to be valid for `self.len` bytes.
// Caller must ensure no concurrent mutable access per trait contract.
// SAFETY: The pointer is valid, properly aligned, and points to `self.len` bytes.
Ok(unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for the storage-next module.
use super::*;
/// Helper function to validate NIXL descriptor consistency.
///
/// For any MemoryDescription that returns Some from nixl_descriptor(),
/// this validates that the descriptor's addr and size match the memory region's addr and size.
///
/// # Panics
/// Panics if descriptor values don't match memory region values.
#[allow(dead_code)]
fn validate_nixl_descriptor<M: MemoryDescription>(memory: &M) {
if let Some(desc) = memory.nixl_descriptor() {
assert_eq!(
desc.addr as usize,
memory.addr(),
"NIXL descriptor addr ({}) does not match memory region addr ({})",
desc.addr,
memory.addr()
);
assert_eq!(
desc.size,
memory.size(),
"NIXL descriptor size ({}) does not match memory region size ({})",
desc.size,
memory.size()
);
}
}
#[test]
fn test_system_storage() {
let storage = SystemStorage::new(1024).unwrap();
assert_eq!(storage.size(), 1024);
assert_eq!(storage.storage_kind(), StorageKind::System);
assert!(storage.addr() != 0);
// Test that we can create multiple allocations
let storage2 = SystemStorage::new(2048).unwrap();
assert_eq!(storage2.size(), 2048);
assert_ne!(storage.addr(), storage2.addr());
}
#[test]
fn test_system_storage_zero_size() {
let result = SystemStorage::new(0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_disk_storage_temp() {
let storage = DiskStorage::new(4096).unwrap();
assert_eq!(storage.size(), 4096);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
// Disk storage is file-backed, so addr() returns 0 (no memory address)
assert_eq!(storage.addr(), 0);
assert!(storage.path().exists());
}
#[test]
fn test_disk_storage_at_path() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.bin");
let storage = DiskStorage::new_at(&path, 8192).unwrap();
assert_eq!(storage.size(), 8192);
assert!(matches!(storage.storage_kind(), StorageKind::Disk(_)));
assert!(path.exists());
}
#[test]
fn test_type_erasure() {
let storage = SystemStorage::new(1024).unwrap();
let buffer = create_buffer(storage);
assert_eq!(buffer.size(), 1024);
assert_eq!(buffer.storage_kind(), StorageKind::System);
}
#[test]
fn test_memory_descriptor() {
let desc = MemoryRegion::new(0x1000, 4096);
assert_eq!(desc.addr, 0x1000);
assert_eq!(desc.size, 4096);
}
#[test]
fn test_system_storage_unregistered_no_nixl_descriptor() {
let storage = SystemStorage::new(1024).unwrap();
assert!(storage.nixl_descriptor().is_none());
}
#[test]
fn test_disk_storage_unregistered_no_nixl_descriptor() {
let storage = DiskStorage::new(4096).unwrap();
assert!(storage.nixl_descriptor().is_none());
}
#[cfg(feature = "testing-cuda")]
mod cuda_tests {
use super::*;
#[test]
fn test_pinned_storage() {
let storage = PinnedStorage::new(2048).unwrap();
assert_eq!(storage.size(), 2048);
assert_eq!(storage.storage_kind(), StorageKind::Pinned);
assert!(storage.addr() != 0);
}
#[test]
fn test_pinned_storage_zero_size() {
let storage = PinnedStorage::new(0);
assert!(storage.is_err());
assert!(matches!(
storage.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_device_storage() {
let storage = DeviceStorage::new(4096, 0).unwrap();
assert_eq!(storage.size(), 4096);
assert_eq!(storage.storage_kind(), StorageKind::Device(0));
assert!(storage.addr() != 0);
assert_eq!(storage.device_id(), 0);
}
#[test]
fn test_device_storage_zero_size() {
let result = DeviceStorage::new(0, 0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StorageError::AllocationFailed(_)
));
}
#[test]
fn test_pinned_storage_unregistered_no_nixl_descriptor() {
let storage = PinnedStorage::new(1024).unwrap();
assert!(storage.nixl_descriptor().is_none());
}
#[test]
fn test_device_storage_unregistered_no_nixl_descriptor() {
let storage = DeviceStorage::new(4096, 0).unwrap();
assert!(storage.nixl_descriptor().is_none());
}
}
#[cfg(feature = "testing-nixl")]
mod nixl_tests {
use super::super::nixl::{NixlAgent, RegisteredView, register_with_nixl};
use super::*;
// System Storage Tests
#[test]
fn test_system_storage_registration() {
let storage = SystemStorage::new(2048).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
assert_eq!(registered.size(), 2048);
assert_eq!(registered.storage_kind(), StorageKind::System);
assert!(registered.is_registered());
}
#[test]
fn test_system_storage_descriptor_consistency() {
let storage = SystemStorage::new(1024).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
// Validate descriptor consistency
validate_nixl_descriptor(&registered);
// Get descriptor and validate fields
let desc = registered.descriptor();
assert_eq!(desc.addr as usize, registered.addr());
assert_eq!(desc.size, registered.size());
assert_eq!(desc.mem_type, nixl_sys::MemType::Dram);
assert_eq!(desc.device_id, 0);
}
// Note: into_storage() test removed due to implementation issue
// The current implementation uses mem::zeroed() which is invalid for types with NonNull
// TODO: Fix NixlRegistered::into_storage() implementation
// Disk Storage Tests
#[test]
fn test_disk_storage_registration() {
let storage = DiskStorage::new(4096).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["POSIX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
assert_eq!(registered.size(), 4096);
assert!(matches!(registered.storage_kind(), StorageKind::Disk(_)));
assert!(registered.is_registered());
}
#[test]
fn test_disk_storage_descriptor_consistency() {
let storage = DiskStorage::new(8192).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["POSIX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
// Validate descriptor consistency
validate_nixl_descriptor(&registered);
// Get descriptor and validate fields
let desc = registered.descriptor();
assert_eq!(desc.size, registered.size());
assert_eq!(desc.mem_type, nixl_sys::MemType::File);
}
// CUDA tests (when both testing-nixl and testing-cuda are enabled)
#[cfg(feature = "testing-all")]
mod cuda_nixl_tests {
use super::*;
#[test]
fn test_pinned_storage_registration() {
let storage = PinnedStorage::new(2048).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
assert_eq!(registered.size(), 2048);
assert_eq!(registered.storage_kind(), StorageKind::Pinned);
assert!(registered.is_registered());
}
#[test]
fn test_pinned_storage_descriptor_consistency() {
let storage = PinnedStorage::new(1024).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
// Validate descriptor consistency
validate_nixl_descriptor(&registered);
// Get descriptor and validate fields
let desc = registered.descriptor();
assert_eq!(desc.addr as usize, registered.addr());
assert_eq!(desc.size, registered.size());
assert_eq!(desc.mem_type, nixl_sys::MemType::Dram);
}
#[test]
fn test_device_storage_registration() {
let storage = DeviceStorage::new(4096, 0).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
assert_eq!(registered.agent_name(), "test_agent");
assert_eq!(registered.size(), 4096);
assert_eq!(registered.storage_kind(), StorageKind::Device(0));
assert!(registered.is_registered());
}
#[test]
fn test_device_storage_descriptor_consistency() {
let storage = DeviceStorage::new(2048, 0).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
// Validate descriptor consistency
validate_nixl_descriptor(&registered);
// Get descriptor and validate fields
let desc = registered.descriptor();
assert_eq!(desc.addr as usize, registered.addr());
assert_eq!(desc.size, registered.size());
assert_eq!(desc.mem_type, nixl_sys::MemType::Vram);
assert_eq!(desc.device_id, 0);
}
}
// Type Erasure Tests
#[test]
fn test_type_erasure_preserves_nixl_descriptor() {
let storage = SystemStorage::new(1024).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let buffer = create_buffer(registered);
// Validate descriptor through type erasure
validate_nixl_descriptor(&buffer);
// Verify descriptor is Some and has correct values
let desc = buffer.nixl_descriptor().unwrap();
assert_eq!(desc.addr as usize, buffer.addr());
assert_eq!(desc.size, buffer.size());
}
#[cfg(feature = "testing-cuda")]
#[test]
fn test_type_erasure_pinned_storage() {
let storage = PinnedStorage::new(2048).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let buffer = create_buffer(registered);
validate_nixl_descriptor(&buffer);
assert_eq!(buffer.storage_kind(), StorageKind::Pinned);
}
#[cfg(feature = "testing-cuda")]
#[test]
fn test_type_erasure_device_storage() {
let storage = DeviceStorage::new(4096, 0).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let buffer = create_buffer(registered);
validate_nixl_descriptor(&buffer);
assert_eq!(buffer.storage_kind(), StorageKind::Device(0));
}
}
// Arena allocator tests with NIXL registration
#[cfg(feature = "testing-nixl")]
mod arena_nixl_tests {
use super::super::arena::ArenaAllocator;
use super::super::nixl::{NixlAgent, register_with_nixl};
use super::*;
const PAGE_SIZE: usize = 4096;
const PAGE_COUNT: usize = 10;
const TOTAL_SIZE: usize = PAGE_SIZE * PAGE_COUNT;
#[test]
fn test_arena_with_registered_storage_single_allocation() {
let storage = SystemStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let base_addr = registered.addr();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
let buffer = allocator.allocate(PAGE_SIZE * 2).unwrap();
// Validate buffer properties
assert_eq!(buffer.size(), PAGE_SIZE * 2);
assert_eq!(buffer.addr(), base_addr); // First allocation starts at base
assert_eq!(buffer.agent_name(), "test_agent");
// Validate descriptor
let desc = buffer.registered_descriptor();
assert_eq!(desc.addr as usize, buffer.addr());
assert_eq!(desc.size, buffer.size());
}
#[test]
fn test_arena_with_registered_storage_multiple_allocations() {
let storage = SystemStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let base_addr = registered.addr();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
// Allocate three buffers
let buffer1 = allocator.allocate(PAGE_SIZE).unwrap();
let buffer2 = allocator.allocate(PAGE_SIZE * 2).unwrap();
let buffer3 = allocator.allocate(PAGE_SIZE).unwrap();
// Validate first buffer (starts at base, uses 1 page)
assert_eq!(buffer1.size(), PAGE_SIZE);
assert_eq!(buffer1.addr(), base_addr);
// Validate second buffer (starts after buffer1, uses 2 pages)
assert_eq!(buffer2.size(), PAGE_SIZE * 2);
assert_eq!(buffer2.addr(), base_addr + PAGE_SIZE);
// Validate third buffer (starts after buffer2, uses 1 page)
assert_eq!(buffer3.size(), PAGE_SIZE);
assert_eq!(buffer3.addr(), base_addr + PAGE_SIZE * 3);
// Validate descriptors for all buffers
let desc1 = buffer1.registered_descriptor();
assert_eq!(desc1.addr as usize, buffer1.addr());
assert_eq!(desc1.size, PAGE_SIZE);
let desc2 = buffer2.registered_descriptor();
assert_eq!(desc2.addr as usize, buffer2.addr());
assert_eq!(desc2.size, PAGE_SIZE * 2);
let desc3 = buffer3.registered_descriptor();
assert_eq!(desc3.addr as usize, buffer3.addr());
assert_eq!(desc3.size, PAGE_SIZE);
}
#[test]
fn test_arena_buffer_agent_name_preservation() {
let storage = SystemStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("my_special_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
let buffer = allocator.allocate(PAGE_SIZE).unwrap();
assert_eq!(buffer.agent_name(), "my_special_agent");
}
#[test]
fn test_arena_multiple_buffers_stress_test() {
let storage = SystemStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let base_addr = registered.addr();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
// Allocate 10 single-page buffers
let mut buffers = Vec::new();
for i in 0..10 {
let buffer = allocator.allocate(PAGE_SIZE).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE);
assert_eq!(buffer.addr(), base_addr + i * PAGE_SIZE);
// Validate descriptor
let desc = buffer.registered_descriptor();
assert_eq!(desc.addr as usize, buffer.addr());
assert_eq!(desc.size, PAGE_SIZE);
buffers.push(buffer);
}
}
#[test]
fn test_arena_reallocation_after_drop() {
let storage = SystemStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let base_addr = registered.addr();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
// Allocate and drop
{
let buffer = allocator.allocate(PAGE_SIZE * 5).unwrap();
assert_eq!(buffer.addr(), base_addr);
let desc = buffer.registered_descriptor();
assert_eq!(desc.addr as usize, base_addr);
assert_eq!(desc.size, PAGE_SIZE * 5);
} // buffer dropped here
// Reallocate same size - should reuse the space
let buffer2 = allocator.allocate(PAGE_SIZE * 5).unwrap();
assert_eq!(buffer2.addr(), base_addr);
// Validate new descriptor
let desc2 = buffer2.registered_descriptor();
assert_eq!(desc2.addr as usize, base_addr);
assert_eq!(desc2.size, PAGE_SIZE * 5);
}
#[cfg(feature = "testing-cuda")]
mod cuda_arena_tests {
use super::*;
#[test]
fn test_arena_with_pinned_storage() {
let storage = PinnedStorage::new(TOTAL_SIZE).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
let buffer = allocator.allocate(PAGE_SIZE * 2).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE * 2);
assert_eq!(buffer.agent_name(), "test_agent");
let desc = buffer.registered_descriptor();
assert_eq!(desc.addr as usize, buffer.addr());
assert_eq!(desc.size, PAGE_SIZE * 2);
assert_eq!(desc.mem_type, nixl_sys::MemType::Dram);
}
#[test]
fn test_arena_with_device_storage() {
let storage = DeviceStorage::new(TOTAL_SIZE, 0).unwrap();
let agent = NixlAgent::with_backends("test_agent", &["UCX"]).unwrap();
let registered = register_with_nixl(storage, &agent, None).unwrap();
let allocator = ArenaAllocator::new(registered, PAGE_SIZE).unwrap();
let buffer = allocator.allocate(PAGE_SIZE * 3).unwrap();
assert_eq!(buffer.size(), PAGE_SIZE * 3);
assert_eq!(buffer.agent_name(), "test_agent");
let desc = buffer.registered_descriptor();
assert_eq!(desc.addr as usize, buffer.addr());
assert_eq!(desc.size, PAGE_SIZE * 3);
assert_eq!(desc.mem_type, nixl_sys::MemType::Vram);
assert_eq!(desc.device_id, 0);
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TorchDevice {
Cuda(usize),
Other(String),
}
impl TorchDevice {
pub fn is_cuda(&self) -> bool {
matches!(self, TorchDevice::Cuda(_))
}
pub fn cuda_device_index(&self) -> Option<usize> {
match self {
TorchDevice::Cuda(index) => Some(*index),
TorchDevice::Other(_) => None,
}
}
}
pub trait TorchTensor: std::fmt::Debug + Send + Sync {
fn device(&self) -> TorchDevice;
fn data_ptr(&self) -> u64;
fn size_bytes(&self) -> usize;
fn shape(&self) -> Vec<usize>;
fn stride(&self) -> Vec<usize>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment