Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
976bb70a
Unverified
Commit
976bb70a
authored
Feb 17, 2026
by
Ryan Olson
Committed by
GitHub
Feb 18, 2026
Browse files
feat: add KVBM memory management enhancements (DIS-1311) (#5532)
parent
57bdfea9
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2494 additions
and
123 deletions
+2494
-123
Cargo.lock
Cargo.lock
+1
-0
lib/memory/Cargo.toml
lib/memory/Cargo.toml
+1
-0
lib/memory/src/actions.rs
lib/memory/src/actions.rs
+236
-4
lib/memory/src/arena.rs
lib/memory/src/arena.rs
+26
-15
lib/memory/src/device.rs
lib/memory/src/device.rs
+6
-2
lib/memory/src/disk.rs
lib/memory/src/disk.rs
+14
-4
lib/memory/src/external.rs
lib/memory/src/external.rs
+164
-0
lib/memory/src/lib.rs
lib/memory/src/lib.rs
+143
-22
lib/memory/src/nixl.rs
lib/memory/src/nixl.rs
+126
-6
lib/memory/src/nixl/agent.rs
lib/memory/src/nixl/agent.rs
+111
-16
lib/memory/src/nixl/config.rs
lib/memory/src/nixl/config.rs
+113
-29
lib/memory/src/numa/mod.rs
lib/memory/src/numa/mod.rs
+271
-0
lib/memory/src/numa/topology.rs
lib/memory/src/numa/topology.rs
+281
-0
lib/memory/src/numa/worker_pool.rs
lib/memory/src/numa/worker_pool.rs
+590
-0
lib/memory/src/offset.rs
lib/memory/src/offset.rs
+208
-2
lib/memory/src/pinned.rs
lib/memory/src/pinned.rs
+63
-16
lib/memory/src/pool/cuda.rs
lib/memory/src/pool/cuda.rs
+7
-1
lib/memory/src/pool/mod.rs
lib/memory/src/pool/mod.rs
+1
-1
lib/memory/src/prelude.rs
lib/memory/src/prelude.rs
+3
-3
lib/memory/src/system.rs
lib/memory/src/system.rs
+129
-2
No files found.
Cargo.lock
View file @
976bb70a
...
...
@@ -2117,6 +2117,7 @@ dependencies = [
"nixl-sys",
"offset-allocator",
"serde",
"serde_json",
"tempfile",
"thiserror 2.0.18",
"tracing",
...
...
lib/memory/Cargo.toml
View file @
976bb70a
...
...
@@ -37,4 +37,5 @@ nix = { version = "0.30", features = ["fs"] }
offset-allocator
=
"0.2"
[dev-dependencies]
serde_json
=
{
workspace
=
true
}
tempfile
=
"3"
lib/memory/src/actions.rs
View file @
976bb70a
...
...
@@ -3,10 +3,10 @@
//! Storage actions.
use
super
::{
MemoryDescript
ion
,
StorageError
};
use
super
::{
MemoryDescript
or
,
StorageError
};
/// Extension trait for storage types that support memory setting operations
pub
trait
Memset
:
MemoryDescript
ion
{
pub
trait
Memset
:
MemoryDescript
or
{
/// Sets a region of memory to a specific value
///
/// # Arguments
...
...
@@ -22,7 +22,7 @@ pub trait Memset: MemoryDescription {
}
/// Extension trait for storage types that support slicing operations
pub
trait
Slice
:
MemoryDescript
ion
+
'static
{
pub
trait
Slice
:
MemoryDescript
or
+
'static
{
/// Returns an immutable byte slice view of the entire storage region
///
/// # Safety
...
...
@@ -133,7 +133,8 @@ pub trait Slice: MemoryDescription + 'static {
}
}
pub
trait
SliceMut
:
MemoryDescription
+
'static
{
/// Extension trait for storage types that support mutable slicing operations.
pub
trait
SliceMut
:
MemoryDescriptor
+
'static
{
/// Returns a mutable byte slice view of the entire storage region
///
/// # Safety
...
...
@@ -239,3 +240,234 @@ pub trait SliceMut: MemoryDescription + 'static {
Ok
(
unsafe
{
std
::
slice
::
from_raw_parts_mut
(
ptr
,
len
)
})
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
SystemStorage
;
// Helper to create a test storage
fn
create_storage
(
size
:
usize
)
->
SystemStorage
{
SystemStorage
::
new
(
size
)
.expect
(
"allocation failed"
)
}
// ========== Memset tests ==========
#[test]
fn
test_memset_full_region
()
{
let
mut
storage
=
create_storage
(
1024
);
storage
.memset
(
0xAB
,
0
,
1024
)
.expect
(
"memset should succeed"
);
let
slice
=
unsafe
{
storage
.as_slice
()
.expect
(
"as_slice should succeed"
)
};
assert
!
(
slice
.iter
()
.all
(|
&
b
|
b
==
0xAB
));
}
#[test]
fn
test_memset_partial_region
()
{
let
mut
storage
=
create_storage
(
1024
);
// First fill with 0x00
storage
.memset
(
0x00
,
0
,
1024
)
.expect
(
"memset should succeed"
);
// Then fill middle region with 0xFF
storage
.memset
(
0xFF
,
100
,
200
)
.expect
(
"memset should succeed"
);
let
slice
=
unsafe
{
storage
.as_slice
()
.expect
(
"as_slice should succeed"
)
};
// Check before region
assert
!
(
slice
[
..
100
]
.iter
()
.all
(|
&
b
|
b
==
0x00
));
// Check filled region
assert
!
(
slice
[
100
..
300
]
.iter
()
.all
(|
&
b
|
b
==
0xFF
));
// Check after region
assert
!
(
slice
[
300
..
]
.iter
()
.all
(|
&
b
|
b
==
0x00
));
}
#[test]
fn
test_memset_at_end
()
{
let
mut
storage
=
create_storage
(
1024
);
// Fill the last 100 bytes
storage
.memset
(
0x42
,
924
,
100
)
.expect
(
"memset should succeed"
);
let
slice
=
unsafe
{
storage
.as_slice
()
.expect
(
"as_slice should succeed"
)
};
assert
!
(
slice
[
924
..
]
.iter
()
.all
(|
&
b
|
b
==
0x42
));
}
#[test]
fn
test_memset_zero_size
()
{
let
mut
storage
=
create_storage
(
1024
);
// Zero-size memset should succeed (no-op)
storage
.memset
(
0xFF
,
500
,
0
)
.expect
(
"zero-size memset should succeed"
);
}
#[test]
fn
test_memset_out_of_bounds
()
{
let
mut
storage
=
create_storage
(
1024
);
// Try to write beyond the storage
let
result
=
storage
.memset
(
0xFF
,
900
,
200
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_memset_offset_overflow
()
{
let
mut
storage
=
create_storage
(
1024
);
// offset + size would overflow
let
result
=
storage
.memset
(
0xFF
,
usize
::
MAX
,
1
);
assert
!
(
result
.is_err
());
}
// ========== Slice tests ==========
#[test]
fn
test_as_slice_full
()
{
let
mut
storage
=
create_storage
(
1024
);
storage
.memset
(
0xCD
,
0
,
1024
)
.expect
(
"memset should succeed"
);
let
slice
=
unsafe
{
storage
.as_slice
()
.expect
(
"as_slice should succeed"
)
};
assert_eq!
(
slice
.len
(),
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
b
|
b
==
0xCD
));
}
#[test]
fn
test_slice_partial
()
{
let
mut
storage
=
create_storage
(
1024
);
storage
.memset
(
0x00
,
0
,
1024
)
.expect
(
"memset should succeed"
);
storage
.memset
(
0xAA
,
100
,
50
)
.expect
(
"memset should succeed"
);
let
partial
=
storage
.slice
(
100
,
50
)
.expect
(
"slice should succeed"
);
assert_eq!
(
partial
.len
(),
50
);
assert
!
(
partial
.iter
()
.all
(|
&
b
|
b
==
0xAA
));
}
#[test]
fn
test_slice_at_start
()
{
let
storage
=
create_storage
(
1024
);
let
slice
=
storage
.slice
(
0
,
100
)
.expect
(
"slice should succeed"
);
assert_eq!
(
slice
.len
(),
100
);
}
#[test]
fn
test_slice_at_end
()
{
let
storage
=
create_storage
(
1024
);
let
slice
=
storage
.slice
(
924
,
100
)
.expect
(
"slice should succeed"
);
assert_eq!
(
slice
.len
(),
100
);
}
#[test]
fn
test_slice_zero_length
()
{
let
storage
=
create_storage
(
1024
);
let
slice
=
storage
.slice
(
500
,
0
)
.expect
(
"zero-length slice should succeed"
);
assert
!
(
slice
.is_empty
());
}
#[test]
fn
test_slice_out_of_bounds
()
{
let
storage
=
create_storage
(
1024
);
let
result
=
storage
.slice
(
900
,
200
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_slice_offset_overflow
()
{
let
storage
=
create_storage
(
1024
);
// offset + len would overflow when using saturating_add
let
result
=
storage
.slice
(
usize
::
MAX
,
1
);
assert
!
(
result
.is_err
());
}
// ========== Typed slice tests ==========
#[test]
fn
test_as_slice_typed_u32
()
{
let
mut
storage
=
create_storage
(
1024
);
// Fill with known pattern
storage
.memset
(
0x00
,
0
,
1024
)
.expect
(
"memset should succeed"
);
let
typed
:
&
[
u32
]
=
storage
.as_slice_typed
()
.expect
(
"typed slice should succeed"
);
assert_eq!
(
typed
.len
(),
256
);
// 1024 / 4
assert
!
(
typed
.iter
()
.all
(|
&
v
|
v
==
0
));
}
#[test]
fn
test_as_slice_typed_u64
()
{
let
storage
=
create_storage
(
1024
);
let
typed
:
&
[
u64
]
=
storage
.as_slice_typed
()
.expect
(
"typed slice should succeed"
);
assert_eq!
(
typed
.len
(),
128
);
// 1024 / 8
}
#[test]
fn
test_slice_typed_partial
()
{
let
mut
storage
=
create_storage
(
1024
);
storage
.memset
(
0x00
,
0
,
1024
)
.expect
(
"memset should succeed"
);
// Slice 10 u32 elements starting at offset 0
let
typed
:
&
[
u32
]
=
storage
.slice_typed
(
0
,
10
)
.expect
(
"typed slice should succeed"
);
assert_eq!
(
typed
.len
(),
10
);
}
#[test]
fn
test_slice_typed_with_offset
()
{
let
storage
=
create_storage
(
1024
);
// Slice starting at offset 64 (aligned for u64)
let
typed
:
&
[
u64
]
=
storage
.slice_typed
(
64
,
5
)
.expect
(
"typed slice should succeed"
);
assert_eq!
(
typed
.len
(),
5
);
}
#[test]
fn
test_as_slice_typed_zst_error
()
{
let
storage
=
create_storage
(
1024
);
// Zero-sized types should fail
let
result
:
Result
<&
[()],
_
>
=
storage
.as_slice_typed
();
assert
!
(
result
.is_err
());
}
#[test]
fn
test_as_slice_typed_size_not_multiple
()
{
// Create storage with size not divisible by 4
let
storage
=
create_storage
(
1023
);
let
result
:
Result
<&
[
u32
],
_
>
=
storage
.as_slice_typed
();
assert
!
(
result
.is_err
());
}
#[test]
fn
test_slice_typed_length_overflow
()
{
let
storage
=
create_storage
(
1024
);
// len * size_of::<u64>() would overflow
let
result
:
Result
<&
[
u64
],
_
>
=
storage
.slice_typed
(
0
,
usize
::
MAX
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_slice_typed_out_of_bounds
()
{
let
storage
=
create_storage
(
1024
);
// Request more elements than available
let
result
:
Result
<&
[
u64
],
_
>
=
storage
.slice_typed
(
0
,
200
);
assert
!
(
result
.is_err
());
}
}
lib/memory/src/arena.rs
View file @
976bb70a
...
...
@@ -4,14 +4,14 @@
//! # Arena Allocator
//!
//! This module provides an arena allocator for generally heap-like allocations.
//! An [`ArenaAllocator`] can be created by taking ownership of a [`MemoryDescript
ion
`] instance.
//! An [`ArenaAllocator`] can be created by taking ownership of a [`MemoryDescript
or
`] instance.
//!
//! The [`ArenaAllocator`] allocates memory contiguous regions using the [`offset_allocator`] crate,
//! which builds on [Sebastian Aaltonen's ArenaAllocator](https://github.com/sebbbi/ArenaAllocator)
use
crate
::
StorageKind
;
use
super
::{
MemoryDescript
ion
,
StorageError
};
use
super
::{
MemoryDescript
or
,
StorageError
};
use
offset_allocator
::{
Allocation
,
Allocator
};
use
std
::{
any
::
Any
,
...
...
@@ -20,6 +20,7 @@ use std::{
/// Errors specific to arena allocation.
#[derive(Debug,
thiserror::Error)]
#[allow(missing_docs)]
pub
enum
ArenaError
{
#[error(
"Page size must be a power of 2"
)]
PageSizeNotAligned
,
...
...
@@ -34,20 +35,20 @@ pub enum ArenaError {
StorageError
(
#[from]
StorageError
),
}
/// Arena allocator backed by an instance of a [`MemoryDescript
ion
`] object.
/// Arena allocator backed by an instance of a [`MemoryDescript
or
`] object.
///
/// This struct wraps an [`Allocator`] from the [`offset_allocator`] crate,
/// and provides methods for allocating memory from the storage.
///
/// The allocator is thread-safe, and the storage is shared between the allocator and the buffers.
#[derive(Clone)]
pub
struct
ArenaAllocator
<
S
:
MemoryDescript
ion
>
{
pub
struct
ArenaAllocator
<
S
:
MemoryDescript
or
>
{
storage
:
Arc
<
S
>
,
allocator
:
Arc
<
Mutex
<
Allocator
>>
,
page_size
:
u64
,
}
impl
<
S
:
MemoryDescript
ion
>
std
::
fmt
::
Debug
for
ArenaAllocator
<
S
>
{
impl
<
S
:
MemoryDescript
or
>
std
::
fmt
::
Debug
for
ArenaAllocator
<
S
>
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
...
...
@@ -62,18 +63,24 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaAllocator<S> {
/// This struct wraps an [`Allocation`] from the [`offset_allocator`] crate,
/// and provides methods for interacting with the allocated memory.
///
/// The buffer is backed by a [`MemoryDescription`] object, and the allocation is freed when the buffer is dropped.
pub
struct
ArenaBuffer
<
S
:
MemoryDescription
>
{
/// The buffer is backed by a [`MemoryDescriptor`] object, and the allocation is freed when the buffer is dropped.
pub
struct
ArenaBuffer
<
S
:
MemoryDescriptor
>
{
/// Byte offset from the start of the backing storage.
offset
:
usize
,
/// Absolute memory address of this buffer.
address
:
usize
,
/// User-requested allocation size in bytes.
requested_size
:
usize
,
/// Shared reference to the backing storage.
storage
:
Arc
<
S
>
,
/// Internal allocation handle from the offset allocator.
allocation
:
Allocation
,
/// Shared reference to the allocator for freeing on drop.
allocator
:
Arc
<
Mutex
<
Allocator
>>
,
}
impl
<
S
:
MemoryDescript
ion
>
ArenaAllocator
<
S
>
{
/// Create a new [`ArenaAllocator`] from a [`MemoryDescript
ion
`] object and a page size.
impl
<
S
:
MemoryDescript
or
>
ArenaAllocator
<
S
>
{
/// Create a new [`ArenaAllocator`] from a [`MemoryDescript
or
`] object and a page size.
///
/// The page size must be a power of two.
///
...
...
@@ -107,7 +114,11 @@ impl<S: MemoryDescription> ArenaAllocator<S> {
})
}
/// Allocate a new [`ArenaBuffer`] from the allocator.
/// Allocates a new [`ArenaBuffer`] of the given size from this allocator.
///
/// The actual allocation may consume more pages than strictly needed due to
/// page-size rounding. Returns [`ArenaError::AllocationFailed`] if there are
/// not enough contiguous pages available.
pub
fn
allocate
(
&
self
,
size
:
usize
)
->
std
::
result
::
Result
<
ArenaBuffer
<
S
>
,
ArenaError
>
{
let
size
=
size
as
u64
;
let
pages
=
size
.div_ceil
(
self
.page_size
);
...
...
@@ -135,7 +146,7 @@ impl<S: MemoryDescription> ArenaAllocator<S> {
}
}
impl
<
S
:
MemoryDescript
ion
>
std
::
fmt
::
Debug
for
ArenaBuffer
<
S
>
{
impl
<
S
:
MemoryDescript
or
>
std
::
fmt
::
Debug
for
ArenaBuffer
<
S
>
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
...
...
@@ -148,7 +159,7 @@ impl<S: MemoryDescription> std::fmt::Debug for ArenaBuffer<S> {
}
}
impl
<
S
:
MemoryDescript
ion
+
'static
>
MemoryDescript
ion
for
ArenaBuffer
<
S
>
{
impl
<
S
:
MemoryDescript
or
+
'static
>
MemoryDescript
or
for
ArenaBuffer
<
S
>
{
fn
addr
(
&
self
)
->
usize
{
self
.address
}
...
...
@@ -177,7 +188,7 @@ use super::nixl::{NixlCompatible, NixlDescriptor, RegisteredView};
impl
<
S
>
ArenaBuffer
<
S
>
where
S
:
MemoryDescript
ion
+
NixlCompatible
,
S
:
MemoryDescript
or
+
NixlCompatible
,
{
/// Create a NIXL descriptor for this buffer with the correct offset and size.
///
...
...
@@ -200,7 +211,7 @@ where
impl
<
S
>
ArenaBuffer
<
S
>
where
S
:
MemoryDescript
ion
+
RegisteredView
,
S
:
MemoryDescript
or
+
RegisteredView
,
{
/// Get the agent name from registered storage.
///
...
...
@@ -223,7 +234,7 @@ where
}
}
impl
<
S
:
MemoryDescript
ion
>
Drop
for
ArenaBuffer
<
S
>
{
impl
<
S
:
MemoryDescript
or
>
Drop
for
ArenaBuffer
<
S
>
{
fn
drop
(
&
mut
self
)
{
self
.allocator
.lock
()
.unwrap
()
.free
(
self
.allocation
);
}
...
...
lib/memory/src/device.rs
View file @
976bb70a
...
...
@@ -3,7 +3,7 @@
//! CUDA device memory storage.
use
super
::{
MemoryDescript
ion
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
};
use
super
::{
MemoryDescript
or
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
};
use
cudarc
::
driver
::
CudaContext
;
use
std
::
any
::
Any
;
use
std
::
collections
::
HashMap
;
...
...
@@ -26,9 +26,13 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
/// CUDA device memory allocated via cudaMalloc.
#[derive(Debug)]
pub
struct
DeviceStorage
{
/// CUDA context used for allocation and deallocation.
ctx
:
Arc
<
CudaContext
>
,
/// Device pointer to the allocated memory.
ptr
:
u64
,
/// CUDA device ID where memory is allocated.
device_id
:
u32
,
/// Size of the allocation in bytes.
len
:
usize
,
}
...
...
@@ -84,7 +88,7 @@ impl Drop for DeviceStorage {
}
}
impl
MemoryDescript
ion
for
DeviceStorage
{
impl
MemoryDescript
or
for
DeviceStorage
{
fn
addr
(
&
self
)
->
usize
{
self
.device_ptr
()
as
usize
}
...
...
lib/memory/src/disk.rs
View file @
976bb70a
...
...
@@ -3,7 +3,7 @@
//! Disk-backed memory storage using memory-mapped files.
use
super
::{
MemoryDescript
ion
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
};
use
super
::{
MemoryDescript
or
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
};
use
std
::
any
::
Any
;
use
std
::
path
::{
Path
,
PathBuf
};
...
...
@@ -16,15 +16,21 @@ use std::os::fd::BorrowedFd;
const
DISK_CACHE_KEY
:
&
str
=
"DYN_KVBM_DISK_CACHE_DIR"
;
const
DEFAULT_DISK_CACHE_DIR
:
&
str
=
"/tmp/"
;
/// Disk-backed storage using memory-mapped files with O_DIRECT support.
#[derive(Debug)]
pub
struct
DiskStorage
{
/// File descriptor for the backing file.
fd
:
u64
,
/// Path to the backing file.
path
:
PathBuf
,
/// Size of the storage in bytes.
size
:
usize
,
/// Whether the file has been unlinked from the filesystem.
unlinked
:
bool
,
}
impl
DiskStorage
{
/// Creates a new disk storage of the given size in the default cache directory.
pub
fn
new
(
size
:
usize
)
->
Result
<
Self
>
{
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
...
...
@@ -36,6 +42,7 @@ impl DiskStorage {
Self
::
new_at
(
file_path
,
size
)
}
/// Creates a new disk storage at the specified path with the given size.
pub
fn
new_at
(
path
:
impl
AsRef
<
Path
>
,
len
:
usize
)
->
Result
<
Self
>
{
if
len
==
0
{
return
Err
(
StorageError
::
AllocationFailed
(
...
...
@@ -140,15 +147,17 @@ impl DiskStorage {
})
}
/// Returns the file descriptor of the backing file.
pub
fn
fd
(
&
self
)
->
u64
{
self
.fd
}
/// Returns the path to the backing file.
pub
fn
path
(
&
self
)
->
&
Path
{
self
.path
.as_path
()
}
/// Unlink
our temp file
.
/// Unlink
s the backing file from the filesystem
.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
...
...
@@ -163,6 +172,7 @@ impl DiskStorage {
Ok
(())
}
/// Returns whether the backing file has been unlinked from the filesystem.
pub
fn
unlinked
(
&
self
)
->
bool
{
self
.unlinked
}
...
...
@@ -177,7 +187,7 @@ impl Drop for DiskStorage {
}
}
impl
MemoryDescript
ion
for
DiskStorage
{
impl
MemoryDescript
or
for
DiskStorage
{
fn
addr
(
&
self
)
->
usize
{
0
}
...
...
@@ -345,7 +355,7 @@ impl super::nixl::NixlCompatible for DiskStorage {
// }
// }
// impl MemoryDescript
ion
for MemMappedFileStorage {
// impl MemoryDescript
or
for MemMappedFileStorage {
// fn addr(&self) -> usize {
// self.mmap.as_ptr() as usize
// }
...
...
lib/memory/src/external.rs
0 → 100644
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! External memory wrapper for memory allocated by external frameworks.
//!
//! This module provides `ExternalDeviceMemory` for wrapping pointers to GPU
//! memory allocated by external frameworks (e.g., vLLM's KV cache). This type
//! does NOT own the memory - ownership remains with the external framework.
//!
//! The primary use case is registering external GPU memory with NIXL for RDMA
//! transfers without copying.
use
crate
::
nixl
::{
MemType
,
NixlCompatible
,
NixlDescriptor
};
use
crate
::{
MemoryDescriptor
,
StorageKind
};
use
std
::
any
::
Any
;
use
std
::
fmt
;
/// Wrapper for externally-allocated device (GPU) memory.
///
/// This type wraps a raw pointer to GPU memory that is owned by an external
/// framework (like vLLM). It provides the necessary traits for NIXL registration
/// without taking ownership of the underlying memory.
///
/// # Safety
///
/// This type relies on the caller to guarantee that:
/// - The pointer points to valid GPU memory on the specified device
/// - The memory remains valid for the lifetime of this wrapper
/// - The memory size is exactly as specified
/// - The external framework doesn't free the memory while this wrapper exists
///
/// # Example
///
/// ```ignore
/// // vLLM allocates KV cache tensors
/// let tensor_ptr = tensor.data_ptr();
/// let tensor_size = tensor.size_bytes();
/// let device_id = tensor.device.index;
///
/// // Wrap without taking ownership
/// let external = unsafe {
/// ExternalDeviceMemory::new(tensor_ptr as *const u8, tensor_size, device_id as u64)
/// };
///
/// // Register with NIXL for RDMA
/// let registered = register_with_nixl(external, &agent, None)?;
/// ```
pub
struct
ExternalDeviceMemory
{
/// Raw pointer to externally-allocated GPU memory.
ptr
:
*
const
u8
,
/// Size of the memory region in bytes.
size
:
usize
,
/// CUDA device ID where this memory resides.
device_id
:
u64
,
}
// Safety: The external framework (e.g., vLLM) guarantees the memory remains valid
// for the lifetime of the KV cache. The pointer is only used for NIXL registration
// and transfer operations which are synchronized by the framework.
unsafe
impl
Send
for
ExternalDeviceMemory
{}
unsafe
impl
Sync
for
ExternalDeviceMemory
{}
impl
ExternalDeviceMemory
{
/// Create a wrapper for external device memory.
///
/// # Safety
///
/// Caller must ensure:
/// - `ptr` points to valid GPU memory on CUDA device `device_id`
/// - The memory remains valid for the lifetime of this wrapper
/// - The memory size is exactly `size` bytes
/// - The external framework doesn't free the memory while this wrapper exists
#[inline]
pub
unsafe
fn
new
(
ptr
:
*
const
u8
,
size
:
usize
,
device_id
:
u64
)
->
Self
{
Self
{
ptr
,
size
,
device_id
,
}
}
/// Get the raw pointer to the external memory.
#[inline]
pub
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.ptr
}
/// Get the CUDA device ID where this memory resides.
#[inline]
pub
fn
device_id
(
&
self
)
->
u64
{
self
.device_id
}
}
impl
fmt
::
Debug
for
ExternalDeviceMemory
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
f
.debug_struct
(
"ExternalDeviceMemory"
)
.field
(
"ptr"
,
&
format_args!
(
"{:p}"
,
self
.ptr
))
.field
(
"size"
,
&
self
.size
)
.field
(
"device_id"
,
&
self
.device_id
)
.finish
()
}
}
impl
MemoryDescriptor
for
ExternalDeviceMemory
{
#[inline]
fn
addr
(
&
self
)
->
usize
{
self
.ptr
as
usize
}
#[inline]
fn
size
(
&
self
)
->
usize
{
self
.size
}
#[inline]
fn
storage_kind
(
&
self
)
->
StorageKind
{
StorageKind
::
Device
(
self
.device_id
as
u32
)
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
self
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
// External memory doesn't have a pre-existing NIXL descriptor
// It will be registered and get one via NixlRegistered wrapper
None
}
}
impl
NixlCompatible
for
ExternalDeviceMemory
{
fn
nixl_params
(
&
self
)
->
(
*
const
u8
,
usize
,
MemType
,
u64
)
{
(
self
.ptr
,
self
.size
,
MemType
::
Vram
,
self
.device_id
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_external_device_memory_traits
()
{
// Create with a dummy pointer (not actually valid GPU memory)
let
ptr
=
0x1000
as
*
const
u8
;
let
size
=
1024
;
let
device_id
=
0
;
let
external
=
unsafe
{
ExternalDeviceMemory
::
new
(
ptr
,
size
,
device_id
)
};
// Check MemoryDescriptor
assert_eq!
(
external
.addr
(),
0x1000
);
assert_eq!
(
external
.size
(),
1024
);
assert_eq!
(
external
.storage_kind
(),
StorageKind
::
Device
(
0
));
assert
!
(
external
.nixl_descriptor
()
.is_none
());
// Check NixlCompatible
let
(
p
,
s
,
mem_type
,
dev
)
=
external
.nixl_params
();
assert_eq!
(
p
as
usize
,
0x1000
);
assert_eq!
(
s
,
1024
);
assert_eq!
(
mem_type
,
MemType
::
Vram
);
assert_eq!
(
dev
,
0
);
}
}
lib/memory/src/lib.rs
View file @
976bb70a
...
...
@@ -4,24 +4,34 @@
//! Clean, minimal storage API for v2 block manager.
//!
//! This module provides a simplified storage abstraction with:
//! - Single trait for type erasure (`MemoryDescript
ion
`)
//! - Single trait for type erasure (`MemoryDescript
or
`)
//! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)
#![deny(missing_docs)]
pub
mod
actions
;
pub
mod
arena
;
pub
mod
nixl
;
pub
mod
numa
;
/// Offset-based buffer views into underlying storage.
pub
mod
offset
;
/// CUDA memory pool utilities.
pub
mod
pool
;
/// Common imports for working with memory types.
pub
mod
prelude
;
mod
device
;
#[cfg(target_os
=
"linux"
)]
mod
disk
;
mod
external
;
mod
pinned
;
mod
system
;
mod
tor
ch
;
mod
t
ens
or
;
#[cfg(test)]
mod
tests
;
...
...
@@ -30,9 +40,13 @@ pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
pub
use
device
::
DeviceStorage
;
#[cfg(target_os
=
"linux"
)]
pub
use
disk
::
DiskStorage
;
pub
use
external
::
ExternalDeviceMemory
;
pub
use
numa
::{
NumaNode
,
is_numa_enabled
};
pub
use
offset
::
OffsetBuffer
;
pub
use
pinned
::
PinnedStorage
;
pub
use
pool
::{
CudaMemPool
,
CudaMemPoolBuilder
};
pub
use
system
::
SystemStorage
;
pub
use
t
orch
::{
TorchDevice
,
TorchTensor
};
pub
use
t
ensor
::{
TensorDescriptor
,
TensorDescriptorExt
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
any
::
Any
;
...
...
@@ -43,8 +57,30 @@ use thiserror::Error;
/// Result type for storage operations.
pub
type
Result
<
T
>
=
std
::
result
::
Result
<
T
,
StorageError
>
;
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescriptor>`.
pub
trait
MemoryDescriptor
:
Send
+
Sync
+
fmt
::
Debug
{
/// Base address of the memory region.
fn
addr
(
&
self
)
->
usize
;
/// Size of the memory region in bytes.
fn
size
(
&
self
)
->
usize
;
/// Type of storage backing this region.
fn
storage_kind
(
&
self
)
->
StorageKind
;
/// Enable downcasting to concrete type.
fn
as_any
(
&
self
)
->
&
dyn
Any
;
/// Get the NIXL descriptor for this memory region.
fn
nixl_descriptor
(
&
self
)
->
Option
<
nixl
::
NixlDescriptor
>
;
}
/// Errors that can occur during storage operations.
#[derive(Debug,
Error)]
#[allow(missing_docs)]
pub
enum
StorageError
{
#[error(
"allocation failed: {0}"
)]
AllocationFailed
(
String
),
...
...
@@ -87,32 +123,51 @@ pub enum StorageKind {
Disk
(
u64
),
}
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescription>`.
pub
trait
MemoryDescription
:
Send
+
Sync
+
fmt
::
Debug
{
/// Base address of the memory region.
fn
addr
(
&
self
)
->
usize
;
impl
StorageKind
{
/// Returns the CUDA device index if this is device memory.
pub
fn
cuda_device_index
(
&
self
)
->
Option
<
u32
>
{
match
self
{
StorageKind
::
Device
(
idx
)
=>
Some
(
*
idx
),
_
=>
None
,
}
}
/// Size of the memory region in bytes.
fn
size
(
&
self
)
->
usize
;
/// Returns true if this is CUDA device memory.
pub
fn
is_cuda
(
&
self
)
->
bool
{
matches!
(
self
,
StorageKind
::
Device
(
_
))
}
/// Type of storage backing this region.
fn
storage_kind
(
&
self
)
->
StorageKind
;
/// Returns true if this is system memory (malloc).
pub
fn
is_system
(
&
self
)
->
bool
{
matches!
(
self
,
StorageKind
::
System
)
}
/// Enable downcasting to concrete type.
fn
as_any
(
&
self
)
->
&
dyn
Any
;
/// Returns true if this is CUDA pinned host memory.
pub
fn
is_pinned
(
&
self
)
->
bool
{
matches!
(
self
,
StorageKind
::
Pinned
)
}
/// Get the NIXL descriptor for this memory region.
fn
nixl_descriptor
(
&
self
)
->
Option
<
nixl
::
NixlDescriptor
>
;
/// Returns true if this is disk-backed memory.
pub
fn
is_disk
(
&
self
)
->
bool
{
matches!
(
self
,
StorageKind
::
Disk
(
_
))
}
}
/// Type-erased memory region for use in layouts.
#[derive(Clone)]
pub
struct
Buffer
(
Arc
<
dyn
MemoryDescription
>
);
pub
struct
Buffer
(
Arc
<
dyn
MemoryDescriptor
>
);
impl
Buffer
{
/// Wraps a concrete storage type into a type-erased [`Buffer`].
///
/// This is the primary way to create a `Buffer` from any type that
/// implements [`MemoryDescriptor`].
pub
fn
new
<
S
:
MemoryDescriptor
+
'static
>
(
memory
:
S
)
->
Self
{
Buffer
(
Arc
::
new
(
memory
))
}
}
impl
MemoryDescript
ion
for
Buffer
{
impl
MemoryDescript
or
for
Buffer
{
fn
addr
(
&
self
)
->
usize
{
self
.0
.addr
()
}
...
...
@@ -131,7 +186,7 @@ impl MemoryDescription for Buffer {
}
impl
std
::
ops
::
Deref
for
Buffer
{
type
Target
=
dyn
MemoryDescript
ion
;
type
Target
=
dyn
MemoryDescript
or
;
fn
deref
(
&
self
)
->
&
Self
::
Target
{
self
.0
.as_ref
()
...
...
@@ -149,10 +204,31 @@ impl std::fmt::Debug for Buffer {
}
/// Helper function to convert concrete storage to type-erased form.
pub
fn
create_buffer
<
S
:
MemoryDescript
ion
+
'static
>
(
memory
:
S
)
->
Buffer
{
pub
fn
create_buffer
<
S
:
MemoryDescript
or
+
'static
>
(
memory
:
S
)
->
Buffer
{
Buffer
(
Arc
::
new
(
memory
))
}
impl
Buffer
{
/// Create a Buffer from an existing Arc<dyn MemoryDescriptor>.
pub
fn
from_arc
(
arc
:
Arc
<
dyn
MemoryDescriptor
>
)
->
Self
{
Buffer
(
arc
)
}
}
// From implementations for ergonomic Buffer creation
impl
From
<
Arc
<
dyn
MemoryDescriptor
>>
for
Buffer
{
fn
from
(
arc
:
Arc
<
dyn
MemoryDescriptor
>
)
->
Self
{
Buffer
::
from_arc
(
arc
)
}
}
impl
From
<
Arc
<
dyn
nixl
::
NixlMemory
+
Send
+
Sync
>>
for
Buffer
{
fn
from
(
arc
:
Arc
<
dyn
nixl
::
NixlMemory
+
Send
+
Sync
>
)
->
Self
{
// Arc<dyn NixlMemory> implements MemoryDescriptor, so we can wrap it
Buffer
::
new
(
arc
)
}
}
/// An unowned contiguous chunk of memory, not storage specific.
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Serialize,
Deserialize)]
pub
struct
MemoryRegion
{
...
...
@@ -164,17 +240,62 @@ pub struct MemoryRegion {
}
impl
MemoryRegion
{
/// Creates a new memory region with the given base address and size.
pub
fn
new
(
addr
:
usize
,
size
:
usize
)
->
Self
{
Self
{
addr
,
size
}
}
/// Returns the base address of this memory region.
#[inline]
pub
fn
addr
(
&
self
)
->
usize
{
self
.addr
}
/// Returns the size of this memory region in bytes.
#[inline]
pub
fn
size
(
&
self
)
->
usize
{
self
.size
}
/// Get a slice view of this memory region.
///
/// # Safety
/// This is unsafe because:
/// - The caller must ensure the memory region is valid and properly initialized
/// - The caller must ensure no mutable references exist to this memory
/// - The caller must ensure the memory remains valid for the lifetime of the slice
#[cfg(feature
=
"unsafe-slices"
)]
pub
unsafe
fn
as_slice
(
&
self
)
->
Result
<&
[
u8
]
>
{
if
self
.size
==
0
{
return
Ok
(
&
[]);
}
// SAFETY: Caller guarantees memory is valid
unsafe
{
Ok
(
std
::
slice
::
from_raw_parts
(
self
.addr
as
*
const
u8
,
self
.size
,
))
}
}
/// Get a mutable slice view of this memory region.
///
/// # Safety
/// This is unsafe because:
/// - The caller must ensure the memory region is valid and properly initialized
/// - The caller must ensure no other references (mutable or immutable) exist to this memory
/// - The caller must ensure the memory remains valid for the lifetime of the slice
#[cfg(feature
=
"unsafe-slices"
)]
pub
unsafe
fn
as_slice_mut
(
&
mut
self
)
->
Result
<&
mut
[
u8
]
>
{
if
self
.size
==
0
{
return
Ok
(
&
mut
[]);
}
// SAFETY: Caller guarantees memory is valid and exclusively accessible
unsafe
{
Ok
(
std
::
slice
::
from_raw_parts_mut
(
self
.addr
as
*
mut
u8
,
self
.size
,
))
}
}
}
lib/memory/src/nixl.rs
View file @
976bb70a
...
...
@@ -6,14 +6,17 @@
mod
agent
;
mod
config
;
use
super
::{
MemoryDescript
ion
,
StorageKind
};
use
super
::{
MemoryDescript
or
,
StorageKind
};
use
std
::
any
::
Any
;
use
std
::
fmt
;
use
std
::
sync
::
Arc
;
pub
use
agent
::
NixlAgent
;
pub
use
config
::
NixlBackendConfig
;
pub
use
nixl_sys
::{
MemType
,
OptArgs
,
RegistrationHandle
};
pub
use
nixl_sys
::{
Agent
,
MemType
,
NotificationMap
,
OptArgs
,
RegistrationHandle
,
XferDescList
,
XferOp
,
XferRequest
,
};
pub
use
serde
::{
Deserialize
,
Serialize
};
/// Trait for storage types that can be registered with NIXL.
...
...
@@ -24,12 +27,29 @@ pub trait NixlCompatible {
fn
nixl_params
(
&
self
)
->
(
*
const
u8
,
usize
,
MemType
,
u64
);
}
/// Combined trait for memory that can be registered with NIXL.
///
/// This supertrait enables type erasure via `Arc<dyn NixlMemory>`.
/// Any type implementing both `MemoryDescriptor` and `NixlCompatible`
/// automatically implements this trait via the blanket implementation.
pub
trait
NixlMemory
:
MemoryDescriptor
+
NixlCompatible
{}
// Blanket impl - any type with both traits automatically implements NixlMemory
impl
<
T
:
MemoryDescriptor
+
NixlCompatible
+
?
Sized
>
NixlMemory
for
T
{}
/// NIXL descriptor containing registration information.
///
/// This struct holds the information needed to describe a memory region
/// to NIXL for transfer operations.
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
NixlDescriptor
{
/// Base address of the memory region.
pub
addr
:
u64
,
/// Size of the memory region in bytes.
pub
size
:
usize
,
/// Type of memory (host, device, etc.).
pub
mem_type
:
MemType
,
/// Device identifier (GPU index for device memory, 0 for host memory).
pub
device_id
:
u64
,
}
...
...
@@ -91,7 +111,7 @@ impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
}
}
impl
<
S
:
MemoryDescript
ion
+
NixlCompatible
+
'static
>
MemoryDescript
ion
for
NixlRegistered
<
S
>
{
impl
<
S
:
MemoryDescript
or
+
NixlCompatible
+
'static
>
MemoryDescript
or
for
NixlRegistered
<
S
>
{
fn
addr
(
&
self
)
->
usize
{
self
.storage
.addr
()
}
...
...
@@ -113,7 +133,7 @@ impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for Nixl
}
}
impl
<
S
:
MemoryDescript
ion
+
NixlCompatible
>
RegisteredView
for
NixlRegistered
<
S
>
{
impl
<
S
:
MemoryDescript
or
+
NixlCompatible
>
RegisteredView
for
NixlRegistered
<
S
>
{
fn
agent_name
(
&
self
)
->
&
str
{
&
self
.agent_name
}
...
...
@@ -129,7 +149,7 @@ impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S>
}
}
impl
<
S
:
MemoryDescript
ion
+
NixlCompatible
>
NixlRegistered
<
S
>
{
impl
<
S
:
MemoryDescript
or
+
NixlCompatible
>
NixlRegistered
<
S
>
{
/// Get a reference to the underlying storage.
pub
fn
storage
(
&
self
)
->
&
S
{
&
self
.storage
...
...
@@ -179,8 +199,42 @@ pub fn register_with_nixl<S>(
opt
:
Option
<&
OptArgs
>
,
)
->
std
::
result
::
Result
<
NixlRegistered
<
S
>
,
S
>
where
S
:
MemoryDescript
ion
+
NixlCompatible
,
S
:
MemoryDescript
or
+
NixlCompatible
,
{
// let storage_kind = storage.storage_kind();
// // Determine if registration is needed based on storage type and available backends
// let should_register = match storage_kind {
// StorageKind::System | StorageKind::Pinned => {
// // System/Pinned memory needs UCX for remote transfers
// agent.has_backend("UCX") || agent.has_backend("POSIX")
// }
// StorageKind::Device(_) => {
// // Device memory needs UCX for remote transfers OR GDS for direct disk transfers
// agent.has_backend("UCX") || agent.has_backend("GDS_MT")
// }
// StorageKind::Disk(_) => {
// // Disk storage needs POSIX for regular I/O OR GDS for GPU direct I/O
// agent.has_backend("POSIX") || agent.has_backend("GDS_MT")
// } // StorageKind::Object(_) => {
// // // Object storage is always registered via NIXL's OBJ plugin
// // agent.has_backend("OBJ")
// // }
// };
// this is not true for our future object storage. so let's rethink this.
// for object, if there is no device_id or device_id is 0, then we need to register
// alternatively, the object storage holds it's own internal metadata but does not
// expose as a nixl descriptor, thus ObjectStorag will by default like all other storage
// types have a None for nixl_descriptor(), and we will use the internal
if
storage
.nixl_descriptor
()
.is_some
()
{
return
Ok
(
NixlRegistered
{
storage
,
handle
:
None
,
agent_name
:
agent
.name
()
.to_string
(),
});
}
// Get NIXL parameters
let
(
ptr
,
size
,
mem_type
,
device_id
)
=
storage
.nixl_params
();
...
...
@@ -201,3 +255,69 @@ where
Err
(
_
)
=>
Err
(
storage
),
}
}
// =============================================================================
// Arc<dyn NixlMemory> support
// =============================================================================
impl
NixlCompatible
for
Arc
<
dyn
NixlMemory
+
Send
+
Sync
>
{
fn
nixl_params
(
&
self
)
->
(
*
const
u8
,
usize
,
MemType
,
u64
)
{
(
**
self
)
.nixl_params
()
}
}
impl
MemoryDescriptor
for
Arc
<
dyn
NixlMemory
+
Send
+
Sync
>
{
fn
addr
(
&
self
)
->
usize
{
(
**
self
)
.addr
()
}
fn
size
(
&
self
)
->
usize
{
(
**
self
)
.size
()
}
fn
storage_kind
(
&
self
)
->
StorageKind
{
(
**
self
)
.storage_kind
()
}
fn
as_any
(
&
self
)
->
&
dyn
Any
{
(
**
self
)
.as_any
()
}
fn
nixl_descriptor
(
&
self
)
->
Option
<
NixlDescriptor
>
{
(
**
self
)
.nixl_descriptor
()
}
}
// =============================================================================
// Extension trait for ergonomic API
// =============================================================================
/// Extension trait providing ergonomic `.register()` method for NIXL registration.
///
/// This trait is automatically implemented for all types that implement both
/// `MemoryDescriptor` and `NixlCompatible`. Import this trait to use the
/// method syntax:
///
///
pub
trait
NixlRegisterExt
:
MemoryDescriptor
+
NixlCompatible
+
Sized
{
/// Get this memory as NIXL-registered.
///
/// This operation is idempotent - it's a no-op if the memory is already registered.
///
/// # Arguments
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper on success, or the original storage on failure.
fn
register
(
self
,
agent
:
&
NixlAgent
,
opt
:
Option
<&
OptArgs
>
,
)
->
std
::
result
::
Result
<
NixlRegistered
<
Self
>
,
Self
>
{
register_with_nixl
(
self
,
agent
,
opt
)
}
}
// Blanket impl for all compatible types
impl
<
T
:
MemoryDescriptor
+
NixlCompatible
+
Sized
>
NixlRegisterExt
for
T
{}
lib/memory/src/nixl/agent.rs
View file @
976bb70a
...
...
@@ -9,7 +9,9 @@
use
anyhow
::
Result
;
use
nixl_sys
::
Agent
;
use
std
::
collections
::
HashSet
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
crate
::
nixl
::
NixlBackendConfig
;
/// A NIXL agent wrapper that tracks which backends were successfully initialized.
///
...
...
@@ -40,26 +42,64 @@ impl NixlAgent {
})
}
/// Add a backend to the agent.
/// Creates a new agent configured with backends from the given config.
///
/// This method iterates over all backends in the config and initializes them
/// with their associated parameters. If a backend has custom parameters defined
/// in the config, those are used; otherwise, default plugin parameters are used.
pub
fn
from_nixl_backend_config
(
name
:
&
str
,
config
:
NixlBackendConfig
)
->
Result
<
Self
>
{
let
mut
agent
=
Self
::
new
(
name
)
?
;
for
(
backend
,
params
)
in
config
.iter
()
{
agent
.add_backend_with_params
(
backend
,
params
)
?
;
}
Ok
(
agent
)
}
/// Add a backend to the agent with default parameters.
pub
fn
add_backend
(
&
mut
self
,
backend
:
&
str
)
->
Result
<
()
>
{
if
self
.available_backends
.contains
(
&
backend
.to_uppercase
())
{
return
Ok
(());
self
.add_backend_with_params
(
backend
,
&
HashMap
::
new
())
}
/// Add a backend to the agent with optional custom parameters.
///
/// If `custom_params` is non-empty, those parameters are used instead of
/// the plugin defaults. If empty, default parameters from the plugin are used.
///
/// # Errors
/// Returns an error if custom parameters are provided (not yet supported until nixl_sys 0.9).
pub
fn
add_backend_with_params
(
&
mut
self
,
backend
:
&
str
,
custom_params
:
&
HashMap
<
String
,
String
>
,
)
->
Result
<
()
>
{
let
backend_upper
=
backend
.to_uppercase
();
match
self
.agent
.get_plugin_params
(
&
backend_upper
)
{
Ok
((
_
,
params
))
=>
match
self
.agent
.create_backend
(
&
backend_upper
,
&
params
)
{
Ok
(
_
)
=>
{
self
.available_backends
.insert
(
backend_upper
);
if
self
.available_backends
.contains
(
&
backend_upper
)
{
return
Ok
(());
}
Err
(
e
)
=>
{
anyhow
::
bail!
(
"Failed to create nixl backend: {}"
,
e
);
// TODO(DIS-1310): Custom params require nixl_sys 0.9+ which adds nixl_capi_params_add
if
!
custom_params
.is_empty
()
{
anyhow
::
bail!
(
"Custom NIXL backend parameters for {} are not yet supported.
\
This feature requires nixl_sys 0.9+. Params provided: {:?}"
,
backend_upper
,
custom_params
.keys
()
.collect
::
<
Vec
<
_
>>
()
);
}
},
Err
(
_
)
=>
{
anyhow
::
bail!
(
"No {} plugin found"
,
backend_upper
);
// Get default params from plugin
let
(
_
,
params
)
=
match
self
.agent
.get_plugin_params
(
&
backend_upper
)
{
Ok
(
result
)
=>
result
,
Err
(
_
)
=>
anyhow
::
bail!
(
"No {} plugin found"
,
backend_upper
),
};
match
self
.agent
.create_backend
(
&
backend_upper
,
&
params
)
{
Ok
(
_
)
=>
{
self
.available_backends
.insert
(
backend_upper
);
Ok
(())
}
Err
(
e
)
=>
anyhow
::
bail!
(
"Failed to create nixl backend: {}"
,
e
),
}
Ok
(())
}
/// Create a NIXL agent requiring ALL specified backends to be available.
...
...
@@ -200,4 +240,59 @@ mod tests {
let
result
=
NixlAgent
::
with_backends
(
"test_strict_fail"
,
&
[
"UCX"
,
"DUDE"
]);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_add_backend_with_empty_params
()
{
let
mut
agent
=
NixlAgent
::
new
(
"test_empty_params"
)
.expect
(
"Failed to create agent"
);
// Empty params should work (uses plugin defaults)
let
result
=
agent
.add_backend_with_params
(
"UCX"
,
&
HashMap
::
new
());
assert
!
(
result
.is_ok
());
assert
!
(
agent
.has_backend
(
"UCX"
));
}
#[test]
fn
test_add_backend_with_custom_params_fails
()
{
let
mut
agent
=
NixlAgent
::
new
(
"test_custom_params"
)
.expect
(
"Failed to create agent"
);
// Custom params should fail until nixl_sys 0.9
let
mut
params
=
HashMap
::
new
();
params
.insert
(
"some_key"
.to_string
(),
"some_value"
.to_string
());
let
result
=
agent
.add_backend_with_params
(
"UCX"
,
&
params
);
assert
!
(
result
.is_err
());
let
err_msg
=
result
.unwrap_err
()
.to_string
();
assert
!
(
err_msg
.contains
(
"not yet supported"
));
assert
!
(
err_msg
.contains
(
"nixl_sys 0.9"
));
assert
!
(
err_msg
.contains
(
"some_key"
));
}
#[test]
fn
test_from_nixl_backend_config_with_custom_params_fails
()
{
// Config with custom params should fail
let
mut
params
=
HashMap
::
new
();
params
.insert
(
"threads"
.to_string
(),
"4"
.to_string
());
let
config
=
NixlBackendConfig
::
default
()
.with_backend_params
(
"UCX"
,
params
);
let
result
=
NixlAgent
::
from_nixl_backend_config
(
"test_config_params"
,
config
);
assert
!
(
result
.is_err
());
let
err_msg
=
result
.unwrap_err
()
.to_string
();
assert
!
(
err_msg
.contains
(
"not yet supported"
));
assert
!
(
err_msg
.contains
(
"threads"
));
}
#[test]
fn
test_from_nixl_backend_config_with_empty_params
()
{
// Config with empty params should work
let
config
=
NixlBackendConfig
::
default
()
.with_backend
(
"UCX"
);
let
result
=
NixlAgent
::
from_nixl_backend_config
(
"test_config_empty"
,
config
);
assert
!
(
result
.is_ok
());
let
agent
=
result
.unwrap
();
assert
!
(
agent
.has_backend
(
"UCX"
));
}
}
lib/memory/src/nixl/config.rs
View file @
976bb70a
...
...
@@ -4,10 +4,11 @@
//! NIXL backend configuration with Figment support.
//!
//! This module provides configuration extraction for NIXL backends from
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>
_<key>
=<value>`
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>=<value>`
use
anyhow
::{
Result
,
bail
};
use
std
::
collections
::
HashSet
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
collections
::
HashMap
;
use
dynamo_config
::
parse_bool
;
...
...
@@ -19,16 +20,40 @@ use dynamo_config::parse_bool;
/// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive)
/// - Invalid values (e.g., "maybe", "random") will cause an error
/// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error
#[derive(Debug,
Clone,
Default)]
///
/// # Data Structure
///
/// Uses a single HashMap where:
/// - Key presence = backend is enabled
/// - Value (inner HashMap) = backend-specific parameters (empty = defaults)
///
/// # TOML Example
///
/// ```toml
/// [backends.UCX]
/// # UCX with default params (empty map)
///
/// [backends.GDS]
/// threads = "4"
/// buffer_size = "1048576"
/// ```
#[derive(Debug,
Clone,
Default,
Serialize,
Deserialize)]
pub
struct
NixlBackendConfig
{
/// Set of enabled backends (just backend names, no custom params yet)
backends
:
HashSet
<
String
>
,
/// Map of backend name (uppercase) -> optional parameters.
///
/// If a backend is present in the map, it's enabled.
/// The inner HashMap contains optional override parameters.
/// An empty inner map means use default parameters.
#[serde(default)]
backends
:
HashMap
<
String
,
HashMap
<
String
,
String
>>
,
}
impl
NixlBackendConfig
{
/// Create a new empty configuration.
pub
fn
new
()
->
Self
{
Self
::
default
()
/// Creates a new configuration with the given backends.
///
/// For an empty configuration with no backends, use [`Default::default()`].
pub
fn
new
(
backends
:
HashMap
<
String
,
HashMap
<
String
,
String
>>
)
->
Self
{
Self
{
backends
}
}
/// Create configuration from environment variables.
...
...
@@ -40,7 +65,7 @@ impl NixlBackendConfig {
/// - Custom parameters are detected (not yet supported)
/// - Invalid boolean values are provided (must be truthy or falsey)
pub
fn
from_env
()
->
Result
<
Self
>
{
let
mut
backends
=
Hash
Set
::
new
();
let
mut
backends
=
Hash
Map
::
new
();
// Extract all environment variables that match our pattern
for
(
key
,
value
)
in
std
::
env
::
vars
()
{
...
...
@@ -59,7 +84,7 @@ impl NixlBackendConfig {
let
backend_name
=
remainder
.to_uppercase
();
match
parse_bool
(
&
value
)
{
Ok
(
true
)
=>
{
backends
.insert
(
backend_name
);
backends
.insert
(
backend_name
,
HashMap
::
new
()
);
}
Ok
(
false
)
=>
{
// Explicitly disabled, don't add to backends
...
...
@@ -70,39 +95,59 @@ impl NixlBackendConfig {
}
}
// Default to UCX if no backends specified
if
backends
.is_empty
()
{
backends
.insert
(
"UCX"
.to_string
());
}
Ok
(
Self
{
backends
})
}
/// Add a backend to the configuration.
///
/// Backend names will be converted to uppercase for consistency.
/// Add a backend with default parameters.
/// Backend name is normalized to uppercase.
pub
fn
with_backend
(
mut
self
,
backend
:
impl
Into
<
String
>
)
->
Self
{
self
.backends
.insert
(
backend
.into
()
.to_uppercase
());
self
.backends
.insert
(
backend
.into
()
.to_uppercase
(),
HashMap
::
new
());
self
}
/// Add a backend with custom parameters.
/// Backend name is normalized to uppercase.
pub
fn
with_backend_params
(
mut
self
,
backend
:
impl
Into
<
String
>
,
params
:
HashMap
<
String
,
String
>
,
)
->
Self
{
self
.backends
.insert
(
backend
.into
()
.to_uppercase
(),
params
);
self
}
/// Get the set of enabled backends.
pub
fn
backends
(
&
self
)
->
&
HashSet
<
String
>
{
&
self
.backends
/// Get the list of enabled backend names (uppercase).
pub
fn
backends
(
&
self
)
->
Vec
<
String
>
{
self
.backends
.keys
()
.cloned
()
.collect
()
}
/// Get parameters for a specific backend.
/// Backend name is normalized to uppercase for lookup.
///
/// Returns None if the backend is not enabled.
pub
fn
backend_params
(
&
self
,
backend
:
&
str
)
->
Option
<&
HashMap
<
String
,
String
>>
{
self
.backends
.get
(
&
backend
.to_uppercase
())
}
/// Check if a specific backend is enabled.
pub
fn
has_backend
(
&
self
,
backend
:
&
str
)
->
bool
{
self
.backends
.contains
(
&
backend
.to_uppercase
())
self
.backends
.contains
_key
(
&
backend
.to_uppercase
())
}
/// Merge another configuration into this one.
///
/// Backends from the other configuration will be added to this one.
/// If both have the same backend, params from `other` take precedence.
pub
fn
merge
(
mut
self
,
other
:
NixlBackendConfig
)
->
Self
{
self
.backends
.extend
(
other
.backends
);
self
}
/// Iterate over all enabled backends and their parameters.
pub
fn
iter
(
&
self
)
->
impl
Iterator
<
Item
=
(
&
String
,
&
HashMap
<
String
,
String
>
)
>
{
self
.backends
.iter
()
}
}
#[cfg(test)]
...
...
@@ -111,13 +156,19 @@ mod tests {
#[test]
fn
test_new_config_is_empty
()
{
let
config
=
NixlBackendConfig
::
new
();
assert
!
(
config
.backends
()
.is_empty
());
let
config
=
NixlBackendConfig
::
default
();
assert_eq!
(
config
.backends
()
.len
(),
0
);
}
#[test]
fn
test_default_is_empty
()
{
let
config
=
NixlBackendConfig
::
default
();
assert
!
(
config
.backends
()
.is_empty
());
// default() has no backends
}
#[test]
fn
test_with_backend
()
{
let
config
=
NixlBackendConfig
::
new
()
let
config
=
NixlBackendConfig
::
default
()
.with_backend
(
"ucx"
)
.with_backend
(
"gds_mt"
);
...
...
@@ -128,10 +179,30 @@ mod tests {
assert
!
(
!
config
.has_backend
(
"other"
));
}
#[test]
fn
test_with_backend_params
()
{
let
mut
params
=
HashMap
::
new
();
params
.insert
(
"threads"
.to_string
(),
"4"
.to_string
());
params
.insert
(
"buffer_size"
.to_string
(),
"1048576"
.to_string
());
let
config
=
NixlBackendConfig
::
default
()
.with_backend
(
"UCX"
)
.with_backend_params
(
"GDS"
,
params
);
// UCX should have empty params
let
ucx_params
=
config
.backend_params
(
"UCX"
)
.unwrap
();
assert
!
(
ucx_params
.is_empty
());
// GDS should have custom params
let
gds_params
=
config
.backend_params
(
"GDS"
)
.unwrap
();
assert_eq!
(
gds_params
.get
(
"threads"
),
Some
(
&
"4"
.to_string
()));
assert_eq!
(
gds_params
.get
(
"buffer_size"
),
Some
(
&
"1048576"
.to_string
()));
}
#[test]
fn
test_merge_configs
()
{
let
config1
=
NixlBackendConfig
::
new
()
.with_backend
(
"ucx"
);
let
config2
=
NixlBackendConfig
::
new
()
.with_backend
(
"gds"
);
let
config1
=
NixlBackendConfig
::
default
()
.with_backend
(
"ucx"
);
let
config2
=
NixlBackendConfig
::
default
()
.with_backend
(
"gds"
);
let
merged
=
config1
.merge
(
config2
);
...
...
@@ -141,7 +212,7 @@ mod tests {
#[test]
fn
test_backend_name_case_insensitive
()
{
let
config
=
NixlBackendConfig
::
new
()
let
config
=
NixlBackendConfig
::
default
()
.with_backend
(
"ucx"
)
.with_backend
(
"Gds_mt"
)
.with_backend
(
"OTHER"
);
...
...
@@ -154,6 +225,19 @@ mod tests {
assert
!
(
config
.has_backend
(
"other"
));
}
#[test]
fn
test_iter
()
{
let
mut
params
=
HashMap
::
new
();
params
.insert
(
"key"
.to_string
(),
"value"
.to_string
());
let
config
=
NixlBackendConfig
::
default
()
.with_backend
(
"UCX"
)
.with_backend_params
(
"GDS"
,
params
);
let
items
:
Vec
<
_
>
=
config
.iter
()
.collect
();
assert_eq!
(
items
.len
(),
2
);
}
// Note: Testing from_env() would require setting environment variables,
// which is challenging in unit tests. This is better tested with integration tests.
}
lib/memory/src/numa/mod.rs
0 → 100644
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA-aware memory allocation utilities.
//!
//! This module provides utilities for NUMA-aware memory allocation, which is critical
//! for optimal performance on multi-socket systems with GPUs. Memory allocated on the
//! NUMA node closest to the target GPU has significantly lower access latency.
//!
//! ## Architecture
//!
//! - [`NumaNode`]: Represents a NUMA node ID
//! - [`topology`]: Reads CPU-to-NUMA mapping from `/sys/devices/system/node`
//! - [`worker_pool`]: Dedicated worker threads pinned to specific NUMA nodes
//!
//! ## Usage
//!
//! NUMA optimization is opt-in via environment variable:
//! ```bash
//! export DYN_KVBM_ENABLE_NUMA=1
//! ```
//!
//! When enabled, pinned memory allocations are routed through NUMA workers
//! that are pinned to the target GPU's NUMA node, ensuring first-touch policy
//! places pages on the correct node.
pub
mod
topology
;
pub
mod
worker_pool
;
use
nix
::
libc
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::{
mem
,
process
::
Command
};
/// Check if NUMA optimization is enabled via environment variable
///
/// Set `DYN_KVBM_ENABLE_NUMA=1` to enable NUMA-aware allocation.
/// Default: disabled (opt-in)
pub
fn
is_numa_enabled
()
->
bool
{
std
::
env
::
var
(
"DYN_KVBM_ENABLE_NUMA"
)
.map
(|
v
|
v
==
"1"
||
v
.to_lowercase
()
==
"true"
)
.unwrap_or
(
false
)
}
/// Represents a NUMA node identifier.
///
/// NUMA nodes are typically numbered 0, 1, 2, etc. corresponding to physical
/// CPU sockets. Use [`NumaNode::UNKNOWN`] when the node cannot be determined.
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize)]
pub
struct
NumaNode
(
pub
u32
);
impl
NumaNode
{
/// Sentinel value for unknown NUMA node.
pub
const
UNKNOWN
:
NumaNode
=
NumaNode
(
u32
::
MAX
);
/// Returns true if this represents an unknown NUMA node.
pub
fn
is_unknown
(
&
self
)
->
bool
{
self
.0
==
u32
::
MAX
}
}
impl
std
::
fmt
::
Display
for
NumaNode
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
if
self
.is_unknown
()
{
write!
(
f
,
"UNKNOWN"
)
}
else
{
write!
(
f
,
"NumaNode({})"
,
self
.0
)
}
}
}
/// Get the current CPU's NUMA node.
///
/// Uses the Linux `getcpu` syscall to determine which NUMA node the current CPU belongs to.
/// Returns [`NumaNode::UNKNOWN`] if the syscall fails.
pub
fn
get_current_cpu_numa_node
()
->
NumaNode
{
unsafe
{
let
mut
cpu
:
libc
::
c_uint
=
0
;
let
mut
node
:
libc
::
c_uint
=
0
;
// getcpu syscall: int getcpu(unsigned *cpu, unsigned *node, struct getcpu_cache *tcache);
let
result
=
libc
::
syscall
(
libc
::
SYS_getcpu
,
&
mut
cpu
,
&
mut
node
,
std
::
ptr
::
null_mut
::
<
libc
::
c_void
>
(),
);
if
result
==
0
{
NumaNode
(
node
)
}
else
{
NumaNode
::
UNKNOWN
}
}
}
/// Get NUMA node for a GPU device.
///
/// For GPU memory, the NUMA affinity depends on which PCIe bus the GPU is attached to.
/// This is queried via nvidia-smi. Falls back to a heuristic (device_id % 2) if nvidia-smi
/// is unavailable.
///
/// # Arguments
/// * `device_id` - CUDA device index (0, 1, 2, ...)
///
/// # Returns
/// The NUMA node closest to the specified GPU, or a heuristic fallback.
pub
fn
get_device_numa_node
(
device_id
:
u32
)
->
NumaNode
{
// Use nvidia-smi topo to get NUMA ID of nearest CPU
// This directly returns the NUMA node
let
output
=
match
Command
::
new
(
"nvidia-smi"
)
.args
([
"topo"
,
"--get-numa-id-of-nearby-cpu"
,
"-i"
,
&
device_id
.to_string
(),
])
.output
()
{
Ok
(
out
)
if
out
.status
.success
()
=>
out
,
_
=>
{
tracing
::
warn!
(
"nvidia-smi failed for GPU {}, using heuristic"
,
device_id
);
return
NumaNode
(
device_id
%
2
);
}
};
if
let
Ok
(
stdout
)
=
std
::
str
::
from_utf8
(
&
output
.stdout
)
&&
let
Some
(
line
)
=
stdout
.lines
()
.next
()
&&
let
Some
(
numa_str
)
=
line
.split
(
':'
)
.nth
(
1
)
&&
let
Ok
(
node
)
=
numa_str
.trim
()
.parse
::
<
u32
>
()
{
tracing
::
trace!
(
"GPU {} on NUMA node {}"
,
device_id
,
node
);
return
NumaNode
(
node
);
}
tracing
::
warn!
(
"Failed to get NUMA node for GPU {}"
,
device_id
);
NumaNode
::
UNKNOWN
}
/// Pin the current thread to a specific NUMA node's CPUs.
///
/// This sets the CPU affinity for the calling thread to only run on CPUs
/// belonging to the specified NUMA node. This is critical for ensuring
/// that memory allocations follow the first-touch policy on the correct node.
///
/// # Arguments
/// * `node` - The NUMA node to pin the thread to
///
/// # Errors
/// Returns an error if:
/// - NUMA topology cannot be read
/// - No CPUs are found for the specified node
/// - The `sched_setaffinity` syscall fails
pub
fn
pin_thread_to_numa_node
(
node
:
NumaNode
)
->
Result
<
(),
String
>
{
let
topology
=
topology
::
get_numa_topology
()
.map_err
(|
e
|
format!
(
"Can not get NUMA topology: {}"
,
e
))
?
;
let
cpus
=
topology
.cpus_for_node
(
node
.0
)
.ok_or_else
(||
format!
(
"No CPUs found for NUMA node {}"
,
node
.0
))
?
;
if
cpus
.is_empty
()
{
return
Err
(
format!
(
"No CPUs found for NUMA node {}"
,
node
.0
));
}
unsafe
{
let
mut
cpu_set
:
libc
::
cpu_set_t
=
mem
::
zeroed
();
for
cpu
in
cpus
{
libc
::
CPU_SET
(
*
cpu
,
&
mut
cpu_set
);
}
let
result
=
libc
::
sched_setaffinity
(
0
,
// current thread
mem
::
size_of
::
<
libc
::
cpu_set_t
>
(),
&
cpu_set
,
);
if
result
!=
0
{
let
err
=
std
::
io
::
Error
::
last_os_error
();
return
Err
(
format!
(
"Failed to set CPU affinity: {}"
,
err
));
}
}
Ok
(())
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_numa_node_equality
()
{
let
node0a
=
NumaNode
(
0
);
let
node0b
=
NumaNode
(
0
);
let
node1
=
NumaNode
(
1
);
assert_eq!
(
node0a
,
node0b
);
assert_ne!
(
node0a
,
node1
);
}
#[test]
fn
test_numa_node_unknown
()
{
let
unknown
=
NumaNode
::
UNKNOWN
;
assert
!
(
unknown
.is_unknown
());
assert_eq!
(
unknown
.0
,
u32
::
MAX
);
let
valid
=
NumaNode
(
0
);
assert
!
(
!
valid
.is_unknown
());
}
#[test]
fn
test_numa_node_display
()
{
assert_eq!
(
format!
(
"{}"
,
NumaNode
(
0
)),
"NumaNode(0)"
);
assert_eq!
(
format!
(
"{}"
,
NumaNode
(
7
)),
"NumaNode(7)"
);
assert_eq!
(
format!
(
"{}"
,
NumaNode
::
UNKNOWN
),
"UNKNOWN"
);
}
#[test]
fn
test_numa_node_serialization
()
{
// Verify NumaNode can be serialized (important for benchmarking)
let
node
=
NumaNode
(
1
);
let
json
=
serde_json
::
to_string
(
&
node
)
.unwrap
();
let
deserialized
:
NumaNode
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
assert_eq!
(
node
,
deserialized
);
}
#[test]
fn
test_get_current_cpu_numa_node
()
{
// Should either return a valid node or UNKNOWN
let
node
=
get_current_cpu_numa_node
();
// If not unknown, should be a reasonable NUMA node number (< 8 on most systems)
if
!
node
.is_unknown
()
{
assert
!
(
node
.0
<
8
,
"NUMA node {} seems unreasonably high"
,
node
.0
);
}
}
#[test]
fn
test_get_device_numa_node_valid_gpu
()
{
// Test GPU 0 detection
let
node
=
get_device_numa_node
(
0
);
// Should return either a valid node (0-7) or use heuristic (gpu_id % 2)
// On dual-socket systems, GPU 0 typically on node 0 or 1
println!
(
"GPU 0 detected on NUMA node: {}"
,
node
.0
);
}
#[test]
fn
test_numa_node_hash
()
{
// Verify NumaNode can be used as a HashMap key
use
std
::
collections
::
HashMap
;
let
mut
map
=
HashMap
::
new
();
map
.insert
(
NumaNode
(
0
),
"node0"
);
map
.insert
(
NumaNode
(
1
),
"node1"
);
assert_eq!
(
map
.get
(
&
NumaNode
(
0
)),
Some
(
&
"node0"
));
assert_eq!
(
map
.get
(
&
NumaNode
(
1
)),
Some
(
&
"node1"
));
assert_eq!
(
map
.get
(
&
NumaNode
(
2
)),
None
);
}
#[test]
fn
test_numa_node_copy_clone
()
{
// Verify NumaNode is Copy and Clone
let
node1
=
NumaNode
(
5
);
let
node2
=
node1
;
// Copy
let
node3
=
node1
;
// Clone
assert_eq!
(
node1
,
node2
);
assert_eq!
(
node1
,
node3
);
assert_eq!
(
node2
,
node3
);
}
}
lib/memory/src/numa/topology.rs
0 → 100644
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA topology detection
//!
//! This module provides utilities to read the actual CPU-to-NUMA mapping from the system,
//! replacing heuristic assumptions with real topology data.
use
std
::
collections
::
HashMap
;
use
std
::
fs
;
/// Global cached topology
static
TOPOLOGY
:
std
::
sync
::
OnceLock
<
Result
<
NumaTopology
,
String
>>
=
std
::
sync
::
OnceLock
::
new
();
/// Represents the CPU topology for NUMA nodes.
///
/// This struct provides bidirectional lookup between NUMA nodes and CPUs,
/// read from the Linux sysfs interface at `/sys/devices/system/node/`.
pub
struct
NumaTopology
{
/// Maps NUMA node ID -> list of CPU IDs
node_to_cpus
:
HashMap
<
u32
,
Vec
<
usize
>>
,
/// Maps CPU ID -> NUMA node ID
cpu_to_node
:
HashMap
<
usize
,
u32
>
,
}
impl
NumaTopology
{
/// Read NUMA topology from sysfs.
///
/// Parses `/sys/devices/system/node/node*/cpulist` to build the CPU-to-NUMA mapping.
pub
fn
from_sysfs
()
->
Result
<
Self
,
String
>
{
let
mut
node_to_cpus
:
HashMap
<
u32
,
Vec
<
usize
>>
=
HashMap
::
new
();
let
mut
cpu_to_node
:
HashMap
<
usize
,
u32
>
=
HashMap
::
new
();
let
node_dir
=
std
::
path
::
Path
::
new
(
"/sys/devices/system/node"
);
if
!
node_dir
.exists
()
{
return
Err
(
"Node directory not found"
.to_string
());
}
let
entries
=
fs
::
read_dir
(
node_dir
)
.map_err
(|
e
|
format!
(
"Failed to read node directory: {}"
,
e
))
?
;
for
entry
in
entries
.flatten
()
{
let
path
=
entry
.path
();
let
name
=
path
.file_name
()
.and_then
(|
n
|
n
.to_str
())
.unwrap_or
(
""
);
// Only process "nodeN" directories
if
!
name
.starts_with
(
"node"
)
{
continue
;
}
// Extract node number
let
node_id
:
u32
=
name
[
4
..
]
.parse
()
.map_err
(|
_
|
format!
(
"Invalid node dir: {}"
,
name
))
?
;
// Read cpulist file
let
cpulist_path
=
path
.join
(
"cpulist"
);
if
!
cpulist_path
.exists
()
{
continue
;
}
let
cpulist
=
fs
::
read_to_string
(
&
cpulist_path
)
.map_err
(|
e
|
format!
(
"Failed to read {}: {}"
,
cpulist_path
.display
(),
e
))
?
;
let
cpus
=
parse_cpulist
(
cpulist
.trim
())
?
;
// Populate both maps
for
cpu
in
&
cpus
{
cpu_to_node
.insert
(
*
cpu
,
node_id
);
}
node_to_cpus
.insert
(
node_id
,
cpus
);
}
if
node_to_cpus
.is_empty
()
{
return
Err
(
"No NUMA nodes found"
.to_string
());
}
Ok
(
Self
{
node_to_cpus
,
cpu_to_node
,
})
}
/// Returns all CPU IDs belonging to the given NUMA node.
///
/// Returns `None` if the node ID is not in the topology.
pub
fn
cpus_for_node
(
&
self
,
node_id
:
u32
)
->
Option
<&
[
usize
]
>
{
self
.node_to_cpus
.get
(
&
node_id
)
.map
(|
v
|
v
.as_slice
())
}
/// Returns the NUMA node ID that contains the given CPU.
///
/// Returns `None` if the CPU ID is not in the topology.
pub
fn
node_for_cpu
(
&
self
,
cpu_id
:
usize
)
->
Option
<
u32
>
{
self
.cpu_to_node
.get
(
&
cpu_id
)
.copied
()
}
/// Returns the number of NUMA nodes in the system.
pub
fn
num_nodes
(
&
self
)
->
usize
{
self
.node_to_cpus
.len
()
}
/// Returns `true` if this is a single-node (non-NUMA) system.
pub
fn
is_single_node
(
&
self
)
->
bool
{
self
.num_nodes
()
==
1
}
}
/// Parse Linux cpulist format.
///
/// # Examples
/// - `"0-15"` -> `[0,1,2,...,15]`
/// - `"0,4,8"` -> `[0,4,8]`
/// - `"0-3,8-11"` -> `[0,1,2,3,8,9,10,11]`
fn
parse_cpulist
(
cpulist
:
&
str
)
->
Result
<
Vec
<
usize
>
,
String
>
{
let
mut
cpus
=
Vec
::
new
();
for
part
in
cpulist
.split
(
','
)
{
if
part
.contains
(
'-'
)
{
// Range: "0-15"
let
range
:
Vec
<&
str
>
=
part
.split
(
'-'
)
.collect
();
if
range
.len
()
!=
2
{
return
Err
(
format!
(
"Invalid range: {}"
,
part
));
}
let
start
:
usize
=
range
[
0
]
.parse
()
.map_err
(|
_
|
format!
(
"Invalid CPU ID: {}"
,
range
[
0
]))
?
;
let
end
:
usize
=
range
[
1
]
.parse
()
.map_err
(|
_
|
format!
(
"Invalid CPU ID: {}"
,
range
[
1
]))
?
;
for
cpu
in
start
..=
end
{
cpus
.push
(
cpu
);
}
}
else
{
// Single CPU
let
cpu
:
usize
=
part
.parse
()
.map_err
(|
_
|
format!
(
"Invalid CPU ID: {}"
,
part
))
?
;
cpus
.push
(
cpu
);
}
}
cpus
.sort_unstable
();
cpus
.dedup
();
Ok
(
cpus
)
}
/// Get the global NUMA topology (cached after first call).
///
/// Returns an error if NUMA topology cannot be read from sysfs. This indicates either:
/// - System doesn't support NUMA
/// - `/sys` is not mounted (e.g., restricted container)
/// - Kernel NUMA support is disabled
///
/// Callers should handle errors gracefully by disabling NUMA optimizations.
pub
fn
get_numa_topology
()
->
Result
<&
'static
NumaTopology
,
&
'static
str
>
{
TOPOLOGY
.get_or_init
(
NumaTopology
::
from_sysfs
)
.as_ref
()
.map_err
(|
e
|
{
tracing
::
warn!
(
"NUMA topology unavailable: {}"
,
e
);
"NUMA topology unavailable"
})
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_parse_cpulist_range
()
{
let
cpus
=
parse_cpulist
(
"0-3"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
0
,
1
,
2
,
3
]);
}
#[test]
fn
test_parse_cpulist_list
()
{
let
cpus
=
parse_cpulist
(
"0,4,8"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
0
,
4
,
8
]);
}
#[test]
fn
test_parse_cpulist_mixed
()
{
let
cpus
=
parse_cpulist
(
"0-2,8,16-17"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
0
,
1
,
2
,
8
,
16
,
17
]);
}
#[test]
fn
test_parse_cpulist_ht
()
{
// Hyperthreading: 0-15,32-47 (physical cores 0-15, HT siblings 32-47)
let
cpus
=
parse_cpulist
(
"0-15,32-47"
)
.unwrap
();
assert_eq!
(
cpus
.len
(),
32
);
assert_eq!
(
cpus
[
0
],
0
);
assert_eq!
(
cpus
[
15
],
15
);
assert_eq!
(
cpus
[
16
],
32
);
assert_eq!
(
cpus
[
31
],
47
);
}
#[test]
fn
test_parse_cpulist_real_numa_system
()
{
// Real dual-socket system with hyperthreading (discovered pattern)
// Node 0: CPUs 0-15, 128-143
let
cpus
=
parse_cpulist
(
"0-15,128-143"
)
.unwrap
();
assert_eq!
(
cpus
.len
(),
32
);
assert_eq!
(
cpus
[
0
],
0
);
assert_eq!
(
cpus
[
15
],
15
);
assert_eq!
(
cpus
[
16
],
128
);
assert_eq!
(
cpus
[
31
],
143
);
// Node 1: CPUs 16-31, 144-159
let
cpus
=
parse_cpulist
(
"16-31,144-159"
)
.unwrap
();
assert_eq!
(
cpus
.len
(),
32
);
assert_eq!
(
cpus
[
0
],
16
);
assert_eq!
(
cpus
[
15
],
31
);
assert_eq!
(
cpus
[
16
],
144
);
assert_eq!
(
cpus
[
31
],
159
);
}
#[test]
fn
test_parse_cpulist_out_of_order
()
{
// Test that parser handles out-of-order input (seen in some systems)
let
cpus
=
parse_cpulist
(
"4,2,0,1,3"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
0
,
1
,
2
,
3
,
4
]);
// Should be sorted
}
#[test]
fn
test_parse_cpulist_duplicates
()
{
// Test deduplication (in case kernel reports duplicates)
let
cpus
=
parse_cpulist
(
"0-2,1-3"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
0
,
1
,
2
,
3
]);
// Should remove duplicates
}
#[test]
fn
test_parse_cpulist_empty
()
{
// Edge case: empty cpulist
let
result
=
parse_cpulist
(
""
);
assert
!
(
result
.is_err
()
||
result
.unwrap
()
.is_empty
());
}
#[test]
fn
test_parse_cpulist_single_cpu
()
{
// Single CPU node (uncommon but valid)
let
cpus
=
parse_cpulist
(
"5"
)
.unwrap
();
assert_eq!
(
cpus
,
vec!
[
5
]);
}
#[test]
fn
test_topology_bidirectional_lookup
()
{
// Test that node->cpu and cpu->node mappings are consistent
let
mut
node_to_cpus
=
std
::
collections
::
HashMap
::
new
();
let
mut
cpu_to_node
=
std
::
collections
::
HashMap
::
new
();
node_to_cpus
.insert
(
0
,
vec!
[
0
,
1
,
2
,
3
]);
node_to_cpus
.insert
(
1
,
vec!
[
4
,
5
,
6
,
7
]);
for
(
node
,
cpus
)
in
&
node_to_cpus
{
for
cpu
in
cpus
{
cpu_to_node
.insert
(
*
cpu
,
*
node
);
}
}
let
topology
=
NumaTopology
{
node_to_cpus
,
cpu_to_node
,
};
// Verify forward lookup (node -> cpus)
assert_eq!
(
topology
.cpus_for_node
(
0
),
Some
(
&
[
0
,
1
,
2
,
3
][
..
]));
assert_eq!
(
topology
.cpus_for_node
(
1
),
Some
(
&
[
4
,
5
,
6
,
7
][
..
]));
// Verify reverse lookup (cpu -> node)
assert_eq!
(
topology
.node_for_cpu
(
0
),
Some
(
0
));
assert_eq!
(
topology
.node_for_cpu
(
3
),
Some
(
0
));
assert_eq!
(
topology
.node_for_cpu
(
4
),
Some
(
1
));
assert_eq!
(
topology
.node_for_cpu
(
7
),
Some
(
1
));
// Verify unknown CPU
assert_eq!
(
topology
.node_for_cpu
(
999
),
None
);
}
}
lib/memory/src/numa/worker_pool.rs
0 → 100644
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NUMA worker pool for memory allocation with first-touch policy.
//!
//! This module provides dedicated worker threads that are pinned to specific NUMA nodes.
//!
//! ## Architecture
//!
//! - One worker thread per NUMA node (spawned lazily)
//! - Workers pin themselves on startup (immune to application thread management)
//! - Channel-based communication for allocation requests
//! - First-touch page allocation ensures correct NUMA placement
use
super
::
get_current_cpu_numa_node
;
use
cudarc
::
driver
::
CudaContext
;
use
cudarc
::
driver
::
result
::
malloc_host
;
use
cudarc
::
driver
::
sys
::
CU_MEMHOSTALLOC_DEVICEMAP
;
use
nix
::
libc
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
mpsc
::{
Receiver
,
Sender
,
channel
};
use
std
::
sync
::{
Arc
,
Mutex
,
OnceLock
};
use
std
::
thread
::{
self
,
JoinHandle
};
use
std
::
time
::
Duration
;
use
super
::{
NumaNode
,
get_device_numa_node
};
/// Get or create a CUDA context for the given device.
fn
cuda_context
(
device_id
:
u32
)
->
Result
<
Arc
<
CudaContext
>
,
String
>
{
static
CONTEXTS
:
OnceLock
<
Mutex
<
HashMap
<
u32
,
Arc
<
CudaContext
>>>>
=
OnceLock
::
new
();
let
mut
map
=
CONTEXTS
.get_or_init
(
Default
::
default
)
.lock
()
.unwrap
();
if
let
Some
(
existing
)
=
map
.get
(
&
device_id
)
{
return
Ok
(
existing
.clone
());
}
let
ctx
=
CudaContext
::
new
(
device_id
as
usize
)
.map_err
(|
e
|
{
format!
(
"Failed to create CUDA context for device {}: {:?}"
,
device_id
,
e
)
})
?
;
map
.insert
(
device_id
,
ctx
.clone
());
Ok
(
ctx
)
}
/// Wrapper for raw pointer that can be sent between threads.
///
/// # Safety
///
/// This wrapper allows sending raw pointers across thread boundaries. The safety contract is:
/// - The pointer is allocated by the worker thread and returned to the caller
/// - The pointer is only dereferenced by the receiver (caller), never by the sender (worker)
/// - Ownership is transferred: the caller is responsible for deallocation
/// - The pointer remains valid for the lifetime expected by the caller
struct
SendPtr
(
*
mut
u8
);
// SAFETY: The pointer ownership is transferred from worker to caller.
// The worker never accesses the pointer after sending it.
unsafe
impl
Send
for
SendPtr
{}
/// Request to allocate CUDA pinned memory on a specific NUMA node.
struct
AllocRequest
{
/// Number of bytes to allocate.
size
:
usize
,
/// Target NUMA node for allocation.
node
:
NumaNode
,
/// CUDA device ID (for context binding).
gpu_id
:
u32
,
/// Channel for sending back the allocation result.
response
:
Sender
<
AllocResult
>
,
}
/// Result of allocation.
type
AllocResult
=
Result
<
SendPtr
,
String
>
;
/// A dedicated worker thread pinned to a specific NUMA node.
struct
NumaWorker
{
node
:
NumaNode
,
request_tx
:
Option
<
Sender
<
AllocRequest
>>
,
handle
:
Option
<
JoinHandle
<
()
>>
,
}
impl
NumaWorker
{
/// Spawn a new worker thread pinned to the specified NUMA node.
fn
spawn
(
node
:
NumaNode
)
->
Result
<
Self
,
String
>
{
let
(
request_tx
,
request_rx
)
=
channel
();
let
handle
=
thread
::
Builder
::
new
()
.name
(
format!
(
"numa-worker-{}"
,
node
.0
))
.spawn
(
move
||
{
Self
::
worker_loop
(
node
,
request_rx
);
})
.map_err
(|
e
|
format!
(
"Failed to spawn worker thread: {}"
,
e
))
?
;
Ok
(
Self
{
node
,
request_tx
:
Some
(
request_tx
),
handle
:
Some
(
handle
),
})
}
/// Worker thread main loop that processes allocation requests.
///
/// On startup, the worker pins itself to the target NUMA node using
/// `sched_setaffinity`. It then processes allocation requests in a loop
/// until the channel is closed.
fn
worker_loop
(
node
:
NumaNode
,
requests
:
Receiver
<
AllocRequest
>
)
{
// First thing: pin this thread to the target NUMA node
tracing
::
trace!
(
"Pinning worker thread to node {}"
,
node
.0
);
if
let
Err
(
e
)
=
super
::
pin_thread_to_numa_node
(
node
)
{
tracing
::
error!
(
"Failed to pin worker thread to node {}: {}"
,
node
.0
,
e
);
tracing
::
error!
(
"Worker will continue but allocations may be suboptimal"
);
}
else
{
tracing
::
trace!
(
"Successfully pinned worker thread to node {}"
,
node
.0
);
// `pin_thread_to_numa_node` uses `sched_setaffinity` to set the CPU affinity mask
// but doesn't immediately migrate the thread. The scheduler will migrate at
// the next opportunity (timer tick, yield, etc).
// We yield once to give the scheduler a chance to migrate before we verify.
// This is primarily for accurate logging - allocations will happen on the right CPU
// regardless since the affinity mask prevents running on wrong CPUs.
thread
::
yield_now
();
thread
::
sleep
(
Duration
::
from_millis
(
1
));
// Verify we're on the right node
let
current_node
=
super
::
get_current_cpu_numa_node
();
tracing
::
trace!
(
"Current node after pinning: {}"
,
current_node
.0
);
if
current_node
!=
node
{
tracing
::
warn!
(
"Worker thread on node {} after pinning (expected {})"
,
current_node
.0
,
node
.0
);
}
else
{
tracing
::
trace!
(
"NUMA worker thread for node {} started and pinned"
,
node
.0
);
}
}
// Process allocation requests
loop
{
tracing
::
trace!
(
"Worker waiting for request on node {}"
,
node
.0
);
match
requests
.recv
()
{
Ok
(
req
)
=>
{
tracing
::
trace!
(
"Worker received CUDA pinned allocation request on node {}"
,
node
.0
);
let
result
=
Self
::
do_cuda_pinned_allocation
(
req
.size
,
req
.node
,
req
.gpu_id
);
match
result
{
Ok
(
SendPtr
(
ptr
))
=>
{
if
let
Err
(
_
e
)
=
req
.response
.send
(
Ok
(
SendPtr
(
ptr
)))
{
// Receiver gone: free to avoid leak
tracing
::
warn!
(
"Receiver dropped before receiving allocation, freeing {} bytes at {:p}"
,
req
.size
,
ptr
);
unsafe
{
let
_
=
cudarc
::
driver
::
result
::
free_host
(
ptr
as
*
mut
std
::
ffi
::
c_void
,
);
}
}
}
Err
(
err
)
=>
{
let
_
=
req
.response
.send
(
Err
(
err
));
}
}
}
Err
(
_
)
=>
{
// Channel closed, exit worker
tracing
::
trace!
(
"NUMA worker for node {} shutting down (channel closed)"
,
node
.0
);
break
;
}
}
}
}
/// Perform CUDA pinned memory allocation.
fn
do_cuda_pinned_allocation
(
size
:
usize
,
node
:
NumaNode
,
gpu_id
:
u32
)
->
AllocResult
{
if
size
==
0
{
return
Err
(
"Cannot allocate zero bytes"
.to_string
());
}
// Verify we're on the correct NUMA node BEFORE allocation
let
node_before
=
get_current_cpu_numa_node
();
if
node_before
!=
node
{
tracing
::
warn!
(
"Worker thread moved! Expected NUMA node {}, currently on node {}"
,
node
.0
,
node_before
.0
);
}
// Get or create CUDA context for this GPU
let
ctx
=
cuda_context
(
gpu_id
)
?
;
unsafe
{
// Bind CUDA context to this worker thread before allocation
// This ensures malloc_host has a valid context to work with
ctx
.bind_to_thread
()
.map_err
(|
e
|
format!
(
"Failed to bind CUDA context to worker thread: {:?}"
,
e
))
?
;
// Verify thread is still on correct node after CUDA context binding
let
node_after_ctx
=
get_current_cpu_numa_node
();
if
node_after_ctx
!=
node
{
tracing
::
warn!
(
"Thread moved after CUDA context bind! Expected node {}, now on node {}"
,
node
.0
,
node_after_ctx
.0
);
}
// Allocate CUDA pinned memory
// This is called from the pinned worker thread, so pages will be
// allocated on the correct NUMA node via first-touch
let
ptr
=
malloc_host
(
size
,
CU_MEMHOSTALLOC_DEVICEMAP
)
.map_err
(|
e
|
format!
(
"malloc_host failed: {:?}"
,
e
))
?
;
let
ptr
=
ptr
as
*
mut
u8
;
if
ptr
.is_null
()
{
return
Err
(
"malloc_host returned null"
.to_string
());
}
// Verify thread is STILL on correct node before touching pages
let
node_before_touch
=
get_current_cpu_numa_node
();
if
node_before_touch
!=
node
{
tracing
::
error!
(
"Thread on wrong node before first-touch! Expected {}, on node {} - memory will be misplaced!"
,
node
.0
,
node_before_touch
.0
);
}
// Touch one byte per page to trigger first-touch policy efficiently
// This is much faster than zeroing the entire region for large allocations
let
page_size
=
match
libc
::
sysconf
(
libc
::
_
SC_PAGESIZE
)
{
n
if
n
>
0
=>
n
as
usize
,
_
=>
4096
,
};
let
mut
offset
=
0u
size
;
while
offset
<
size
{
std
::
ptr
::
write_volatile
(
ptr
.add
(
offset
),
0
);
offset
=
offset
.saturating_add
(
page_size
);
}
// Ensure the last page is touched
if
size
>
0
&&
!
size
.is_multiple_of
(
page_size
)
{
std
::
ptr
::
write_volatile
(
ptr
.add
(
size
-
1
),
0
);
}
// Verify final node after touching
let
node_after_touch
=
get_current_cpu_numa_node
();
tracing
::
trace!
(
"Worker allocated {} bytes (CUDA pinned) on GPU {} (target NUMA node {}) at {:p} - thread nodes: before={} after_ctx={} before_touch={} after_touch={}"
,
size
,
gpu_id
,
node
.0
,
ptr
,
node_before
.0
,
node_after_ctx
.0
,
node_before_touch
.0
,
node_after_touch
.0
);
Ok
(
SendPtr
(
ptr
))
}
}
/// Request an allocation from this worker.
fn
allocate
(
&
self
,
size
:
usize
,
gpu_id
:
u32
)
->
AllocResult
{
let
(
response_tx
,
response_rx
)
=
channel
();
let
request
=
AllocRequest
{
size
,
node
:
self
.node
,
gpu_id
,
response
:
response_tx
,
};
self
.request_tx
.as_ref
()
.ok_or_else
(||
"Worker has been shut down"
.to_string
())
?
.send
(
request
)
.map_err
(|
_
|
"Worker thread has died"
.to_string
())
?
;
// Wait for response with dynamic timeout based on allocation size
// Large allocations take time: we account for ~1 second per GB to touch pages
// Add 10 second base + 1 second per GB
let
timeout_secs
=
10u64
+
(
size
as
u64
/
(
1024
*
1024
*
1024
));
let
timeout
=
Duration
::
from_secs
(
timeout_secs
.clamp
(
10
,
300
));
// Clamp to 10-300 seconds
tracing
::
trace!
(
"Worker pool waiting for allocation of {} MB with timeout of {} seconds"
,
size
/
(
1024
*
1024
),
timeout
.as_secs
()
);
response_rx
.recv_timeout
(
timeout
)
.map_err
(|
e
|
format!
(
"Worker timeout after {} seconds: {}"
,
timeout
.as_secs
(),
e
))
?
}
}
impl
Drop
for
NumaWorker
{
fn
drop
(
&
mut
self
)
{
tracing
::
trace!
(
"Dropping NUMA worker for node {}"
,
self
.node
.0
);
// Drop request_tx FIRST to close the channel
// This causes recv() in worker thread to return Err and exit
self
.request_tx
.take
();
tracing
::
trace!
(
"Channel closed for worker node {}"
,
self
.node
.0
);
// Now the worker thread will exit its loop
if
let
Some
(
handle
)
=
self
.handle
.take
()
{
tracing
::
trace!
(
"Waiting for worker thread {} to join"
,
self
.node
.0
);
let
_
=
handle
.join
();
tracing
::
trace!
(
"Worker thread {} joined"
,
self
.node
.0
);
}
}
}
/// Pool of NUMA workers, one per node.
///
/// This pool manages dedicated worker threads that are pinned to specific NUMA nodes.
/// When you request an allocation for a GPU, the pool automatically determines the
/// GPU's NUMA node and routes the request to the appropriate worker.
pub
struct
NumaWorkerPool
{
workers
:
Mutex
<
std
::
collections
::
HashMap
<
u32
,
Arc
<
NumaWorker
>>>
,
}
impl
NumaWorkerPool
{
fn
new
()
->
Self
{
Self
{
workers
:
Mutex
::
new
(
std
::
collections
::
HashMap
::
new
()),
}
}
/// Get the global worker pool.
///
/// The pool is created lazily on first access and lives for the entire process lifetime.
pub
fn
global
()
->
&
'static
Self
{
static
POOL
:
OnceLock
<
NumaWorkerPool
>
=
OnceLock
::
new
();
POOL
.get_or_init
(
NumaWorkerPool
::
new
)
}
/// Get or create a worker for a NUMA node.
fn
get_or_spawn_worker
(
&
self
,
node
:
NumaNode
)
->
Result
<
Arc
<
NumaWorker
>
,
String
>
{
let
mut
workers
=
self
.workers
.lock
()
.unwrap
();
if
let
Some
(
worker
)
=
workers
.get
(
&
node
.0
)
{
return
Ok
(
worker
.clone
());
}
// Spawn new worker
let
worker
=
NumaWorker
::
spawn
(
node
)
?
;
let
worker
=
Arc
::
new
(
worker
);
workers
.insert
(
node
.0
,
worker
.clone
());
tracing
::
trace!
(
"Spawned NUMA worker for node {}"
,
node
.0
);
Ok
(
worker
)
}
/// Allocate CUDA pinned memory for a specific GPU (auto-detects NUMA node).
///
/// This method:
/// 1. Determines the GPU's NUMA node via nvidia-smi
/// 2. Routes the allocation to a worker pinned to that node
/// 3. The worker allocates and touches pages to ensure first-touch placement
///
/// # Arguments
/// * `size` - Number of bytes to allocate
/// * `gpu_id` - CUDA device ID
///
/// # Returns
/// Raw pointer to the allocated memory. Caller is responsible for freeing via
/// `cudarc::driver::result::free_host`.
pub
fn
allocate_pinned_for_gpu
(
&
self
,
size
:
usize
,
gpu_id
:
u32
)
->
Result
<*
mut
u8
,
String
>
{
let
node
=
get_device_numa_node
(
gpu_id
);
tracing
::
debug!
(
"Allocating {} bytes pinned memory for GPU {} (NUMA node {})"
,
size
,
gpu_id
,
node
.0
);
let
worker
=
self
.get_or_spawn_worker
(
node
)
?
;
worker
.allocate
(
size
,
gpu_id
)
.map
(|
send_ptr
|
send_ptr
.0
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
numa
::{
get_current_cpu_numa_node
,
get_device_numa_node
};
/// Check if CUDA is available for testing.
fn
is_cuda_available
()
->
bool
{
// Check if nvidia-smi is available
if
std
::
process
::
Command
::
new
(
"nvidia-smi"
)
.arg
(
"--query-gpu=count"
)
.arg
(
"--format=csv,noheader"
)
.output
()
.is_err
()
{
return
false
;
}
// Try to initialize CUDA context for device 0
cuda_context
(
0
)
.is_ok
()
}
#[test]
fn
test_worker_spawn
()
{
let
node
=
NumaNode
(
0
);
let
worker
=
NumaWorker
::
spawn
(
node
);
assert
!
(
worker
.is_ok
());
}
#[test]
fn
test_worker_allocate_pinned
()
{
if
!
is_cuda_available
()
{
eprintln!
(
"Skipping test_worker_allocate_pinned: CUDA not available"
);
return
;
}
let
node
=
NumaNode
(
0
);
let
worker
=
NumaWorker
::
spawn
(
node
)
.unwrap
();
let
send_ptr
=
worker
.allocate
(
4096
,
0
)
.unwrap
();
let
ptr
=
send_ptr
.0
;
assert
!
(
!
ptr
.is_null
());
unsafe
{
cudarc
::
driver
::
result
::
free_host
(
ptr
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
}
}
#[test]
fn
test_worker_pool
()
{
if
!
is_cuda_available
()
{
eprintln!
(
"Skipping test_worker_pool: CUDA not available"
);
return
;
}
let
pool
=
NumaWorkerPool
::
new
();
unsafe
{
let
ptr
=
pool
.allocate_pinned_for_gpu
(
8192
,
0
)
.unwrap
();
assert
!
(
!
ptr
.is_null
());
cudarc
::
driver
::
result
::
free_host
(
ptr
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
}
}
#[test]
fn
test_worker_pool_singleton
()
{
// Verify that global() returns the same instance
let
pool1
=
NumaWorkerPool
::
global
();
let
pool2
=
NumaWorkerPool
::
global
();
// They should be the same static reference
assert
!
(
std
::
ptr
::
eq
(
pool1
,
pool2
));
}
#[test]
fn
test_worker_reuse
()
{
if
!
is_cuda_available
()
{
eprintln!
(
"Skipping test_worker_reuse: CUDA not available"
);
return
;
}
// Test that subsequent allocations for the same GPU reuse the same worker
let
pool
=
NumaWorkerPool
::
new
();
unsafe
{
// First allocation spawns worker for GPU 0
let
ptr1
=
pool
.allocate_pinned_for_gpu
(
1024
,
0
)
.unwrap
();
// Second allocation should reuse worker for GPU 0
let
ptr2
=
pool
.allocate_pinned_for_gpu
(
1024
,
0
)
.unwrap
();
assert
!
(
!
ptr1
.is_null
());
assert
!
(
!
ptr2
.is_null
());
assert_ne!
(
ptr1
,
ptr2
);
cudarc
::
driver
::
result
::
free_host
(
ptr1
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
cudarc
::
driver
::
result
::
free_host
(
ptr2
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
}
}
#[test]
fn
test_zero_size_allocation
()
{
// Test that zero-size allocations are rejected
let
pool
=
NumaWorkerPool
::
new
();
let
result
=
pool
.allocate_pinned_for_gpu
(
0
,
0
);
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"zero"
));
}
#[test]
fn
test_get_current_cpu_numa_node
()
{
// Test that we can detect current CPU's NUMA node
let
node
=
get_current_cpu_numa_node
();
// On a real NUMA system, should return a valid node
// On fake NUMA or single-node, might return 0 or UNKNOWN
if
!
node
.is_unknown
()
{
println!
(
"Current CPU on NUMA node: {}"
,
node
.0
);
}
else
{
println!
(
"NUMA node detection unavailable (single-node or fake NUMA)"
);
}
}
#[test]
fn
test_get_device_numa_node
()
{
// Test GPU NUMA node detection
// This will only work if nvidia-smi is available
let
node
=
get_device_numa_node
(
0
);
if
!
node
.is_unknown
()
{
println!
(
"GPU 0 on NUMA node: {}"
,
node
.0
);
// Node should be 0 or 1 on typical dual-socket systems
assert
!
(
node
.0
<=
1
||
node
.0
==
u32
::
MAX
);
}
else
{
println!
(
"GPU NUMA detection unavailable (no nvidia-smi or no GPU)"
);
}
}
#[test]
fn
test_numa_node_display
()
{
// Test Display implementation for NumaNode
let
node
=
NumaNode
(
0
);
assert_eq!
(
format!
(
"{}"
,
node
),
"NumaNode(0)"
);
let
unknown
=
NumaNode
::
UNKNOWN
;
assert_eq!
(
format!
(
"{}"
,
unknown
),
"UNKNOWN"
);
}
#[test]
fn
test_numa_node_is_unknown
()
{
let
valid
=
NumaNode
(
0
);
assert
!
(
!
valid
.is_unknown
());
let
unknown
=
NumaNode
::
UNKNOWN
;
assert
!
(
unknown
.is_unknown
());
}
#[test]
fn
test_pinned_allocation_api
()
{
// Verify the public API works for pinned allocation
let
pool
=
NumaWorkerPool
::
new
();
unsafe
{
// Test that we can allocate pinned memory for a GPU
if
let
Ok
(
ptr
)
=
pool
.allocate_pinned_for_gpu
(
1024
,
0
)
{
assert
!
(
!
ptr
.is_null
());
cudarc
::
driver
::
result
::
free_host
(
ptr
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
}
}
}
#[test]
fn
test_worker_channel_communication
()
{
// Test that worker receives and processes requests
let
node
=
NumaNode
(
0
);
let
worker
=
NumaWorker
::
spawn
(
node
)
.unwrap
();
// Send allocation request
let
result
=
worker
.allocate
(
1024
,
0
);
// Should get a response (either success or timeout)
assert
!
(
result
.is_ok
()
||
result
.is_err
());
if
let
Ok
(
send_ptr
)
=
result
{
unsafe
{
let
ptr
=
send_ptr
.0
;
assert
!
(
!
ptr
.is_null
());
cudarc
::
driver
::
result
::
free_host
(
ptr
as
*
mut
std
::
ffi
::
c_void
)
.unwrap
();
}
}
}
}
lib/memory/src/offset.rs
View file @
976bb70a
...
...
@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use
super
::{
Any
,
Buffer
,
MemoryDescript
ion
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
,
Any
,
Buffer
,
MemoryDescript
or
,
Result
,
StorageError
,
StorageKind
,
nixl
::
NixlDescriptor
,
};
/// An [`OffsetBuffer`] is a new [`Buffer`]-like object that represents a sub-region (still contiguous)
...
...
@@ -40,6 +40,26 @@ impl OffsetBuffer {
Ok
(
Self
{
base
,
offset
,
size
})
}
/// Creates an offset buffer from an absolute address within the base region.
pub
fn
from_inner_address
(
base
:
Buffer
,
address
:
usize
,
size
:
usize
)
->
Result
<
Self
>
{
// Use checked arithmetic to prevent overflow
let
end
=
address
.checked_add
(
size
)
.ok_or_else
(||
StorageError
::
Unsupported
(
"address + size overflow"
.into
()))
?
;
let
base_end
=
base
.addr
()
.checked_add
(
base
.size
())
.ok_or_else
(||
StorageError
::
Unsupported
(
"base address + size overflow"
.into
()))
?
;
// Verify address is within the base region
if
address
<
base
.addr
()
||
end
>
base_end
{
return
Err
(
StorageError
::
Unsupported
(
"address out of bounds"
.into
()));
}
let
offset
=
address
-
base
.addr
();
Self
::
new
(
base
,
offset
,
size
)
}
/// Get the offset relative to the base mapping.
pub
fn
offset
(
&
self
)
->
usize
{
self
.offset
...
...
@@ -51,7 +71,7 @@ impl OffsetBuffer {
}
}
impl
MemoryDescript
ion
for
OffsetBuffer
{
impl
MemoryDescript
or
for
OffsetBuffer
{
fn
addr
(
&
self
)
->
usize
{
self
.base
.addr
()
+
self
.offset
}
...
...
@@ -75,3 +95,189 @@ impl MemoryDescription for OffsetBuffer {
Some
(
descriptor
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
SystemStorage
;
fn
create_test_buffer
(
size
:
usize
)
->
Buffer
{
Buffer
::
new
(
SystemStorage
::
new
(
size
)
.expect
(
"allocation failed"
))
}
#[test]
fn
test_offset_buffer_new_valid
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
100
);
assert_eq!
(
offset_buf
.size
(),
200
);
}
#[test]
fn
test_offset_buffer_new_zero_offset
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
.clone
(),
0
,
1024
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
0
);
assert_eq!
(
offset_buf
.size
(),
1024
);
assert_eq!
(
offset_buf
.addr
(),
base
.addr
());
}
#[test]
fn
test_offset_buffer_new_at_end
()
{
let
base
=
create_test_buffer
(
1024
);
// Offset at exact end with zero size should succeed
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
1024
,
0
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
1024
);
assert_eq!
(
offset_buf
.size
(),
0
);
}
#[test]
fn
test_offset_buffer_new_invalid_offset
()
{
let
base
=
create_test_buffer
(
1024
);
// Offset beyond bounds
let
result
=
OffsetBuffer
::
new
(
base
,
1025
,
0
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_new_invalid_size
()
{
let
base
=
create_test_buffer
(
1024
);
// Size exceeds remaining space
let
result
=
OffsetBuffer
::
new
(
base
,
100
,
1000
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_new_size_overflow
()
{
let
base
=
create_test_buffer
(
1024
);
// offset + size would overflow usize
let
result
=
OffsetBuffer
::
new
(
base
,
usize
::
MAX
,
1
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_from_inner_address_valid
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
let
offset_buf
=
OffsetBuffer
::
from_inner_address
(
base
,
base_addr
+
100
,
200
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
100
);
assert_eq!
(
offset_buf
.size
(),
200
);
}
#[test]
fn
test_offset_buffer_from_inner_address_at_start
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
let
offset_buf
=
OffsetBuffer
::
from_inner_address
(
base
.clone
(),
base_addr
,
1024
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
0
);
assert_eq!
(
offset_buf
.addr
(),
base
.addr
());
}
#[test]
fn
test_offset_buffer_from_inner_address_overflow
()
{
let
base
=
create_test_buffer
(
1024
);
// address + size would overflow
let
result
=
OffsetBuffer
::
from_inner_address
(
base
,
usize
::
MAX
,
1
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_from_inner_address_out_of_bounds_before
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
// Address before base region
let
result
=
OffsetBuffer
::
from_inner_address
(
base
,
base_addr
.saturating_sub
(
1
),
100
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_from_inner_address_out_of_bounds_after
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
// End address beyond base region
let
result
=
OffsetBuffer
::
from_inner_address
(
base
,
base_addr
+
900
,
200
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_offset_buffer_accessors
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
256
,
512
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.offset
(),
256
);
assert_eq!
(
offset_buf
.base
()
.addr
(),
base_addr
);
assert_eq!
(
offset_buf
.base
()
.size
(),
1024
);
}
#[test]
fn
test_offset_buffer_memory_descriptor_addr
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_addr
=
base
.addr
();
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
// addr() should return base_addr + offset
assert_eq!
(
offset_buf
.addr
(),
base_addr
+
100
);
}
#[test]
fn
test_offset_buffer_memory_descriptor_size
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
assert_eq!
(
offset_buf
.size
(),
200
);
}
#[test]
fn
test_offset_buffer_memory_descriptor_storage_kind
()
{
let
base
=
create_test_buffer
(
1024
);
let
base_kind
=
base
.storage_kind
();
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
// storage_kind should match the base
assert_eq!
(
offset_buf
.storage_kind
(),
base_kind
);
}
#[test]
fn
test_offset_buffer_as_any
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
// Should be able to downcast to OffsetBuffer
let
any_ref
=
offset_buf
.as_any
();
assert
!
(
any_ref
.downcast_ref
::
<
OffsetBuffer
>
()
.is_some
());
}
#[test]
fn
test_offset_buffer_clone
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
let
cloned
=
offset_buf
.clone
();
assert_eq!
(
offset_buf
.addr
(),
cloned
.addr
());
assert_eq!
(
offset_buf
.size
(),
cloned
.size
());
assert_eq!
(
offset_buf
.offset
(),
cloned
.offset
());
}
#[test]
fn
test_offset_buffer_debug
()
{
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
let
debug_str
=
format!
(
"{:?}"
,
offset_buf
);
assert
!
(
debug_str
.contains
(
"OffsetBuffer"
));
assert
!
(
debug_str
.contains
(
"offset"
));
assert
!
(
debug_str
.contains
(
"size"
));
}
#[test]
fn
test_offset_buffer_nixl_descriptor_none
()
{
// SystemStorage doesn't have a NIXL descriptor
let
base
=
create_test_buffer
(
1024
);
let
offset_buf
=
OffsetBuffer
::
new
(
base
,
100
,
200
)
.expect
(
"should succeed"
);
// Should return None since base has no NIXL descriptor
assert
!
(
offset_buf
.nixl_descriptor
()
.is_none
());
}
}
lib/memory/src/pinned.rs
View file @
976bb70a
...
...
@@ -3,7 +3,7 @@
//! CUDA pinned host memory storage.
use
super
::{
MemoryDescript
ion
,
Result
,
StorageError
,
StorageKind
,
actions
,
nixl
::
NixlDescriptor
};
use
super
::{
MemoryDescript
or
,
Result
,
StorageError
,
StorageKind
,
actions
,
nixl
::
NixlDescriptor
};
use
cudarc
::
driver
::
CudaContext
;
use
cudarc
::
driver
::
sys
;
use
std
::
any
::
Any
;
...
...
@@ -27,8 +27,11 @@ fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
/// CUDA pinned host memory allocated via cudaHostAlloc.
#[derive(Debug)]
pub
struct
PinnedStorage
{
/// Host pointer to the pinned memory.
ptr
:
usize
,
/// Size of the allocation in bytes.
len
:
usize
,
/// CUDA context used for allocation and deallocation.
ctx
:
Arc
<
CudaContext
>
,
}
...
...
@@ -38,21 +41,63 @@ unsafe impl Sync for PinnedStorage {}
impl
PinnedStorage
{
/// Allocate new pinned memory of the given size.
///
/// This is a convenience method that calls `new_for_device(len, None)`.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device to associate with the allocation
pub
fn
new
(
len
:
usize
)
->
Result
<
Self
>
{
Self
::
new_for_device
(
len
,
None
)
}
/// Allocate pinned memory, optionally NUMA-aware for a specific GPU.
///
/// When `device_id` is `Some`, the allocation is performed on a worker thread
/// pinned to the GPU's NUMA node, ensuring optimal memory placement via
/// first-touch policy, However, NUMA is only used if enabled via the
/// `DYN_KVBM_ENABLE_NUMA=1` environment variable.
///
/// When `device_id` is `None`, a direct allocation is performed on device 0.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - If Some, use NUMA-aware allocation on the GPU's NUMA node
///
/// # Errors
/// Returns an error if:
/// - `len` is 0
/// - CUDA context creation fails
/// - Memory allocation fails
pub
fn
new_for_device
(
len
:
usize
,
device_id
:
Option
<
u32
>
)
->
Result
<
Self
>
{
use
super
::
numa
;
if
len
==
0
{
return
Err
(
StorageError
::
AllocationFailed
(
"zero-sized allocations are not supported"
.into
(),
));
}
let
ctx
=
cuda_context
(
0
)
?
;
let
ptr
=
unsafe
{
let
gpu_id
=
device_id
.unwrap_or
(
0
);
let
ctx
=
cuda_context
(
gpu_id
)
?
;
let
ptr
=
match
device_id
{
Some
(
gpu_id
)
if
numa
::
is_numa_enabled
()
=>
{
// NUMA-aware allocation via worker pool
tracing
::
debug!
(
"Using NUMA-aware allocation for {} bytes on GPU {}"
,
len
,
gpu_id
);
numa
::
worker_pool
::
NumaWorkerPool
::
global
()
.allocate_pinned_for_gpu
(
len
,
gpu_id
)
.map_err
(
StorageError
::
AllocationFailed
)
?
as
usize
}
_
=>
{
// Direct allocation (no NUMA or device_id not specified)
unsafe
{
ctx
.bind_to_thread
()
.map_err
(
StorageError
::
Cuda
)
?
;
let
ptr
=
cudarc
::
driver
::
result
::
malloc_host
(
len
,
sys
::
CU_MEMHOSTALLOC_WRITECOMBINED
)
let
ptr
=
cudarc
::
driver
::
result
::
malloc_host
(
len
,
sys
::
CU_MEMHOSTALLOC_DEVICEMAP
)
.map_err
(
StorageError
::
Cuda
)
?
;
let
ptr
=
ptr
as
*
mut
u8
;
...
...
@@ -61,6 +106,8 @@ impl PinnedStorage {
assert
!
(
len
<
isize
::
MAX
as
usize
);
ptr
as
usize
}
}
};
Ok
(
Self
{
ptr
,
len
,
ctx
})
...
...
@@ -97,7 +144,7 @@ impl Drop for PinnedStorage {
}
}
impl
MemoryDescript
ion
for
PinnedStorage
{
impl
MemoryDescript
or
for
PinnedStorage
{
fn
addr
(
&
self
)
->
usize
{
unsafe
{
self
.as_ptr
()
as
usize
}
}
...
...
lib/memory/src/pool/cuda.rs
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c)
2024-
2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA memory pool for efficient device memory allocation in hot paths.
...
...
@@ -6,6 +6,12 @@
//! This module provides a safe wrapper around CUDA's memory pool APIs, enabling
//! fast async allocations that avoid the overhead of cudaMalloc/cudaFree per call.
//! Memory is returned to the pool on free and reused for subsequent allocations.
//!
//! # Thread Safety
//!
//! [`CudaMemPool`] uses internal locking to serialize host-side calls to the CUDA
//! driver. This is required because `cuMemAllocFromPoolAsync` is not host-thread
//! reentrant. The GPU-side operations remain stream-ordered and asynchronous.
use
anyhow
::{
Result
,
anyhow
};
use
cudarc
::
driver
::
sys
::{
...
...
lib/memory/src/pool/mod.rs
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c)
2025-
2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Memory pool for efficient device memory allocation in hot paths.
...
...
lib/memory/src/prelude.rs
View file @
976bb70a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub
use
super
::
MemoryDescript
ion
;
pub
use
super
::
nixl
::{
NixlCompatible
,
RegisteredView
};
pub
use
super
::
MemoryDescript
or
;
pub
use
super
::
nixl
::{
NixlCompatible
,
NixlMemory
,
NixlRegisterExt
,
RegisteredView
};
pub
use
super
::
tensor
::{
TensorDescriptor
,
TensorDescriptorExt
};
lib/memory/src/system.rs
View file @
976bb70a
...
...
@@ -3,7 +3,7 @@
//! System memory storage backed by malloc.
use
super
::{
MemoryDescript
ion
,
Result
,
StorageError
,
StorageKind
,
actions
,
nixl
::
NixlDescriptor
};
use
super
::{
MemoryDescript
or
,
Result
,
StorageError
,
StorageKind
,
actions
,
nixl
::
NixlDescriptor
};
use
std
::
any
::
Any
;
use
std
::
ptr
::
NonNull
;
...
...
@@ -82,7 +82,7 @@ impl Drop for SystemStorage {
}
}
impl
MemoryDescript
ion
for
SystemStorage
{
impl
MemoryDescript
or
for
SystemStorage
{
fn
addr
(
&
self
)
->
usize
{
self
.ptr
.as_ptr
()
as
usize
}
...
...
@@ -139,3 +139,130 @@ impl actions::Slice for SystemStorage {
Ok
(
unsafe
{
std
::
slice
::
from_raw_parts
(
self
.ptr
.as_ptr
(),
self
.len
)
})
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
actions
::{
Memset
,
Slice
};
#[test]
fn
test_system_storage_new
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.expect
(
"allocation should succeed"
);
assert_eq!
(
storage
.size
(),
1024
);
assert
!
(
storage
.addr
()
!=
0
);
}
#[test]
fn
test_system_storage_zero_size_fails
()
{
let
result
=
SystemStorage
::
new
(
0
);
assert
!
(
result
.is_err
());
}
#[test]
fn
test_system_storage_storage_kind
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
assert_eq!
(
storage
.storage_kind
(),
StorageKind
::
System
);
}
#[test]
fn
test_system_storage_as_any
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
any
=
storage
.as_any
();
assert
!
(
any
.downcast_ref
::
<
SystemStorage
>
()
.is_some
());
}
#[test]
fn
test_system_storage_nixl_descriptor
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
// Unregistered storage has no NIXL descriptor
assert
!
(
storage
.nixl_descriptor
()
.is_none
());
}
#[test]
fn
test_system_storage_as_ptr
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
unsafe
{
let
ptr
=
storage
.as_ptr
();
assert
!
(
!
ptr
.is_null
());
assert_eq!
(
ptr
as
usize
,
storage
.addr
());
}
}
#[test]
fn
test_system_storage_as_mut_ptr
()
{
let
mut
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
unsafe
{
let
ptr
=
storage
.as_mut_ptr
();
assert
!
(
!
ptr
.is_null
());
assert_eq!
(
ptr
as
usize
,
storage
.addr
());
// Write and read back to verify the pointer works
*
ptr
=
0xAB
;
assert_eq!
(
*
ptr
,
0xAB
);
}
}
#[test]
fn
test_system_storage_zero_initialized
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
unsafe
{
let
slice
=
storage
.as_slice
()
.unwrap
();
// Memory should be zero-initialized
assert
!
(
slice
.iter
()
.all
(|
&
b
|
b
==
0
));
}
}
#[test]
fn
test_system_storage_memset_and_read
()
{
let
mut
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
storage
.memset
(
0xCD
,
0
,
1024
)
.unwrap
();
unsafe
{
let
slice
=
storage
.as_slice
()
.unwrap
();
assert
!
(
slice
.iter
()
.all
(|
&
b
|
b
==
0xCD
));
}
}
#[test]
fn
test_system_storage_multiple_allocations_independent
()
{
let
storage1
=
SystemStorage
::
new
(
512
)
.unwrap
();
let
storage2
=
SystemStorage
::
new
(
512
)
.unwrap
();
// Different allocations should have different addresses
assert_ne!
(
storage1
.addr
(),
storage2
.addr
());
}
#[test]
fn
test_system_storage_alignment
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
// posix_memalign allocates with 4096-byte alignment
assert
!
(
storage
.addr
()
.is_multiple_of
(
4096
));
}
#[test]
fn
test_system_storage_nixl_compatible
()
{
use
crate
::
nixl
::
NixlCompatible
;
let
storage
=
SystemStorage
::
new
(
2048
)
.unwrap
();
let
(
ptr
,
size
,
mem_type
,
device_id
)
=
storage
.nixl_params
();
assert_eq!
(
ptr
as
usize
,
storage
.addr
());
assert_eq!
(
size
,
2048
);
assert_eq!
(
mem_type
,
nixl_sys
::
MemType
::
Dram
);
assert_eq!
(
device_id
,
0
);
}
#[test]
fn
test_system_storage_large_allocation
()
{
// Allocate 1MB to test larger sizes
let
storage
=
SystemStorage
::
new
(
1024
*
1024
)
.unwrap
();
assert_eq!
(
storage
.size
(),
1024
*
1024
);
}
#[test]
fn
test_system_storage_debug
()
{
let
storage
=
SystemStorage
::
new
(
1024
)
.unwrap
();
let
debug_str
=
format!
(
"{:?}"
,
storage
);
assert
!
(
debug_str
.contains
(
"SystemStorage"
));
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment