Unverified Commit 42ce6931 authored by Harrison Saturley-Hall's avatar Harrison Saturley-Hall Committed by GitHub
Browse files

feat: kv block manager (#965) (#1021)


Co-authored-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
parent cafc74eb
...@@ -47,15 +47,18 @@ jobs: ...@@ -47,15 +47,18 @@ jobs:
GITHUB_TOKEN: ${{ secrets.CI_TOKEN }} GITHUB_TOKEN: ${{ secrets.CI_TOKEN }}
run: | run: |
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }} ./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
- name: Run Rust checks (llm/block-manager)
run: |
docker run -w /workspace/lib/llm --name ${{ env.CONTAINER_ID }}_rust_checks ${{ steps.define_image_tag.outputs.image_tag }} bash -ec 'rustup component add rustfmt clippy && cargo fmt -- --check && cargo clippy --features block-manager --no-deps --all-targets -- -D warnings && cargo test --locked --all-targets --features=block-manager'
- name: Run pytest - name: Run pytest
env: env:
PYTEST_MARKS: "pre_merge or mypy" PYTEST_MARKS: "pre_merge or mypy"
run: | run: |
docker run -w /workspace --name ${{ env.CONTAINER_ID }} ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}" docker run -w /workspace --name ${{ env.CONTAINER_ID }}_pytest ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
- name: Copy test report from test Container - name: Copy test report from test Container
if: always() if: always()
run: | run: |
docker cp ${{ env.CONTAINER_ID }}:/workspace/${{ env.PYTEST_XML_FILE }} . docker cp ${{ env.CONTAINER_ID }}_pytest:/workspace/${{ env.PYTEST_XML_FILE }} .
- name: Archive test report - name: Archive test report
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
if: always() if: always()
......
...@@ -23,8 +23,6 @@ on: ...@@ -23,8 +23,6 @@ on:
# Run this workflow on pull requests targeting main but only if files in runtime/rust change. # Run this workflow on pull requests targeting main but only if files in runtime/rust change.
pull_request: pull_request:
branches:
- main
paths: paths:
- .github/workflows/pre-merge-rust.yml - .github/workflows/pre-merge-rust.yml
- 'lib/runtime/**' - 'lib/runtime/**'
......
...@@ -513,6 +513,26 @@ dependencies = [ ...@@ -513,6 +513,26 @@ dependencies = [
"which", "which",
] ]
[[package]]
name = "bindgen"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.1",
"shlex",
"syn 2.0.100",
]
[[package]] [[package]]
name = "bindgen_cuda" name = "bindgen_cuda"
version = "0.1.5" version = "0.1.5"
...@@ -1606,6 +1626,8 @@ dependencies = [ ...@@ -1606,6 +1626,8 @@ dependencies = [
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"ndarray", "ndarray",
"nixl-sys",
"oneshot",
"prometheus", "prometheus",
"proptest", "proptest",
"rand 0.9.1", "rand 0.9.1",
...@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"windows-targets 0.52.6", "windows-targets 0.48.5",
] ]
[[package]] [[package]]
...@@ -3441,7 +3463,7 @@ version = "0.1.103" ...@@ -3441,7 +3463,7 @@ version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9" checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.69.5",
"cc", "cc",
"cmake", "cmake",
"find_cuda_helper", "find_cuda_helper",
...@@ -4057,6 +4079,21 @@ dependencies = [ ...@@ -4057,6 +4079,21 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "nixl-sys"
version = "0.2.1-rc.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4698155ec791ae888482b0c1c5d19792a1b457167fa8ec0ffa487ffef8c8e5a"
dependencies = [
"bindgen 0.71.1",
"cc",
"libc",
"pkg-config",
"serde",
"thiserror 2.0.12",
"tracing",
]
[[package]] [[package]]
name = "nkeys" name = "nkeys"
version = "0.4.4" version = "0.4.4"
...@@ -4282,6 +4319,12 @@ version = "1.21.3" ...@@ -4282,6 +4319,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]] [[package]]
name = "onig" name = "onig"
version = "6.4.0" version = "6.4.0"
......
...@@ -60,6 +60,7 @@ futures = { version = "0.3" } ...@@ -60,6 +60,7 @@ futures = { version = "0.3" }
hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] } hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] }
humantime = { version = "2.2.0" } humantime = { version = "2.2.0" }
libc = { version = "0.2" } libc = { version = "0.2" }
oneshot = { version = "0.1.11", features = ["std", "async"] }
prometheus = { version = "0.14" } prometheus = { version = "0.14" }
rand = { version = "0.9.0" } rand = { version = "0.9.0" }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
...@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] } ...@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
url = {version = "2.5", features = ["serde"]} url = {version = "2.5", features = ["serde"]}
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] } xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
[profile.dev.package] [profile.dev.package]
insta.opt-level = 3 insta.opt-level = 3
......
...@@ -108,6 +108,12 @@ WORKDIR /workspace ...@@ -108,6 +108,12 @@ WORKDIR /workspace
# Copy nixl source, and use commit hash as cache hint # Copy nixl source, and use commit hash as cache hint
COPY --from=nixl_base /opt/nixl /opt/nixl COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install
### NATS & ETCD SETUP ### ### NATS & ETCD SETUP ###
# nats # nats
...@@ -310,6 +316,7 @@ ARG RELEASE_BUILD ...@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
WORKDIR /workspace WORKDIR /workspace
RUN yum update -y \ RUN yum update -y \
&& yum install -y llvm-toolset \
&& yum install -y python3.12-devel \ && yum install -y python3.12-devel \
&& yum install -y protobuf-compiler \ && yum install -y protobuf-compiler \
&& yum clean all \ && yum clean all \
...@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \ ...@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base /usr/local/nixl /opt/nvidia/nvda_nixl
COPY --from=base /workspace /workspace COPY --from=base /workspace /workspace
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
...@@ -342,7 +350,7 @@ COPY launch /workspace/launch ...@@ -342,7 +350,7 @@ COPY launch /workspace/launch
COPY deploy/sdk /workspace/deploy/sdk COPY deploy/sdk /workspace/deploy/sdk
# Build Rust crate binaries packaged with the wheel # Build Rust crate binaries packaged with the wheel
RUN cargo build --release --locked --features mistralrs,python \ RUN cargo build --release --locked --features mistralrs,python,dynamo-llm/block-manager \
-p dynamo-run \ -p dynamo-run \
-p llmctl \ -p llmctl \
# Multiple http named crates are present in dependencies, need to specify the path # Multiple http named crates are present in dependencies, need to specify the path
...@@ -370,6 +378,7 @@ WORKDIR /workspace ...@@ -370,6 +378,7 @@ WORKDIR /workspace
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/ COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/target/ /workspace/target/ COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /opt/nvidia/nvda_nixl /opt/nvidia/nvda_nixl
# Copy Cargo cache to avoid re-downloading dependencies # Copy Cargo cache to avoid re-downloading dependencies
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
...@@ -377,7 +386,7 @@ COPY . /workspace ...@@ -377,7 +386,7 @@ COPY . /workspace
# Build rest of the crates # Build rest of the crates
# Need to figure out rust caching to avoid rebuilding and remove exclude flags # Need to figure out rust caching to avoid rebuilding and remove exclude flags
RUN cargo build --release --locked --workspace \ RUN cargo build --release --locked --features block-manager --workspace \
--exclude dynamo-run \ --exclude dynamo-run \
--exclude llmctl \ --exclude llmctl \
--exclude file://$PWD/components/http \ --exclude file://$PWD/components/http \
...@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la ...@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing # Tell vllm to use the Dynamo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/
########################################## ##########################################
########## Perf Analyzer Image ########### ########## Perf Analyzer Image ###########
......
...@@ -1058,6 +1058,7 @@ dependencies = [ ...@@ -1058,6 +1058,7 @@ dependencies = [
"memmap2", "memmap2",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"oneshot",
"prometheus", "prometheus",
"rand 0.9.1", "rand 0.9.1",
"rayon", "rayon",
...@@ -2859,6 +2860,12 @@ version = "1.21.3" ...@@ -2859,6 +2860,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]] [[package]]
name = "onig" name = "onig"
version = "6.4.0" version = "6.4.0"
......
...@@ -51,6 +51,7 @@ module-name = "dynamo._core" ...@@ -51,6 +51,7 @@ module-name = "dynamo._core"
manifest-path = "Cargo.toml" manifest-path = "Cargo.toml"
python-packages = ["dynamo"] python-packages = ["dynamo"]
python-source = "src" python-source = "src"
features = ["dynamo-llm/block-manager"]
[build-system] [build-system]
requires = ["maturin>=1.0,<2.0", "patchelf"] requires = ["maturin>=1.0,<2.0", "patchelf"]
......
...@@ -27,7 +27,10 @@ description = "Dynamo LLM Library" ...@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
[features] [features]
default = [] default = []
cuda_kv = ["dep:cudarc", "dep:ndarray"] testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
[dependencies] [dependencies]
...@@ -48,6 +51,7 @@ etcd-client = { workspace = true } ...@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
hf-hub = { workspace = true } hf-hub = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
oneshot = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
...@@ -72,8 +76,9 @@ derive-getters = "0.5" ...@@ -72,8 +76,9 @@ derive-getters = "0.5"
regex = "1" regex = "1"
rayon = "1" rayon = "1"
# kv_cuda # block_manager
cudarc = { version = "0.16.2", features = ["cuda-12040"], optional = true } nixl-sys = { version = "0.2.1-rc.1", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
ndarray = { version = "0.16", optional = true } ndarray = { version = "0.16", optional = true }
# protocols # protocols
......
...@@ -13,121 +13,128 @@ ...@@ -13,121 +13,128 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(not(feature = "cuda_kv"))]
fn main() {}
#[cfg(feature = "cuda_kv")]
fn main() { fn main() {
use std::{path::PathBuf, process::Command}; println!("cargo:warning=Building with CUDA KV off");
println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let nvcc = Command::new("which").arg("nvcc").output().unwrap();
let cuda_lib = if nvcc.status.success() {
println!("cargo:info=nvcc found in path");
// Extract the path from nvcc location by removing "bin/nvcc"
let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
let path = PathBuf::from(nvcc_path);
if let Some(parent) = path.parent() {
// Remove "nvcc"
if let Some(cuda_root) = parent.parent() {
// Remove "bin"
cuda_root.to_string_lossy().to_string()
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
println!("cargo:warning=nvcc not found in path");
get_cuda_root_or_default()
};
println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// Link against multiple CUDA libraries
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
// Make sure the CUDA libraries are found before other system libraries
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
cuda_lib_path.display()
);
// Create kernels directory for output if it doesn't exist
std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
println!("Kernels directory already exists");
});
// Compile CUDA code
let output = Command::new("nvcc")
.arg("src/kernels/block_copy.cu")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-o")
.arg("src/kernels/libblock_copy.o")
.arg("-c")
.output()
.expect("Failed to compile CUDA code");
if !output.status.success() {
panic!(
"Failed to compile CUDA kernel: {}",
String::from_utf8_lossy(&output.stderr)
);
}
// Create static library
#[cfg(target_os = "windows")]
{
Command::new("lib")
.arg("/OUT:src/kernels/block_copy.lib")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
}
#[cfg(not(target_os = "windows"))]
{
Command::new("ar")
.arg("rcs")
.arg("src/kernels/libblock_copy.a")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
} }
#[cfg(feature = "cuda_kv")] // NOTE: Preserving this build.rs for reference. We may want to re-enable
fn get_cuda_root_or_default() -> String { // custom kernel compilation in the future.
match std::env::var("CUDA_ROOT") {
Ok(path) => path, // #[cfg(not(feature = "cuda_kv"))]
Err(_) => { // fn main() {}
// Default locations based on OS
if cfg!(target_os = "windows") { // #[cfg(feature = "cuda_kv")]
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string() // fn main() {
} else { // use std::{path::PathBuf, process::Command};
"/usr/local/cuda".to_string()
} // println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
}
} // // first do a which nvcc, if it is in the path
} // // if so, we don't need to set the cuda_lib
// let nvcc = Command::new("which").arg("nvcc").output().unwrap();
// let cuda_lib = if nvcc.status.success() {
// println!("cargo:info=nvcc found in path");
// // Extract the path from nvcc location by removing "bin/nvcc"
// let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
// let path = PathBuf::from(nvcc_path);
// if let Some(parent) = path.parent() {
// // Remove "nvcc"
// if let Some(cuda_root) = parent.parent() {
// // Remove "bin"
// cuda_root.to_string_lossy().to_string()
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// println!("cargo:warning=nvcc not found in path");
// get_cuda_root_or_default()
// };
// println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
// let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
// println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
// println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// // Link against multiple CUDA libraries
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// // Make sure the CUDA libraries are found before other system libraries
// println!(
// "cargo:rustc-link-arg=-Wl,-rpath,{}",
// cuda_lib_path.display()
// );
// // Create kernels directory for output if it doesn't exist
// std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
// println!("Kernels directory already exists");
// });
// // Compile CUDA code
// let output = Command::new("nvcc")
// .arg("src/kernels/block_copy.cu")
// .arg("-O3")
// .arg("--compiler-options")
// .arg("-fPIC")
// .arg("-o")
// .arg("src/kernels/libblock_copy.o")
// .arg("-c")
// .output()
// .expect("Failed to compile CUDA code");
// if !output.status.success() {
// panic!(
// "Failed to compile CUDA kernel: {}",
// String::from_utf8_lossy(&output.stderr)
// );
// }
// // Create static library
// #[cfg(target_os = "windows")]
// {
// Command::new("lib")
// .arg("/OUT:src/kernels/block_copy.lib")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// }
// #[cfg(not(target_os = "windows"))]
// {
// Command::new("ar")
// .arg("rcs")
// .arg("src/kernels/libblock_copy.a")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// }
// }
// #[cfg(feature = "cuda_kv")]
// fn get_cuda_root_or_default() -> String {
// match std::env::var("CUDA_ROOT") {
// Ok(path) => path,
// Err(_) => {
// // Default locations based on OS
// if cfg!(target_os = "windows") {
// "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
// } else {
// "/usr/local/cuda".to_string()
// }
// }
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block Manager for LLM KV Cache
//!
//! This module provides functionality for managing KV blocks in LLM attention
//! mechanisms. It handles storage allocation, block management, and safe access
//! patterns for both system memory and remote (NIXL) storage.
mod config;
mod state;
pub mod block;
pub mod events;
pub mod layout;
pub mod pool;
pub mod storage;
pub use crate::common::dtype::DType;
pub use block::{
nixl::{
AsBlockDescriptorSet, BlockDescriptorList, IsImmutable, IsMutable, MutabilityKind,
RemoteBlock,
},
transfer::{BlockTransferEngineV1, TransferRequestPut},
BasicMetadata, BlockMetadata, Blocks,
};
pub use config::*;
pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType};
pub use pool::BlockPool;
pub use storage::{
nixl::NixlRegisterableStorage, DeviceStorage, PinnedStorage, Storage, StorageAllocator,
};
pub use tokio_util::sync::CancellationToken;
use anyhow::{Context, Result};
use block::nixl::{BlockMutability, NixlBlockSet, RemoteBlocks, SerializedNixlBlockSet};
use derive_builder::Builder;
use nixl_sys::Agent as NixlAgent;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use storage::nixl::MemType;
use validator::Validate;
pub type WorkerID = u64;
pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>;
/// Represents the different cache levels for KV blocks
pub enum CacheLevel {
/// Represents KV blocks in GPU memory
G1,
/// Represents KV blocks in CPU memory
G2,
/// Represents KV blocks in Local NVMe storage
G3,
/// Represents KV blocks in Remote NVMe storage
G4,
}
// When we construct the pool:
// 1. instantiate the runtime,
// 2. build layout::LayoutConfigs for each of the requested storage types
// 3. register the layouts with the NIXL agent if enabled
// 4. construct a Blocks object for each layout providing a unique block_set_idx
// for each layout type.
// 5. initialize the pools for each set of blocks
pub struct KvBlockManager<Metadata: BlockMetadata> {
state: Arc<state::KvBlockManagerState<Metadata>>,
cancellation_token: CancellationToken,
}
impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
/// Create a new [KvBlockManager]
///
/// The returned object is a frontend to the [KvBlockManager] which owns the cancellation
/// tokens. When this object gets drop, the cancellation token will be cancelled and begin
/// the gracefully shutdown of the block managers internal state.
pub fn new(config: KvBlockManagerConfig) -> Result<Self> {
let mut config = config;
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let cancellation_token = config.runtime.cancellation_token.clone();
// The internal state will use a child token of the original token
config.runtime.cancellation_token = cancellation_token.child_token();
// Create the internal state
let state = state::KvBlockManagerState::new(config)?;
Ok(Self {
state,
cancellation_token,
})
}
/// Exports the local blockset configuration as a serialized object.
pub fn export_local_blockset(&self) -> Result<SerializedNixlBlockSet> {
self.state.export_local_blockset()
}
/// Imports a remote blockset configuration from a serialized object.
pub fn import_remote_blockset(
&self,
serialized_blockset: SerializedNixlBlockSet,
) -> Result<()> {
self.state.import_remote_blockset(serialized_blockset)
}
/// Get a [`Vec<RemoteBlock<IsImmutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_immutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsImmutable>>> {
self.state.get_remote_blocks_immutable(bds)
}
/// Get a [`Vec<RemoteBlock<IsMutable>>`] from a [`BlockDescriptorList`]
pub fn get_remote_blocks_mutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsMutable>>> {
self.state.get_remote_blocks_mutable(bds)
}
/// Get a reference to the host block pool
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.state.host()
}
/// Get a reference to the device block pool
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.state.device()
}
/// Get the worker ID
pub fn worker_id(&self) -> WorkerID {
self.state.worker_id()
}
}
impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
fn drop(&mut self) {
self.cancellation_token.cancel();
}
}
#[cfg(all(test, feature = "testing-full"))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
// Atomic Counter for Worker ID
static WORKER_ID: AtomicU64 = AtomicU64::new(1337);
fn create_reference_block_manager() -> ReferenceBlockManager {
let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
let config = KvBlockManagerConfig::builder()
.runtime(
KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.build()
.unwrap(),
)
.model(
KvManagerModelConfig::builder()
.num_layers(3)
.page_size(4)
.inner_dim(16)
.build()
.unwrap(),
)
.host_layout(
KvManagerLayoutConfig::builder()
.num_blocks(16)
.allocator(storage::PinnedAllocator::default())
.build()
.unwrap(),
)
.device_layout(
KvManagerLayoutConfig::builder()
.num_blocks(8)
.allocator(storage::DeviceAllocator::new(0).unwrap())
.build()
.unwrap(),
)
.build()
.unwrap();
ReferenceBlockManager::new(config).unwrap()
}
#[tokio::test]
async fn test_reference_block_manager_inherited_async_runtime() {
dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager();
}
#[test]
fn test_reference_block_manager_blocking() {
dynamo_runtime::logging::init();
let _block_manager = create_reference_block_manager();
}
// This tests mimics the behavior of two unique kvbm workers exchanging blocksets
// Each KvBlockManager is a unique worker in this test, each has its resources including
// it's own worker_ids, nixl_agent, and block pools.
//
// This test is meant to mimic the behavior of the basic nixl integration test found here:
// https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs
#[tokio::test]
async fn test_reference_block_managers() {
dynamo_runtime::logging::init();
// create two block managers - mimics two unique dynamo workers
let kvbm_0 = create_reference_block_manager();
let kvbm_1 = create_reference_block_manager();
assert_ne!(kvbm_0.worker_id(), kvbm_1.worker_id());
// in dynamo, we would exchange the blocksets via the discovery plane
let blockset_0 = kvbm_0.export_local_blockset().unwrap();
let blockset_1 = kvbm_1.export_local_blockset().unwrap();
// in dynamo, we would be watching the discovery plane for remote blocksets
kvbm_0.import_remote_blockset(blockset_1).unwrap();
kvbm_1.import_remote_blockset(blockset_0).unwrap();
// Worker 0
// Allocate 4 mutable blocks on the host
let blocks_0 = kvbm_0.host().unwrap().allocate_blocks(4).await.unwrap();
// Create a BlockDescriptorList for the mutable blocks
// let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap();
let blockset_0 = blocks_0.as_block_descriptor_set().unwrap();
// Worker 1
// Create a RemoteBlock list from blockset_0
let _blocks_1 = kvbm_1.host().unwrap().allocate_blocks(4).await.unwrap();
let mut _remote_blocks_0 = kvbm_1.get_remote_blocks_mutable(&blockset_0).unwrap();
// TODO(#967) - Enable with TransferEngine
// // Create a TransferRequestPut for the mutable blocks
// let transfer_request = TransferRequestPut::new(&blocks_0, &mut remote_blocks_0).unwrap();
// // Validate blocks - this could be an expensive operation
// // TODO: Create an ENV trigger debug flag which will call this on every transfer request
// // In this case, we expect an error because we have overlapping blocks as we are sending to/from the same blocks
// // because we are using the wrong target (artifact of the test setup allowing variable to cross what woudl be
// // worker boundaries)
// assert!(transfer_request.validate_blocks().is_err());
// // This is proper request - PUT from worker 1 (local) to worker 0 (remote)
// let transfer_request = TransferRequestPut::new(&blocks_1, &mut remote_blocks_0).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
// let mut put_request = PutRequestBuilder::<_, _>::builder();
// put_request.from(&blocks_1).to(&mut remote_blocks_0);
// // Create a Put request direct between two local blocks
// // split the blocks into two vecs each with 2 blocks
// let mut blocks_1 = blocks_1;
// let slice_0 = blocks_1.split_off(2);
// let mut slice_1 = blocks_1;
// let transfer_request = TransferRequestPut::new(&slice_0, &mut slice_1).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod registry;
pub mod state;
pub mod transfer;
pub mod view;
pub use crate::tokens::TokenBlockError;
pub use anyhow::Result;
use nixl_sys::NixlDescriptor;
pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{
state::{KvBlockManagerState as BlockManager, TransferContext},
storage::{Local, Remote, Storage},
};
use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens};
use transfer::{Immutable, Mutable, Readable, Writable};
use super::{
events::PublishHandle,
layout::{BlockLayout, LayoutError, LayoutType},
storage::StorageType,
WorkerID,
};
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
sync::Arc,
};
use thiserror::Error;
mod private {
pub struct PrivateToken;
}
/// A unique identifier for a block
pub type BlockId = usize;
/// A unique identifier for a block set
pub type BlockSetId = usize;
/// Result type for Block operations
pub type BlockResult<T> = std::result::Result<T, BlockError>;
/// Errors specific to block storage operations
#[derive(Debug, Error)]
pub enum BlockError {
#[error(transparent)]
Layout(#[from] LayoutError),
#[error("Invalid state: {0}")]
InvalidState(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + 'static {
/// Called when the block is acquired from the pool
fn on_acquired(&mut self, tick: u64);
/// Called when the block is returned to the pool
fn on_returned(&mut self, tick: u64);
/// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true
fn reset_metadata(&mut self);
}
/// Marker trait for types that are mutable blocks
pub trait WritableBlock: BlockDataProviderMut {
type StorageType: Storage + NixlDescriptor;
fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<<Self as WritableBlock>::StorageType>()
}
}
/// Marker trait for types that are immutable blocks
pub trait ReadableBlock: BlockDataProvider {
type StorageType: Storage + NixlDescriptor;
fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>()
}
fn transfer_context(&self) -> &TransferContext {
unimplemented!()
}
}
pub trait ReadableBlocks {}
impl<T: ReadableBlock> ReadableBlocks for Vec<T> {}
impl<T: ReadableBlock> ReadableBlocks for [T] {}
impl<T: ReadableBlock> ReadableBlocks for &[T] {}
pub trait WritableBlocks {}
impl<T: WritableBlock> WritableBlocks for Vec<T> {}
impl<T: WritableBlock> WritableBlocks for [T] {}
impl<T: WritableBlock> WritableBlocks for &[T] {}
/// Blanket trait for anything that can be viewed as a slice of blocks
pub trait AsBlockSlice<'a, B: 'a> {
fn as_block_slice(&'a self) -> &'a [B];
}
/// Blanket trait for anything that can be viewed as a mutable slice of blocks
pub trait AsBlockMutSlice<'a, B: 'a> {
fn as_block_mut_slice(&'a mut self) -> &'a mut [B];
}
/// Blanket trait for anything that can be converted into a mutable block
pub trait IntoWritableBlocks<M: BlockMetadata> {
type Output: WritableBlocks;
fn into_writable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output>;
}
impl<T: WritableBlocks, M: BlockMetadata> IntoWritableBlocks<M> for T {
type Output = T;
fn into_writable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(self)
}
}
pub trait IntoReadableBlocks<M: BlockMetadata> {
type Output: ReadableBlocks;
fn into_readable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output>;
}
impl<T: ReadableBlocks, M: BlockMetadata> IntoReadableBlocks<M> for T {
type Output = T;
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(self)
}
}
/// A block with storage and associated metadata/state
#[derive(Debug)]
pub struct Block<S: Storage, M: BlockMetadata> {
data: BlockData<S>,
metadata: M,
state: BlockState,
manager: Option<Arc<BlockManager<M>>>,
}
impl<S: Storage, M: BlockMetadata> Block<S, M> {
/// Create a new block with default metadata/state
pub fn new(data: BlockData<S>, metadata: M) -> BlockResult<Self> {
Ok(Self {
data,
metadata,
state: BlockState::Reset,
manager: None,
})
}
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState(
"Block is not complete".to_string(),
)),
}
}
pub(crate) fn reset(&mut self) {
self.state = BlockState::Reset;
self.metadata.reset_metadata();
}
pub(crate) fn set_manager(&mut self, manager: Arc<BlockManager<M>>) {
self.manager = Some(manager);
}
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
self.manager.as_ref()
}
/// Get the metadata of the block
pub fn metadata(&self) -> &M {
&self.metadata
}
/// Update the metadata of the block
pub fn update_metadata(&mut self, metadata: M) {
self.metadata = metadata;
}
/// Update the state of the block
#[allow(dead_code)]
pub(crate) fn update_state(&mut self, state: BlockState) {
self.state = state;
}
/// Get a reference to the state of the block
pub fn state(&self) -> &BlockState {
&self.state
}
/// Get the number of blocks in the block
pub fn num_blocks(&self) -> usize {
self.data.layout.num_blocks()
}
/// Get the number of layers in the block
pub fn num_layers(&self) -> usize {
self.data.layout.num_layers()
}
/// Get the size of each block in the block
pub fn page_size(&self) -> usize {
self.data.layout.page_size()
}
/// Get the inner dimension of the block
pub fn inner_dim(&self) -> usize {
self.data.layout.inner_dim()
}
pub(crate) fn metadata_on_acquired(&mut self, tick: u64) {
self.metadata.on_acquired(tick);
}
pub(crate) fn metadata_on_returned(&mut self, tick: u64) {
self.metadata.on_returned(tick);
}
}
pub(crate) trait PrivateBlockExt {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError>;
}
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError> {
registry.register_block(&mut self.state)
}
}
pub trait BlockExt {
/// Reset the state of the block
fn reset(&mut self);
/// Initialize a sequence on the block using a [SaltHash]
///
/// The block must be in the [BlockState::Reset] state.
///
/// After initialization, the block will be in the [BlockState::Partial] state.
fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()>;
/// Appends a single token to the block if it is in the Partial state and not full.
/// Returns `Err` if the block is not Partial or already full.
fn add_token(&mut self, token: Token) -> Result<()>;
/// Appends multiple tokens to the block if it is in the Partial state
/// and has enough remaining capacity for *all* provided tokens.
/// The block must be in the [BlockState::Partial] state.
/// Returns `Err` if the block is not Partial or if there isn't enough space.
fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens>;
/// Removes the last token from the block.
/// Requires the block to be in the Partial state and not empty.
/// Returns `Err` otherwise.
fn pop_token(&mut self) -> Result<()>;
/// Removes the last `count` tokens from the block.
/// Requires the block to be in the Partial state and have at least `count` tokens.
/// Returns `Err` otherwise.
fn pop_tokens(&mut self, count: usize) -> Result<()>;
/// Commit the block
/// Requires the block to be in the [BlockState::Partial] state and completely full.
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
fn commit(&mut self) -> Result<()>;
/// Apply a [TokenBlock] to the block
/// Requires the block to be in the [BlockState::Reset] state.
///
/// Additionally, the [TokenBlock] must match the [BlockLayout::page_size()]
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()>;
/// Returns the number of tokens currently in the block.
fn len(&self) -> usize;
/// Returns the number of additional tokens that can be added (only valid for Partial state).
fn remaining(&self) -> usize;
/// Returns true if the block contains no tokens (only true for Reset or empty Partial state).
fn is_empty(&self) -> bool;
/// Returns true if the block is full.
fn is_full(&self) -> bool;
/// Returns a list of tokens in the block.
fn tokens(&self) -> Option<&Tokens>;
}
impl<S: Storage, M: BlockMetadata> BlockExt for Block<S, M> {
fn reset(&mut self) {
Block::reset(self);
}
fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()> {
Ok(self
.state
.initialize_sequence(self.page_size(), salt_hash)?)
}
fn add_token(&mut self, token: Token) -> Result<()> {
self.state.add_token(token)
}
fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
self.state.add_tokens(tokens)
}
fn pop_token(&mut self) -> Result<()> {
self.state.pop_token()
}
fn pop_tokens(&mut self, count: usize) -> Result<()> {
self.state.pop_tokens(count)
}
fn commit(&mut self) -> Result<()> {
self.state.commit()
}
fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
if self.page_size() != token_block.tokens().len() {
return Err(BlockStateInvalid(format!(
"TokenBlock size ({}) does not match Block page size ({})",
token_block.tokens().len(),
self.page_size()
))
.into());
}
self.state.apply_token_block(token_block)
}
fn len(&self) -> usize {
match self.state.len() {
Some(len) => len,
None => self.page_size(),
}
}
fn remaining(&self) -> usize {
self.state.remaining()
}
fn is_empty(&self) -> bool {
self.state.is_empty()
}
fn is_full(&self) -> bool {
self.len() == self.page_size()
}
fn tokens(&self) -> Option<&Tokens> {
self.state.tokens()
}
}
pub trait BlockDataExt<S: Storage + NixlDescriptor> {
/// Returns true if the block data is fully contiguous
fn is_fully_contiguous(&self) -> bool;
/// Returns the number of layers in the block
fn num_layers(&self) -> usize;
/// Get a read-only view of this block's storage for a layer
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>>;
/// Get a mutable view of this block's storage for a layer
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>>;
/// Get a read-only view of this block's storage
fn block_view(&self) -> BlockResult<view::BlockView<S>>;
/// Get a mutable view of this block's storage
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>>;
}
/// Individual block storage - cannot be cloned to ensure uniqueness
#[derive(Debug)]
pub struct BlockData<S: Storage> {
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<S> BlockData<S>
where
S: Storage,
{
/// Create a new block storage
pub(crate) fn new(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_idx,
block_set_idx,
worker_id,
}
}
pub fn storage_type(&self) -> StorageType {
self.layout.storage_type()
}
}
impl<S: Storage + NixlDescriptor> BlockDataExt<S> for BlockData<S>
where
S: Storage + NixlDescriptor,
{
fn is_fully_contiguous(&self) -> bool {
self.layout.layout_type() == LayoutType::FullyContiguous
}
fn num_layers(&self) -> usize {
self.layout.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<S>> {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?;
unsafe { view::LayerView::new(self, offset as usize, self.layout.memory_region_size()) }
}
fn layer_view_mut(&mut self, layer_idx: usize) -> BlockResult<view::LayerViewMut<S>> {
let offset = self.layout.memory_region_addr(self.block_idx, layer_idx)?;
unsafe { view::LayerViewMut::new(self, offset as usize, self.layout.memory_region_size()) }
}
fn block_view(&self) -> BlockResult<view::BlockView<S>> {
if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers();
unsafe { view::BlockView::new(self, offset as usize, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<S>> {
if self.is_fully_contiguous() {
let offset = self.layout.memory_region_addr(self.block_idx, 0)?;
let size = self.layout.memory_region_size() * self.layout.num_layers();
unsafe { view::BlockViewMut::new(self, offset as usize, size) }
} else {
Err(BlockError::InvalidState(
"Block is not fully contiguous".to_string(),
))
}
}
}
pub trait BlockDataProvider {
type StorageType: Storage + NixlDescriptor;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<Self::StorageType>;
}
pub trait BlockDataProviderMut: BlockDataProvider {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>;
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
pub struct BasicMetadata {
priority: u32,
returned_tick: u64,
acquired_tick: u64,
}
impl BlockMetadata for BasicMetadata {
fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick;
}
fn on_returned(&mut self, tick: u64) {
self.returned_tick = tick;
}
fn reset_metadata(&mut self) {
self.priority = 0;
}
}
/// Collection that holds shared storage and layout
#[derive(Debug)]
pub struct Blocks<L: BlockLayout, M: BlockMetadata> {
layout: Box<L>,
metadata: std::marker::PhantomData<M>,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<L: BlockLayout + 'static, M: BlockMetadata> Blocks<L, M> {
/// Create a new block storage collection
pub fn new(layout: L, block_set_idx: usize, worker_id: WorkerID) -> BlockResult<Self> {
let layout = Box::new(layout);
Ok(Self {
layout,
metadata: std::marker::PhantomData,
block_set_idx,
worker_id,
})
}
/// Convert collection into Vec<Block> with default metadata/state
pub fn into_blocks(self) -> BlockResult<Vec<Block<L::StorageType, M>>> {
// convert box to arc
let layout: Arc<dyn BlockLayout<StorageType = L::StorageType>> = Arc::new(*self.layout);
layout_to_blocks(layout, self.block_set_idx, self.worker_id)
}
}
pub(crate) fn layout_to_blocks<S: Storage, M: BlockMetadata>(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_set_idx: usize,
worker_id: WorkerID,
) -> BlockResult<Vec<Block<S, M>>> {
(0..layout.num_blocks())
.map(|idx| {
let data = BlockData::new(layout.clone(), idx, block_set_idx, worker_id);
Block::new(data, M::default())
})
.collect()
}
pub struct MutableBlock<S: Storage, M: BlockMetadata> {
block: Option<Block<S, M>>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> WritableBlock for MutableBlock<S, M> {
type StorageType = S;
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for MutableBlock<S, M> {
type StorageType = S;
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Writable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Readable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Mutable for MutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Local for MutableBlock<S, M> {}
impl<S: Storage, M: BlockMetadata> MutableBlock<S, M> {
pub(crate) fn new(
block: Block<S, M>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
) -> Self {
Self {
block: Some(block),
return_tx,
}
}
}
impl<S: Storage, M: BlockMetadata> std::fmt::Debug for MutableBlock<S, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MutableBlock {{ block: {:?} }}", self.block)
}
}
impl<S: Storage, M: BlockMetadata> Drop for MutableBlock<S, M> {
fn drop(&mut self) {
if let Some(block) = self.block.take() {
if self.return_tx.send(block).is_err() {
tracing::warn!("block pool shutdown before block was returned");
}
}
}
}
impl<S: Storage, M: BlockMetadata> Deref for MutableBlock<S, M> {
type Target = Block<S, M>;
fn deref(&self) -> &Self::Target {
self.block.as_ref().expect("block was dropped")
}
}
impl<S: Storage, M: BlockMetadata> DerefMut for MutableBlock<S, M> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.block.as_mut().expect("block was dropped")
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProvider for MutableBlock<S, M> {
type StorageType = S;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<S> {
&self.block.as_ref().expect("block was dropped").data
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProviderMut for MutableBlock<S, M> {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<S> {
&mut self.block.as_mut().expect("block was dropped").data
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, MutableBlock<S, M>>
for [MutableBlock<S, M>]
{
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, M>] {
self
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, MutableBlock<S, M>>
for Vec<MutableBlock<S, M>>
{
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, M>] {
self.as_slice()
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockMutSlice<'a, MutableBlock<S, M>>
for [MutableBlock<S, M>]
{
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, M>] {
self
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockMutSlice<'a, MutableBlock<S, M>>
for Vec<MutableBlock<S, M>>
{
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, M>] {
self.as_mut_slice()
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoWritableBlocks<M> for MutableBlock<S, M> {
type Output = Vec<MutableBlock<S, M>>;
fn into_writable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoReadableBlocks<M> for MutableBlock<S, M> {
type Output = Vec<MutableBlock<S, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
#[derive(Debug)]
pub struct ImmutableBlock<S: Storage, M: BlockMetadata> {
block: Arc<MutableBlock<S, M>>,
}
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self {
Self { block }
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> {
type StorageType = S;
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Readable for ImmutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Immutable for ImmutableBlock<S, M> {}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> Local for ImmutableBlock<S, M> {}
impl<S: Storage, M: BlockMetadata> Deref for ImmutableBlock<S, M> {
type Target = Block<S, M>;
fn deref(&self) -> &Self::Target {
self.block
.as_ref()
.block
.as_ref()
.expect("block was dropped")
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> BlockDataProvider for ImmutableBlock<S, M> {
type StorageType = S;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<S> {
&self
.block
.as_ref()
.block
.as_ref()
.expect("block was dropped")
.data
}
}
impl<S: Storage + NixlDescriptor, M: BlockMetadata> IntoReadableBlocks<M> for ImmutableBlock<S, M> {
type Output = Vec<ImmutableBlock<S, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
impl<'a, S: Storage + NixlDescriptor, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
for [ImmutableBlock<S, M>]
{
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, M>] {
self
}
}
impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
for Vec<ImmutableBlock<S, M>>
{
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, M>] {
self.as_slice()
}
}
pub mod nixl {
use super::*;
use super::view::{BlockKind, Kind, LayerKind};
use super::super::{
layout::nixl::{NixlLayout, SerializedNixlBlockLayout},
storage::nixl::{MemType, NixlRegisterableStorage, NixlStorage},
WorkerID,
};
use derive_getters::{Dissolve, Getters};
use nixl_sys::{Agent as NixlAgent, MemoryRegion, NixlDescriptor, OptArgs};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// --- Mutability Marker ---
pub trait MutabilityKind: Debug + Clone + Copy + Send + Sync + 'static {}
#[derive(Debug, Clone, Copy)]
pub struct IsMutable;
impl MutabilityKind for IsMutable {}
#[derive(Debug, Clone, Copy)]
pub struct IsImmutable;
impl MutabilityKind for IsImmutable {}
impl<L: NixlLayout, M: BlockMetadata> Blocks<L, M>
where
L::StorageType: NixlRegisterableStorage,
{
/// Register the blocks with an NIXL agent
pub fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()> {
self.layout.nixl_register(agent, opt_args)
}
}
/// A unified, lifetime-bound descriptor containing information needed for NIXL operations.
/// Typed by Kind (Block/Layer) and Mutability (IsMutable/IsImmutable).
#[derive(Copy, Clone)] // Can be Copy/Clone as it holds basic data + markers
pub struct NixlMemoryDescriptor<'a, K: Kind, M: MutabilityKind> {
addr: u64,
size: usize,
mem_type: MemType,
device_id: u64,
_lifetime: std::marker::PhantomData<&'a ()>, // Binds the descriptor's lifetime to 'a
_kind: std::marker::PhantomData<K>, // Stores the Kind marker type
_mutability: std::marker::PhantomData<M>, // Stores the Mutability marker type
}
// Helper function to get the short type name
pub(crate) fn short_type_name<T>() -> &'static str {
let name = core::any::type_name::<T>();
name.split("::").last().unwrap_or(name)
}
// Implement Debug manually to avoid bounds on K/M
impl<K: Kind, M: MutabilityKind> Debug for NixlMemoryDescriptor<'_, K, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NixlMemoryDescriptor")
.field("addr", &self.addr)
.field("size", &self.size)
.field("mem_type", &self.mem_type)
.field("device_id", &self.device_id)
.field("kind", &short_type_name::<K>()) // Show marker types
.field("mutability", &short_type_name::<M>())
.finish()
}
}
impl<K: Kind, M: MutabilityKind> NixlMemoryDescriptor<'_, K, M> {
/// Creates a new NixlMemoryDescriptor. Typically called via conversion methods.
#[inline]
pub(crate) fn new(addr: u64, size: usize, mem_type: MemType, device_id: u64) -> Self {
Self {
addr,
size,
mem_type,
device_id,
_lifetime: std::marker::PhantomData,
_kind: std::marker::PhantomData,
_mutability: std::marker::PhantomData,
}
}
}
impl<K: Kind, M: MutabilityKind> MemoryRegion for NixlMemoryDescriptor<'_, K, M> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl<K: Kind, M: MutabilityKind> NixlDescriptor for NixlMemoryDescriptor<'_, K, M> {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
pub trait NixlBlockDataImmutable<S: Storage + NixlDescriptor>: BlockDataExt<S> {
/// Get the NIXL memory descriptor for the entire block
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>>;
/// Get the NIXL memory descriptor for a specific layer
fn as_layer_descriptor(
&self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>;
}
pub trait NixlBlockDataMutable<S: Storage + NixlDescriptor>:
BlockDataExt<S> + NixlBlockDataImmutable<S>
{
/// Get the NIXL memory descriptor for the entire block
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>>;
/// Get the NIXL memory descriptor for a specific layer
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>>;
}
impl<S: Storage + NixlDescriptor> NixlBlockDataImmutable<S> for BlockData<S> {
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>> {
Ok(self.block_view()?.as_nixl_descriptor())
}
fn as_layer_descriptor(
&self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
Ok(self.layer_view(layer_idx)?.as_nixl_descriptor())
}
}
impl<S: Storage + NixlDescriptor> NixlBlockDataMutable<S> for BlockData<S> {
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>> {
Ok(self.block_view_mut()?.as_nixl_descriptor_mut())
}
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
Ok(self.layer_view_mut(layer_idx)?.as_nixl_descriptor_mut())
}
}
/// Error type for NixlBlockSet serialization/deserialization failures.
#[derive(Debug, Error)]
pub enum NixlSerializationError {
#[error("Serialization failed: {0}")]
Serialize(#[from] serde_json::Error),
}
/// A strongly-typed wrapper for serialized NixlBlockSet data.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedNixlBlockSet(Vec<u8>);
impl TryFrom<&NixlBlockSet> for SerializedNixlBlockSet {
type Error = NixlSerializationError;
/// Serializes a NixlBlockSet into SerializedNixlBlockSet.
fn try_from(value: &NixlBlockSet) -> Result<Self, Self::Error> {
let bytes = serde_json::to_vec(value)?;
Ok(SerializedNixlBlockSet(bytes))
}
}
impl TryFrom<NixlBlockSet> for SerializedNixlBlockSet {
type Error = NixlSerializationError;
/// Serializes a NixlBlockSet into SerializedNixlBlockSet, consuming the original.
fn try_from(value: NixlBlockSet) -> Result<Self, Self::Error> {
let bytes = serde_json::to_vec(&value)?;
Ok(SerializedNixlBlockSet(bytes))
}
}
impl TryFrom<&SerializedNixlBlockSet> for NixlBlockSet {
type Error = NixlSerializationError;
/// Deserializes SerializedNixlBlockSet into a NixlBlockSet.
fn try_from(value: &SerializedNixlBlockSet) -> Result<Self, Self::Error> {
let block_set = serde_json::from_slice(&value.0)?;
Ok(block_set)
}
}
impl TryFrom<SerializedNixlBlockSet> for NixlBlockSet {
type Error = NixlSerializationError;
/// Deserializes SerializedNixlBlockSet into a NixlBlockSet, consuming the original.
fn try_from(value: SerializedNixlBlockSet) -> Result<Self, Self::Error> {
let block_set = serde_json::from_slice(&value.0)?;
Ok(block_set)
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize, Dissolve)]
pub struct NixlBlockSet {
/// The block set index
block_sets: HashMap<usize, SerializedNixlBlockLayout>,
/// Captures the NIXL metadata from [nixl_sys::Agent::get_local_md]
nixl_metadata: Vec<u8>,
/// Worker ID
worker_id: u64,
}
impl NixlBlockSet {
pub fn new(worker_id: u64) -> Self {
Self {
block_sets: HashMap::new(),
nixl_metadata: Vec::new(),
worker_id,
}
}
pub fn worker_id(&self) -> u64 {
self.worker_id
}
/// Get the block set for a given block set index
pub fn block_sets(&self) -> &HashMap<usize, SerializedNixlBlockLayout> {
&self.block_sets
}
/// Add a block set to the block set
pub fn add_block_set(
&mut self,
block_set_idx: usize,
serialized_layout: SerializedNixlBlockLayout,
) {
self.block_sets.insert(block_set_idx, serialized_layout);
}
/// Get the NIXL metadata
pub fn get_nixl_metadata(&self) -> &Vec<u8> {
&self.nixl_metadata
}
/// Set the NIXL metadata
pub fn set_nixl_metadata(&mut self, nixl_metadata: Vec<u8>) {
self.nixl_metadata = nixl_metadata;
}
}
#[derive(Debug, Clone)]
pub struct RemoteBlocks {
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_set_idx: usize,
worker_id: WorkerID,
}
impl RemoteBlocks {
pub fn new(
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_set_idx,
worker_id,
}
}
pub fn from_serialized(
serialized: SerializedNixlBlockLayout,
block_set_idx: usize,
worker_id: WorkerID,
) -> BlockResult<Self> {
let layout = serialized.deserialize()?;
Ok(Self::new(layout, block_set_idx, worker_id))
}
pub fn block<M: MutabilityKind>(&self, block_idx: usize) -> BlockResult<RemoteBlock<M>> {
if block_idx >= self.layout.num_blocks() {
return Err(BlockError::InvalidState(format!(
"block index out of bounds: {} >= {}",
block_idx,
self.layout.num_blocks()
)));
}
Ok(RemoteBlock::new(
self.layout.clone(),
block_idx,
self.block_set_idx,
self.worker_id,
))
}
/// Get the layout of the remote blocks
pub fn layout(&self) -> &dyn BlockLayout<StorageType = NixlStorage> {
self.layout.as_ref()
}
}
pub type ImmutableRemoteBlock = RemoteBlock<IsImmutable>;
pub type MutableRemoteBlock = RemoteBlock<IsMutable>;
pub struct RemoteBlock<M: MutabilityKind> {
data: BlockData<NixlStorage>,
_mutability: std::marker::PhantomData<M>,
}
impl<M: MutabilityKind> Remote for RemoteBlock<M> {}
impl<M: MutabilityKind> ReadableBlock for RemoteBlock<M> {
type StorageType = NixlStorage;
}
impl WritableBlock for RemoteBlock<IsMutable> {
type StorageType = NixlStorage;
}
impl<M: MutabilityKind> RemoteBlock<M> {
pub fn new(
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
let data = BlockData::new(layout, block_idx, block_set_idx, worker_id);
Self {
data,
_mutability: std::marker::PhantomData,
}
}
}
impl<M: MutabilityKind> BlockDataExt<NixlStorage> for RemoteBlock<M> {
fn is_fully_contiguous(&self) -> bool {
self.data.is_fully_contiguous()
}
fn num_layers(&self) -> usize {
self.data.num_layers()
}
fn layer_view(&self, layer_idx: usize) -> BlockResult<view::LayerView<NixlStorage>> {
self.data.layer_view(layer_idx)
}
fn layer_view_mut(
&mut self,
layer_idx: usize,
) -> BlockResult<view::LayerViewMut<NixlStorage>> {
self.data.layer_view_mut(layer_idx)
}
fn block_view(&self) -> BlockResult<view::BlockView<NixlStorage>> {
self.data.block_view()
}
fn block_view_mut(&mut self) -> BlockResult<view::BlockViewMut<NixlStorage>> {
self.data.block_view_mut()
}
}
impl<M: MutabilityKind> BlockDataProvider for RemoteBlock<M> {
type StorageType = NixlStorage;
fn block_data(&self, _: private::PrivateToken) -> &BlockData<NixlStorage> {
&self.data
}
}
impl<M: MutabilityKind> NixlBlockDataImmutable<NixlStorage> for RemoteBlock<M> {
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>> {
self.data.as_block_descriptor()
}
fn as_layer_descriptor(
&self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
self.data.as_layer_descriptor(layer_idx)
}
}
impl BlockDataProviderMut for RemoteBlock<IsMutable> {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<NixlStorage> {
&mut self.data
}
}
impl NixlBlockDataMutable<NixlStorage> for RemoteBlock<IsMutable> {
fn as_block_descriptor_mut(
&mut self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsMutable>> {
self.data.as_block_descriptor_mut()
}
fn as_layer_descriptor_mut(
&mut self,
layer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsMutable>> {
self.data.as_layer_descriptor_mut(layer_idx)
}
}
impl<'a, M: MutabilityKind> AsBlockSlice<'a, RemoteBlock<M>> for [RemoteBlock<M>] {
fn as_block_slice(&'a self) -> &'a [RemoteBlock<M>] {
self
}
}
impl<'a, M: MutabilityKind> AsBlockSlice<'a, RemoteBlock<M>> for Vec<RemoteBlock<M>> {
fn as_block_slice(&'a self) -> &'a [RemoteBlock<M>] {
self.as_slice()
}
}
impl<'a> AsBlockMutSlice<'a, RemoteBlock<IsMutable>> for [RemoteBlock<IsMutable>] {
fn as_block_mut_slice(&'a mut self) -> &'a mut [RemoteBlock<IsMutable>] {
self
}
}
impl<'a> AsBlockMutSlice<'a, RemoteBlock<IsMutable>> for Vec<RemoteBlock<IsMutable>> {
fn as_block_mut_slice(&'a mut self) -> &'a mut [RemoteBlock<IsMutable>] {
self.as_mut_slice()
}
}
/// Defines the intended access pattern for a block represented by a descriptor.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockMutability {
Immutable,
Mutable,
}
/// Describes a single block for identification and potential remote access setup.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BlockDescriptor {
pub worker_id: WorkerID,
pub block_set_idx: usize,
pub block_idx: usize,
pub mutability: BlockMutability,
}
// Placeholder Trait: Real pool handles must provide this info.
// This trait allows BlockDescriptorList constructors to be generic.
pub trait BlockHandleInfo {
fn worker_id(&self) -> WorkerID; // Needs access to the parent KvBlockManager's ID
fn block_set_idx(&self) -> usize;
fn block_idx(&self) -> usize;
}
impl<S: Storage> BlockHandleInfo for BlockData<S> {
fn worker_id(&self) -> WorkerID {
self.worker_id
}
fn block_set_idx(&self) -> usize {
self.block_set_idx
}
fn block_idx(&self) -> usize {
self.block_idx
}
}
impl<S: Storage, M: BlockMetadata> BlockHandleInfo for Block<S, M> {
fn worker_id(&self) -> WorkerID {
self.data.worker_id
}
fn block_set_idx(&self) -> usize {
self.data.block_set_idx
}
fn block_idx(&self) -> usize {
self.data.block_idx
}
}
/// A validated, homogeneous, and serializable collection of BlockDescriptors.
/// Primarily used to describe sets of remote blocks for transfer operations.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct BlockDescriptorList {
#[getter(copy)]
worker_id: WorkerID,
#[getter(copy)]
block_set_idx: usize,
#[getter(copy)]
mutability: BlockMutability,
block_indices: Vec<usize>,
// TODO: Consider storing MemType explicitly if it cannot be reliably
// derived from block_set_idx via the NixlBlockSet on the receiving side.
}
impl<M: BlockMetadata> IntoWritableBlocks<M> for BlockDescriptorList {
type Output = Vec<RemoteBlock<IsMutable>>;
fn into_writable_blocks(self, manager: &BlockManager<M>) -> BlockResult<Self::Output> {
Ok(manager.get_remote_blocks_mutable(&self)?)
}
}
#[derive(Debug, Error)]
pub enum BlockDescriptorSetError {
#[error("Input block list cannot be empty")]
EmptyInput,
#[error(
"Blocks in the input list are not homogeneous (worker_id, block_set_idx mismatch)"
)]
NotHomogeneous,
#[error("Serialization failed: {0}")]
SerializationError(#[from] serde_json::Error),
#[error(
"An invalid block handle was encountered (block may have been dropped prematurely)"
)]
InvalidBlockHandle,
}
impl BlockDescriptorList {
/// Creates a new validated BlockDescriptorList from a slice of block handles.
/// Ensures all handles belong to the same worker and block set.
fn new<S: Storage>(
blocks: &[&BlockData<S>], // Use the generic trait bound
mutability: BlockMutability,
) -> Result<Self, BlockDescriptorSetError> {
if blocks.is_empty() {
return Err(BlockDescriptorSetError::EmptyInput);
}
let first = blocks[0];
let worker_id = first.worker_id();
let block_set_idx = first.block_set_idx();
let mut block_indices = Vec::with_capacity(blocks.len());
block_indices.push(first.block_idx());
for block in blocks.iter().skip(1) {
// Validate homogeneity
if block.worker_id() != worker_id || block.block_set_idx() != block_set_idx {
return Err(BlockDescriptorSetError::NotHomogeneous);
}
block_indices.push(block.block_idx());
}
// TODO: Potentially validate MemType derived from block_set_idx here if possible
Ok(Self {
worker_id,
block_set_idx,
mutability,
block_indices,
})
}
/// Creates a BlockDescriptorList representing immutable blocks.
pub fn from_immutable_blocks<S: Storage, M: BlockMetadata>(
blocks: &[ImmutableBlock<S, M>],
) -> Result<Self, BlockDescriptorSetError> {
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let data: Vec<&BlockData<S>> = blocks
.iter()
.map(|b| b.block.block.as_ref().map(|inner_b| &inner_b.data))
.map(|opt| opt.ok_or(BlockDescriptorSetError::InvalidBlockHandle))
.collect::<Result<Vec<&BlockData<S>>, _>>()?;
Self::new(&data, BlockMutability::Immutable)
}
/// Creates a BlockDescriptorList representing mutable blocks.
pub fn from_mutable_blocks<S: Storage, M: BlockMetadata>(
blocks: &[MutableBlock<S, M>],
) -> Result<Self, BlockDescriptorSetError> {
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let data: Vec<&BlockData<S>> = blocks
.iter()
.map(|b| b.block.as_ref().map(|inner_b| &inner_b.data))
.map(|opt| opt.ok_or(BlockDescriptorSetError::InvalidBlockHandle))
.collect::<Result<Vec<&BlockData<S>>, _>>()?;
Self::new(&data, BlockMutability::Mutable)
}
// /// Serializes the BlockDescriptorList into a byte vector.
// pub fn serialize(&self) -> Result<Vec<u8>, BlockDescriptorSetError> {
// Ok(serde_json::to_vec(self)?)
// }
// /// Deserializes a BlockDescriptorList from a byte slice.
// pub fn deserialize(data: &[u8]) -> Result<Self, BlockDescriptorSetError> {
// Ok(serde_json::from_slice(data)?)
// }
}
pub trait AsBlockDescriptorSet {
type Block;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError>;
}
impl<S, M> AsBlockDescriptorSet for [ImmutableBlock<S, M>]
where
S: Storage,
M: BlockMetadata,
{
type Block = ImmutableBlock<S, M>;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
BlockDescriptorList::from_immutable_blocks(self)
}
}
impl<S, M> AsBlockDescriptorSet for [MutableBlock<S, M>]
where
S: Storage,
M: BlockMetadata,
{
type Block = MutableBlock<S, M>;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
BlockDescriptorList::from_mutable_blocks(self)
}
}
impl<T> AsBlockDescriptorSet for Vec<T>
where
[T]: AsBlockDescriptorSet<Block = T>,
{
type Block = T;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
self.as_slice().as_block_descriptor_set()
}
}
impl<T, const N: usize> AsBlockDescriptorSet for [T; N]
where
[T]: AsBlockDescriptorSet<Block = T>,
{
type Block = T;
fn as_block_descriptor_set(&self) -> Result<BlockDescriptorList, BlockDescriptorSetError> {
self.as_slice().as_block_descriptor_set()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::nixl::*;
use super::super::layout::{
nixl::{NixlLayout, SerializedNixlBlockLayout, ToSerializedNixlBlockLayout},
tests::setup_layout,
FullyContiguous, LayoutConfig,
};
use crate::block_manager::storage::SystemAllocator;
use crate::tokens::TokenBlockSequence;
use dynamo_runtime::logging::init as init_logging;
use nixl_sys::Agent as NixlAgent;
const BLOCK_SIZE: usize = 4;
const SALT_HASH: SaltHash = 12345;
// Helper to create a default reset block
fn create_reset_block() -> Block<impl Storage, BasicMetadata> {
let layout = setup_layout(None).unwrap();
let data = BlockData::new(Arc::new(layout), 0, 42, 0);
Block::new(data, BasicMetadata::default()).unwrap()
}
// Helper to create a complete TokenBlock for testing apply_token_block
fn create_full_token_block() -> TokenBlock {
let tokens = Tokens::from(vec![1, 2, 3, 4]);
let salt_hash = SALT_HASH;
let block_size = BLOCK_SIZE;
let (mut blocks, _) =
TokenBlockSequence::split_tokens(tokens.as_ref(), block_size, salt_hash);
blocks.pop().unwrap()
}
#[test]
fn test_block_state_transitions_and_ops() {
let mut block = create_reset_block();
assert!(matches!(block.state(), BlockState::Reset));
// --- Reset State --- //
assert!(block.add_token(1).is_err(), "Append on Reset should fail");
assert!(
block.add_tokens(Tokens::from(vec![1])).is_err(),
"Extend on Reset should fail"
);
assert!(block.commit().is_err(), "Commit on Reset should fail");
assert!(block.pop_token().is_err(), "Pop on Reset should fail");
assert!(
block.pop_tokens(1).is_err(),
"Pop tokens on Reset should fail"
);
// --- Reset -> Partial (via init_sequence) --- //
assert!(block.init_sequence(SALT_HASH).is_ok());
assert!(matches!(block.state(), BlockState::Partial(_)));
// --- Partial State --- //
let invalid_block = create_full_token_block();
assert!(
block.apply_token_block(invalid_block).is_err(),
"Apply block on Partial should fail"
);
// Append tokens
assert!(block.add_token(1).is_ok()); // 1
assert!(block.add_token(2).is_ok()); // 1, 2
assert!(block.add_tokens(Tokens::from(vec![3])).is_ok()); // 1, 2, 3
assert_eq!(block.len(), 3);
// Extend beyond capacity (should fail)
let new_tokens = Tokens::from(vec![4, 5]);
assert_eq!(block.add_tokens(new_tokens.clone()).unwrap().as_ref(), &[5]);
// Extend to fill capacity
assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4
assert_eq!(block.len(), BLOCK_SIZE);
// Append when full (should fail)
assert!(block.add_token(5).is_err(), "Append on full Partial block");
// Pop tokens
assert!(block.pop_token().is_ok()); // After pop: 1, 2, 3
assert_eq!(block.len(), 3);
// Pop multiple tokens
assert!(block.pop_tokens(2).is_ok()); // After pop: [1]
assert_eq!(block.len(), 1);
// Pop too many tokens (should fail)
assert!(block.pop_tokens(2).is_err(), "Pop too many tokens");
assert_eq!(block.len(), 1);
// Pop last token
assert!(block.pop_token().is_ok()); // empty
assert_eq!(block.len(), 0);
assert!(block.is_empty());
// Fill block again for commit
assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok());
assert_eq!(block.len(), BLOCK_SIZE);
// --- Partial -> Complete (via commit) --- //
assert!(block.commit().is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
assert_eq!(block.tokens().unwrap().as_ref(), &[1, 2, 3, 4]);
// --- Complete State --- //
assert!(
block.init_sequence(SALT_HASH).is_err(),
"Init sequence on Complete should fail"
);
assert!(
block.add_token(5).is_err(),
"Append on Complete should fail"
);
assert!(
block.add_tokens(Tokens::from(vec![5])).is_err(),
"Extend on Complete should fail"
);
assert!(block.commit().is_err(), "Commit on Complete should fail");
assert!(block.pop_token().is_err(), "Pop on Complete should fail");
assert!(
block.pop_tokens(1).is_err(),
"Pop tokens on Complete should fail"
);
let invalid_block = create_full_token_block();
assert!(
block.apply_token_block(invalid_block).is_err(),
"Apply block on Complete should fail"
);
// --- Complete -> Reset (via reset) --- //
block.reset();
assert!(matches!(block.state(), BlockState::Reset));
// --- Reset -> Complete (via apply_token_block) --- //
let full_block = create_full_token_block();
assert!(block.apply_token_block(full_block.clone()).is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
let applied_tokens = block.tokens().unwrap();
assert_eq!(applied_tokens, full_block.tokens());
// Testing applying to a non-reset state:
let mut non_reset_block = create_reset_block();
non_reset_block.init_sequence(SALT_HASH).unwrap(); // Put in Partial state
assert!(
non_reset_block.apply_token_block(full_block).is_err(),
"Apply block to non-reset state"
);
}
#[test]
fn test_block_state_incomplete_commit() {
// Commit incomplete block (should fail)
let mut partial_block = create_reset_block();
partial_block.init_sequence(SALT_HASH).unwrap();
partial_block.add_token(1).unwrap();
partial_block.add_tokens(Tokens::from(vec![2, 3])).unwrap();
assert_eq!(partial_block.len(), 3);
assert!(
partial_block.commit().is_err(),
"Commit on incomplete Partial block"
);
}
#[test]
fn test_error_types() {
let mut block = create_reset_block();
block.init_sequence(SALT_HASH).unwrap();
// Fill the block
block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).unwrap();
// Append when full
let append_err = block.add_token(5).unwrap_err();
assert!(append_err.is::<TokenBlockError>());
assert_eq!(
*append_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::Full
);
// .add_tokens will try to fill the block and return the remaining tokens in the Tokens passed in
let new_tokens = Tokens::from(vec![5]);
let ret_tokens = block.add_tokens(new_tokens.clone()).unwrap();
assert_eq!(new_tokens, ret_tokens);
// Commit when full (should succeed)
block.commit().unwrap();
// Commit when Complete
let commit_err = block.commit().unwrap_err();
assert!(commit_err.is::<BlockStateInvalid>());
// Reset and test pop empty
block.reset();
block.init_sequence(SALT_HASH).unwrap();
let pop_err = block.pop_token().unwrap_err();
assert!(pop_err.is::<TokenBlockError>());
assert_eq!(
*pop_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::Empty
);
let pop_tokens_err = block.pop_tokens(1).unwrap_err();
assert!(pop_tokens_err.is::<TokenBlockError>());
assert_eq!(
*pop_tokens_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::InsufficientTokens
);
// Test commit incomplete
block.add_token(1).unwrap();
let commit_incomplete_err = block.commit().unwrap_err();
assert!(commit_incomplete_err.is::<TokenBlockError>());
assert_eq!(
*commit_incomplete_err
.downcast_ref::<TokenBlockError>()
.unwrap(),
TokenBlockError::Incomplete
);
}
#[test]
fn test_nixl_block_data_ext() {
init_logging();
let config = LayoutConfig::builder()
.num_blocks(10)
.num_layers(2)
.page_size(4)
.inner_dim(13)
.build()
.unwrap();
let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
let agent = NixlAgent::new("test").unwrap();
tracing::info!("Registering layout");
layout.nixl_register(&agent, None).unwrap();
tracing::info!("Layout registered");
let serialized = layout.serialize().unwrap();
let layout = Arc::new(layout);
let data = BlockData::new(layout.clone(), 0, 42, 0);
assert_eq!(data.block_idx(), 0);
assert_eq!(data.block_set_idx(), 42);
let block_desc = data.as_block_descriptor().unwrap();
println!("Block descriptor: {:?}", block_desc);
let data = BlockData::new(layout.clone(), 1, 42, 0);
assert_eq!(data.block_idx(), 1);
assert_eq!(data.block_set_idx(), 42);
let block_desc = data.as_block_descriptor().unwrap();
println!("Block descriptor: {:?}", block_desc);
let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
println!("Nixl layout: {:?}", remote_layout);
let remote_block = RemoteBlock::<IsMutable>::new(remote_layout.clone(), 0, 42, 0);
let remote_desc = remote_block.as_block_descriptor().unwrap();
println!("Remote Descriptor: {:?}", remote_desc);
// drop(layout);
tracing::info!("Layout dropped");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
use super::state::BlockState;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters;
#[derive(Debug, thiserror::Error)]
pub enum BlockRegistationError {
#[error("Block already registered")]
BlockAlreadyRegistered(SequenceHash),
#[error("Invalid state: {0}")]
InvalidState(String),
}
/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug, thiserror::Error)]
#[error("Failed to unregister block: {0}")]
pub struct UnregisterFailure(SequenceHash);
#[derive()]
pub struct BlockRegistry {
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
event_manager: Arc<dyn EventManager>,
}
impl BlockRegistry {
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
Self {
blocks: HashMap::new(),
event_manager,
}
}
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return true;
}
}
false
}
pub fn register_block(
&mut self,
block_state: &mut BlockState,
) -> Result<PublishHandle, BlockRegistationError> {
match block_state {
BlockState::Reset => Err(BlockRegistationError::InvalidState(
"Block is in Reset state".to_string(),
)),
BlockState::Partial(_partial) => Err(BlockRegistationError::InvalidState(
"Block is in Partial state".to_string(),
)),
BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash();
if let Some(handle) = self.blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let publish_handle =
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
let reg_handle = publish_handle.remove_handle();
// Insert the [RegistrationHandle] into the registry
self.blocks
.insert(sequence_hash, Arc::downgrade(&reg_handle));
// Update the [BlockState] to [BlockState::Registered]
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
Ok(publish_handle)
}
BlockState::Registered(registered) => Err(
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
),
}
}
pub fn unregister_block(
&mut self,
sequence_hash: SequenceHash,
) -> Result<(), UnregisterFailure> {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
self.blocks.remove(&sequence_hash);
return Ok(());
} else {
return Err(UnregisterFailure(sequence_hash));
}
}
Ok(())
}
fn create_publish_handle(
token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>,
) -> PublishHandle {
let reg_handle = RegistrationHandle::from_token_block(token_block, event_manager.clone());
PublishHandle::new(reg_handle, event_manager)
}
}
#[derive(Getters)]
pub struct RegistrationHandle {
#[getter(copy)]
block_hash: BlockHash,
#[getter(copy)]
sequence_hash: SequenceHash,
#[getter(copy)]
parent_sequence_hash: Option<SequenceHash>,
#[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>,
}
impl RegistrationHandle {
fn from_token_block(
token_block: &TokenBlock,
release_manager: Arc<dyn EventReleaseManager>,
) -> Self {
Self {
block_hash: token_block.block_hash(),
sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(),
release_manager,
}
}
}
impl std::fmt::Debug for RegistrationHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}",
self.sequence_hash, self.block_hash, self.parent_sequence_hash
)
}
}
impl Drop for RegistrationHandle {
fn drop(&mut self) {
self.release_manager.block_release(self);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::block_manager::events::tests::{EventType, MockEventManager};
use crate::tokens::{TokenBlockSequence, Tokens};
fn create_sequence() -> TokenBlockSequence {
let tokens = Tokens::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
// NOTE: 1337 was the original seed, so we are temporarily using that here to prove the logic has not changed
let sequence = TokenBlockSequence::new(tokens, 4, Some(1337_u64));
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().len(), 2);
assert_eq!(sequence.blocks()[0].tokens(), &vec![1, 2, 3, 4]);
assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
assert_eq!(sequence.blocks()[1].tokens(), &vec![5, 6, 7, 8]);
assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
assert_eq!(sequence.current_block().tokens(), &vec![9, 10]);
sequence
}
#[test]
fn test_mock_event_manager_with_single_publish_handle() {
let sequence = create_sequence();
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(&sequence.blocks()[0], event_manager.clone());
// no event should have been triggered
assert!(rx.try_recv().is_err());
// we shoudl get two events when this is dropped, since we never took ownership of the RegistrationHandle
drop(publish_handle);
// the first event should be a Register event
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Register(sequence.blocks()[0].sequence_hash())
);
// the second event should be a Remove event
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(sequence.blocks()[0].sequence_hash())
);
// there should be no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_mock_event_manager_single_publish_handle_removed() {
let sequence = create_sequence();
let block_to_test = &sequence.blocks()[0];
let expected_sequence_hash = block_to_test.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let publish_handle =
BlockRegistry::create_publish_handle(block_to_test, event_manager.clone());
// Remove the registration handle before dropping the publish handle
let reg_handle = publish_handle.remove_handle();
// no event should have been triggered yet
assert!(rx.try_recv().is_err());
// Drop the publish handle - it SHOULD trigger a Register event now because remove_handle doesn't disarm
drop(publish_handle);
let register_events = rx.try_recv().unwrap();
assert_eq!(
register_events.len(),
1,
"Register event should be triggered on PublishHandle drop"
);
assert_eq!(
register_events[0],
EventType::Register(expected_sequence_hash),
"Expected Register event"
);
// Drop the registration handle - this SHOULD trigger the Remove event
drop(reg_handle);
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0],
EventType::Remove(expected_sequence_hash),
"Only Remove event should be triggered"
);
// there should be no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_mock_event_manager_publisher_multiple_handles_removed() {
let sequence = create_sequence();
let block1 = &sequence.blocks()[0];
let block2 = &sequence.blocks()[1];
let hash1 = block1.sequence_hash();
let hash2 = block2.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
let publish_handle2 = BlockRegistry::create_publish_handle(block2, event_manager.clone());
// Remove handles before adding to publisher
let reg_handle1 = publish_handle1.remove_handle();
let reg_handle2 = publish_handle2.remove_handle();
// Add disarmed handles to publisher
publisher.take_handle(publish_handle1);
publisher.take_handle(publish_handle2);
// no events yet
assert!(rx.try_recv().is_err());
// Drop the publisher - should trigger a single Publish event with both Register events
drop(publisher);
let events = rx.try_recv().unwrap();
assert_eq!(
events.len(),
2,
"Should receive two Register events in one batch"
);
// Order isn't guaranteed, so check for both
assert!(events.contains(&EventType::Register(hash1)));
assert!(events.contains(&EventType::Register(hash2)));
// no more events immediately after publish
assert!(rx.try_recv().is_err());
// Drop registration handles individually - should trigger Remove events
drop(reg_handle1);
let events1 = rx.try_recv().unwrap();
assert_eq!(events1.len(), 1);
assert_eq!(events1[0], EventType::Remove(hash1));
drop(reg_handle2);
let events2 = rx.try_recv().unwrap();
assert_eq!(events2.len(), 1);
assert_eq!(events2[0], EventType::Remove(hash2));
// no more events
assert!(rx.try_recv().is_err());
}
#[test]
fn test_publisher_empty_drop() {
let (event_manager, mut rx) = MockEventManager::new();
let publisher = event_manager.publisher();
drop(publisher);
// No events should be sent
assert!(rx.try_recv().is_err());
}
#[test]
fn test_publisher_publish_multiple_times() {
let sequence = create_sequence();
let block1 = &sequence.blocks()[0];
let hash1 = block1.sequence_hash();
let (event_manager, mut rx) = MockEventManager::new();
let mut publisher = event_manager.publisher();
let publish_handle1 = BlockRegistry::create_publish_handle(block1, event_manager.clone());
publisher.take_handle(publish_handle1);
// First publish call
publisher.publish();
let events = rx.try_recv().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(events[0], EventType::Register(hash1));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately.
let remove_events = rx.try_recv().unwrap();
assert_eq!(
remove_events.len(),
1,
"Remove event should be triggered after publish consumes the handle"
);
assert_eq!(
remove_events[0],
EventType::Remove(hash1),
"Expected Remove event"
);
// Second publish call (should do nothing as handles were taken)
publisher.publish();
assert!(rx.try_recv().is_err());
// Drop publisher (should also do nothing)
drop(publisher);
assert!(rx.try_recv().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use derive_getters::Getters;
use super::registry::RegistrationHandle;
use super::Result;
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};
#[derive(Debug, thiserror::Error)]
#[error("Block state is invalid: {0}")]
pub struct BlockStateInvalid(pub String);
#[derive(Debug)]
pub enum BlockState {
Reset,
Partial(PartialState),
Complete(CompleteState),
Registered(Arc<RegistrationHandle>),
}
impl BlockState {
pub fn initialize_sequence(
&mut self,
page_size: usize,
salt_hash: SaltHash,
) -> Result<(), BlockStateInvalid> {
if !matches!(self, BlockState::Reset) {
return Err(BlockStateInvalid("Block is not reset".to_string()));
}
let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash);
*self = BlockState::Partial(PartialState::new(block));
Ok(())
}
pub fn add_token(&mut self, token: Token) -> Result<()> {
match self {
BlockState::Partial(state) => Ok(state.block.push_token(token)?),
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
match self {
BlockState::Partial(state) => Ok(state.block.push_tokens(tokens)),
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn pop_token(&mut self) -> Result<()> {
match self {
BlockState::Partial(state) => {
state.block.pop_token()?;
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn pop_tokens(&mut self, count: usize) -> Result<()> {
match self {
BlockState::Partial(state) => {
state.block.pop_tokens(count)?;
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn commit(&mut self) -> Result<()> {
match self {
BlockState::Partial(state) => {
let token_block = state.block.commit()?;
*self = BlockState::Complete(CompleteState::new(token_block));
Ok(())
}
_ => Err(BlockStateInvalid("Block is not partial".to_string()))?,
}
}
pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
match self {
BlockState::Reset => {
*self = BlockState::Complete(CompleteState::new(token_block));
Ok(())
}
_ => Err(BlockStateInvalid("Block is not reset".to_string()))?,
}
}
/// Returns the number of tokens currently in the block.
pub fn len(&self) -> Option<usize> {
match self {
BlockState::Reset => Some(0),
BlockState::Partial(state) => Some(state.block.len()),
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
BlockState::Registered(_) => None,
}
}
/// Returns the number of additional tokens that can be added.
pub fn remaining(&self) -> usize {
match self {
BlockState::Partial(state) => state.block.remaining(),
_ => 0, // Reset, Complete, Registered have 0 remaining capacity
}
}
/// Returns true if the block contains no tokens.
pub fn is_empty(&self) -> bool {
match self {
BlockState::Reset => true,
BlockState::Partial(state) => state.block.is_empty(),
BlockState::Complete(_) => false, // Always full
BlockState::Registered(_) => false, // Always full
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub fn tokens(&self) -> Option<&Tokens> {
match self {
BlockState::Reset | BlockState::Registered(_) => None,
BlockState::Partial(state) => Some(state.block.tokens()),
BlockState::Complete(state) => Some(state.token_block.tokens()),
}
}
/// Returns true if the block is empty
pub fn is_reset(&self) -> bool {
matches!(self, BlockState::Reset)
}
/// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool {
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
}
/// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool {
matches!(self, BlockState::Registered(_state))
}
}
#[derive(Debug)]
pub struct PartialState {
block: PartialTokenBlock,
}
impl PartialState {
pub fn new(block: PartialTokenBlock) -> Self {
Self { block }
}
}
#[derive(Debug, Getters)]
pub struct CompleteState {
token_block: TokenBlock,
}
impl CompleteState {
pub fn new(token_block: TokenBlock) -> Self {
Self { token_block }
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod cuda;
mod memcpy;
mod nixl;
mod strategy;
use super::nixl::{IsMutable, NixlBlockDataImmutable, NixlBlockDataMutable, RemoteBlock};
use super::*;
use crate::block_manager::storage::{
nixl::{NixlRegisterableStorage, NixlStorage},
DeviceStorage, PinnedStorage, SystemStorage,
};
use cudarc::driver::CudaStream;
use std::ops::Range;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait;
/// A block that can be the target of a write
pub trait Writable {}
/// A block that can be the source of a read
pub trait Readable {}
pub trait Mutable: Readable + Writable {}
pub trait Immutable: Readable {}
#[derive(Debug)]
pub enum BlockTarget {
Source,
Destination,
}
#[derive(Debug, thiserror::Error)]
pub enum TransferError {
#[error("Builder configuration error: {0}")]
BuilderError(String),
#[error("Transfer execution failed: {0}")]
ExecutionError(String),
#[error("Incompatible block types provided: {0}")]
IncompatibleTypes(String),
#[error("Mismatched source/destination counts: {0} sources, {1} destinations")]
CountMismatch(usize, usize),
#[error("Block operation failed: {0}")]
BlockError(#[from] BlockError),
// TODO: Add NIXL specific errors
#[error("No blocks provided")]
NoBlocksProvided,
#[error("Mismatched {0:?} block set index: {1} != {2}")]
MismatchedBlockSetIndex(BlockTarget, usize, usize),
#[error("Mismatched {0:?} worker ID: {1} != {2}")]
MismatchedWorkerID(BlockTarget, usize, usize),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
Memcpy,
CudaAsyncH2D,
CudaAsyncD2H,
CudaAsyncD2D,
CudaBlockingH2D,
CudaBlockingD2H,
NixlWrite, // aka PUT
NixlRead, // aka GET
Invalid,
}
/// Trait for determining the transfer strategy for writing from a local
/// source to a target destination which could be local or remote
pub trait WriteToStrategy<Target> {
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Invalid
}
}
/// Trait for determining the transfer strategy for reading from a
/// `Source` which could be local or remote into `Self` which must
/// be both local and writable.
pub trait ReadFromStrategy<Source> {
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::Invalid
}
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteToStrategy<WB> for RB
where
<RB as ReadableBlock>::StorageType: Local + WriteToStrategy<<WB as WritableBlock>::StorageType>,
{
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
<<RB as ReadableBlock>::StorageType as WriteToStrategy<
<WB as WritableBlock>::StorageType,
>>::write_to_strategy()
}
}
impl<WB: WritableBlock, RB: ReadableBlock> ReadFromStrategy<RB> for WB
where
<RB as ReadableBlock>::StorageType: Remote,
<WB as WritableBlock>::StorageType: NixlRegisterableStorage,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
}
}
pub trait WriteTo<Target> {
fn write_to(&self, dst: &mut Target, notify: Option<String>) -> Result<(), TransferError>;
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
where
RB: WriteToStrategy<WB> + Local,
{
fn write_to(&self, dst: &mut WB, notify: Option<String>) -> Result<(), TransferError> {
let ctx = self.transfer_context();
match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
}
TransferStrategy::NixlWrite => Ok(nixl::write_block_to(self, dst, ctx, notify)?),
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
RB::write_to_strategy()
))),
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
}
#[derive(Default)]
pub struct GetXferRequestBuilder<
'xfer,
Source: BlockDataProvider,
Target: BlockDataProviderMut + Local,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
// impl<'xfer, Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
// GetXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_or_remote_blocks: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote_blocks);
// self
// }
// pub fn to(&mut self, local_mutable_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_mutable_blocks);
// self
// }
// }
pub struct PutXferRequestBuilder<
'xfer,
Source: BlockDataProvider + Local,
Target: BlockDataProviderMut,
> {
_src: Option<&'xfer [Source]>,
_dst: Option<&'xfer [Target]>,
}
// impl<'xfer, Source: BlockDataProvider + Local, Target: BlockDataProviderMut>
// PutXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_blocks);
// self
// }
// pub fn to(&mut self, local_or_remote: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote);
// self
// }
// }
// #[async_trait]
// impl<'xfer, Target: BlockDataProviderMut + Local>
// AsyncBlockTransferEngine<RemoteBlock<IsImmutable>, Target>
// for GetXferRequestBuilder<'xfer, RemoteBlock<IsImmutable>, Target>
// where
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// #[async_trait]
// impl<'xfer, Source, Target> AsyncBlockTransferEngine<Source, Target>
// for GetXferRequestBuilder<'xfer, Source, Target>
// where
// Source: BlockDataProvider + Local + Send + Sync,
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// pub trait BlockCopyTo<Target:BlockDataProviderMut + Local>: BlockDataProvider + Local {
// fn copy_blocks
#[async_trait]
pub trait AsyncBlockTransferEngine<Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
{
async fn execute(self) -> anyhow::Result<()>;
}
pub trait BlockTransferEngineV1<Source: BlockDataProvider, Target: BlockDataProviderMut> {
fn prepare(&mut self) -> Result<(), TransferError> {
Ok(())
}
fn execute(self) -> Result<(), TransferError>;
}
// memcpy transfer engine
// - System -> System
// - Pinned -> Pinned
// cuda memcpy transfer engine
// - Pinned -> Device
// - Device -> Pinned
// - Device -> Device
// nixl memcpy transfer engine
// - NixlRegisterableStorage -> Nixl
// - Nixl -> NixlRegisterableStorage
// where System, Pinned, Device are NixlRegisterableStorage
// Placeholder for the actual transfer plan
#[derive(Debug)]
pub struct TransferRequestPut<
'a,
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
> {
sources: &'a [Source],
destinations: &'a mut [Destination],
}
// --- NIXL PUT Transfer Implementation ---
impl<Source> BlockTransferEngineV1<Source, RemoteBlock<IsMutable>>
for TransferRequestPut<'_, Source, RemoteBlock<IsMutable>>
where
Source: BlockDataProvider + Local, // + NixlBlockDataMutable<Source::StorageType>,
Source::StorageType: NixlRegisterableStorage,
{
fn execute(self) -> Result<(), TransferError> {
self.validate_counts()?;
tracing::info!("Executing NIXL PUT transfer request");
// TODO: Get NixlAgent handle
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
let src_data = src_block.block_data(private::PrivateToken);
let src_nixl_desc = src_data.as_block_descriptor()?;
let dst_data = dst_block.block_data_mut(private::PrivateToken);
let dst_nixl_desc = dst_data.as_block_descriptor_mut()?;
// TODO: Perform NIXL PUT operation
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "NIXL PUT block");
tracing::trace!(src_desc = ?src_nixl_desc, dst_desc = ?dst_nixl_desc, "NIXL PUT block");
}
Ok(())
}
}
impl<'a, Source, Destination> TransferRequestPut<'a, Source, Destination>
where
Source: BlockDataProvider + Local,
Destination: BlockDataProviderMut,
{
pub fn new(
sources: &'a [Source],
destinations: &'a mut [Destination],
) -> Result<Self, TransferError> {
let transfer_request = Self {
sources,
destinations,
};
transfer_request.validate_counts()?;
Ok(transfer_request)
}
/// Validate blocks
///
/// For a put, we can have duplicate blocks on the source side, but all destinations must be unique
/// For all transfers, the source and destination block sets must be disjoint.
pub fn validate_blocks(&self) -> Result<(), TransferError> {
let mut src_set = std::collections::HashSet::new();
let mut dst_set = std::collections::HashSet::new();
for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter()) {
let src_data = src_block.block_data(private::PrivateToken);
let dst_data = dst_block.block_data(private::PrivateToken);
src_set.insert((
src_data.block_set_idx,
src_data.block_idx,
src_data.worker_id,
));
dst_set.insert((
dst_data.block_set_idx,
dst_data.block_idx,
dst_data.worker_id,
));
}
if dst_set.len() != self.destinations.len() {
return Err(TransferError::BuilderError(
"Duplicate destination blocks".to_string(),
));
}
// the intersection of src_set and dst_set must be empty
if !src_set.is_disjoint(&dst_set) {
return Err(TransferError::BuilderError(
"Duplicate one or more duplicate entries in source and destination list"
.to_string(),
));
}
Ok(())
}
/// Common validation for all PUT requests.
fn validate_counts(&self) -> Result<(), TransferError> {
if self.sources.len() != self.destinations.len() {
Err(TransferError::CountMismatch(
self.sources.len(),
self.destinations.len(),
))
} else if self.sources.is_empty() {
Err(TransferError::BuilderError(
"Sources cannot be empty".to_string(),
))
} else if self.destinations.is_empty() {
Err(TransferError::BuilderError(
"Destinations cannot be empty".to_string(),
))
} else {
Ok(())
}
}
}
// // --- Local Transfer Implementations ---
// // Local Pinned -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2H or std::ptr::copy
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Pinned -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2H
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// pub fn dispatch_copy_to<RB, WB>(
// src: &RB,
// dst: &mut WB,
// ctx: &TransferContext,
// ) -> Result<(), TransferError>
// where
// RB: ReadableBlock,
// WB: WritableBlock,
// // Ensure the necessary capability traits are implemented for the storage types
// // Note: These bounds aren't strictly *required* for the TypeId check,
// // but help ensure the backend calls will compile if a match occurs.
// // RB::Storage: SystemAccessible + CudaAccessible, // Might be too restrictive, apply within match arms
// // WB::Storage: SystemAccessible + CudaAccessible,
// {
// let src_type = src.storage_type_id();
// let dst_type = dst.storage_type_id();
// match (src_type, dst_type) {
// // === Memcpy Cases ===
// (s, d)
// if (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<PinnedStorage>()) =>
// {
// memcpy::memcpy_block(src, dst)
// }
// // === CUDA Cases ===
// (s, d)
// if (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<DeviceStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<DeviceStorage>()) =>
// {
// cuda::cuda_memcpy_block(src, dst, ctx.stream().as_ref())
// // let stream = stream.ok_or_else(|| {
// // TransferError::BuilderError("CUDA stream required for this transfer".into())
// // })?;
// // if is_cuda_compatible::<RB, WB>() {
// // tracing::debug!("Dispatching copy using CUDA");
// // cuda::cuda_memcpy_block(src_provider, dst_provider, stream) // Assumes cuda_memcpy_block is generic
// // } else {
// // Err(TransferError::IncompatibleTypes(
// // "CUDA copy requires CudaAccessible storage".into(),
// // ))
// // }
// }
// // === NIXL Cases ===
// (s, d)
// if d == TypeId::of::<NixlStorage>()
// && (s == TypeId::of::<SystemStorage>()
// || s == TypeId::of::<PinnedStorage>()
// || s == TypeId::of::<DeviceStorage>()) =>
// {
// unimplemented!()
// // tracing::debug!("Dispatching copy using NIXL PUT");
// // // TODO: Implement NIXL PUT logic
// // // You might need a specific NIXL transfer function here.
// // // Example: nixl::nixl_put_block(src_provider, dst_provider)
// // Err(TransferError::ExecutionError(
// // "NIXL PUT not yet implemented".into(),
// // ))
// }
// // TODO: Add NIXL GET cases (Nixl -> System/Pinned/Device)
// // === Error Case ===
// _ => Err(TransferError::IncompatibleTypes(format!(
// "Unsupported storage combination for copy: {:?} -> {:?}",
// std::any::type_name::<<RB as ReadableBlock>::StorageType>(), // Requires nightly or use debug print
// std::any::type_name::<<WB as WritableBlock>::StorageType>()
// ))),
// }
// }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_to_strategy() {
// System to ...
assert_eq!(
<SystemStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Pinned to ...
assert_eq!(
<PinnedStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Device to ...
assert_eq!(
<DeviceStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use super::TransferError;
use crate::block_manager::storage::{DeviceStorage, PinnedStorage};
use anyhow::Result;
use cudarc::driver::result as cuda_result;
use std::ops::Range;
type CudaMemcpyFnPtr = unsafe fn(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError>;
fn cuda_memcpy_fn_ptr(strategy: &TransferStrategy) -> Result<CudaMemcpyFnPtr, TransferError> {
match strategy {
TransferStrategy::CudaAsyncH2D => Ok(cuda_memcpy_h2d),
TransferStrategy::CudaAsyncD2H => Ok(cuda_memcpy_d2h),
TransferStrategy::CudaAsyncD2D => Ok(cuda_memcpy_d2d),
_ => Err(TransferError::ExecutionError(
"Unsupported copy strategy for CUDA memcpy async".into(),
)),
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub fn copy_block<'a, Source, Destination>(
sources: &'a Source,
destinations: &'a mut Destination,
stream: &CudaStream,
strategy: TransferStrategy,
) -> Result<(), TransferError>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)]
{
let expected_strategy =
expected_strategy::<Source::StorageType, Destination::StorageType>();
assert_eq!(strategy, expected_strategy);
}
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let src_view = src_data.block_view()?;
let mut dst_view = dst_data.block_view_mut()?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
}
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
copy_layers(
0..src_data.num_layers(),
sources,
destinations,
stream,
strategy,
)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn copy_layers<'a, Source, Destination>(
layer_range: Range<usize>,
sources: &'a Source,
destinations: &'a mut Destination,
stream: &CudaStream,
strategy: TransferStrategy,
) -> Result<(), TransferError>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
let memcpy_fn = cuda_memcpy_fn_ptr(&strategy)?;
#[cfg(debug_assertions)]
{
let expected_strategy =
expected_strategy::<Source::StorageType, Destination::StorageType>();
assert_eq!(strategy, expected_strategy);
}
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy_fn(
src_view.as_ptr(),
dst_view.as_mut_ptr(),
src_view.size(),
stream,
)?;
}
}
Ok(())
}
/// Helper function to perform the appropriate CUDA memcpy based on storage types
// Allow dead code because it's used in debug assertions
#[allow(dead_code)]
fn expected_strategy<Source: Storage, Dest: Storage>() -> TransferStrategy {
match (
std::any::TypeId::of::<Source>(),
std::any::TypeId::of::<Dest>(),
) {
(src, dst)
if src == std::any::TypeId::of::<PinnedStorage>()
&& dst == std::any::TypeId::of::<DeviceStorage>() =>
{
TransferStrategy::CudaAsyncH2D
}
(src, dst)
if src == std::any::TypeId::of::<DeviceStorage>()
&& dst == std::any::TypeId::of::<PinnedStorage>() =>
{
TransferStrategy::CudaAsyncD2H
}
(src, dst)
if src == std::any::TypeId::of::<DeviceStorage>()
&& dst == std::any::TypeId::of::<DeviceStorage>() =>
{
TransferStrategy::CudaAsyncD2D
}
_ => TransferStrategy::Invalid,
}
}
/// H2D Implementation
#[inline(always)]
unsafe fn cuda_memcpy_h2d(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source host pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination device pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
let src_slice = std::slice::from_raw_parts(src_ptr, size);
cuda_result::memcpy_htod_async(dst_ptr as u64, src_slice, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA H2D memcpy failed: {}", e)))?;
Ok(())
}
/// D2H Implementation
#[inline(always)]
unsafe fn cuda_memcpy_d2h(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source device pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination host pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, size);
cuda_result::memcpy_dtoh_async(dst_slice, src_ptr as u64, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2H memcpy failed: {}", e)))?;
Ok(())
}
/// D2D Implementation
#[inline(always)]
unsafe fn cuda_memcpy_d2d(
src_ptr: *const u8,
dst_ptr: *mut u8,
size: usize,
stream: &CudaStream,
) -> Result<(), TransferError> {
debug_assert!(!src_ptr.is_null(), "Source device pointer is null");
debug_assert!(!dst_ptr.is_null(), "Destination device pointer is null");
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination device memory regions must not overlap for D2D copy"
);
cuda_result::memcpy_dtod_async(dst_ptr as u64, src_ptr as u64, size, stream.cu_stream())
.map_err(|e| TransferError::ExecutionError(format!("CUDA D2D memcpy failed: {}", e)))?;
Ok(())
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
use crate::block_manager::storage::{
DeviceAllocator, PinnedAllocator, StorageAllocator, StorageMemset,
};
#[test]
fn test_memset_and_transfer() {
// Create allocators
let device_allocator = DeviceAllocator::default();
let pinned_allocator = PinnedAllocator::default();
let ctx = device_allocator.ctx().clone();
// Create CUDA stream
let stream = ctx.new_stream().unwrap();
// Allocate host and device memory
let mut host = pinned_allocator.allocate(1024).unwrap();
let mut device = device_allocator.allocate(1024).unwrap();
// Set a pattern in host memory
StorageMemset::memset(&mut host, 42, 0, 1024).unwrap();
// Verify host memory was set correctly
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 42));
}
// Copy host to device
unsafe {
cuda_memcpy_h2d(host.as_ptr(), device.as_mut_ptr(), 1024, stream.as_ref()).unwrap();
}
// Synchronize to ensure H2D copy is complete
stream.synchronize().unwrap();
// Clear host memory
StorageMemset::memset(&mut host, 0, 0, 1024).unwrap();
// Verify host memory was cleared
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 0));
}
// Copy back from device to host
unsafe {
cuda_memcpy_d2h(device.as_ptr(), host.as_mut_ptr(), 1024, stream.as_ref()).unwrap();
}
// Synchronize to ensure D2H copy is complete before verifying
stream.synchronize().unwrap();
// Verify the original pattern was restored
unsafe {
let ptr = host.as_ptr();
let slice = std::slice::from_raw_parts(ptr, 1024);
assert!(slice.iter().all(|&x| x == 42));
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
/// Copy a block from a source to a destination using memcpy
pub fn copy_block<'a, Source, Destination>(
sources: &'a Source,
destinations: &'a mut Destination,
) -> Result<(), TransferError>
where
Source: ReadableBlock,
Destination: WritableBlock,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let src_view = src_data.block_view()?;
let mut dst_view = dst_data.block_view_mut()?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
}
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
copy_layers(0..src_data.num_layers(), sources, destinations)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using memcpy
pub fn copy_layers<'a, Source, Destination>(
layer_range: Range<usize>,
sources: &'a Source,
destinations: &'a mut Destination,
) -> Result<(), TransferError>
where
Source: ReadableBlock,
// <Source as ReadableBlock>::StorageType: SystemAccessible + Local,
Destination: WritableBlock,
// <Destination as WritableBlock>::StorageType: SystemAccessible + Local,
{
let src_data = sources.block_data(private::PrivateToken);
let dst_data = destinations.block_data_mut(private::PrivateToken);
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
unsafe {
memcpy(src_view.as_ptr(), dst_view.as_mut_ptr(), src_view.size());
}
}
Ok(())
}
#[inline(always)]
unsafe fn memcpy(src_ptr: *const u8, dst_ptr: *mut u8, size: usize) {
debug_assert!(
(src_ptr as usize + size <= dst_ptr as usize)
|| (dst_ptr as usize + size <= src_ptr as usize),
"Source and destination memory regions must not overlap for copy_nonoverlapping"
);
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::ops::Range;
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
let src_desc = src_data.block_view()?.as_nixl_descriptor();
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?;
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?;
}
Ok(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>,
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let xfer_req = nixl_agent.create_xfer_req(
XferOp::Write,
&src_dl,
&dst_dl,
&remote_worker_id,
Some(&xfer_args),
)?;
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This module implements the `WriteToStrategy` and `ReadFromStrategy` traits
//! for the common storage types.
use super::*;
impl WriteToStrategy<SystemStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<PinnedStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<DeviceStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaBlockingH2D
}
}
impl WriteToStrategy<SystemStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<PinnedStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::Memcpy
}
}
impl WriteToStrategy<DeviceStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncH2D
}
}
impl WriteToStrategy<SystemStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaBlockingD2H
}
}
impl WriteToStrategy<PinnedStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncD2H
}
}
impl WriteToStrategy<DeviceStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::CudaAsyncD2D
}
}
impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl<S> ReadFromStrategy<S> for SystemStorage
where
S: WriteToStrategy<SystemStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S> ReadFromStrategy<S> for PinnedStorage
where
S: WriteToStrategy<PinnedStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S> ReadFromStrategy<S> for DeviceStorage
where
S: WriteToStrategy<DeviceStorage> + Storage + Local,
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
S::write_to_strategy()
}
}
impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S {
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_to_strategy() {
// System to ...
assert_eq!(
<SystemStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Pinned to ...
assert_eq!(
<PinnedStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Device to ...
assert_eq!(
<DeviceStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<DeviceStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
#[test]
fn read_from_strategy() {
// System to ...
assert_eq!(
<SystemStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<SystemStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaBlockingD2H
);
assert_eq!(
<SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Pinned to ...
assert_eq!(
<PinnedStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::Memcpy
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncD2H
);
assert_eq!(
<PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Device to ...
assert_eq!(
<DeviceStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
TransferStrategy::CudaBlockingH2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncH2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
TransferStrategy::CudaAsyncD2D
);
assert_eq!(
<DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block storage management.
//!
//! This module provides the implementation for managing collections of blocks
//! and their storage. It handles the relationship between storage, layout,
//! and individual blocks.
use super::{BlockData, BlockError, Storage};
pub trait Kind: std::marker::Sized + std::fmt::Debug + Clone + Copy + Send + Sync {}
#[derive(Debug, Clone, Copy)]
pub struct BlockKind;
impl Kind for BlockKind {}
#[derive(Debug, Clone, Copy)]
pub struct LayerKind;
impl Kind for LayerKind {}
pub type BlockView<'a, S> = MemoryView<'a, S, BlockKind>;
pub type BlockViewMut<'a, S> = MemoryViewMut<'a, S, BlockKind>;
pub type LayerView<'a, S> = MemoryView<'a, S, LayerKind>;
pub type LayerViewMut<'a, S> = MemoryViewMut<'a, S, LayerKind>;
/// Storage view that provides safe access to a region of storage
#[derive(Debug)]
pub struct MemoryView<'a, S: Storage, K: Kind> {
_block_data: &'a BlockData<S>,
addr: usize,
size: usize,
kind: std::marker::PhantomData<K>,
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage,
K: Kind,
{
/// Create a new storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
pub(crate) unsafe fn new(
_block_data: &'a BlockData<S>,
addr: usize,
size: usize,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
kind: std::marker::PhantomData,
})
}
/// Get a raw pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - Access patterns respect the storage's thread safety model
pub unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
/// Size of the view in bytes
pub fn size(&self) -> usize {
self.size
}
}
/// Mutable storage view that provides exclusive access to a region of storage
#[derive(Debug)]
pub struct MemoryViewMut<'a, S: Storage, K: Kind> {
_block_data: &'a mut BlockData<S>,
addr: usize,
size: usize,
kind: std::marker::PhantomData<K>,
}
impl<'a, S: Storage, K: Kind> MemoryViewMut<'a, S, K> {
/// Create a new mutable storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
/// - No other views exist for this region
pub(crate) unsafe fn new(
_block_data: &'a mut BlockData<S>,
addr: usize,
size: usize,
) -> Result<Self, BlockError> {
Ok(Self {
_block_data,
addr,
size,
kind: std::marker::PhantomData,
})
}
/// Get a raw mutable pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - No other references exist while the pointer is in use
/// - Access patterns respect the storage's thread safety model
pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
self.addr as *mut u8
}
/// Size of the view in bytes
pub fn size(&self) -> usize {
self.size
}
}
mod nixl {
use super::*;
use super::super::nixl::*;
pub use nixl_sys::{MemType, MemoryRegion, NixlDescriptor};
impl<S: Storage, K: Kind> MemoryRegion for MemoryView<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S, K> NixlDescriptor for MemoryView<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
}
}
impl<S: Storage, K: Kind> MemoryRegion for MemoryViewMut<'_, S, K> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size()
}
}
impl<S: Storage, K: Kind> NixlDescriptor for MemoryViewMut<'_, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
fn mem_type(&self) -> MemType {
self._block_data.layout.storage_type().nixl_mem_type()
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
}
}
impl<'a, S, K> MemoryView<'a, S, K>
where
S: Storage + NixlDescriptor, // Ensure the underlying storage is a NixlDescriptor
K: Kind,
{
/// Creates an immutable NIXL memory descriptor from this view.
pub fn as_nixl_descriptor(&self) -> NixlMemoryDescriptor<'a, K, IsImmutable> {
NixlMemoryDescriptor::new(
self.addr as u64, // Address from the view
self.size(), // Size from the view
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl
)
}
}
impl<'a, S, K> MemoryViewMut<'a, S, K>
where
S: Storage + NixlDescriptor,
K: Kind,
{
/// Creates a mutable NIXL memory descriptor from this view.
// Note: We return a mutable descriptor even from an immutable borrow (&self)
// because the underlying memory region *can* be mutated.
pub fn as_nixl_descriptor_mut(&mut self) -> NixlMemoryDescriptor<'a, K, IsMutable> {
NixlMemoryDescriptor::new(
self.addr as u64,
self.size(),
NixlDescriptor::mem_type(self), // Delegate to self's NixlDescriptor impl
NixlDescriptor::device_id(self), // Delegate to self's NixlDescriptor impl
)
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
#[derive(Debug, Clone)]
pub enum NixlOptions {
/// Enable NIXL and create a new NIXL agent
Enabled,
/// Enable NIXL and use the provided NIXL agent
EnabledWithAgent(NixlAgent),
/// Disable NIXL
Disabled,
}
#[derive(Debug, Clone, Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvManagerRuntimeConfig {
pub worker_id: u64,
#[builder(default)]
pub cancellation_token: CancellationToken,
#[builder(default = "NixlOptions::Enabled")]
pub nixl: NixlOptions,
}
impl KvManagerRuntimeConfig {
pub fn builder() -> KvManagerRuntimeConfigBuilder {
KvManagerRuntimeConfigBuilder::default()
}
}
impl KvManagerRuntimeConfigBuilder {
pub fn enable_nixl(mut self) -> Self {
self.nixl = Some(NixlOptions::Enabled);
self
}
pub fn use_nixl_agent(mut self, agent: NixlAgent) -> Self {
self.nixl = Some(NixlOptions::EnabledWithAgent(agent));
self
}
pub fn disable_nixl(mut self) -> Self {
self.nixl = Some(NixlOptions::Disabled);
self
}
}
#[derive(Debug, Clone, Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvManagerModelConfig {
#[validate(range(min = 1))]
pub num_layers: usize,
#[validate(range(min = 1))]
pub page_size: usize,
#[validate(range(min = 1))]
pub inner_dim: usize,
#[builder(default = "DType::FP16")]
pub dtype: DType,
}
impl KvManagerModelConfig {
pub fn builder() -> KvManagerModelConfigBuilder {
KvManagerModelConfigBuilder::default()
}
}
#[derive(Builder, Validate)]
#[builder(pattern = "owned", build_fn(validate = "Self::validate"))]
pub struct KvManagerLayoutConfig<S: Storage + NixlRegisterableStorage> {
/// The number of blocks to allocate
#[validate(range(min = 1))]
pub num_blocks: usize,
/// The type of layout to use
#[builder(default = "LayoutType::FullyContiguous")]
pub layout_type: LayoutType,
/// Storage for the blocks
/// If provided, the blocks will be allocated from the provided storage
/// Otherwise, the blocks will be allocated from
#[builder(default)]
pub storage: Option<Vec<S>>,
/// If provided, the blocks will be allocated from the provided allocator
/// This option is mutually exclusive with the `storage` option
#[builder(default, setter(custom))]
pub allocator: Option<Arc<dyn StorageAllocator<S>>>,
}
impl<S: Storage + NixlRegisterableStorage> KvManagerLayoutConfig<S> {
/// Create a new builder for the KvManagerLayoutConfig
pub fn builder() -> KvManagerLayoutConfigBuilder<S> {
KvManagerLayoutConfigBuilder::default()
}
}
// Implement the validation and build functions on the generated builder type
// Note: derive_builder generates KvManagerBlockConfigBuilder<S>
impl<S: Storage + NixlRegisterableStorage> KvManagerLayoutConfigBuilder<S> {
/// Custom setter for the `allocator` field
pub fn allocator(mut self, allocator: impl StorageAllocator<S> + 'static) -> Self {
self.allocator = Some(Some(Arc::new(allocator)));
self
}
// Validation function
fn validate(&self) -> Result<(), String> {
match (self.storage.is_some(), self.allocator.is_some()) {
(true, false) | (false, true) => Ok(()), // XOR condition met
(true, true) => Err("Cannot provide both `storage` and `allocator`.".to_string()),
(false, false) => Err("Must provide either `storage` or `allocator`.".to_string()),
}
}
}
/// Configuration for the KvBlockManager
#[derive(Builder, Validate)]
#[builder(pattern = "owned")]
pub struct KvBlockManagerConfig {
/// Runtime configuration
///
/// This provides core runtime configuration for the KvBlockManager.
pub runtime: KvManagerRuntimeConfig,
/// Model configuration
///
/// This provides model-specific configuration for the KvBlockManager, specifically,
/// the number of layers and the size of the inner dimension which is directly related
/// to the type of attention used by the model.
///
/// Included in this configuration is also the page_size, i.e. the number of tokens that will
/// be represented in each "paged" KV block.
pub model: KvManagerModelConfig,
/// Specific configuration for the device layout
///
/// This includes the number of blocks and the layout of the data into the device memory/storage.
#[builder(default, setter(strip_option))]
pub device_layout: Option<KvManagerLayoutConfig<DeviceStorage>>,
/// Specific configuration for the host layout
///
/// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default, setter(strip_option))]
pub host_layout: Option<KvManagerLayoutConfig<PinnedStorage>>,
}
impl KvBlockManagerConfig {
/// Create a new builder for the KvBlockManagerConfig
pub fn builder() -> KvBlockManagerConfigBuilder {
KvBlockManagerConfigBuilder::default()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment