Unverified Commit 3998fdcb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KVBM V2 Initial Migration (#3861)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent e64d2f09
......@@ -604,6 +604,26 @@ dependencies = [
"serde",
]
[[package]]
name = "bincode"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740"
dependencies = [
"bincode_derive",
"serde",
"unty",
]
[[package]]
name = "bincode_derive"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09"
dependencies = [
"virtue",
]
[[package]]
name = "bindgen"
version = "0.71.1"
......@@ -2136,7 +2156,7 @@ dependencies = [
"async_zmq",
"axum 0.8.4",
"axum-server",
"bincode",
"bincode 2.0.1",
"bitflags 2.9.4",
"blake3",
"bs62",
......@@ -2273,7 +2293,7 @@ dependencies = [
"async-trait",
"async_zmq",
"axum 0.8.4",
"bincode",
"bincode 1.3.3",
"blake3",
"bytes",
"chrono",
......@@ -5575,9 +5595,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordered-float"
version = "5.0.0"
version = "5.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d"
dependencies = [
"num-traits",
]
......@@ -7899,9 +7919,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "symphonia"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9"
checksum = "5773a4c030a19d9bfaa090f49746ff35c75dfddfa700df7a5939d5e076a57039"
dependencies = [
"lazy_static",
"symphonia-bundle-flac",
......@@ -7917,9 +7937,9 @@ dependencies = [
[[package]]
name = "symphonia-bundle-flac"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97"
checksum = "c91565e180aea25d9b80a910c546802526ffd0072d0b8974e3ebe59b686c9976"
dependencies = [
"log",
"symphonia-core",
......@@ -7929,9 +7949,9 @@ dependencies = [
[[package]]
name = "symphonia-bundle-mp3"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c01c2aae70f0f1fb096b6f0ff112a930b1fb3626178fba3ae68b09dce71706d4"
checksum = "4872dd6bb56bf5eac799e3e957aa1981086c3e613b27e0ac23b176054f7c57ed"
dependencies = [
"lazy_static",
"log",
......@@ -7941,9 +7961,9 @@ dependencies = [
[[package]]
name = "symphonia-codec-pcm"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b"
checksum = "4e89d716c01541ad3ebe7c91ce4c8d38a7cf266a3f7b2f090b108fb0cb031d95"
dependencies = [
"log",
"symphonia-core",
......@@ -7951,9 +7971,9 @@ dependencies = [
[[package]]
name = "symphonia-codec-vorbis"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30"
checksum = "f025837c309cd69ffef572750b4a2257b59552c5399a5e49707cc5b1b85d1c73"
dependencies = [
"log",
"symphonia-core",
......@@ -7962,9 +7982,9 @@ dependencies = [
[[package]]
name = "symphonia-core"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "798306779e3dc7d5231bd5691f5a813496dc79d3f56bf82e25789f2094e022c3"
checksum = "ea00cc4f79b7f6bb7ff87eddc065a1066f3a43fe1875979056672c9ef948c2af"
dependencies = [
"arrayvec",
"bitflags 1.3.2",
......@@ -7975,9 +7995,9 @@ dependencies = [
[[package]]
name = "symphonia-format-isomp4"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844"
checksum = "243739585d11f81daf8dac8d9f3d18cc7898f6c09a259675fc364b382c30e0a5"
dependencies = [
"encoding_rs",
"log",
......@@ -7988,9 +8008,9 @@ dependencies = [
[[package]]
name = "symphonia-format-ogg"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ada3505789516bcf00fc1157c67729eded428b455c27ca370e41f4d785bfa931"
checksum = "2b4955c67c1ed3aa8ae8428d04ca8397fbef6a19b2b051e73b5da8b1435639cb"
dependencies = [
"log",
"symphonia-core",
......@@ -8000,9 +8020,9 @@ dependencies = [
[[package]]
name = "symphonia-format-riff"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50"
checksum = "c2d7c3df0e7d94efb68401d81906eae73c02b40d5ec1a141962c592d0f11a96f"
dependencies = [
"extended",
"log",
......@@ -8012,9 +8032,9 @@ dependencies = [
[[package]]
name = "symphonia-metadata"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc622b9841a10089c5b18e99eb904f4341615d5aa55bbf4eedde1be721a4023c"
checksum = "36306ff42b9ffe6e5afc99d49e121e0bd62fe79b9db7b9681d48e29fa19e6b16"
dependencies = [
"encoding_rs",
"lazy_static",
......@@ -8024,9 +8044,9 @@ dependencies = [
[[package]]
name = "symphonia-utils-xiph"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe"
checksum = "ee27c85ab799a338446b68eec77abf42e1a6f1bb490656e121c6e27bfbab9f16"
dependencies = [
"symphonia-core",
"symphonia-metadata",
......@@ -9281,6 +9301,12 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "unty"
version = "0.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
[[package]]
name = "ureq"
version = "2.12.1"
......@@ -9515,6 +9541,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "virtue"
version = "0.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1"
[[package]]
name = "vob"
version = "3.0.6"
......@@ -9730,9 +9762,9 @@ dependencies = [
[[package]]
name = "widestring"
version = "1.2.0"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd7cf3379ca1aac9eea11fba24fd7e315d621f8dfe35c8d7d2be8b793726e07d"
checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471"
[[package]]
name = "winapi"
......
......@@ -63,6 +63,7 @@ chrono = { version = "0.4", default-features = false, features = [
"now",
"serde",
] }
cudarc = { version = "0.17.1", features = ["cuda-12020"] }
derive_builder = { version = "0.20" }
derive-getters = { version = "0.5" }
either = { version = "1.13", features = ["serde"] }
......
......@@ -515,6 +515,26 @@ dependencies = [
"serde",
]
[[package]]
name = "bincode"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740"
dependencies = [
"bincode_derive",
"serde",
"unty",
]
[[package]]
name = "bincode_derive"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09"
dependencies = [
"virtue",
]
[[package]]
name = "bindgen"
version = "0.69.5"
......@@ -1103,15 +1123,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "cudarc"
version = "0.16.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17200eb07e7d85a243aa1bf4569a7aa998385ba98d14833973a817a63cc86e92"
dependencies = [
"libloading",
]
[[package]]
name = "cudarc"
version = "0.17.2"
......@@ -1452,6 +1463,15 @@ dependencies = [
"uuid",
]
[[package]]
name = "dynamo-kvbm-kernels"
version = "0.6.0"
dependencies = [
"cc",
"cudarc",
"once_cell",
]
[[package]]
name = "dynamo-llm"
version = "0.6.1"
......@@ -1459,6 +1479,7 @@ dependencies = [
"ahash",
"aho-corasick",
"akin",
"aligned-vec",
"anyhow",
"async-nats",
"async-stream",
......@@ -1466,7 +1487,7 @@ dependencies = [
"async_zmq",
"axum",
"axum-server",
"bincode",
"bincode 2.0.1",
"bitflags 2.9.3",
"blake3",
"bs62",
......@@ -1474,12 +1495,13 @@ dependencies = [
"bytes",
"candle-core",
"chrono",
"cudarc 0.17.2",
"cudarc",
"dashmap",
"derive-getters",
"derive_builder",
"dialoguer",
"dynamo-async-openai",
"dynamo-kvbm-kernels",
"dynamo-parsers",
"dynamo-runtime",
"either",
......@@ -1560,7 +1582,7 @@ dependencies = [
"anyhow",
"async-stream",
"async-trait",
"cudarc 0.16.6",
"cudarc",
"derive-getters",
"dlpark",
"dynamo-async-openai",
......@@ -1602,7 +1624,7 @@ dependencies = [
"async-trait",
"async_zmq",
"axum",
"bincode",
"bincode 1.3.3",
"blake3",
"bytes",
"chrono",
......@@ -6825,6 +6847,12 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "unty"
version = "0.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
[[package]]
name = "ureq"
version = "2.12.1"
......@@ -6981,6 +7009,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "virtue"
version = "0.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1"
[[package]]
name = "walkdir"
version = "2.5.0"
......
......@@ -73,7 +73,7 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features =
pythonize = "0.23"
dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
cudarc = { version = "0.17.1", features = ["cuda-12020"], optional = true }
prometheus = "0.14.0"
......
......@@ -21,7 +21,7 @@ testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
testing-etcd = []
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix", "dep:aligned-vec"]
cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
......@@ -85,7 +85,7 @@ offset-allocator = "0.2"
regex = "1"
rayon = "1"
dashmap = { version = "5.5.3" }
bincode = "1"
bincode = { version = "2.0.1", features = ["serde", "derive"] }
# input/text
dialoguer = { version = "0.11", default-features = false, features = [
......@@ -94,11 +94,13 @@ dialoguer = { version = "0.11", default-features = false, features = [
] }
# block_manager
aligned-vec = { version = "0.6.4", optional = true }
nixl-sys = { version = "=0.6.0", optional = true }
cudarc = { version = "0.17.1", features = ["cuda-12020"], optional = true }
cudarc = { workspace = true, optional = true }
ndarray = { version = "0.16", optional = true }
nix = { version = "0.26", optional = true }
# protocols
unicode-segmentation = "1.12"
......@@ -163,7 +165,7 @@ insta = { version = "1.41", features = [
"redactions",
"filters",
] }
aligned-vec = "0.6.4"
lazy_static = "1.4"
[build-dependencies]
......
......@@ -7,7 +7,7 @@ mod benchmarks {
use criterion::{BenchmarkId, Criterion, criterion_group};
use cudarc::driver::{CudaContext, CudaStream};
use nixl_sys;
use tokio::runtime::Runtime;
use tokio_util::task::TaskTracker;
......
......@@ -20,6 +20,7 @@ pub mod numa_allocator;
pub mod offload;
pub mod pool;
pub mod storage;
pub mod v2;
// dynamo rt integration
pub mod controller;
......@@ -326,18 +327,6 @@ mod tests {
.unwrap()
}
pub async fn create_reference_block_manager_with_counts(
device: usize,
host: usize,
disk: usize,
) -> ReferenceBlockManager {
ReferenceBlockManager::new(create_reference_block_manager_config_with_counts(
device, host, disk,
))
.await
.unwrap()
}
#[tokio::test]
async fn test_reference_block_manager_inherited_async_runtime() {
dynamo_runtime::logging::init();
......
......@@ -563,11 +563,11 @@ pub mod v2 {
tracker.spawn(async move {
let event = ctx_clone
.record_event()
.expect(&format!("Failed to record event {}", i));
.unwrap_or_else(|_| panic!("Failed to record event {}", i));
event
.synchronize()
.await
.expect(&format!("Failed to sync event {}", i));
.unwrap_or_else(|_| panic!("Failed to sync event {}", i));
});
}
......@@ -575,26 +575,6 @@ pub mod v2 {
tracker.wait().await;
}
#[tokio::test]
async fn test_performance_baseline() {
let ctx = setup_context();
let start = std::time::Instant::now();
// Test a reasonable number of synchronizations
for _ in 0..10 {
let event = ctx.record_event().expect("Failed to record event");
event.synchronize().await.expect("Sync failed");
}
let duration = start.elapsed();
// Should complete 10 synchronizations in reasonable time (< 1ms total)
assert!(
duration < std::time::Duration::from_millis(1),
"Performance regression: took {:?} for 10 syncs",
duration
);
}
#[tokio::test]
async fn test_error_handling() {
let ctx = setup_context();
......
......@@ -185,10 +185,13 @@ struct WorkerMetadataHandler {
#[async_trait]
impl Handler for WorkerMetadataHandler {
async fn handle(&self, mut message: MessageHandle) -> anyhow::Result<()> {
let payload = bincode::serialize(&WorkerMetadata {
let payload = bincode::serde::encode_to_vec(
&WorkerMetadata {
num_device_blocks: self.num_device_blocks,
bytes_per_block: self.bytes_per_block,
})?;
},
bincode::config::standard(),
)?;
message
.reply(ZMQ_WORKER_METADATA_MESSAGE, &[payload])
.await?;
......@@ -226,8 +229,11 @@ impl Handler for LeaderMetadataHandler {
);
return Ok(());
}
let leader_meta: LeaderMetadata = match bincode::deserialize(&message.data[0]) {
Ok(m) => m,
let leader_meta: LeaderMetadata = match bincode::serde::decode_from_slice(
&message.data[0],
bincode::config::standard(),
) {
Ok((m, _)) => m,
Err(e) => {
tracing::error!("leader_metadata: bad payload: {e:#}");
return Ok(());
......
......@@ -166,14 +166,18 @@ impl ZmqActiveMessageLeader {
}
};
let workers: Vec<WorkerMetadata> = workers_payloads
.into_iter()
.map(|b| bincode::deserialize::<WorkerMetadata>(&b))
.collect::<std::result::Result<_, _>>()?;
let mut workers: Vec<WorkerMetadata> = Vec::with_capacity(workers_payloads.len());
for payload in workers_payloads {
let worker: WorkerMetadata =
bincode::serde::decode_from_slice(&payload, bincode::config::standard())?.0;
workers.push(worker);
}
// 2) Compute & broadcast LeaderMetadata; wait for ALL acks in the SAME round.
let leader_meta = make_leader_meta(&workers);
let leader_meta_bytes = bincode::serialize(&leader_meta)?;
let leader_meta_bytes =
bincode::serde::encode_to_vec(&leader_meta, bincode::config::standard())?;
loop {
if Instant::now() >= deadline {
......
......@@ -693,7 +693,7 @@ impl OffloadFiltersBuilder {
}
}
#[cfg(all(test, feature = "testing-cuda"))]
#[cfg(all(test, feature = "testing-cuda", feature = "testing-nixl"))]
mod tests {
use super::*;
......@@ -713,8 +713,7 @@ mod tests {
use nixl_sys::{MemoryRegion, NixlDescriptor};
use aligned_vec::avec;
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind, cudaMemset};
use prometheus::Registry;
use cudarc::runtime::sys::{cudaDeviceSynchronize, cudaMemcpy, cudaMemcpyKind, cudaMemset};
use rstest::*;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
......@@ -1286,6 +1285,8 @@ mod tests {
// Check that this is the same block.
check_block_contents(&immutable_host_block, &device_blocks[0], 42)?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Ok(())
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod kernels;
pub mod memory;
pub mod physical;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Safe-ish wrappers around the CUDA block/universal packing kernels.
//!
//! The core ideas:
//! * A “block” represents the stack of `nl * no` tensors arranged either as NHD
//! (inner axes `[nt, nh, hd]`) or HND (inner axes `[nh, nt, hd]`).
//! * A “universal” tensor is `[nh, nl, no, nt, hd]` stored contiguously.
//! * An “operational” tensor is `[nl, no, inner]` with `inner = nt * nh * hd`.
//!
//! Host code calls these helpers with flattened pointer tables so a single
//! launch can move many logical blocks in one go.
#![allow(dead_code)]
#![allow(clippy::missing_safety_doc)]
/// Numeric tags passed across the FFI boundary to select the CUDA template.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TensorDataType {
F16 = 0,
BF16 = 1,
F32 = 2,
F64 = 3,
}
/// Identifies how each `[nt, nh, hd]` chunk is laid out in device memory.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BlockLayout {
NHD = 0,
HND = 1,
}
/// Direction flag for copying between block stacks and operational buffers.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OperationalCopyDirection {
BlockToOperational = 0,
OperationalToBlock = 1,
}
/// Selects how the operational copy should move data.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OperationalCopyBackend {
/// Try cudaMemcpyBatchAsync, fall back to cudaMemcpyAsync, then the kernel.
Auto = 0,
/// Force the custom CUDA kernel path.
KernelOnly = 1,
/// Issue one cudaMemcpyAsync per chunk.
MemcpyAsync = 2,
/// Invoke cudaMemcpyBatchAsync directly.
MemcpyBatch = 3,
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Storage actions.
use super::{MemoryRegion, StorageError};
/// Extension trait for storage types that support memory setting operations
pub trait Memset: MemoryRegion {
/// Sets a region of memory to a specific value
///
/// # Arguments
/// * `value` - The value to set (will be truncated to u8)
/// * `offset` - Offset in bytes from the start of the storage
/// * `size` - Number of bytes to set
///
/// # Safety
/// The caller must ensure:
/// - offset + size <= self.size()
/// - No other references exist to the memory region being set
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
}
/// Extension trait for storage types that support slicing operations
pub trait Slice {
/// Returns an immutable byte slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
fn as_slice(&self) -> Result<&[u8], StorageError>;
/// Returns an immutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
fn slice(&self, offset: usize, len: usize) -> Result<&[u8], StorageError> {
let slice = self.as_slice()?;
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed immutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid and initialized
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn as_slice_typed<T>(&self) -> Result<&[T], StorageError> {
let bytes = self.as_slice()?;
let ptr = bytes.as_ptr() as *const T;
let len = bytes.len() / std::mem::size_of::<T>();
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % std::mem::size_of::<T>() != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
std::mem::size_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
/// Returns a typed immutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid and initialized
/// - No concurrent mutable access occurs while the slice is in use
/// - The data represents valid values of type T
fn slice_typed<T>(&self, offset: usize, len: usize) -> Result<&[T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice(offset, byte_len)?;
let ptr = bytes.as_ptr() as *const T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}
}
pub trait SliceMut {
/// Returns a mutable byte slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn as_slice_mut(&mut self) -> Result<&mut [u8], StorageError>;
/// Returns a mutable byte slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of bytes to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + len <= self.size()
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_mut(&mut self, offset: usize, len: usize) -> Result<&mut [u8], StorageError> {
let slice = self.as_slice_mut()?;
// validate offset and len
if offset.saturating_add(len) > slice.len() {
return Err(StorageError::Unsupported("slice out of bounds".into()));
}
slice
.get_mut(offset..offset.saturating_add(len))
.ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
}
/// Returns a typed mutable slice view of the entire storage region
///
/// # Safety
/// The caller must ensure:
/// - The memory region is valid
/// - The memory is properly aligned for type T
/// - The size is a multiple of `size_of::<T>()`
/// - No other references (mutable or immutable) exist to this memory region
fn as_slice_typed_mut<T>(&mut self) -> Result<&mut [T], StorageError> {
let bytes = self.as_slice_mut()?;
let ptr = bytes.as_mut_ptr() as *mut T;
let len = bytes.len() / std::mem::size_of::<T>();
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
if bytes.len() % std::mem::size_of::<T>() != 0 {
return Err(StorageError::Unsupported(format!(
"size {} is not a multiple of type size {}",
bytes.len(),
std::mem::size_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
/// Returns a typed mutable slice view of a subregion
///
/// # Arguments
/// * `offset` - Offset in bytes from the start of the storage
/// * `len` - Number of elements of type T to slice
///
/// # Safety
/// The caller must ensure:
/// - offset + (len * size_of::<T>()) <= self.size()
/// - offset is properly aligned for type T
/// - The memory region is valid
/// - No other references (mutable or immutable) exist to this memory region
fn slice_typed_mut<T>(&mut self, offset: usize, len: usize) -> Result<&mut [T], StorageError> {
let type_size = std::mem::size_of::<T>();
let byte_len = len
.checked_mul(type_size)
.ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
let bytes = self.slice_mut(offset, byte_len)?;
let ptr = bytes.as_mut_ptr() as *mut T;
if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
return Err(StorageError::Unsupported(format!(
"memory not aligned for type (required alignment: {})",
std::mem::align_of::<T>()
)));
}
// SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA device memory storage.
use super::{MemoryRegion, Result, StorageError, StorageKind};
use cudarc::driver::CudaContext;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA device memory allocated via cudaMalloc.
#[derive(Debug)]
pub struct DeviceStorage {
ctx: Arc<CudaContext>,
ptr: u64,
device_id: u32,
len: usize,
}
unsafe impl Send for DeviceStorage {}
unsafe impl Sync for DeviceStorage {}
impl DeviceStorage {
/// Allocate new device memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device on which to allocate
pub fn new(len: usize, device_id: u32) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(device_id)?;
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
Ok(Self {
ctx,
ptr,
device_id,
len,
})
}
/// Get the device pointer value.
pub fn device_ptr(&self) -> u64 {
self.ptr
}
/// Get the CUDA device ID this memory is allocated on.
pub fn device_id(&self) -> u32 {
self.device_id
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
};
}
}
impl MemoryRegion for DeviceStorage {
fn addr(&self) -> usize {
self.device_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Device(self.device_id)
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for DeviceStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(
self.ptr as *const u8,
self.len,
nixl_sys::MemType::Vram,
self.device_id as u64,
)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Disk-backed memory storage using memory-mapped files.
use super::{MemoryRegion, Result, StorageError, StorageKind};
use std::any::Any;
use std::path::{Path, PathBuf};
use core::ffi::c_char;
use nix::fcntl::{FallocateFlags, fallocate};
use nix::unistd::unlink;
use std::ffi::CString;
const DISK_CACHE_KEY: &str = "DYN_KVBM_DISK_CACHE_DIR";
const DEFAULT_DISK_CACHE_DIR: &str = "/tmp/";
#[derive(Debug)]
pub struct DiskStorage {
fd: u64,
path: PathBuf,
size: usize,
unlinked: bool,
}
impl DiskStorage {
pub fn new(size: usize) -> Result<Self> {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let specified_dir =
std::env::var(DISK_CACHE_KEY).unwrap_or_else(|_| DEFAULT_DISK_CACHE_DIR.to_string());
let file_path = Path::new(&specified_dir).join("dynamo-kvbm-disk-cache-XXXXXX");
Self::new_at(file_path, size)
}
pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let file_path = path.as_ref().to_path_buf();
if !file_path.exists() {
std::fs::create_dir_all(file_path.parent().unwrap()).unwrap();
}
tracing::debug!("Allocating disk cache file at {}", file_path.display());
let path_str = file_path.to_str().unwrap();
let is_template = path_str.contains("XXXXXX");
let (raw_fd, actual_path) = if is_template {
// Template path - use mkostemp to generate unique filename
let template = CString::new(path_str).unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut c_char,
nix::libc::O_RDWR | nix::libc::O_DIRECT,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"mkostemp failed: {}",
std::io::Error::last_os_error()
)));
}
// Extract the actual path created by mkostemp
let actual = PathBuf::from(
CString::from_vec_with_nul(template_bytes)
.unwrap()
.to_str()
.unwrap(),
);
(fd, actual)
} else {
// Specific path - use open with O_CREAT
let path_cstr = CString::new(path_str).unwrap();
let fd = unsafe {
nix::libc::open(
path_cstr.as_ptr(),
nix::libc::O_CREAT | nix::libc::O_RDWR | nix::libc::O_DIRECT,
0o644,
)
};
if fd == -1 {
return Err(StorageError::AllocationFailed(format!(
"open failed: {}",
std::io::Error::last_os_error()
)));
}
(fd, file_path)
};
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(raw_fd, FallocateFlags::empty(), 0, len as i64).map_err(|e| {
StorageError::AllocationFailed(format!("Failed to allocate temp file: {}", e))
})?;
Ok(Self {
fd: raw_fd as u64,
path: actual_path,
size: len,
unlinked: false,
})
}
pub fn fd(&self) -> u64 {
self.fd
}
pub fn path(&self) -> &Path {
self.path.as_path()
}
/// Unlink our temp file.
/// This means that when this process terminates, the file will be automatically deleted by the OS.
/// Unfortunately, GDS requires that files we try to register must be linked.
/// To get around this, we unlink the file only after we've registered it with NIXL.
pub fn unlink(&mut self) -> Result<()> {
if self.unlinked {
return Ok(());
}
unlink(self.path.as_path())
.map_err(|e| StorageError::AllocationFailed(format!("Failed to unlink file: {}", e)))?;
self.unlinked = true;
Ok(())
}
pub fn unlinked(&self) -> bool {
self.unlinked
}
}
impl Drop for DiskStorage {
fn drop(&mut self) {
let _ = self.unlink();
}
}
impl MemoryRegion for DiskStorage {
fn addr(&self) -> usize {
0
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Disk(self.fd)
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for DiskStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
#[cfg(unix)]
{
// Use file descriptor as device_id for MemType::File
(
std::ptr::null(),
self.size,
nixl_sys::MemType::File,
self.fd,
)
}
#[cfg(not(unix))]
{
// On non-Unix systems, we can't get the file descriptor easily
// Return device_id as 0 - registration will fail on these systems
(
self.mmap.as_ptr(),
self.mmap.len(),
nixl_sys::MemType::File,
0,
)
}
}
}
// mod mmap {
// use super::*;
// #[cfg(unix)]
// use std::os::unix::io::AsRawFd;
// use memmap2::{MmapMut, MmapOptions};
// use std::fs::{File, OpenOptions};
// use tempfile::NamedTempFile;
// /// Disk-backed storage using memory-mapped files.
// #[derive(Debug)]
// pub struct MemMappedFileStorage {
// _file: File, // Keep file alive for the lifetime of the mmap
// mmap: MmapMut,
// path: PathBuf,
// #[cfg(unix)]
// fd: i32,
// }
// unsafe impl Send for MemMappedFileStorage {}
// unsafe impl Sync for MemMappedFileStorage {}
// impl MemMappedFileStorage {
// /// Create new disk storage with a temporary file.
// pub fn new_temp(len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// // Create temporary file
// let temp_file = NamedTempFile::new()?;
// let path = temp_file.path().to_path_buf();
// let file = temp_file.into_file();
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Create new disk storage with a specific file path.
// pub fn new_at(path: impl AsRef<Path>, len: usize) -> Result<Self> {
// if len == 0 {
// return Err(StorageError::AllocationFailed(
// "zero-sized allocations are not supported".into(),
// ));
// }
// let path = path.as_ref().to_path_buf();
// // Create or open file
// let file = OpenOptions::new()
// .read(true)
// .write(true)
// .create(true)
// .open(&path)?;
// // Set file size
// file.set_len(len as u64)?;
// #[cfg(unix)]
// let fd = file.as_raw_fd();
// // Memory map the file
// let mmap = unsafe { MmapOptions::new().len(len).map_mut(&file)? };
// Ok(Self {
// _file: file,
// mmap,
// path,
// #[cfg(unix)]
// fd,
// })
// }
// /// Get the path to the backing file.
// pub fn path(&self) -> &Path {
// &self.path
// }
// /// Get the file descriptor (Unix only).
// #[cfg(unix)]
// pub fn fd(&self) -> i32 {
// self.fd
// }
// /// Get a pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped.
// pub unsafe fn as_ptr(&self) -> *const u8 {
// self.mmap.as_ptr()
// }
// /// Get a mutable pointer to the memory-mapped region.
// ///
// /// # Safety
// /// The caller must ensure the pointer is not used after this storage is dropped
// /// and that there are no other references to this memory.
// pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
// self.mmap.as_mut_ptr()
// }
// }
// impl MemoryRegion for MemMappedFileStorage {
// fn addr(&self) -> usize {
// self.mmap.as_ptr() as usize
// }
// fn size(&self) -> usize {
// self.mmap.len()
// }
// fn storage_kind(&self) -> StorageKind {
// StorageKind::Disk
// }
// fn as_any(&self) -> &dyn Any {
// self
// }
// }
// // Support for NIXL registration
// impl super::super::registered::NixlCompatible for MemMappedFileStorage {
// fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
// #[cfg(unix)]
// {
// // Use file descriptor as device_id for MemType::File
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// self.fd as u64,
// )
// }
// #[cfg(not(unix))]
// {
// // On non-Unix systems, we can't get the file descriptor easily
// // Return device_id as 0 - registration will fail on these systems
// (
// self.mmap.as_ptr(),
// self.mmap.len(),
// nixl_sys::MemType::File,
// 0,
// )
// }
// }
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Clean, minimal storage API for v2 block manager.
//!
//! This module provides a simplified storage abstraction with:
//! - Single trait for type erasure (`MemoryRegion`)
//! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)
pub mod actions;
mod device;
mod disk;
mod pinned;
mod registered;
mod system;
mod torch;
#[cfg(test)]
mod tests;
pub use device::DeviceStorage;
pub use disk::DiskStorage;
pub use pinned::PinnedStorage;
pub use registered::{
NixlCompatible, NixlDescriptor, NixlRegistered, RegisteredView, register_with_nixl,
};
pub use system::SystemStorage;
pub use torch::{TorchDevice, TorchTensor};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;
/// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>;
/// Errors that can occur during storage operations.
#[derive(Debug, Error)]
pub enum StorageError {
#[error("allocation failed: {0}")]
AllocationFailed(String),
#[error("registration failed: {0}")]
RegistrationFailed(String),
#[error("operation failed: {0}")]
OperationFailed(String),
#[error("unsupported operation: {0}")]
Unsupported(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
// #[cfg(feature = "cuda")]
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")]
Nixl(#[from] nixl_sys::NixlError),
}
/// Storage type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageKind {
/// System memory (malloc)
System,
/// CUDA pinned host memory
// #[cfg(feature = "cuda")]
Pinned,
/// CUDA device memory with device ID
// #[cfg(feature = "cuda")]
Device(u32),
/// Disk-backed memory (mmap)
Disk(u64),
}
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryRegion>`.
pub trait MemoryRegion: Send + Sync + fmt::Debug {
/// Base address of the memory region.
fn addr(&self) -> usize;
/// Size of the memory region in bytes.
fn size(&self) -> usize;
/// Type of storage backing this region.
fn storage_kind(&self) -> StorageKind;
/// Enable downcasting to concrete type.
fn as_any(&self) -> &dyn Any;
/// Get the NIXL descriptor for this memory region.
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
/// Type-erased memory region for use in layouts.
pub type OwnedMemoryRegion = Arc<dyn MemoryRegion>;
/// Helper function to convert concrete storage to type-erased form.
pub fn erase_storage<S: MemoryRegion + 'static>(storage: S) -> OwnedMemoryRegion {
Arc::new(storage)
}
/// Simple memory region descriptor.
#[derive(Debug)]
pub struct OffsetMemoryRegion {
base: OwnedMemoryRegion,
offset: usize,
len: usize,
}
impl OffsetMemoryRegion {
/// Create a new offset view into an existing memory region.
///
/// Returns an error if the offset and length exceed the bounds of the base region.
pub fn new(base: OwnedMemoryRegion, offset: usize, len: usize) -> Result<Self> {
let end = offset
.checked_add(len)
.ok_or_else(|| StorageError::Unsupported("offset overflow".into()))?;
if end > base.size() {
return Err(StorageError::Unsupported(
"offset region exceeds base allocation bounds".into(),
));
}
Ok(Self { base, offset, len })
}
/// Get the offset relative to the base mapping.
pub fn offset(&self) -> usize {
self.offset
}
/// Get the length of the offset region.
pub fn len(&self) -> usize {
self.len
}
/// Check if the offset region is empty.
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Access the underlying base region.
pub fn base(&self) -> &OwnedMemoryRegion {
&self.base
}
}
impl MemoryRegion for OffsetMemoryRegion {
fn addr(&self) -> usize {
self.base.addr() + self.offset
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
self.base.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryDescriptor {
pub addr: usize,
pub size: usize,
}
impl MemoryDescriptor {
pub fn new(addr: usize, size: usize) -> Self {
Self { addr, size }
}
#[inline]
pub fn addr(&self) -> usize {
self.addr
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
}
impl actions::Slice for MemoryDescriptor {
fn as_slice(&self) -> Result<&[u8]> {
Ok(unsafe { std::slice::from_raw_parts(self.addr as *const u8, self.size) })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA pinned host memory storage.
use super::{MemoryRegion, Result, StorageError, StorageKind, actions};
use cudarc::driver::CudaContext;
use cudarc::driver::sys;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
/// Get or create a CUDA context for the given device.
fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
if let Some(existing) = map.get(&device_id) {
return Ok(existing.clone());
}
let ctx = CudaContext::new(device_id as usize)?;
map.insert(device_id, ctx.clone());
Ok(ctx)
}
/// CUDA pinned host memory allocated via cudaHostAlloc.
#[derive(Debug)]
pub struct PinnedStorage {
ptr: usize,
len: usize,
ctx: Arc<CudaContext>,
}
unsafe impl Send for PinnedStorage {}
unsafe impl Sync for PinnedStorage {}
impl PinnedStorage {
/// Allocate new pinned memory of the given size.
///
/// # Arguments
/// * `len` - Size in bytes to allocate
/// * `device_id` - CUDA device to associate with the allocation
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let ctx = cuda_context(0)?;
let ptr = unsafe {
ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory");
assert!(ptr.is_aligned(), "Pinned memory is not aligned");
assert!(len < isize::MAX as usize);
ptr as usize
};
Ok(Self { ptr, len, ctx })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr as *const u8
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr as *mut u8
}
}
impl Drop for PinnedStorage {
fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
tracing::debug!("failed to free pinned memory: {e}");
}
};
}
}
impl MemoryRegion for PinnedStorage {
fn addr(&self) -> usize {
unsafe { self.as_ptr() as usize }
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::Pinned
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for PinnedStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
let ptr = unsafe { self.as_ptr() };
(ptr, self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for PinnedStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = (self.ptr as *mut u8).add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL registration wrapper for storage types.
use super::{MemoryRegion, StorageKind};
use nixl_sys::{Agent as NixlAgent, MemType, OptArgs, RegistrationHandle};
use std::any::Any;
use std::fmt;
/// Trait for storage types that can be registered with NIXL.
pub trait NixlCompatible {
/// Get parameters needed for NIXL registration.
///
/// Returns (ptr, size, mem_type, device_id)
fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
}
/// NIXL descriptor containing registration information.
#[derive(Debug, Clone)]
pub struct NixlDescriptor {
pub addr: u64,
pub size: usize,
pub mem_type: MemType,
pub device_id: u64,
}
impl nixl_sys::MemoryRegion for NixlDescriptor {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl nixl_sys::NixlDescriptor for NixlDescriptor {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
/// View trait for accessing registration information without unwrapping.
pub trait RegisteredView {
/// Get the name of the NIXL agent that registered this memory.
fn agent_name(&self) -> &str;
/// Get the NIXL descriptor for this registered memory.
fn descriptor(&self) -> NixlDescriptor;
}
/// Wrapper for storage that has been registered with NIXL.
///
/// This wrapper ensures proper drop order: the registration handle is
/// dropped before the storage, ensuring deregistration happens before
/// the memory is freed.
pub struct NixlRegistered<S: NixlCompatible> {
storage: S,
handle: Option<RegistrationHandle>,
agent_name: String,
}
impl<S: NixlCompatible> Drop for NixlRegistered<S> {
fn drop(&mut self) {
// Explicitly drop the registration handle first
drop(self.handle.take());
// Storage drops naturally after
}
}
impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NixlRegistered")
.field("storage", &self.storage)
.field("agent_name", &self.agent_name)
.field("handle", &self.handle.is_some())
.finish()
}
}
impl<S: MemoryRegion + NixlCompatible + 'static> MemoryRegion for NixlRegistered<S> {
fn addr(&self) -> usize {
self.storage.addr()
}
fn size(&self) -> usize {
self.storage.size()
}
fn storage_kind(&self) -> StorageKind {
self.storage.storage_kind()
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor())
}
}
impl<S: MemoryRegion + NixlCompatible> RegisteredView for NixlRegistered<S> {
fn agent_name(&self) -> &str {
&self.agent_name
}
fn descriptor(&self) -> NixlDescriptor {
let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
}
}
}
impl<S: MemoryRegion + NixlCompatible> NixlRegistered<S> {
/// Get a reference to the underlying storage.
pub fn storage(&self) -> &S {
&self.storage
}
/// Get a mutable reference to the underlying storage.
pub fn storage_mut(&mut self) -> &mut S {
&mut self.storage
}
/// Check if the registration handle is still valid.
pub fn is_registered(&self) -> bool {
self.handle.is_some()
}
/// Consume this wrapper and return the underlying storage.
///
/// This will deregister the storage from NIXL.
pub fn into_storage(mut self) -> S {
// Manually drop the handle first
self.handle = None;
// Now we can move out the storage
// We need to use mem::forget to prevent Drop from running
let storage = std::mem::replace(&mut self.storage, unsafe { std::mem::zeroed() });
std::mem::forget(self);
storage
}
}
/// Register storage with a NIXL agent.
///
/// This consumes the storage and returns a `NixlRegistered` wrapper that
/// manages the registration lifetime. The registration handle will be
/// automatically dropped when the wrapper is dropped, ensuring proper
/// cleanup order.
///
/// # Arguments
/// * `storage` - The storage to register (consumed)
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper containing the storage and registration handle.
pub fn register_with_nixl<S>(
storage: S,
agent: &NixlAgent,
opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S>
where
S: MemoryRegion + NixlCompatible,
{
// Get NIXL parameters
let (ptr, size, mem_type, device_id) = storage.nixl_params();
// Create a NIXL descriptor for registration
let descriptor = NixlDescriptor {
addr: ptr as u64,
size,
mem_type,
device_id,
};
match agent.register_memory(&descriptor, opt) {
Ok(handle) => Ok(NixlRegistered {
storage,
handle: Some(handle),
agent_name: agent.name().to_string(),
}),
Err(_) => Err(storage),
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! System memory storage backed by malloc.
use super::{MemoryRegion, Result, StorageError, StorageKind, actions};
use std::any::Any;
use std::ptr::NonNull;
use nix::libc;
/// System memory allocated via malloc.
#[derive(Debug)]
pub struct SystemStorage {
ptr: NonNull<u8>,
len: usize,
}
unsafe impl Send for SystemStorage {}
unsafe impl Sync for SystemStorage {}
impl SystemStorage {
/// Allocate new system memory of the given size.
pub fn new(len: usize) -> Result<Self> {
if len == 0 {
return Err(StorageError::AllocationFailed(
"zero-sized allocations are not supported".into(),
));
}
let mut ptr: *mut libc::c_void = std::ptr::null_mut();
// We need 4KB alignment here for NIXL disk transfers to work.
// The O_DIRECT flag is required for GDS.
// However, a limitation of this flag is that all operations involving disk
// (both read and write) must be page-aligned.
// Pinned memory is already page-aligned, so we only need to align system memory.
// TODO(jthomson04): Is page size always 4KB?
// SAFETY: malloc returns suitably aligned memory or null on failure.
let result = unsafe { libc::posix_memalign(&mut ptr, 4096, len) };
if result != 0 {
return Err(StorageError::AllocationFailed(format!(
"posix_memalign failed for size {}",
len
)));
}
let ptr = NonNull::new(ptr as *mut u8).ok_or_else(|| {
StorageError::AllocationFailed(format!("malloc failed for size {}", len))
})?;
// Zero-initialize the memory
unsafe {
std::ptr::write_bytes(ptr.as_ptr(), 0, len);
}
Ok(Self { ptr, len })
}
/// Get a pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped.
pub unsafe fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
/// Get a mutable pointer to the underlying memory.
///
/// # Safety
/// The caller must ensure the pointer is not used after this storage is dropped
/// and that there are no other references to this memory.
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl Drop for SystemStorage {
fn drop(&mut self) {
// SAFETY: pointer was allocated by malloc.
unsafe {
libc::free(self.ptr.as_ptr() as *mut libc::c_void);
}
}
}
impl MemoryRegion for SystemStorage {
fn addr(&self) -> usize {
self.ptr.as_ptr() as usize
}
fn size(&self) -> usize {
self.len
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
}
// Support for NIXL registration
impl super::registered::NixlCompatible for SystemStorage {
fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
(self.ptr.as_ptr(), self.len, nixl_sys::MemType::Dram, 0)
}
}
impl actions::Memset for SystemStorage {
fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
if offset + size > self.len {
return Err(StorageError::OperationFailed(
"memset: offset + size > storage size".into(),
));
}
unsafe {
let ptr = self.ptr.as_ptr().add(offset);
std::ptr::write_bytes(ptr, value, size);
}
Ok(())
}
}
impl actions::Slice for SystemStorage {
fn as_slice(&self) -> Result<&[u8]> {
Ok(unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) })
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment