Unverified Commit 42ce6931 authored by Harrison Saturley-Hall's avatar Harrison Saturley-Hall Committed by GitHub
Browse files

feat: kv block manager (#965) (#1021)


Co-authored-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
parent cafc74eb
......@@ -47,15 +47,18 @@ jobs:
GITHUB_TOKEN: ${{ secrets.CI_TOKEN }}
run: |
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
- name: Run Rust checks (llm/block-manager)
run: |
docker run -w /workspace/lib/llm --name ${{ env.CONTAINER_ID }}_rust_checks ${{ steps.define_image_tag.outputs.image_tag }} bash -ec 'rustup component add rustfmt clippy && cargo fmt -- --check && cargo clippy --features block-manager --no-deps --all-targets -- -D warnings && cargo test --locked --all-targets --features=block-manager'
- name: Run pytest
env:
PYTEST_MARKS: "pre_merge or mypy"
run: |
docker run -w /workspace --name ${{ env.CONTAINER_ID }} ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
docker run -w /workspace --name ${{ env.CONTAINER_ID }}_pytest ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
- name: Copy test report from test Container
if: always()
run: |
docker cp ${{ env.CONTAINER_ID }}:/workspace/${{ env.PYTEST_XML_FILE }} .
docker cp ${{ env.CONTAINER_ID }}_pytest:/workspace/${{ env.PYTEST_XML_FILE }} .
- name: Archive test report
uses: actions/upload-artifact@v4
if: always()
......
......@@ -23,8 +23,6 @@ on:
# Run this workflow on pull requests targeting main but only if files in runtime/rust change.
pull_request:
branches:
- main
paths:
- .github/workflows/pre-merge-rust.yml
- 'lib/runtime/**'
......
......@@ -513,6 +513,26 @@ dependencies = [
"which",
]
[[package]]
name = "bindgen"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.1",
"shlex",
"syn 2.0.100",
]
[[package]]
name = "bindgen_cuda"
version = "0.1.5"
......@@ -1606,6 +1626,8 @@ dependencies = [
"minijinja",
"minijinja-contrib",
"ndarray",
"nixl-sys",
"oneshot",
"prometheus",
"proptest",
"rand 0.9.1",
......@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.52.6",
"windows-targets 0.48.5",
]
[[package]]
......@@ -3441,7 +3463,7 @@ version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [
"bindgen",
"bindgen 0.69.5",
"cc",
"cmake",
"find_cuda_helper",
......@@ -4057,6 +4079,21 @@ dependencies = [
"libc",
]
[[package]]
name = "nixl-sys"
version = "0.2.1-rc.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4698155ec791ae888482b0c1c5d19792a1b457167fa8ec0ffa487ffef8c8e5a"
dependencies = [
"bindgen 0.71.1",
"cc",
"libc",
"pkg-config",
"serde",
"thiserror 2.0.12",
"tracing",
]
[[package]]
name = "nkeys"
version = "0.4.4"
......@@ -4282,6 +4319,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
name = "onig"
version = "6.4.0"
......
......@@ -60,6 +60,7 @@ futures = { version = "0.3" }
hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] }
humantime = { version = "2.2.0" }
libc = { version = "0.2" }
oneshot = { version = "0.1.11", features = ["std", "async"] }
prometheus = { version = "0.14" }
rand = { version = "0.9.0" }
serde = { version = "1", features = ["derive"] }
......@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
url = {version = "2.5", features = ["serde"]}
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
[profile.dev.package]
insta.opt-level = 3
......
......@@ -108,6 +108,12 @@ WORKDIR /workspace
# Copy nixl source, and use commit hash as cache hint
COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install
### NATS & ETCD SETUP ###
# nats
......@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
WORKDIR /workspace
RUN yum update -y \
&& yum install -y llvm-toolset \
&& yum install -y python3.12-devel \
&& yum install -y protobuf-compiler \
&& yum clean all \
......@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base /usr/local/nixl /opt/nvidia/nvda_nixl
COPY --from=base /workspace /workspace
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
......@@ -342,7 +350,7 @@ COPY launch /workspace/launch
COPY deploy/sdk /workspace/deploy/sdk
# Build Rust crate binaries packaged with the wheel
RUN cargo build --release --locked --features mistralrs,python \
RUN cargo build --release --locked --features mistralrs,python,dynamo-llm/block-manager \
-p dynamo-run \
-p llmctl \
# Multiple http named crates are present in dependencies, need to specify the path
......@@ -370,6 +378,7 @@ WORKDIR /workspace
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /opt/nvidia/nvda_nixl /opt/nvidia/nvda_nixl
# Copy Cargo cache to avoid re-downloading dependencies
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
......@@ -377,7 +386,7 @@ COPY . /workspace
# Build rest of the crates
# Need to figure out rust caching to avoid rebuilding and remove exclude flags
RUN cargo build --release --locked --workspace \
RUN cargo build --release --locked --features block-manager --workspace \
--exclude dynamo-run \
--exclude llmctl \
--exclude file://$PWD/components/http \
......@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/
##########################################
########## Perf Analyzer Image ###########
......
......@@ -1058,6 +1058,7 @@ dependencies = [
"memmap2",
"minijinja",
"minijinja-contrib",
"oneshot",
"prometheus",
"rand 0.9.1",
"rayon",
......@@ -2859,6 +2860,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
name = "onig"
version = "6.4.0"
......
......@@ -51,6 +51,7 @@ module-name = "dynamo._core"
manifest-path = "Cargo.toml"
python-packages = ["dynamo"]
python-source = "src"
features = ["dynamo-llm/block-manager"]
[build-system]
requires = ["maturin>=1.0,<2.0", "patchelf"]
......
......@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
[features]
default = []
cuda_kv = ["dep:cudarc", "dep:ndarray"]
testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray"]
sentencepiece = ["dep:sentencepiece"]
[dependencies]
......@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
futures = { workspace = true }
hf-hub = { workspace = true }
rand = { workspace = true }
oneshot = { workspace = true }
prometheus = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
......@@ -72,8 +76,9 @@ derive-getters = "0.5"
regex = "1"
rayon = "1"
# kv_cuda
cudarc = { version = "0.16.2", features = ["cuda-12040"], optional = true }
# block_manager
nixl-sys = { version = "0.2.1-rc.1", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
ndarray = { version = "0.16", optional = true }
# protocols
......
......@@ -13,121 +13,128 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(not(feature = "cuda_kv"))]
fn main() {}
#[cfg(feature = "cuda_kv")]
fn main() {
use std::{path::PathBuf, process::Command};
println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let nvcc = Command::new("which").arg("nvcc").output().unwrap();
let cuda_lib = if nvcc.status.success() {
println!("cargo:info=nvcc found in path");
// Extract the path from nvcc location by removing "bin/nvcc"
let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
let path = PathBuf::from(nvcc_path);
if let Some(parent) = path.parent() {
// Remove "nvcc"
if let Some(cuda_root) = parent.parent() {
// Remove "bin"
cuda_root.to_string_lossy().to_string()
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
println!("cargo:warning=nvcc not found in path");
get_cuda_root_or_default()
};
println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// Link against multiple CUDA libraries
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
// Make sure the CUDA libraries are found before other system libraries
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
cuda_lib_path.display()
);
// Create kernels directory for output if it doesn't exist
std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
println!("Kernels directory already exists");
});
// Compile CUDA code
let output = Command::new("nvcc")
.arg("src/kernels/block_copy.cu")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-o")
.arg("src/kernels/libblock_copy.o")
.arg("-c")
.output()
.expect("Failed to compile CUDA code");
if !output.status.success() {
panic!(
"Failed to compile CUDA kernel: {}",
String::from_utf8_lossy(&output.stderr)
);
}
// Create static library
#[cfg(target_os = "windows")]
{
Command::new("lib")
.arg("/OUT:src/kernels/block_copy.lib")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
}
#[cfg(not(target_os = "windows"))]
{
Command::new("ar")
.arg("rcs")
.arg("src/kernels/libblock_copy.a")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
println!("cargo:warning=Building with CUDA KV off");
}
#[cfg(feature = "cuda_kv")]
fn get_cuda_root_or_default() -> String {
match std::env::var("CUDA_ROOT") {
Ok(path) => path,
Err(_) => {
// Default locations based on OS
if cfg!(target_os = "windows") {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
} else {
"/usr/local/cuda".to_string()
}
}
}
}
// NOTE: Preserving this build.rs for reference. We may want to re-enable
// custom kernel compilation in the future.
// #[cfg(not(feature = "cuda_kv"))]
// fn main() {}
// #[cfg(feature = "cuda_kv")]
// fn main() {
// use std::{path::PathBuf, process::Command};
// println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
// // first do a which nvcc, if it is in the path
// // if so, we don't need to set the cuda_lib
// let nvcc = Command::new("which").arg("nvcc").output().unwrap();
// let cuda_lib = if nvcc.status.success() {
// println!("cargo:info=nvcc found in path");
// // Extract the path from nvcc location by removing "bin/nvcc"
// let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
// let path = PathBuf::from(nvcc_path);
// if let Some(parent) = path.parent() {
// // Remove "nvcc"
// if let Some(cuda_root) = parent.parent() {
// // Remove "bin"
// cuda_root.to_string_lossy().to_string()
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// println!("cargo:warning=nvcc not found in path");
// get_cuda_root_or_default()
// };
// println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
// let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
// println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
// println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// // Link against multiple CUDA libraries
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// // Make sure the CUDA libraries are found before other system libraries
// println!(
// "cargo:rustc-link-arg=-Wl,-rpath,{}",
// cuda_lib_path.display()
// );
// // Create kernels directory for output if it doesn't exist
// std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
// println!("Kernels directory already exists");
// });
// // Compile CUDA code
// let output = Command::new("nvcc")
// .arg("src/kernels/block_copy.cu")
// .arg("-O3")
// .arg("--compiler-options")
// .arg("-fPIC")
// .arg("-o")
// .arg("src/kernels/libblock_copy.o")
// .arg("-c")
// .output()
// .expect("Failed to compile CUDA code");
// if !output.status.success() {
// panic!(
// "Failed to compile CUDA kernel: {}",
// String::from_utf8_lossy(&output.stderr)
// );
// }
// // Create static library
// #[cfg(target_os = "windows")]
// {
// Command::new("lib")
// .arg("/OUT:src/kernels/block_copy.lib")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// }
// #[cfg(not(target_os = "windows"))]
// {
// Command::new("ar")
// .arg("rcs")
// .arg("src/kernels/libblock_copy.a")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// }
// }
// #[cfg(feature = "cuda_kv")]
// fn get_cuda_root_or_default() -> String {
// match std::env::var("CUDA_ROOT") {
// Ok(path) => path,
// Err(_) => {
// // Default locations based on OS
// if cfg!(target_os = "windows") {
// "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
// } else {
// "/usr/local/cuda".to_string()
// }
// }
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block Manager for LLM KV Cache
//!
//! This module provides functionality for managing KV blocks in LLM attention
//! mechanisms. It handles storage allocation, block management, and safe access
//! patterns for both system memory and remote (NIXL) storage.
mod config;
mod state;
pub mod block;
pub mod events;
pub mod layout;
pub mod pool;
pub mod storage;
pub use crate::common::dtype::DType;
pub use block::{
nixl::{
AsBlockDescriptorSet, BlockDescriptorList, IsImmutable, IsMutable, MutabilityKind,
RemoteBlock,
},
transfer::{BlockTransferEngineV1, TransferRequestPut},
BasicMetadata, BlockMetadata, Blocks,
};
pub use config::*;
pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType};
pub use pool::BlockPool;
pub use storage::{
nixl::NixlRegisterableStorage, DeviceStorage, PinnedStorage, Storage, StorageAllocator,
};
pub use tokio_util::sync::CancellationToken;
use anyhow::{Context, Result};
use block::nixl::{BlockMutability, NixlBlockSet, RemoteBlocks, SerializedNixlBlockSet};
use derive_builder::Builder;
use nixl_sys::Agent as NixlAgent;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use storage::nixl::MemType;
use validator::Validate;
pub type WorkerID = u64;
pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>;
/// Represents the different cache levels for KV blocks
pub enum CacheLevel {
/// Represents KV blocks in GPU memory
G1,
/// Represents KV blocks in CPU memory
G2,
/// Represents KV blocks in Local NVMe storage
G3,
/// Represents KV blocks in Remote NVMe storage
G4,
}
// When we construct the pool:
// 1. instantiate the runtime,
// 2. build layout::LayoutConfigs for each of the requested storage types
// 3. register the layouts with the NIXL agent if enabled
// 4. construct a Blocks object for each layout providing a unique block_set_idx
// for each layout type.
// 5. initialize the pools for each set of blocks
pub struct KvBlockManager<Metadata: BlockMetadata> {
state: Arc<state::KvBlockManagerState<Metadata>>,
cancellation_token: CancellationToken,
}
impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
/// Create a new [KvBlockManager]
///
/// The returned object is a frontend to the [KvBlockManager] which owns the cancellation
/// tokens. When this object gets drop, the cancellation token will be cancelled and begin
/// the gracefully shutdown of the block managers internal state.
pub fn new(config: KvBlockManagerConfig) -> Result<Self> {
let mut config = config;
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let cancellation_token = config.runtime.cancellation_token.clone();
// The internal state will use a child token of the original token
config.runtime.cancellation_token = cancellation_token.child_token();
// Create the internal state
let state = state::KvBlockManagerState::new(config)?;
Ok(Self {
state,
cancellation_token,
})
}
/// Exports the local blockset configuration as a serialized object.
pub fn export_local_blockset(&self) -> Result<SerializedNixlBlockSet> {
self.state.export_local_blockset()
}
/// Imports a remote blockset configuration from a serialized object.
pub fn import_remote_blockset(
&self,
serialized_blockset: SerializedNixlBlockSet,
) -> Result<()> {
self.state.import_remote_blockset(serialized_blockset)
}
/// Get a [`Vec<RemoteBlock<IsImmutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_immutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsImmutable>>> {
self.state.get_remote_blocks_immutable(bds)
}
/// Get a [`Vec<RemoteBlock<IsMutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_mutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsMutable>>> {
self.state.get_remote_blocks_mutable(bds)
}
/// Get a reference to the host block pool
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.state.host()
}
/// Get a reference to the device block pool
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.state.device()
}
/// Get the worker ID
pub fn worker_id(&self) -> WorkerID {
self.state.worker_id()
}
}
impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
fn drop(&mut self) {
self.cancellation_token.cancel();
}
}
#[cfg(all(test, feature = "testing-full"))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
// Atomic Counter for Worker ID
static WORKER_ID: AtomicU64 = AtomicU64::new(1337);
fn create_reference_block_manager() -> ReferenceBlockManager {
let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
let config = KvBlockManagerConfig::builder()
.runtime(
KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.build()
.unwrap(),
)
.model(
KvManagerModelConfig::builder()
.num_layers(3)
.page_size(4)
.inner_dim(16)
.build()
.unwrap(),
)
.host_layout(
KvManagerLayoutConfig::builder()
.num_blocks(16)
.allocator(storage::PinnedAllocator::default())
.build()
.unwrap(),
)
.device_layout(
KvManagerLayoutConfig::builder()
.num_blocks(8)
.allocator(storage::DeviceAllocator::new(0).unwrap())
.build()
.unwrap(),
)
.build()
.unwrap();
ReferenceBlockManager::new(config).unwrap()
}
#[tokio::test]
async fn test_reference_block_manager_inherited_async_runtime() {
dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager();
}
#[test]
fn test_reference_block_manager_blocking() {
dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager();
}
// This tests mimics the behavior of two unique kvbm workers exchanging blocksets
// Each KvBlockManager is a unique worker in this test, each has its resources including
// it's own worker_ids, nixl_agent, and block pools.
//
// This test is meant to mimic the behavior of the basic nixl integration test found here:
// https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs
#[tokio::test]
async fn test_reference_block_managers() {
dynamo_runtime::logging::init();
// create two block managers - mimics two unique dynamo workers
let kvbm_0 = create_reference_block_manager();
let kvbm_1 = create_reference_block_manager();
assert_ne!(kvbm_0.worker_id(), kvbm_1.worker_id());
// in dynamo, we would exchange the blocksets via the discovery plane
let blockset_0 = kvbm_0.export_local_blockset().unwrap();
let blockset_1 = kvbm_1.export_local_blockset().unwrap();
// in dynamo, we would be watching the discovery plane for remote blocksets
kvbm_0.import_remote_blockset(blockset_1).unwrap();
kvbm_1.import_remote_blockset(blockset_0).unwrap();
// Worker 0
// Allocate 4 mutable blocks on the host
let blocks_0 = kvbm_0.host().unwrap().allocate_blocks(4).await.unwrap();
// Create a BlockDescriptorList for the mutable blocks
// let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap();
let blockset_0 = blocks_0.as_block_descriptor_set().unwrap();
// Worker 1
// Create a RemoteBlock list from blockset_0
let _blocks_1 = kvbm_1.host().unwrap().allocate_blocks(4).await.unwrap();
let mut _remote_blocks_0 = kvbm_1.get_remote_blocks_mutable(&blockset_0).unwrap();
// TODO(#967) - Enable with TransferEngine
// // Create a TransferRequestPut for the mutable blocks
// let transfer_request = TransferRequestPut::new(&blocks_0, &mut remote_blocks_0).unwrap();
// // Validate blocks - this could be an expensive operation
// // TODO: Create an ENV trigger debug flag which will call this on every transfer request
// // In this case, we expect an error because we have overlapping blocks as we are sending to/from the same blocks
// // because we are using the wrong target (artifact of the test setup allowing variable to cross what woudl be
// // worker boundaries)
// assert!(transfer_request.validate_blocks().is_err());
// // This is proper request - PUT from worker 1 (local) to worker 0 (remote)
// let transfer_request = TransferRequestPut::new(&blocks_1, &mut remote_blocks_0).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
// let mut put_request = PutRequestBuilder::<_, _>::builder();
// put_request.from(&blocks_1).to(&mut remote_blocks_0);
// // Create a Put request direct between two local blocks
// // split the blocks into two vecs each with 2 blocks
// let mut blocks_1 = blocks_1;
// let slice_0 = blocks_1.split_off(2);
// let mut slice_1 = blocks_1;
// let transfer_request = TransferRequestPut::new(&slice_0, &mut slice_1).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
use super::state::BlockState;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters;
#[derive(Debug, thiserror::Error)]
pub enum BlockRegistationError {
#[error("Block already registered")]
BlockAlreadyRegistered(SequenceHash),
#[error("Invalid state: {0}")]
InvalidState(String),
}
/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug, thiserror::Error)]
#[error("Failed to unregister block: {0}")]
pub struct UnregisterFailure(SequenceHash);
#[derive()]
pub struct BlockRegistry {
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
event_manager: Arc<dyn EventManager>,
}
impl BlockRegistry {
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
Self {
blocks: HashMap::new(),
event_manager,
}
}
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return true;
}
}
false
}
pub fn register_block(
&mut self,
block_state: &mut BlockState,
) -> Result<PublishHandle, BlockRegistationError> {
match block_state {
BlockState::Reset => Err(BlockRegistationError::InvalidState(
"Block is in Reset state".to_string(),
)),
BlockState::Partial(_partial) => Err(BlockRegistationError::InvalidState(
"Block is in Partial state".to_string(),
)),
BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash();
if let Some(handle) = self.blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let publish_handle =
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
let reg_handle = publish_handle.remove_handle();
// Insert the [RegistrationHandle] into the registry
self.blocks
.insert(sequence_hash, Arc::downgrade(&reg_handle));
// Update the [BlockState] to [BlockState::Registered]
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
Ok(publish_handle)
}
BlockState::Registered(registered) => Err(
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
),
}
}
pub fn unregister_block(
&mut self,
sequence_hash: SequenceHash,
) -> Result<(), UnregisterFailure> {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
self.blocks.remove(&sequence_hash);
return Ok(());
} else {
return Err(UnregisterFailure(sequence_hash));
}
}
Ok(())
}
fn create_publish_handle(
token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>,
) -> PublishHandle {
let reg_handle = RegistrationHandle::from_token_block(token_block, event_manager.clone());
PublishHandle::new(reg_handle, event_manager)
}
}
#[derive(Getters)]
pub struct RegistrationHandle {
#[getter(copy)]
block_hash: BlockHash,
#[getter(copy)]
sequence_hash: SequenceHash,
#[getter(copy)]
parent_sequence_hash: Option<SequenceHash>,
#[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>,
}
impl RegistrationHandle {
fn from_token_block(
token_block: &TokenBlock,
release_manager: Arc<dyn EventReleaseManager>,
) -> Self {
Self {
block_hash: token_block.block_hash(),
sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(),
release_manager,
}
}
}
impl std::fmt::Debug for RegistrationHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}",
self.sequence_hash, self.block_hash, self.parent_sequence_hash
)
}
}
impl Drop for RegistrationHandle {
fn drop(&mut self) {
self.release_manager.block_release(self);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::block_manager::events::tests::{EventType, MockEventManager};
use crate::tokens::{TokenBlockSequence, Tokens};
fn create_sequence() -> TokenBlockSequence {
let tokens = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
// NOTE: 1337 was the original seed, so we are temporarily using that here to prove the logic has not changed
let sequence = TokenBlockSequence::new(tokens, 4, Some(1337_u64));
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().len(), 2);
assert_eq!(sequence.blocks()[0].tokens(), &vec![1, 2, 3, 4]);
assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
assert_eq!(sequence.blocks()[1].tokens(), &vec![5, 6, 7, 8]);
assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
assert_eq!(sequence.current_block().tokens(), &vec![9, 10]);
sequence
}
#[test]
fn test_mock_event_manager_with_single_publish_handle() {
let sequence = create_sequence();
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(&sequence.blocks()[0], event_manager.clone());
// no event should have been triggered
assert!(rx.try_recv().is_err());
// we shoudl get two events when this is dropped, since we never took ownership of the RegistrationHandle
drop(publish_handle);
// the first event should be a Register event
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Register(sequence.blocks()[0].sequence_hash())
);
// the second event should be a Remove event
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(sequence.blocks()[0].sequence_hash())
);
// there should be no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_mock_event_manager_single_publish_handle_removed() {
let sequence = create_sequence();
let block_to_test = &sequence.blocks()[0];
let expected_sequence_hash = block_to_test.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(block_to_test, event_manager.clone());
// Remove the registration handle before dropping the publish handle
let reg_handle = publish_handle.remove_handle();
// no event should have been triggered yet
assert!(rx.try_recv().is_err());
// Drop the publish handle - it SHOULD trigger a Register event now because remove_handle doesn't disarm
drop(publish_handle);
let register_events = rx.try_recv().unwrap();
assert_eq!(
register_events.len(),
1,
"Register event should be triggered on PublishHandle drop"
);
assert_eq!(
register_events[0],
EventType::Register(expected_sequence_hash),
"Expected Register event"
);
// Drop the registration handle - this SHOULD trigger the Remove event
drop(reg_handle);
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(expected_sequence_hash),
"Only Remove event should be triggered"
);
// there should be no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_mock_event_manager_publisher_multiple_handles_removed() {
let sequence = create_sequence();
let block1 = &sequence.blocks()[0];
let block2 = &sequence.blocks()[1];
let hash1 = block1.sequence_hash();
let hash2 = block2.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
let publish_handle2 = BlockRegistry::create_publish_handle(block2, event_manager.clone());
// Remove handles before adding to publisher
let reg_handle1 = publish_handle1.remove_handle();
let reg_handle2 = publish_handle2.remove_handle();
// Add disarmed handles to publisher
publisher.take_handle(publish_handle1);
publisher.take_handle(publish_handle2);
// no events yet
assert!(rx.try_recv().is_err());
// Drop the publisher - should trigger a single Publish event with both Register events
drop(publisher);
let events = rx.try_recv().unwrap();
assert_eq!(
events.len(),
2,
"Should receive two Register events in one batch"
);
// Order isn't guaranteed, so check for both
assert!(events.contains(&EventType::Register(hash1)));
assert!(events.contains(&EventType::Register(hash2)));
// no more events immediately after publish
assert!(rx.try_recv().is_err());
// Drop registration handles individually - should trigger Remove events
drop(reg_handle1);
let events1 = rx.try_recv().unwrap();
assert_eq!(events1.len(), 1);
assert_eq!(events1[0], EventType::Remove(hash1));
drop(reg_handle2);
let events2 = rx.try_recv().unwrap();
assert_eq!(events2.len(), 1);
assert_eq!(events2[0], EventType::Remove(hash2));
// no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_publisher_empty_drop() {
let (event_manager, mut rx) = MockEventManager::new();
let publisher = event_manager.publisher();
drop(publisher);
// No events should be sent
assert!(rx.try_recv().is_err());
}
#[test]
fn test_publisher_publish_multiple_times() {
let sequence = create_sequence();
let block1 = &sequence.blocks()[0];
let hash1 = block1.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
publisher.take_handle(publish_handle1);
// First publish call
publisher.publish();
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], EventType::Register(hash1));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately.
let remove_events = rx.try_recv().unwrap();
assert_eq!(
remove_events.len(),
1,
"Remove event should be triggered after publish consumes the handle"
);
assert_eq!(
remove_events[0],
EventType::Remove(hash1),
"Expected Remove event"
);
// Second publish call (should do nothing as handles were taken)
publisher.publish();
assert!(rx.try_recv().is_err());
// Drop publisher (should also do nothing)
drop(publisher);
assert!(rx.try_recv().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use derive_getters::Getters;
use super::registry::RegistrationHandle;
use super::Result;
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};
#[derive(Debug, thiserror::Error)]
#[error("Block state is invalid: {0}")]
pub struct BlockStateInvalid(pub String);
#[derive(Debug)]
pub enum BlockState {
Reset,
Partial(PartialState),
Complete(CompleteState),
Registered(Arc<RegistrationHandle>),
}
impl BlockState {
pub fn initialize_sequence(
&mut self,
page_size: usize,
salt_hash: SaltHash,
) -> Result<(), BlockStateInvalid> {
if !matches!(self, BlockState::Reset) {
return Err(BlockStateInvalid("Block is not reset".to_string()));
}
let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash);
*self = BlockState::Partial(PartialState::new(block));
Ok(())
}
pub fn add_token(&mut self, token: Token) -> Result<()> {
match self {
BlockState::Partial(state) => Ok(state.block.push_token(token)?),
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
match self {
BlockState::Partial(state) => Ok(state.block.push_tokens(tokens)),
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn pop_token(&mut self) -> Result<()> {
match self {
BlockState::Partial(state) => {
state.block.pop_token()?;
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn pop_tokens(&mut self, count: usize) -> Result<()> {
match self {
BlockState::Partial(state) => {
state.block.pop_tokens(count)?;
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn commit(&mut self) -> Result<()> {
match self {
BlockState::Partial(state) => {
let token_block = state.block.commit()?;
*self = BlockState::Complete(CompleteState::new(token_block));
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
match self {
BlockState::Reset => {
*self = BlockState::Complete(CompleteState::new(token_block));
Ok(())
}
_ => Err(BlockStateInvalid("Block is not reset".to_string()))?,
}
}
/// Returns the number of tokens currently in the block.
pub fn len(&self) -> Option<usize> {
match self {
BlockState::Reset => Some(0),
BlockState::Partial(state) => Some(state.block.len()),
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
BlockState::Registered(_) => None,
}
}
/// Returns the number of additional tokens that can be added.
pub fn remaining(&self) -> usize {
match self {
BlockState::Partial(state) => state.block.remaining(),
_ => 0, // Reset, Complete, Registered have 0 remaining capacity
}
}
/// Returns true if the block contains no tokens.
pub fn is_empty(&self) -> bool {
match self {
BlockState::Reset => true,
BlockState::Partial(state) => state.block.is_empty(),
BlockState::Complete(_) => false, // Always full
BlockState::Registered(_) => false, // Always full
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub fn tokens(&self) -> Option<&Tokens> {
match self {
BlockState::Reset | BlockState::Registered(_) => None,
BlockState::Partial(state) => Some(state.block.tokens()),
BlockState::Complete(state) => Some(state.token_block.tokens()),
}
}
/// Returns true if the block is empty
pub fn is_reset(&self) -> bool {
matches!(self, BlockState::Reset)
}
/// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool {
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
}
/// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool {
matches!(self, BlockState::Registered(_state))
}
}
#[derive(Debug)]
pub struct PartialState {
block: PartialTokenBlock,
}
impl PartialState {
pub fn new(block: PartialTokenBlock) -> Self {
Self { block }
}
}
#[derive(Debug, Getters)]
pub struct CompleteState {
token_block: TokenBlock,
}
impl CompleteState {
pub fn new(token_block: TokenBlock) -> Self {
Self { token_block }
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use super::TransferError;
use crate::block_manager::storage::{DeviceStorage, PinnedStorage};
use anyhow::Result;
use cudarc::driver::result as cuda_result;
use std::ops::Range;
type CudaMemcpyFnPtr = unsafe fn(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError>;
fn cuda_memcpy_fn_ptr(strategy: &TransferStrategy) -> Result<CudaMemcpyFnPtr, TransferError> {
match strategy {
TransferStrategy::CudaAsyncH2D => Ok(cuda_memcpy_h2d),
TransferStrategy::CudaAsyncD2H => Ok(cuda_memcpy_d2h),
TransferStrategy::CudaAsyncD2D => Ok(cuda_memcpy_d2d),
_ => Err(TransferError::ExecutionError(
"Unsupported copy strategy for CUDA memcpy async".into(),
)),
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub fn copy_block<'a, Source, Destination>(
sources: &'a Source,
destinations: &'a mut Destination,
stream: &CudaStream,
strategy: TransferStrategy,
) -> Result<(), TransferError>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)]
{
let expected_strategy =
expected_strategy::<Source::StorageType, Destination::StorageType>();
assert_eq!(strategy, expected_strategy);
}
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let src_view = src_data.block_view()?;
let mut dst_view = dst_data.block_view_mut()?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
}
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
copy_layers(
0..src_data.num_layers(),
sources,
destinations,
stream,
strategy,
)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn copy_layers<'a, Source, Destination>(
layer_range: Range<usize>,
sources: &'a Source,
destinations: &'a mut Destination,
stream: &CudaStream,
strategy: TransferStrategy,
) -> Result<(), TransferError>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)]
{
let expected_strategy =
expected_strategy::<Source::StorageType, Destination::StorageType>();
assert_eq!(strategy, expected_strategy);
}
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
}
}
Ok(())
}
/// Helper function to perform the appropriate CUDA memcpy based on storage types
// Allow dead code because it's used in debug assertions
#[allow(dead_code)]
fn expected_strategy<Source: Storage, Dest: Storage>() -> TransferStrategy {
match (
std::any::TypeId::of::<Source>(),
std::any::TypeId::of::<Dest>(),
) {
(src, dst)
if src == std::any::TypeId::of::<PinnedStorage>()
&& dst == std::any::TypeId::of::<DeviceStorage>() =>
{
TransferStrategy::CudaAsyncH2D
}
(src, dst)
if src == std::any::TypeId::of::<DeviceStorage>()
&& dst == std::any::TypeId::of::<PinnedStorage>() =>
{
TransferStrategy::CudaAsyncD2H
}
(src, dst)
if src == std::any::TypeId::of::<DeviceStorage>()
&& dst == std::any::TypeId::of::<DeviceStorage>() =>
{
TransferStrategy::CudaAsyncD2D
}
_ => TransferStrategy::Invalid,
}
}
/// H2D Implementation
#[inline(always)]
unsafe fn cuda_memcpy_h2d(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source host pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination device pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
let src_slice = std::slice::from_raw_parts(src_ptr, size);
cuda_result::memcpy_htod_async(dst_ptr as u64, src_slice, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA H2D memcpy failed: {}", e)))?;
Ok(())
}
/// D2H Implementation
#[inline(always)]
unsafe fn cuda_memcpy_d2h(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source device pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination host pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, size);
cuda_result::memcpy_dtoh_async(dst_slice, src_ptr as u64, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2H memcpy failed: {}", e)))?;
Ok(())
}
/// D2D Implementation
#[inline(always)]
unsafe fn cuda_memcpy_d2d(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source device pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination device pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
cuda_result::memcpy_dtod_async(dst_ptr as u64, src_ptr as u64, size, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2D memcpy failed: {}", e)))?;
Ok(())
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
use crate::block_manager::storage::{
DeviceAllocator, PinnedAllocator, StorageAllocator, StorageMemset,
};
#[test]
fn test_memset_and_transfer() {
// Create allocators
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
// Create CUDA stream
let stream = ctx.new_stream().unwrap();
// Allocate host and device memory
let mut host = pinned_allocator.allocate(1024).unwrap();
let mut device = device_allocator.allocate(1024).unwrap();
// Set a pattern in host memory
StorageMemset::memset(&mut host, 42, 0, 1024).unwrap();
// Verify host memory was set correctly
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 42));
}
// Copy host to device
unsafe {
cuda_memcpy_h2d(host.as_ptr(), device.as_mut_ptr(), 1024, stream.as_ref()).unwrap();
}
// Synchronize to ensure H2D copy is complete
stream.synchronize().unwrap();
// Clear host memory
StorageMemset::memset(&mut host, 0, 0, 1024).unwrap();
// Verify host memory was cleared
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 0));
}
// Copy back from device to host
unsafe {
cuda_memcpy_d2h(device.as_ptr(), host.as_mut_ptr(), 1024, stream.as_ref()).unwrap();
}
// Synchronize to ensure D2H copy is complete before verifying
stream.synchronize().unwrap();
// Verify the original pattern was restored
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 42));
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
/// Copy a block from a source to a destination using memcpy
pub fn copy_block<'a, Source, Destination>(
sources: &'a Source,
destinations: &'a mut Destination,
) -> Result<(), TransferError>
where
Source: ReadableBlock,
Destination: WritableBlock,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let src_view = src_data.block_view()?;
let mut dst_view = dst_data.block_view_mut()?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
}
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
copy_layers(0..src_data.num_layers(), sources, destinations)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using memcpy
pub fn copy_layers<'a, Source, Destination>(
layer_range: Range<usize>,
sources: &'a Source,
destinations: &'a mut Destination,
) -> Result<(), TransferError>
where
Source: ReadableBlock,
// <Source as ReadableBlock>::StorageType: SystemAccessible + Local,
Destination: WritableBlock,
// <Destination as WritableBlock>::StorageType: SystemAccessible + Local,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
}
}
Ok(())
}
#[inline(always)]
unsafe fn memcpy(src_ptr: *const u8, dst_ptr: *mut u8, size: usize) {
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination memory regions must not overlap for copy_nonoverlapping"
);
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::ops::Range;
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
let src_desc = src_data.block_view()?.as_nixl_descriptor();
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?;
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>,
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let xfer_req = nixl_agent.create_xfer_req(
XferOp::Write,
&src_dl,
&dst_dl,
&remote_worker_id,
Some(&xfer_args),
)?;
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This module implements the `WriteToStrategy` and `ReadFromStrategy` traits
//! for the common storage types.
use super::*;
impl WriteToStrategy<SystemStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<PinnedStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<DeviceStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaBlockingH2D
}
}
impl WriteToStrategy<SystemStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<PinnedStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<DeviceStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncH2D
}
}
impl WriteToStrategy<SystemStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaBlockingD2H
}
}
impl WriteToStrategy<PinnedStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncD2H
}
}
impl WriteToStrategy<DeviceStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncD2D
}
}
impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl<S> ReadFromStrategy<S> for SystemStorage
where
S: WriteToStrategy<SystemStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S> ReadFromStrategy<S> for PinnedStorage
where
S: WriteToStrategy<PinnedStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S> ReadFromStrategy<S> for DeviceStorage
where
S: WriteToStrategy<DeviceStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S {
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_to_strategy() {
// System to ...
assert_eq!(
<SystemStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Pinned to ...
assert_eq!(
<PinnedStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Device to ...
assert_eq!(
<DeviceStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
#[test]
fn read_from_strategy() {
// System to ...
assert_eq!(
<SystemStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Pinned to ...
assert_eq!(
<PinnedStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Device to ...
assert_eq!(
<DeviceStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block storage management.
//!
//! This module provides the implementation for managing collections of blocks
//! and their storage. It handles the relationship between storage, layout,
//! and individual blocks.
use super::{BlockData, BlockError, Storage};
pub trait Kind: std::marker::Sized + std::fmt::Debug + Clone + Copy + Send + Sync {}
#[derive(Debug, Clone, Copy)]
pub struct BlockKind;
impl Kind for BlockKind {}
#[derive(Debug, Clone, Copy)]
pub struct LayerKind;
impl Kind for LayerKind {}
pub type BlockView<'a, S> = MemoryView<'a, S, BlockKind>;
pub type BlockViewMut<'a, S> = MemoryViewMut<'a, S, BlockKind>;
pub type LayerView<'a, S> = MemoryView<'a, S, LayerKind>;
pub type LayerViewMut<'a, S> = MemoryViewMut<'a, S, LayerKind>;
/// Storage view that provides safe access to a region of storage
#[derive(Debug)]
pub struct MemoryView<'a, S: Storage, K: Kind> {
_block_data: &'a BlockData<S>,
addr: usize,
size: usize,
kind: std::marker::PhantomData<K>,
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage,
K: Kind,
{
/// Create a new storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
pub(crate) unsafe fn new(
_block_data: &'a BlockData<S>,
addr: usize,
size: usize,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
kind: std::marker::PhantomData,
})
}
/// Get a raw pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - Access patterns respect the storage's thread safety model
pub unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
/// Size of the view in bytes
pub fn size(&self) -> usize {
self.size
}
}
/// Mutable storage view that provides exclusive access to a region of storage
#[derive(Debug)]
pub struct MemoryViewMut<'a, S: Storage, K: Kind> {
_block_data: &'a mut BlockData<S>,
addr: usize,
size: usize,
kind: std::marker::PhantomData<K>,
}
impl<'a, S: Storage, K: Kind> MemoryViewMut<'a, S, K> {
/// Create a new mutable storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
/// - No other views exist for this region
pub(crate) unsafe fn new(
_block_data: &'a mut BlockData<S>,
addr: usize,
size: usize,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
kind: std::marker::PhantomData,
})
}
/// Get a raw mutable pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - No other references exist while the pointer is in use
/// - Access patterns respect the storage's thread safety model
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
/// Size of the view in bytes
pub fn size(&self) -> usize {
self.size
}
}
mod nixl {
use super::*;
use super::super::nixl::*;
pub use nixl_sys::{MemType, MemoryRegion, NixlDescriptor};
impl<S: Storage, K: Kind> MemoryRegion for MemoryView<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S, K> NixlDescriptor for MemoryView<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
}
}
impl<S: Storage, K: Kind> MemoryRegion for MemoryViewMut<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S: Storage, K: Kind> NixlDescriptor for MemoryViewMut<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
}
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage + NixlDescriptor, // Ensure the underlying storage is a NixlDescriptor
K: Kind,
{
/// Creates an immutable NIXL memory descriptor from this view.
pub fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'a, K, IsImmutable> {
NixlMemoryDescriptor::new(
self.addr as u64, // Address from the view
self.size(), // Size from the view
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl
)
}
}
impl<'a, S, K> MemoryViewMut<'a, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
/// Creates a mutable NIXL memory descriptor from this view.
// Note: We return a mutable descriptor even from an immutable borrow (&self)
// because the underlying memory region *can* be mutated.
pub fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'a, K, IsMutable> {
NixlMemoryDescriptor::new(
self.addr as u64,
self.size(),
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl
)
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
#[derive(Debug, Clone)]
pub enum NixlOptions {
/// Enable NIXL and create a new NIXL agent
Enabled,
/// Enable NIXL and use the provided NIXL agent
EnabledWithAgent(NixlAgent),
/// Disable NIXL
Disabled,
}
#[derive(Debug, Clone, Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvManagerRuntimeConfig {
pub worker_id: u64,
#[builder(default)]
pub cancellation_token: CancellationToken,
#[builder(default = "NixlOptions::Enabled")]
pub nixl: NixlOptions,
}
impl KvManagerRuntimeConfig {
pub fn builder() -> KvManagerRuntimeConfigBuilder {
KvManagerRuntimeConfigBuilder::default()
}
}
impl KvManagerRuntimeConfigBuilder {
pub fn enable_nixl(mut self) -> Self {
self.nixl = Some(NixlOptions::Enabled);
self
}
pub fn use_nixl_agent(mut self, agent: NixlAgent) -> Self {
self.nixl = Some(NixlOptions::EnabledWithAgent(agent));
self
}
pub fn disable_nixl(mut self) -> Self {
self.nixl = Some(NixlOptions::Disabled);
self
}
}
#[derive(Debug, Clone, Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvManagerModelConfig {
#[validate(range(min = 1))]
pub num_layers: usize,
#[validate(range(min = 1))]
pub page_size: usize,
#[validate(range(min = 1))]
pub inner_dim: usize,
#[builder(default = "DType::FP16")]
pub dtype: DType,
}
impl KvManagerModelConfig {
pub fn builder() -> KvManagerModelConfigBuilder {
KvManagerModelConfigBuilder::default()
}
}
#[derive(Builder, Validate)]
#[builder(pattern = "owned", build_fn(validate = "Self::validate"))]
pub struct KvManagerLayoutConfig<S: Storage + NixlRegisterableStorage> {
/// The number of blocks to allocate
#[validate(range(min = 1))]
pub num_blocks: usize,
/// The type of layout to use
#[builder(default = "LayoutType::FullyContiguous")]
pub layout_type: LayoutType,
/// Storage for the blocks
/// If provided, the blocks will be allocated from the provided storage
/// Otherwise, the blocks will be allocated from
#[builder(default)]
pub storage: Option<Vec<S>>,
/// If provided, the blocks will be allocated from the provided allocator
/// This option is mutually exclusive with the `storage` option
#[builder(default, setter(custom))]
pub allocator: Option<Arc<dyn StorageAllocator<S>>>,
}
impl<S: Storage + NixlRegisterableStorage> KvManagerLayoutConfig<S> {
/// Create a new builder for the KvManagerLayoutConfig
pub fn builder() -> KvManagerLayoutConfigBuilder<S> {
KvManagerLayoutConfigBuilder::default()
}
}
// Implement the validation and build functions on the generated builder type
// Note: derive_builder generates KvManagerBlockConfigBuilder<S>
impl<S: Storage + NixlRegisterableStorage> KvManagerLayoutConfigBuilder<S> {
/// Custom setter for the `allocator` field
pub fn allocator(mut self, allocator: impl StorageAllocator<S> + 'static) -> Self {
self.allocator = Some(Some(Arc::new(allocator)));
self
}
// Validation function
fn validate(&self) -> Result<(), String> {
match (self.storage.is_some(), self.allocator.is_some()) {
(true, false) | (false, true) => Ok(()), // XOR condition met
(true, true) => Err("Cannot provide both `storage` and `allocator`.".to_string()),
(false, false) => Err("Must provide either `storage` or `allocator`.".to_string()),
}
}
}
/// Configuration for the KvBlockManager
#[derive(Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvBlockManagerConfig {
/// Runtime configuration
///
/// This provides core runtime configuration for the KvBlockManager.
pub runtime: KvManagerRuntimeConfig,
/// Model configuration
///
/// This provides model-specific configuration for the KvBlockManager, specifically,
/// the number of layers and the size of the inner dimension which is directly related
/// to the type of attention used by the model.
///
/// Included in this configuration is also the page_size, i.e. the number of tokens that will
/// be represented in each "paged" KV block.
pub model: KvManagerModelConfig,
/// Specific configuration for the device layout
///
/// This includes the number of blocks and the layout of the data into the device memory/storage.
#[builder(default, setter(strip_option))]
pub device_layout: Option<KvManagerLayoutConfig<DeviceStorage>>,
/// Specific configuration for the host layout
///
/// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default, setter(strip_option))]
pub host_layout: Option<KvManagerLayoutConfig<PinnedStorage>>,
}
impl KvBlockManagerConfig {
/// Create a new builder for the KvBlockManagerConfig
pub fn builder() -> KvBlockManagerConfigBuilder {
KvBlockManagerConfigBuilder::default()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment