Unverified Commit 61c67804 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: add kvbm-kernels crate and upgrade cudarc to 0.19 (#6309)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 673822ea
......@@ -982,7 +982,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
"libloading 0.8.9",
]
[[package]]
......@@ -1448,11 +1448,12 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.17.8"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8"
checksum = "aed81f178e780f3d5d354d12b4c5c5a484c4a9c329ecd037ac57f2a0e0648397"
dependencies = [
"libloading",
"half 2.7.1",
"libloading 0.9.0",
]
[[package]]
......@@ -1959,7 +1960,7 @@ dependencies = [
"mockito",
"modelexpress-client",
"modelexpress-common",
"ndarray",
"ndarray 0.16.1",
"nix 0.26.4",
"nixl-sys",
"object_store",
......@@ -2030,7 +2031,7 @@ dependencies = [
"dynamo-kv-router",
"dynamo-runtime",
"dynamo-tokens",
"ndarray",
"ndarray 0.16.1",
"ndarray-interp",
"ndarray-npy",
"rand 0.9.2",
......@@ -2848,6 +2849,9 @@ checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
dependencies = [
"cfg-if 1.0.4",
"crunchy",
"num-traits",
"rand 0.9.2",
"rand_distr",
"zerocopy",
]
......@@ -3613,9 +3617,9 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.86"
version = "0.3.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d36139f1c97c42c0c86a411910b04e48d4939a0376e6e0f989420cbdee0120e5"
checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21"
dependencies = [
"once_cell",
"wasm-bindgen",
......@@ -3868,6 +3872,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "kvbm-kernels"
version = "0.9.0"
dependencies = [
"clap 4.5.60",
"cudarc",
"half 2.7.1",
"ndarray 0.17.2",
"rand 0.9.2",
]
[[package]]
name = "kvbm-logical"
version = "0.9.0"
......@@ -3961,6 +3976,16 @@ dependencies = [
"windows-link",
]
[[package]]
name = "libloading"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
dependencies = [
"cfg-if 1.0.4",
"windows-link",
]
[[package]]
name = "libm"
version = "0.2.16"
......@@ -4460,13 +4485,28 @@ dependencies = [
"rawpointer",
]
[[package]]
name = "ndarray"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "ndarray-interp"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e43087829efb5ec2736598e88587df286425b59df5a9ce991994cdd2c5855d3f"
dependencies = [
"ndarray",
"ndarray 0.16.1",
"num-traits",
"thiserror 2.0.18",
]
......@@ -4478,7 +4518,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b313788c468c49141a9d9b6131fc15f403e6ef4e8446a0b2e18f664ddb278a9"
dependencies = [
"byteorder",
"ndarray",
"ndarray 0.16.1",
"num-complex",
"num-traits",
"py_literal",
......@@ -4750,6 +4790,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
......@@ -5951,6 +5992,16 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.2",
]
[[package]]
name = "rand_xorshift"
version = "0.4.0"
......@@ -8575,9 +8626,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff9c7baef35ac3c0e17d8bfc9ad75eb62f85a2f02bccc906699dadb0aa9c622"
checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866"
dependencies = [
"cfg-if 1.0.4",
"once_cell",
......@@ -8588,9 +8639,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.59"
version = "0.4.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d24699cd39db9966cf6e2ef10d2f72779c961ad905911f395ea201c3ec9f545d"
checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da"
dependencies = [
"cfg-if 1.0.4",
"futures-util",
......@@ -8602,9 +8653,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39455e84ad887a0bbc93c116d72403f1bb0a39e37dd6f235a43e2128a0c7f1fd"
checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
......@@ -8612,9 +8663,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dff4761f60b0b51fd13fec8764167b7bbcc34498ce3e52805fe1db6f2d56b6d6"
checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309"
dependencies = [
"bumpalo",
"proc-macro2",
......@@ -8625,9 +8676,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6a171c53d98021a93a474c4a4579d76ba97f9517d871bc12e27640f218b6dd"
checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53"
dependencies = [
"unicode-ident",
]
......@@ -8681,9 +8732,9 @@ dependencies = [
[[package]]
name = "web-sys"
version = "0.3.86"
version = "0.3.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "668fa5d00434e890a452ab060d24e3904d1be93f7bb01b70e5603baa2b8ab23b"
checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff"
dependencies = [
"js-sys",
"wasm-bindgen",
......
......@@ -10,6 +10,7 @@ members = [
"lib/mocker",
"lib/kv-router",
"lib/memory",
"lib/kvbm-kernels",
"lib/kvbm-logical",
"lib/async-openai",
"lib/parsers",
......@@ -41,6 +42,7 @@ dynamo-mocker = { path = "lib/mocker", version = "0.9.0" }
dynamo-kv-router = { path = "lib/kv-router", version = "0.9.0", features = ["metrics"] }
dynamo-async-openai = { path = "lib/async-openai", version = "0.9.0", features = ["byot"] }
dynamo-parsers = { path = "lib/parsers", version = "0.9.0" }
kvbm-kernels = { path = "lib/kvbm-kernels", version = "0.9.0" }
kvbm-logical = { path = "lib/kvbm-logical", version = "0.9.0" }
# External dependencies
......@@ -58,7 +60,7 @@ chrono = { version = "0.4", default-features = false, features = [
"now",
"serde",
] }
cudarc = { version = "0.17.8", features = ["cuda-12020"] }
cudarc = { version = "0.19.2", features = ["cuda-12020"] }
dashmap = { version = "6.1" }
derive_builder = { version = "0.20" }
derive-getters = { version = "0.5" }
......
......@@ -808,7 +808,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
"libloading 0.8.9",
]
[[package]]
......@@ -1106,20 +1106,11 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.16.6"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17200eb07e7d85a243aa1bf4569a7aa998385ba98d14833973a817a63cc86e92"
checksum = "aed81f178e780f3d5d354d12b4c5c5a484c4a9c329ecd037ac57f2a0e0648397"
dependencies = [
"libloading",
]
[[package]]
name = "cudarc"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8"
dependencies = [
"libloading",
"libloading 0.9.0",
]
[[package]]
......@@ -1564,7 +1555,7 @@ dependencies = [
"bytemuck",
"bytes",
"chrono",
"cudarc 0.17.8",
"cudarc",
"dashmap 5.5.3",
"derive-getters",
"derive_builder",
......@@ -1632,7 +1623,7 @@ name = "dynamo-memory"
version = "0.9.0"
dependencies = [
"anyhow",
"cudarc 0.17.8",
"cudarc",
"dynamo-config",
"libc",
"nix 0.30.1",
......@@ -3018,9 +3009,9 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.86"
version = "0.3.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d36139f1c97c42c0c86a411910b04e48d4939a0376e6e0f989420cbdee0120e5"
checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21"
dependencies = [
"once_cell",
"wasm-bindgen",
......@@ -3248,11 +3239,12 @@ name = "kvbm-py3"
version = "0.9.0"
dependencies = [
"anyhow",
"cudarc 0.16.6",
"cudarc",
"derive-getters",
"dlpark",
"dynamo-llm",
"dynamo-runtime",
"prometheus",
"pyo3",
"pyo3-async-runtimes",
"serde",
......@@ -3314,6 +3306,16 @@ dependencies = [
"windows-link",
]
[[package]]
name = "libloading"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
dependencies = [
"cfg-if 1.0.4",
"windows-link",
]
[[package]]
name = "libm"
version = "0.2.16"
......@@ -7400,9 +7402,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff9c7baef35ac3c0e17d8bfc9ad75eb62f85a2f02bccc906699dadb0aa9c622"
checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866"
dependencies = [
"cfg-if 1.0.4",
"once_cell",
......@@ -7413,9 +7415,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.59"
version = "0.4.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d24699cd39db9966cf6e2ef10d2f72779c961ad905911f395ea201c3ec9f545d"
checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da"
dependencies = [
"cfg-if 1.0.4",
"futures-util",
......@@ -7427,9 +7429,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39455e84ad887a0bbc93c116d72403f1bb0a39e37dd6f235a43e2128a0c7f1fd"
checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
......@@ -7437,9 +7439,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dff4761f60b0b51fd13fec8764167b7bbcc34498ce3e52805fe1db6f2d56b6d6"
checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309"
dependencies = [
"bumpalo",
"proc-macro2",
......@@ -7450,9 +7452,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.109"
version = "0.2.110"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6a171c53d98021a93a474c4a4579d76ba97f9517d871bc12e27640f218b6dd"
checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53"
dependencies = [
"unicode-ident",
]
......@@ -7506,9 +7508,9 @@ dependencies = [
[[package]]
name = "web-sys"
version = "0.3.86"
version = "0.3.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "668fa5d00434e890a452ab060d24e3904d1be93f7bb01b70e5603baa2b8ab23b"
checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff"
dependencies = [
"js-sys",
"wasm-bindgen",
......
......@@ -56,6 +56,7 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features =
] }
dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
cudarc = { version = "0.19.2", features = ["cuda-12020"], optional = true }
prometheus = "0.14.0"
[dev-dependencies]
......@@ -826,7 +826,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
"libloading 0.8.9",
]
[[package]]
......@@ -1124,11 +1124,11 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.17.8"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8"
checksum = "aed81f178e780f3d5d354d12b4c5c5a484c4a9c329ecd037ac57f2a0e0648397"
dependencies = [
"libloading",
"libloading 0.9.0",
]
[[package]]
......@@ -3344,6 +3344,16 @@ dependencies = [
"windows-link",
]
[[package]]
name = "libloading"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
dependencies = [
"cfg-if 1.0.4",
"windows-link",
]
[[package]]
name = "libm"
version = "0.2.16"
......
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## What This Is
kvbm-kernels is a high-performance CUDA transfer library for batched H2D, D2H, and D2D block copies used by the Dynamo KV cache system. The core API (`vectorized_copy`, `memcpy_batch`) is always available and handles the common case of moving KV cache blocks between host and device without layout changes. Fused permute-and-copy kernels for layout conversion between **Block Stack** (vLLM) and **Universal** (Dynamo storage) formats are feature-gated behind `permute_kernels`.
## Build Commands
```bash
# Default build (auto-detects nvcc -> source build; no nvcc -> stubs)
cargo build
# Build from source with custom GPU architectures
CUDA_ARCHS="80,86,89,90,100" cargo build
# Static linking (embed kernels into binary instead of .so)
cargo build --features static-kernels
# Check compilation without linking
cargo check
# Run CUDA integration tests for core transfer APIs (requires GPU + nvcc)
cargo test --features testing-cuda
# Run all CUDA integration tests including permute kernels
cargo test --features testing-cuda,permute_kernels
# Run a specific test
cargo test --features testing-cuda,permute_kernels fused_copy_roundtrip -- --nocapture --test-threads=1
# Run benchmarks (Llama 3.1 70B KV cache profile)
cargo run --example kvbench --features kvbench
```
**Environment variables**: `CUDA_ARCHS` (comma-separated SM versions), `CUDA_PTX_ARCHS` (PTX targets), `KVBM_REQUIRE_CUDA` (fail if nvcc missing), `CUDA_PATH`/`CUDA_HOME`.
## Architecture
### Two-tier build system (`build.rs`)
The build script selects one of two modes: **FromSource** (nvcc available, compiles CUDA, requires CUDA >= 12.0) or **Stubs** (no nvcc, C stubs that abort on call). Stubs set the `stub_kernels` cfg flag so tests can be conditionally skipped.
### Core transfer API (always available)
These live in `src/tensor_kernels.rs` and work on any device-visible memory (device allocations or pinned host via unified addressing):
- **`vectorized_copy`** — Batched copy of `(src, dst)` pointer pairs. Per-pair runtime alignment detection selects the widest safe vector width (int4/int2/int/char for 16/8/4/1-byte loads).
- **`memcpy_batch`** — Takes HOST arrays of src/dst pointers. Dispatches to `cudaMemcpyBatchAsync` (CUDA 12.9+) with fallback to individual `cudaMemcpyAsync` loop. Three modes: `BatchedWithFallback`, `FallbackOnly`, `BatchWithoutFallback`.
- **`is_using_stubs`** / **`is_memcpy_batch_available`** — Runtime capability queries.
### Permute kernels (feature-gated: `permute_kernels`)
These fuse layout permutation with copy for non-standard transfer paths:
- **`universal_from_block`** / **`block_from_universal`** — Permute between block stack layout (`nl*no` separate allocations, each `[nt, nh, hd]` NHD or `[nh, nt, hd]` HND) and universal layout (contiguous `[nh, nl, no, nt, hd]`).
### Source organization
- `cuda/tensor_kernels.cu` — All CUDA kernels. C++ templates on dtype (F16/BF16/F32/F64) and layout (NHD/HND), exposed via `extern "C"` functions prefixed `kvbm_kernels_launch_*` / `kvbm_kernels_memcpy_batch`.
- `cuda/stubs.c` — Abort-on-call fallbacks for all `extern "C"` symbols.
- `src/tensor_kernels.rs` — Rust FFI wrappers, enums (`TensorDataType`, `BlockLayout`, `MemcpyBatchMode`), and integration tests.
- `examples/kvbench.rs` — Benchmark harness (Llama 3.1 70B profile, CSV output).
- `scripts/plot_roofline.py` — Roofline bandwidth plots from kvbench output.
### Dimension conventions
`nl` = layers, `no` = outer chunks (2: K and V), `nh` = attention heads, `nt` = tokens per block, `hd` = head dimension.
### Pointer conventions
All pointer-list parameters (e.g. `universal_ptrs`, `src_ptrs`) must be device-accessible: allocated via `cudaMalloc` (device memory) or `cudaMallocHost` / `cuMemHostRegister` (pinned/registered/page-locked host memory).
### Cargo features
| Feature | Purpose |
|---------|---------|
| `permute_kernels` | Enable fused permute-and-copy kernels (block<->universal) |
| `testing-cuda` | Enable CUDA integration tests |
| `static-kernels` | Link as `.a` instead of `.so` |
| `kvbench` | Enable benchmark example (pulls in `clap`) |
### Test organization
- `tests/stub_build.rs` — Verifies stub behavior (gated on `stub_kernels`).
- `tests/memcpy_batch.rs` — Core transfer API roundtrip tests (H2D + D2H via pinned host memory). Gated on `testing-cuda`.
- `tests/kernel_roundtrip.rs` — Permute kernel roundtrip tests across all dtypes and layouts. Gated on `testing-cuda` + `permute_kernels`.
- Inline tests in `src/tensor_kernels.rs` — Integration tests including `universal_roundtrip`. Gated on `testing-cuda` + `permute_kernels`.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "kvbm-kernels"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository = "https://github.com/ai-dynamo/dynamo.git"
build = "build.rs"
[lib]
name = "kvbm_kernels"
crate-type = ["rlib", "cdylib"]
[features]
default = []
# Build kernels as a static archive (.a) instead of shared library (.so).
# When enabled, the kernel code is embedded directly into the consuming crate,
# eliminating the runtime dependency on libkvbm_kernels.so.
# Note: This only affects real CUDA builds; stubs always remain dynamic.
static-kernels = []
# Enable CUDA tests - only works when real CUDA kernels are built (not stubs)
# Tests are gated with #[cfg(all(test, feature = "testing-cuda", not(stub_kernels)))]
testing-cuda = []
# Enable operational_copy, universal_from_block, block_from_universal kernels.
# These kernels perform data layout permutation and are only needed for
# non-standard transfer paths. The default vectorized_copy kernel handles
# most FC↔LW transfers efficiently without permutation.
permute_kernels = []
# Enable kvbench example (pulls in clap for CLI)
kvbench = ["dep:clap"]
[[example]]
name = "kvbench"
required-features = ["kvbench"]
[dependencies]
cudarc = { workspace = true }
# kvbench
clap = { version = "4", features = ["derive"], optional = true }
[dev-dependencies]
ndarray = "0.17.2"
half = "2"
rand = { workspace = true }
cudarc = { workspace = true, features = ["f16"] }
[build-dependencies]
## Dynamo KV Block Manager Kernels
GPU kernels for converting KV cache blocks between three memory layouts used by LLM inference frameworks. All conversions run entirely on-device via fused CUDA kernels.
### Dimensions
| Symbol | Meaning | Example |
|--------|--------------------------------|------------------|
| `nb` | Number of blocks in the batch | 1–128 |
| `nl` | Number of layers | 32 (Llama-70B) |
| `no` | Outer chunks (K and V) | 2 |
| `nh` | Number of attention heads | 32 or 64 |
| `nt` | Tokens per block | 128 or 256 |
| `hd` | Head dimension | 128 |
### Layouts
#### Block Stack (NHD or HND)
`nl * no` separate GPU allocations per block. Each allocation holds one layer's keys or values.
- **NHD shape**: `[nt, nh, hd]` — index: `(nt_idx * nh + nh_idx) * hd + hd_idx`
- **HND shape**: `[nh, nt, hd]` — index: `(nh_idx * nt + nt_idx) * hd + hd_idx`
Passed to kernels as a flat pointer table of length `nb * nl * no`.
#### Operational
Single contiguous buffer per block: `[nl, no, inner]` where `inner = nt * nh * hd`.
The three innermost dimensions (`nt`, `nh`, `hd`) are fused into one `inner` dimension. When no layout permutation is needed (same TP config, same head layout), block-to-operational is a flat copy — the cheapest conversion. Transforming to/from other layouts requires knowing the constituent dimensions.
#### Universal
Single contiguous buffer per block: `[nh, nl, no, nt, hd]`.
Heads are the outermost dimension so that tensor-parallelism resharding is a contiguous slice along `nh`. A block saved from a TP=4 deployment can be loaded into TP=8 by slicing the head dimension differently.
### Layout Cheat Sheet
| Layout | Logical Shape | Stored As | Notes |
|---------------------|----------------------------|------------------------------------|-------------------------------|
| NHD block stack | `[nl][no][nt, nh, hd]` | list of `nl * no` pointers | Inner layout = NHD |
| HND block stack | `[nl][no][nh, nt, hd]` | list of `nl * no` pointers | Inner layout = HND |
| Operational block | `[nl, no, inner]` | contiguous buffer per block | `inner = nt * nh * hd` |
| Universal block | `[nh, nl, no, nt, hd]` | contiguous buffer per block | Heads outermost for TP slicing |
### Kernel Functions
All kernels are batched: a single launch processes `nb` blocks from flat pointer tables prepared by host code.
#### Layout permutation kernels
| C API | Conversion |
|----------------------------------------------|-----------------------------|
| `kvbm_kernels_launch_universal_from_block` | Block stack → Universal |
| `kvbm_kernels_launch_block_from_universal` | Universal → Block stack |
Both accept `layout_value` (NHD=0, HND=1) and `dtype_value` (F16=0, BF16=1, F32=2, F64=3). Internally dispatched to C++ template kernels specialized on dtype and layout.
#### Standalone copy utilities
| C API | Description |
|------------------------------------------|----------------------------------------------------------|
| `kvbm_kernels_launch_vectorized_copy` | Adaptive vectorized copy (16/8/4-byte or scalar) across `num_pairs` pointer pairs |
| `kvbm_kernels_memcpy_batch` | Batched `cudaMemcpyAsync` from host pointer arrays |
| `kvbm_kernels_has_memcpy_batch_async` | Returns `true` if `cudaMemcpyBatchAsync` is available |
| `kvbm_kernels_is_stub_build` | Returns `true` if built without CUDA (stub mode) |
### Python Bindings (Planned)
Python kernel bindings are not yet implemented. The `lib/bindings/kvbm/` crate currently exposes block manager functionality only. Future work will add Python wrappers for the permute and copy kernels.
### Development
```bash
# Default build (auto-detects nvcc → source; no nvcc → stubs)
cargo build
# Custom GPU architectures
CUDA_ARCHS="80,86,89,90,100" cargo build
# Static linking
cargo build --features static-kernels
# Run CUDA integration tests (requires GPU + nvcc)
cargo test --features testing-cuda,permute_kernels
# Specific test with output
cargo test --features testing-cuda,permute_kernels fused_copy_roundtrip -- --nocapture
# Python bindings
cd lib/bindings/kvbm
uv pip install -e ".[dev]"
pytest tests/
```
**Environment variables**: `CUDA_ARCHS` (comma-separated SM versions, default `80,86,89,90,100,120`), `CUDA_PATH`/`CUDA_HOME` (toolkit root), `KVBM_REQUIRE_CUDA` (fail build if nvcc missing).
### Benchmarking
```text
root@9eb240f7ded8:/workspace/lib/kvbm-kernels# cargo run --release --example kvbench --features testing-cuda,kvbench -- --num-blocks=1,128 --tokens-per-block=16,64 --
backend vectorized,batched --direction h2d
...
Running `/workspace/target/release/examples/kvbench --num-blocks=1,128 --tokens-per-block=16,64 --backend vectorized,batched --direction h2d`
KV Cache Transfer Benchmark
Model: Llama 3.1 70B (bf16)
Layers: 80, KV heads: 8, Head dim: 128, Outer dim: 2
Warmup: 10, Timed: 100
Batch API available: true
tokens_per_block: [16, 64]
num_blocks: [1, 128]
directions: [h2d]
patterns: [fc_to_fc, lw_to_fc]
backends: [vectorized, batched]
Total tests: 16
tokens_per_block,num_blocks,pattern,direction,backend,total_bytes,inner_bytes,copy_size,num_copies,median_ms,bandwidth_gbps
--- tokens_per_block=16, inner=32768 bytes (32 KB), block=5242880 bytes (5.0 MB) ---
[1/16] tpb=16 N= 1 fc_to_fc h2d vectorized ... 16,1,fc_to_fc,h2d,vectorized,5242880,32768,5242880,1,1.8686,2.81
2.81 GB/s (1.8686 ms)
[2/16] tpb=16 N= 1 fc_to_fc h2d batched ... 16,1,fc_to_fc,h2d,batched,5242880,32768,5242880,1,0.2105,24.91
24.91 GB/s (0.2105 ms)
[3/16] tpb=16 N= 1 lw_to_fc h2d vectorized ... 16,1,lw_to_fc,h2d,vectorized,5242880,32768,32768,160,0.2171,24.15
24.15 GB/s (0.2171 ms)
[4/16] tpb=16 N= 1 lw_to_fc h2d batched ... 16,1,lw_to_fc,h2d,batched,5242880,32768,32768,160,0.2775,18.89
18.89 GB/s (0.2775 ms)
[5/16] tpb=16 N=128 fc_to_fc h2d vectorized ... 16,128,fc_to_fc,h2d,vectorized,671088640,32768,5242880,128,26.6097,25.22
25.22 GB/s (26.6097 ms)
[6/16] tpb=16 N=128 fc_to_fc h2d batched ... 16,128,fc_to_fc,h2d,batched,671088640,32768,5242880,128,26.6180,25.21
25.21 GB/s (26.6180 ms)
[7/16] tpb=16 N=128 lw_to_fc h2d vectorized ... 16,128,lw_to_fc,h2d,vectorized,671088640,32768,32768,20480,26.6034,25.23
25.23 GB/s (26.6034 ms)
[8/16] tpb=16 N=128 lw_to_fc h2d batched ... 16,128,lw_to_fc,h2d,batched,671088640,32768,32768,20480,30.3346,22.12
22.12 GB/s (30.3346 ms)
--- tokens_per_block=64, inner=131072 bytes (128 KB), block=20971520 bytes (20.0 MB) ---
[9/16] tpb=64 N= 1 fc_to_fc h2d vectorized ... 64,1,fc_to_fc,h2d,vectorized,20971520,131072,20971520,1,7.5837,2.77
2.77 GB/s (7.5837 ms)
[10/16] tpb=64 N= 1 fc_to_fc h2d batched ... 64,1,fc_to_fc,h2d,batched,20971520,131072,20971520,1,0.8334,25.16
25.16 GB/s (0.8334 ms)
[11/16] tpb=64 N= 1 lw_to_fc h2d vectorized ... 64,1,lw_to_fc,h2d,vectorized,20971520,131072,131072,160,0.8407,24.95
24.95 GB/s (0.8407 ms)
[12/16] tpb=64 N= 1 lw_to_fc h2d batched ... 64,1,lw_to_fc,h2d,batched,20971520,131072,131072,160,0.9020,23.25
23.25 GB/s (0.9020 ms)
[13/16] tpb=64 N=128 fc_to_fc h2d vectorized ... 64,128,fc_to_fc,h2d,vectorized,2684354560,131072,20971520,128,106.3677,25.24
25.24 GB/s (106.3677 ms)
[14/16] tpb=64 N=128 fc_to_fc h2d batched ... 64,128,fc_to_fc,h2d,batched,2684354560,131072,20971520,128,106.3199,25.25
25.25 GB/s (106.3199 ms)
[15/16] tpb=64 N=128 lw_to_fc h2d vectorized ... 64,128,lw_to_fc,h2d,vectorized,2684354560,131072,131072,20480,106.3158,25.25
25.25 GB/s (106.3158 ms)
[16/16] tpb=64 N=128 lw_to_fc h2d batched ... 64,128,lw_to_fc,h2d,batched,2684354560,131072,131072,20480,110.0665,24.39
24.39 GB/s (110.0665 ms)
Done.
```
### Troubleshooting
| Symptom | Likely Cause / Fix |
|---------------------------------------|--------------------------------------------------------------------|
| `cudaErrorInvalidValue` on launch | Pointer counts mismatch (`nb`, `nl`, `no`) or non-contiguous input |
| Wrong values when using HND layout | Inner tensors not shaped as `[nh, nt, hd]` before passing in |
| Python bindings complain about dtype | Mixed precision in a batch; convert tensors to a common dtype |
| Kernels take unexpected time | Verify that `CUDA_ARCHS` matches your GPU to avoid JIT at runtime |
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::env;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Declare the stub_kernels cfg so Rust knows it's a valid cfg option
println!("cargo:rustc-check-cfg=cfg(stub_kernels)");
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let out_dir = env::var("OUT_DIR").unwrap();
// Track file changes
let cu_files = discover_cuda_files();
for file in &cu_files {
println!("cargo:rerun-if-changed={}", file.display());
}
println!(
"cargo:rerun-if-changed={}",
Path::new(&manifest_dir).join("cuda/stubs.c").display()
);
println!("cargo:rerun-if-env-changed=CUDA_ARCHS");
println!("cargo:rerun-if-env-changed=CUDA_PTX_ARCHS");
println!("cargo:rerun-if-env-changed=KVBM_REQUIRE_CUDA");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_HOME");
// Check if CUDA is required (set by Python bindings build)
let require_cuda = env::var("KVBM_REQUIRE_CUDA").is_ok();
let nvcc_available = is_nvcc_available();
// Fail early if CUDA required but not available
if require_cuda && !nvcc_available {
panic!(
"\n\n\
╔════════════════════════════════════════════════════════════════════════╗\n\
║ KVBM_REQUIRE_CUDA is set but nvcc is not available! ║\n\
║ ║\n\
║ Python bindings require real CUDA kernels. Please: ║\n\
║ 1. Install CUDA toolkit with nvcc, or ║\n\
║ 2. Unset KVBM_REQUIRE_CUDA for stub-only build ║\n\
╚════════════════════════════════════════════════════════════════════════╝\n\
"
);
}
// Determine build mode
let build_mode = determine_build_mode(nvcc_available);
// Check if static linking is requested (only applies to CUDA builds, not stubs)
#[cfg(feature = "static-kernels")]
let use_static = true;
#[cfg(not(feature = "static-kernels"))]
let use_static = false;
match build_mode {
BuildMode::FromSource => {
if use_static {
println!("cargo:warning=Building CUDA kernels from source (static linking)");
} else {
println!("cargo:warning=Building CUDA kernels from source (dynamic linking)");
}
build_cuda_library(&cu_files, &out_dir, use_static);
}
BuildMode::Stubs => {
// Stubs always use dynamic linking regardless of static-kernels feature
println!("cargo:warning=Building stub kernels (no CUDA available, dynamic linking)");
build_stub_shared_library(&manifest_dir, &out_dir);
// Set cfg flag so tests can be skipped
println!("cargo:rustc-cfg=stub_kernels");
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum BuildMode {
FromSource,
Stubs,
}
/// Determine the build mode based on nvcc availability.
fn determine_build_mode(nvcc_available: bool) -> BuildMode {
if nvcc_available {
BuildMode::FromSource
} else {
BuildMode::Stubs
}
}
fn is_nvcc_available() -> bool {
Command::new("nvcc").arg("--version").output().is_ok()
}
/// Build CUDA kernels from source.
/// If `use_static` is true, builds a static archive (.a); otherwise builds a shared library (.so).
fn build_cuda_library(cu_files: &[PathBuf], out_dir: &str, use_static: bool) {
let arch_flags = get_cuda_arch_flags();
// Only build tensor_kernels.cu into the library (it has the extern "C" functions)
let tensor_kernels_path = cu_files
.iter()
.find(|p| p.file_stem().unwrap() == "tensor_kernels")
.expect("tensor_kernels.cu not found");
let obj_path = Path::new(out_dir).join("kvbm_kernels.o");
// Step 1: Compile to object file
let mut nvcc_cmd = Command::new("nvcc");
nvcc_cmd
.arg("-m64")
.arg("-c")
.arg("-std=c++17")
.arg("-O3")
.arg("-Xcompiler")
.arg("-fPIC")
.arg(tensor_kernels_path)
.arg("-o")
.arg(&obj_path);
for flag in &arch_flags {
nvcc_cmd.arg(flag);
}
println!("cargo:warning=Compiling tensor_kernels.cu to object file...");
let status = nvcc_cmd
.status()
.expect("Failed to execute nvcc for object file");
if !status.success() {
panic!("nvcc failed to compile tensor_kernels.cu");
}
if use_static {
// Step 2a: Create static archive
let ar_path = Path::new(out_dir).join("libkvbm_kernels.a");
let mut ar_cmd = Command::new("ar");
ar_cmd.arg("crus").arg(&ar_path).arg(&obj_path);
println!("cargo:warning=Creating static archive libkvbm_kernels.a...");
let status = ar_cmd
.status()
.expect("Failed to execute ar for static archive");
if !status.success() {
panic!("ar failed to create static archive");
}
// Set up static linking
println!("cargo:rustc-link-search=native={}", out_dir);
println!("cargo:rustc-link-lib=static=kvbm_kernels");
// Add CUDA runtime library paths and link cudart dynamically
add_cuda_library_paths();
println!("cargo:rustc-link-lib=cudart");
// CUDA object code compiled by nvcc contains C++ runtime symbols
// (operator new/delete, __gxx_personality_v0, etc.)
println!("cargo:rustc-link-lib=stdc++");
} else {
// Step 2b: Link into shared library
let so_path = Path::new(out_dir).join("libkvbm_kernels.so");
let mut link_cmd = Command::new("nvcc");
link_cmd
.arg("-shared")
.arg("-o")
.arg(&so_path)
.arg(&obj_path)
.arg("-lcudart");
println!("cargo:warning=Linking kvbm_kernels into shared library...");
let status = link_cmd
.status()
.expect("Failed to execute nvcc for linking");
if !status.success() {
panic!("nvcc failed to link shared library");
}
// Set up dynamic linking
println!("cargo:rustc-link-search=native={}", out_dir);
println!("cargo:rustc-link-lib=dylib=kvbm_kernels");
// Add CUDA runtime library paths
add_cuda_library_paths();
println!("cargo:rustc-link-lib=cudart");
}
}
/// Build stub shared library from stubs.c when CUDA is not available.
fn build_stub_shared_library(manifest_dir: &str, out_dir: &str) {
let stubs_path = Path::new(manifest_dir).join("cuda/stubs.c");
if !stubs_path.exists() {
panic!(
"Stub source file not found at {}. Cannot build without CUDA.",
stubs_path.display()
);
}
// Build shared library from stubs.c using the system C compiler
let so_path = Path::new(out_dir).join("libkvbm_kernels.so");
let obj_path = Path::new(out_dir).join("stubs.o");
// Compile to object file
let mut gcc_compile = Command::new("cc");
gcc_compile
.arg("-c")
.arg("-fPIC")
.arg("-O2")
.arg(&stubs_path)
.arg("-o")
.arg(&obj_path);
println!("cargo:warning=Compiling stubs.c...");
let status = gcc_compile
.status()
.expect("Failed to execute cc for stubs");
if !status.success() {
panic!("Failed to compile stubs.c");
}
// Link into shared library
let mut gcc_link = Command::new("cc");
gcc_link
.arg("-shared")
.arg("-o")
.arg(&so_path)
.arg(&obj_path);
println!("cargo:warning=Linking stub shared library...");
let status = gcc_link.status().expect("Failed to link stub library");
if !status.success() {
panic!("Failed to link stub shared library");
}
// Set up linking
println!("cargo:rustc-link-search=native={}", out_dir);
println!("cargo:rustc-link-lib=dylib=kvbm_kernels");
}
fn discover_cuda_files() -> Vec<PathBuf> {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let cuda_dir = Path::new(&manifest_dir).join("cuda");
let mut cu_files = Vec::new();
for entry in fs::read_dir(cuda_dir).expect("Failed to read cuda directory") {
let entry = entry.expect("Failed to read entry");
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "cu") {
cu_files.push(path);
}
}
cu_files
}
/// Parse CUDA toolkit version from `nvcc --version` output.
/// Returns (major, minor) tuple, e.g. (12, 8) for CUDA 12.8.
fn parse_cuda_version() -> Option<(u32, u32)> {
let output = Command::new("nvcc").arg("--version").output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
// nvcc output contains a line like: "Cuda compilation tools, release 12.8, V12.8.89"
for line in stdout.lines() {
if let Some(pos) = line.find("release ") {
let after = &line[pos + "release ".len()..];
let version_str = after.split(',').next().unwrap_or("").trim();
let mut parts = version_str.split('.');
let major = parts.next()?.parse::<u32>().ok()?;
let minor = parts.next()?.parse::<u32>().ok()?;
return Some((major, minor));
}
}
None
}
/// Return the maximum supported compute capability for a given CUDA toolkit version.
///
/// Panics if the CUDA version is below 12.0.
fn max_supported_compute(cuda_version: (u32, u32)) -> u32 {
match cuda_version {
(major, _) if major < 12 => {
panic!("CUDA {major}.x is not supported; CUDA 12.0 or newer is required")
}
(12, minor) if minor >= 8 => 120,
(major, _) if major >= 13 => 120,
_ => 90,
}
}
fn get_cuda_arch_flags() -> Vec<String> {
let mut flags = Vec::new();
let cuda_version = parse_cuda_version();
let max_compute = cuda_version.map(max_supported_compute);
if let Some((major, minor)) = cuda_version {
println!(
"cargo:warning=Detected CUDA {}.{}, max supported compute: sm_{}",
major,
minor,
max_compute.unwrap()
);
} else {
println!("cargo:warning=Could not detect CUDA version, including all architectures");
}
let explicit_archs = env::var("CUDA_ARCHS").ok();
let arch_list = explicit_archs.as_deref().unwrap_or("80,86,89,90,100,120");
for arch in arch_list.split(',') {
let arch = arch.trim();
if arch.is_empty() {
continue;
}
let arch_num: u32 = match arch.parse() {
Ok(n) => n,
Err(_) => {
println!("cargo:warning=Skipping invalid CUDA_ARCHS entry: {}", arch);
continue;
}
};
if let Some(max) = max_compute
&& arch_num > max
{
println!(
"cargo:warning=Skipping sm_{} (unsupported by detected CUDA toolkit, max: sm_{})",
arch_num, max
);
continue;
}
flags.push(format!("-gencode=arch=compute_{},code=sm_{}", arch, arch));
}
// Generate forward-compatible PTX for each major architecture family that is
// both present in the arch list and supported by the detected CUDA toolkit.
let ptx_archs_env = env::var("CUDA_PTX_ARCHS").ok();
let ptx_candidates: Vec<u32> = if let Some(ref ptx_env) = ptx_archs_env {
ptx_env
.split(',')
.filter_map(|s| s.trim().parse::<u32>().ok())
.collect()
} else {
vec![90, 100, 120]
};
for &ptx_arch in &ptx_candidates {
if let Some(max) = max_compute
&& ptx_arch > max
{
continue;
}
flags.push(format!(
"-gencode=arch=compute_{},code=compute_{}",
ptx_arch, ptx_arch
));
}
flags
}
fn add_cuda_library_paths() {
if let Ok(cuda_path) = env::var("CUDA_PATH") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-search=native={}/lib", cuda_path);
} else if let Ok(cuda_home) = env::var("CUDA_HOME") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_home);
println!("cargo:rustc-link-search=native={}/lib", cuda_home);
} else {
// Try standard paths
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Stub implementations for CUDA kernel functions.
// These are used when nvcc is not available, allowing the library to be built
// without CUDA. The stubs abort() when called, but the binary can be moved to
// an environment with the real .so and work correctly via LD_LIBRARY_PATH.
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
// cudaError_t equivalent - cudaSuccess = 0
typedef int cudaError_t;
// cudaStream_t is an opaque pointer
typedef void* cudaStream_t;
#define STUB_ABORT(name) \
do { \
fprintf( \
stderr, \
"FATAL: %s called but CUDA kernels not available.\n" \
"This binary was built with stub kernels. To use CUDA:\n" \
" 1. Build with nvcc available, or\n" \
" 2. Set LD_LIBRARY_PATH to include real libkvbm_kernels.so\n", \
name); \
abort(); \
} while (0)
cudaError_t
kvbm_kernels_launch_universal_from_block(
void* const* universal_ptrs, const void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
(void)universal_ptrs;
(void)block_ptrs;
(void)num_blocks;
(void)nh;
(void)nl;
(void)no;
(void)nt;
(void)hd;
(void)dtype_value;
(void)layout_value;
(void)stream;
STUB_ABORT("kvbm_kernels_launch_universal_from_block");
return 1; // Unreachable, but silences compiler warning
}
cudaError_t
kvbm_kernels_launch_block_from_universal(
const void* const* universal_ptrs, void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
(void)universal_ptrs;
(void)block_ptrs;
(void)num_blocks;
(void)nh;
(void)nl;
(void)no;
(void)nt;
(void)hd;
(void)dtype_value;
(void)layout_value;
(void)stream;
STUB_ABORT("kvbm_kernels_launch_block_from_universal");
return 1; // Unreachable
}
cudaError_t
kvbm_kernels_launch_vectorized_copy(
void** src_ptrs, void** dst_ptrs, size_t copy_size_bytes, int num_pairs, cudaStream_t stream)
{
(void)src_ptrs;
(void)dst_ptrs;
(void)copy_size_bytes;
(void)num_pairs;
(void)stream;
STUB_ABORT("kvbm_kernels_launch_vectorized_copy");
return 1; // Unreachable
}
// This function is safe to call even with stubs - it just returns false
// indicating that batch async is not available.
bool
kvbm_kernels_has_memcpy_batch_async(void)
{
return false;
}
// Stub for memcpy_batch - returns not supported since we can't do CUDA ops
cudaError_t
kvbm_kernels_memcpy_batch(
const void* const* src_ptrs, void* const* dst_ptrs, size_t size_per_copy, size_t num_copies, int mode,
cudaStream_t stream)
{
(void)src_ptrs;
(void)dst_ptrs;
(void)size_per_copy;
(void)num_copies;
(void)mode;
(void)stream;
STUB_ABORT("kvbm_kernels_memcpy_batch");
return 1; // Unreachable
}
// Returns true if this is the stub library (no real CUDA kernels).
// Downstream crates can use this to skip CUDA tests at runtime.
bool
kvbm_kernels_is_stub_build(void)
{
return true;
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <type_traits>
#include <vector>
// Compile-time CUDA version detection and diagnostics
#if defined(CUDART_VERSION)
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
#if CUDART_VERSION >= 13000
#elif CUDART_VERSION >= 12090
#else
#pragma message("Building with CUDA " TOSTRING(CUDART_VERSION) " - cudaMemcpyBatchAsync NOT available (requires 12.9+)")
#endif
#else
#pragma message("Warning: CUDART_VERSION not defined - cannot detect CUDA version")
#endif
#ifndef CUDA_CALLABLE_MEMBER
#define CUDA_CALLABLE_MEMBER __host__ __device__
#endif
namespace {
/**
* There are three logical tensor views involved in these kernels:
*
* 1. Universal blocks: contiguous buffers whose logical shape is
* [nh, nl, no, nt, hd]. Every “block” is a separate pointer.
* 2. NHD/HND block stacks: `nl * no` pointers per block, each pointing
* to a chunk shaped either [nt, nh, hd] (NHD) or [nh, nt, hd] (HND).
* Stacks are arranged as `[layer][outer]`.
* 3. Operational blocks: contiguous buffers whose logical shape is
* [nl, no, inner], where inner = nt * nh * hd. These are used when
* the consumer does not care about the split between nh/nt/hd.
*
* Each kernel batch-processes `num_blocks` block pairs. All pointer
* tables are flattened on the host:
* • universal_ptrs : [num_blocks]
* • block_ptrs : [num_blocks * nl * no]
* • operational_ptrs : [num_blocks]
*
* All pointer-list parameters (e.g. `universal_ptrs`, `src_ptrs`) must be
* device-accessible: allocated via cudaMalloc (device memory) or
* cudaMallocHost / cuMemHostRegister (pinned/registered/page-locked host memory).
*
* This lets us launch a single grid per direction, keeps the per-block
* math regular, and avoids any per-kernel pointer chasing on the CPU.
*/
enum class TensorDataType : int {
F16 = 0,
BF16 = 1,
F32 = 2,
F64 = 3,
};
enum class BlockLayout : int {
NHD = 0,
HND = 1,
};
template <TensorDataType>
struct DTypeTraits;
template <>
struct DTypeTraits<TensorDataType::F16> {
using type = __half;
};
template <>
struct DTypeTraits<TensorDataType::BF16> {
using type = __nv_bfloat16;
};
template <>
struct DTypeTraits<TensorDataType::F32> {
using type = float;
};
template <>
struct DTypeTraits<TensorDataType::F64> {
using type = double;
};
template <typename T>
CUDA_CALLABLE_MEMBER inline T*
ptr_offset(T* base, size_t index)
{
return base + index;
}
template <typename T>
CUDA_CALLABLE_MEMBER inline const T*
ptr_offset(const T* base, size_t index)
{
return base + index;
}
template <BlockLayout Layout>
CUDA_CALLABLE_MEMBER inline size_t
block_inner_offset(size_t nt_idx, size_t nh_idx, size_t hd_idx, size_t nt, size_t nh, size_t hd)
{
if constexpr (Layout == BlockLayout::NHD) {
return ((nt_idx * nh) + nh_idx) * hd + hd_idx;
} else {
return ((nh_idx * nt) + nt_idx) * hd + hd_idx;
}
}
// Choose a conservative grid size so every thread handles a roughly equal
// share of the work even when the total element count spans many blocks.
inline int
kvbm_kernels_compute_grid_dim(size_t total_elements, int block_dim)
{
if (total_elements == 0) {
return 0;
}
size_t blocks = (total_elements + static_cast<size_t>(block_dim) - 1) / static_cast<size_t>(block_dim);
if (blocks == 0) {
blocks = 1;
}
blocks = std::min<size_t>(blocks, 65535);
return static_cast<int>(blocks);
}
// Returns the log2 shift amount if x is a non-zero power of 2, otherwise -1.
// Used on the host side to pre-compute whether kernel divisors can use
// cheap bit-shift/mask operations instead of expensive integer division.
// Example: po2_shift(64) returns 6, po2_shift(48) returns -1.
inline int
kvbm_kernels_po2_shift(size_t x)
{
if (x == 0 || (x & (x - 1)) != 0)
return -1;
return __builtin_ctzll(static_cast<unsigned long long>(x));
}
// Flatten the [nh, nl, no, nt, hd] coordinates into a linear index so a single
// launch can cover many independent blocks in one pass.
template <typename T, BlockLayout Layout>
__global__ void
kvbm_kernels_block_to_universal_kernel(
const T* const* block_chunks, T* const* universal_blocks, size_t block_stride, size_t total_per_block,
size_t num_blocks, size_t nh, size_t nl, size_t no, size_t nt, size_t hd)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t tmp = residual;
size_t hd_idx = tmp % hd;
tmp /= hd;
size_t nt_idx = tmp % nt;
tmp /= nt;
size_t no_idx = tmp % no;
tmp /= no;
size_t nl_idx = tmp % nl;
tmp /= nl;
size_t nh_idx = tmp;
const T* const* block_base = block_chunks + block_idx * block_stride;
const T* chunk_base = block_base[nl_idx * no + no_idx];
size_t chunk_offset = block_inner_offset<Layout>(nt_idx, nh_idx, hd_idx, nt, nh, hd);
T* universal_base = universal_blocks[block_idx];
universal_base[residual] = chunk_base[chunk_offset];
thread_id += stride;
}
}
// The inverse of kvbm_kernels_block_to_universal_kernel: peel apart the same linear index
// and scatter back into the layer/outer stacks.
template <typename T, BlockLayout Layout>
__global__ void
kvbm_kernels_universal_to_block_kernel(
const T* const* universal_blocks, T* const* block_chunks, size_t block_stride, size_t total_per_block,
size_t num_blocks, size_t nh, size_t nl, size_t no, size_t nt, size_t hd)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t tmp = residual;
size_t hd_idx = tmp % hd;
tmp /= hd;
size_t nt_idx = tmp % nt;
tmp /= nt;
size_t no_idx = tmp % no;
tmp /= no;
size_t nl_idx = tmp % nl;
tmp /= nl;
size_t nh_idx = tmp;
T* const* block_base = const_cast<T* const*>(block_chunks + block_idx * block_stride);
T* chunk_base = block_base[nl_idx * no + no_idx];
size_t chunk_offset = block_inner_offset<Layout>(nt_idx, nh_idx, hd_idx, nt, nh, hd);
const T* universal_base = universal_blocks[block_idx];
chunk_base[chunk_offset] = universal_base[residual];
thread_id += stride;
}
}
template <typename T>
cudaError_t
kvbm_kernels_launch_block_to_universal_impl(
void* const* universal_ptrs, const void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, BlockLayout layout, cudaStream_t stream)
{
size_t block_stride = nl * no;
size_t total_per_block = nh * nl * no * nt * hd;
size_t total = total_per_block * num_blocks;
if (total == 0) {
return cudaSuccess;
}
if (!block_ptrs || !universal_ptrs) {
return cudaErrorInvalidValue;
}
constexpr int kBlockDim = 256;
int grid_dim = kvbm_kernels_compute_grid_dim(total, kBlockDim);
if (grid_dim == 0) {
return cudaSuccess;
}
const T* const* chunks = reinterpret_cast<const T* const*>(block_ptrs);
T* const* universal_blocks = reinterpret_cast<T* const*>(const_cast<void* const*>(universal_ptrs));
if (layout == BlockLayout::NHD) {
kvbm_kernels_block_to_universal_kernel<T, BlockLayout::NHD><<<grid_dim, kBlockDim, 0, stream>>>(
chunks, universal_blocks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
} else {
kvbm_kernels_block_to_universal_kernel<T, BlockLayout::HND><<<grid_dim, kBlockDim, 0, stream>>>(
chunks, universal_blocks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
}
return cudaGetLastError();
}
template <typename T>
cudaError_t
kvbm_kernels_launch_block_from_universal_impl(
const void* const* universal_ptrs, void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, BlockLayout layout, cudaStream_t stream)
{
size_t block_stride = nl * no;
size_t total_per_block = nh * nl * no * nt * hd;
size_t total = total_per_block * num_blocks;
if (total == 0) {
return cudaSuccess;
}
if (!block_ptrs || !universal_ptrs) {
return cudaErrorInvalidValue;
}
constexpr int kBlockDim = 256;
int grid_dim = kvbm_kernels_compute_grid_dim(total, kBlockDim);
if (grid_dim == 0) {
return cudaSuccess;
}
const T* const* universal_blocks = reinterpret_cast<const T* const*>(universal_ptrs);
T* const* chunks = reinterpret_cast<T* const*>(const_cast<void* const*>(block_ptrs));
if (layout == BlockLayout::NHD) {
kvbm_kernels_universal_to_block_kernel<T, BlockLayout::NHD><<<grid_dim, kBlockDim, 0, stream>>>(
universal_blocks, chunks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
} else {
kvbm_kernels_universal_to_block_kernel<T, BlockLayout::HND><<<grid_dim, kBlockDim, 0, stream>>>(
universal_blocks, chunks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
}
return cudaGetLastError();
}
} // namespace
extern "C" cudaError_t
kvbm_kernels_launch_universal_from_block(
void* const* universal_ptrs, const void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
auto dtype = static_cast<TensorDataType>(dtype_value);
auto layout = static_cast<BlockLayout>(layout_value);
switch (dtype) {
case TensorDataType::F16:
return kvbm_kernels_launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F16>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::BF16:
return kvbm_kernels_launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::BF16>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F32:
return kvbm_kernels_launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F32>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F64:
return kvbm_kernels_launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F64>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
default:
return cudaErrorInvalidValue;
}
}
extern "C" cudaError_t
kvbm_kernels_launch_block_from_universal(
const void* const* universal_ptrs, void* const* block_ptrs, size_t num_blocks, size_t nh, size_t nl, size_t no,
size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
auto dtype = static_cast<TensorDataType>(dtype_value);
auto layout = static_cast<BlockLayout>(layout_value);
switch (dtype) {
case TensorDataType::F16:
return kvbm_kernels_launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F16>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::BF16:
return kvbm_kernels_launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::BF16>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F32:
return kvbm_kernels_launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F32>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F64:
return kvbm_kernels_launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F64>::type>(
universal_ptrs, block_ptrs, num_blocks, nh, nl, no, nt, hd, layout, stream);
default:
return cudaErrorInvalidValue;
}
}
/// Check if cudaMemcpyBatchAsync is available at compile time.
/// Returns true if CUDA 12.9+ was used to compile this library.
extern "C" bool
kvbm_kernels_has_memcpy_batch_async()
{
#if CUDART_VERSION >= 12090
return true;
#else
return false;
#endif
}
/// Controls how kvbm_kernels_memcpy_batch dispatches copies.
enum class MemcpyBatchMode : int {
BatchedWithFallback = 0, // Try cudaMemcpyBatchAsync, fall back to individual cudaMemcpyAsync on failure
FallbackOnly = 1, // Only use individual cudaMemcpyAsync loop (never attempt batch API)
BatchWithoutFallback = 2, // Try cudaMemcpyBatchAsync, return error on failure (no fallback)
};
/// Batched memcpy using cudaMemcpyBatchAsync (CUDA 12.9+) and/or individual cudaMemcpyAsync.
///
/// Takes HOST arrays of src/dst pointers - no device allocation needed.
/// Direction is auto-determined by CUDA from pointer types using cudaMemcpyDefault.
///
/// @param src_ptrs Host array of source pointers
/// @param dst_ptrs Host array of destination pointers
/// @param size_per_copy Size in bytes for each copy
/// @param num_copies Number of copies to perform
/// @param mode_value Controls dispatch: 0 = BatchedWithFallback, 1 = FallbackOnly, 2 = BatchWithoutFallback
/// @param stream CUDA stream for async execution
/// @return cudaSuccess on success, cudaErrorNotSupported if batch API unavailable and mode disallows fallback
extern "C" cudaError_t
kvbm_kernels_memcpy_batch(
const void* const* src_ptrs, void* const* dst_ptrs, size_t size_per_copy, size_t num_copies, int mode_value,
cudaStream_t stream)
{
auto mode = static_cast<MemcpyBatchMode>(mode_value);
if (num_copies == 0 || size_per_copy == 0) {
return cudaSuccess;
}
if (!src_ptrs || !dst_ptrs) {
return cudaErrorInvalidValue;
}
auto launch_memcpy_async_fallback = [&]() -> cudaError_t {
for (size_t i = 0; i < num_copies; ++i) {
cudaError_t copy_err = cudaMemcpyAsync(dst_ptrs[i], src_ptrs[i], size_per_copy, cudaMemcpyDefault, stream);
if (copy_err != cudaSuccess) {
return copy_err;
}
}
return cudaSuccess;
};
// FallbackOnly: skip batch entirely, always use individual cudaMemcpyAsync
if (mode == MemcpyBatchMode::FallbackOnly) {
return launch_memcpy_async_fallback();
}
#if defined(CUDART_VERSION)
#if CUDART_VERSION >= 12090
std::vector<size_t> sizes(num_copies, size_per_copy);
std::vector<void*> src_ptrs_mut(num_copies);
for (size_t i = 0; i < num_copies; ++i) {
src_ptrs_mut[i] = const_cast<void*>(src_ptrs[i]);
}
// attrIdxList must have one entry per copy, mapping each to an attribute.
// We use a single attribute (index 0) for all copies.
std::vector<size_t> attr_indices(num_copies, 0);
cudaMemcpyAttributes attr = {};
attr.srcAccessOrder = cudaMemcpySrcAccessOrderStream;
#if CUDART_VERSION >= 13000
// CUDA 13.0+: 8-parameter API (no failIdx)
cudaError_t err = cudaMemcpyBatchAsync(
const_cast<void**>(dst_ptrs), src_ptrs_mut.data(), sizes.data(), num_copies, &attr, attr_indices.data(), 1,
stream);
#else
// CUDA 12.9: 9-parameter API (with failIdx)
size_t fail_idx = 0;
cudaError_t err = cudaMemcpyBatchAsync(
const_cast<void**>(dst_ptrs), src_ptrs_mut.data(), sizes.data(), num_copies, &attr, attr_indices.data(), 1,
&fail_idx, stream);
#endif
if (err == cudaErrorNotSupported || err == cudaErrorInvalidValue) {
if (mode == MemcpyBatchMode::BatchWithoutFallback) {
return err;
}
#ifdef KVBM_TENSOR_KERNELS_DEBUG
fprintf(
stderr, "cudaMemcpyBatchAsync failed with error %d (%s), falling back to individual cudaMemcpyAsync\n",
(int)err, cudaGetErrorString(err));
#endif
return launch_memcpy_async_fallback();
}
return err;
#else
// CUDA < 12.9: batch API not available at compile time
if (mode == MemcpyBatchMode::BatchWithoutFallback) {
return cudaErrorNotSupported;
}
#pragma message("CUDA < 12.9: Fallback to individual cudaMemcpyAsync with cudaMemcpyDefault")
return launch_memcpy_async_fallback();
#endif
#else
// CUDART_VERSION not defined
if (mode == MemcpyBatchMode::BatchWithoutFallback) {
return cudaErrorNotSupported;
}
return launch_memcpy_async_fallback();
#endif
}
/// Returns false - this is the real CUDA implementation, not stubs.
/// Downstream crates can use this to skip CUDA tests at runtime when stubs are linked.
extern "C" bool
kvbm_kernels_is_stub_build()
{
return false;
}
/// Vectorized memory copy kernel for arbitrary device-visible pointer pairs.
///
/// Each block handles one or more (src, dst) pairs using a grid-strided loop.
/// Per-pair alignment detection selects the widest safe vector width:
/// - 16-byte (int4) if both pointers are 16-byte aligned
/// - 8-byte (int2) if both pointers are 8-byte aligned
/// - 4-byte (int) if both pointers are 4-byte aligned
/// - 1-byte fallback for any remainder
///
/// Source and destination pointers may be device memory or pinned host memory —
/// any memory reachable via CUDA unified addressing is valid.
__global__ void
kvbm_kernels_vectorized_copy_kernel(void** src_ptrs, void** dst_ptrs, size_t copy_size_in_bytes, int num_pairs)
{
int pair_id = blockIdx.x;
int block_stride = gridDim.x;
int tid = threadIdx.x;
int block_size = blockDim.x;
for (; pair_id < num_pairs; pair_id += block_stride) {
char* src = static_cast<char*>(src_ptrs[pair_id]);
char* dst = static_cast<char*>(dst_ptrs[pair_id]);
// Check alignment for THIS specific pair (all threads in block see same values)
uintptr_t src_addr = reinterpret_cast<uintptr_t>(src);
uintptr_t dst_addr = reinterpret_cast<uintptr_t>(dst);
size_t vectorized_bytes = 0;
if (((src_addr & 0xF) == 0) && ((dst_addr & 0xF) == 0) && (copy_size_in_bytes >= 16)) {
// Best case: 16-byte vectorized copy using int4
size_t num_int4 = copy_size_in_bytes >> 4;
for (size_t i = tid; i < num_int4; i += block_size) {
reinterpret_cast<int4*>(dst)[i] = reinterpret_cast<const int4*>(src)[i];
}
vectorized_bytes = num_int4 << 4;
} else if (((src_addr & 0x7) == 0) && ((dst_addr & 0x7) == 0) && (copy_size_in_bytes >= 8)) {
// 8-byte vectorized copy using int2 (matches LMCache int64_t approach)
size_t num_int2 = copy_size_in_bytes >> 3;
for (size_t i = tid; i < num_int2; i += block_size) {
reinterpret_cast<int2*>(dst)[i] = reinterpret_cast<const int2*>(src)[i];
}
vectorized_bytes = num_int2 << 3;
} else if (((src_addr & 0x3) == 0) && ((dst_addr & 0x3) == 0) && (copy_size_in_bytes >= 4)) {
// 4-byte vectorized copy
size_t num_int = copy_size_in_bytes >> 2;
for (size_t i = tid; i < num_int; i += block_size) {
reinterpret_cast<int*>(dst)[i] = reinterpret_cast<const int*>(src)[i];
}
vectorized_bytes = num_int << 2;
}
// Handle remaining bytes (from vectorized remainder or full scalar fallback)
size_t remaining = copy_size_in_bytes - vectorized_bytes;
for (size_t i = tid; i < remaining; i += block_size) {
dst[vectorized_bytes + i] = src[vectorized_bytes + i];
}
}
}
/// Launch the vectorized copy kernel for copying between arbitrary pointer pairs.
/// This kernel automatically selects optimal vectorization (4/8/16 bytes) based on alignment.
///
/// @param src_ptrs Device-accessible pointer to array of source pointers
/// @param dst_ptrs Device-accessible pointer to array of destination pointers
/// @param copy_size_bytes Size of each copy in bytes
/// @param num_pairs Number of pointer pairs to copy
/// @param stream CUDA stream for async execution
extern "C" cudaError_t
kvbm_kernels_launch_vectorized_copy(
void** src_ptrs, void** dst_ptrs, size_t copy_size_bytes, int num_pairs, cudaStream_t stream)
{
if (num_pairs == 0 || copy_size_bytes == 0) {
return cudaSuccess;
}
if (!src_ptrs || !dst_ptrs) {
return cudaErrorInvalidValue;
}
// Use 128 threads per block, one block per pair (up to 65535 blocks)
constexpr int kBlockDim = 128;
int grid_dim = std::min(num_pairs, 65535);
kvbm_kernels_vectorized_copy_kernel<<<grid_dim, kBlockDim, 0, stream>>>(
src_ptrs, dst_ptrs, copy_size_bytes, num_pairs);
return cudaGetLastError();
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! KV cache transfer benchmark.
//!
//! Compares vectorized copy kernel, cudaMemcpyBatchAsync, and individual
//! cudaMemcpyAsync for layerwise vs fully-contiguous block transfers using
//! Llama 3.1 70B KV cache dimensions.
//!
//! Output: CSV on stdout suitable for piping to a plotting script.
//!
//! Usage:
//! # Run all configurations:
//! cargo run --example kvbench --features kvbench 2>/dev/null
//!
//! # Single test:
//! cargo run --example kvbench --features kvbench -- \
//! --num-blocks 16 --tokens-per-block 32 --backend vectorized --direction d2d
//!
//! # Subset:
//! cargo run --example kvbench --features kvbench -- \
//! --num-blocks 1,4,16,64 --backend vectorized,batched --direction d2d --pattern fc_to_fc
//!
//! # Pipe to plotter:
//! cargo run --example kvbench --features kvbench 2>/dev/null | python3 scripts/plot_roofline.py
use std::ffi::c_void;
use clap::Parser;
use cudarc::driver::CudaContext;
use cudarc::runtime::sys as cuda_runtime;
use kvbm_kernels::{MemcpyBatchMode, memcpy_batch, vectorized_copy};
// ---------------------------------------------------------------------------
// Llama 3.1 70B, bf16 KV cache dimensions
// ---------------------------------------------------------------------------
const NUM_LAYERS: usize = 80;
const NUM_KV_HEADS: usize = 8;
const HEAD_DIM: usize = 128;
const ELEM_SIZE: usize = 2; // bf16
const OUTER_DIM: usize = 2; // K and V
// ---------------------------------------------------------------------------
// CLI
// ---------------------------------------------------------------------------
/// KV cache transfer benchmark (Llama 3.1 70B, bf16).
#[derive(Parser, Debug)]
#[command(name = "kvbench", about = "KV cache transfer bandwidth benchmark")]
struct Cli {
/// Comma-separated number of blocks to benchmark.
#[arg(
long,
default_value = "1,2,4,8,16,32,64,128,256",
value_delimiter = ','
)]
num_blocks: Vec<usize>,
/// Comma-separated tokens per block values.
#[arg(long, default_value = "16,32,64", value_delimiter = ',')]
tokens_per_block: Vec<usize>,
/// Comma-separated backends: vectorized, batched, memcpy_async.
#[arg(
long,
default_value = "vectorized,batched,memcpy_async",
value_delimiter = ','
)]
backend: Vec<String>,
/// Comma-separated directions: h2d, d2h, d2d.
#[arg(long, default_value = "h2d,d2h,d2d", value_delimiter = ',')]
direction: Vec<String>,
/// Comma-separated patterns: fc_to_fc, lw_to_fc.
#[arg(long, default_value = "fc_to_fc,lw_to_fc", value_delimiter = ',')]
pattern: Vec<String>,
/// Number of warmup iterations.
#[arg(long, default_value = "10")]
warmup: usize,
/// Number of timed iterations.
#[arg(long, default_value = "100")]
iters: usize,
}
// ---------------------------------------------------------------------------
// Direct FFI for CUDA runtime functions not exposed through cudarc
// ---------------------------------------------------------------------------
unsafe extern "C" {
fn cudaMallocHost(ptr: *mut *mut c_void, size: usize) -> u32;
fn cudaFreeHost(ptr: *mut c_void) -> u32;
fn cudaMalloc(ptr: *mut *mut c_void, size: usize) -> u32;
fn cudaFree(ptr: *mut c_void) -> u32;
fn cudaEventCreate(event: *mut cuda_runtime::cudaEvent_t) -> u32;
fn cudaEventDestroy(event: cuda_runtime::cudaEvent_t) -> u32;
fn cudaEventRecord(event: cuda_runtime::cudaEvent_t, stream: cuda_runtime::cudaStream_t)
-> u32;
fn cudaEventSynchronize(event: cuda_runtime::cudaEvent_t) -> u32;
fn cudaEventElapsedTime(
ms: *mut f32,
start: cuda_runtime::cudaEvent_t,
end: cuda_runtime::cudaEvent_t,
) -> u32;
fn cudaStreamSynchronize(stream: cuda_runtime::cudaStream_t) -> u32;
fn cudaMemcpyAsync(
dst: *mut c_void,
src: *const c_void,
count: usize,
kind: u32,
stream: cuda_runtime::cudaStream_t,
) -> u32;
}
const CUDA_MEMCPY_HOST_TO_DEVICE: u32 = 1;
// ---------------------------------------------------------------------------
// Memory management helpers
// ---------------------------------------------------------------------------
/// RAII wrapper for pinned host memory.
struct PinnedBuffer {
ptr: *mut c_void,
_len: usize,
}
impl PinnedBuffer {
fn new(len: usize) -> Self {
let mut ptr: *mut c_void = std::ptr::null_mut();
let err = unsafe { cudaMallocHost(&mut ptr, len) };
assert_eq!(err, 0, "cudaMallocHost failed: {err}");
// Fill with pattern so we're not benchmarking zero-page tricks
unsafe { std::ptr::write_bytes(ptr as *mut u8, 0xAB, len) };
Self { ptr, _len: len }
}
fn as_ptr(&self) -> *mut c_void {
self.ptr
}
}
impl Drop for PinnedBuffer {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { cudaFreeHost(self.ptr) };
}
}
}
/// RAII wrapper for device memory.
struct DeviceBuffer {
ptr: *mut c_void,
_len: usize,
}
impl DeviceBuffer {
fn new(len: usize) -> Self {
let mut ptr: *mut c_void = std::ptr::null_mut();
let err = unsafe { cudaMalloc(&mut ptr, len) };
assert_eq!(err, 0, "cudaMalloc failed: {err}");
Self { ptr, _len: len }
}
fn as_ptr(&self) -> *mut c_void {
self.ptr
}
}
impl Drop for DeviceBuffer {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { cudaFree(self.ptr) };
}
}
}
/// RAII wrapper for CUDA events.
struct CudaEvent {
event: cuda_runtime::cudaEvent_t,
}
impl CudaEvent {
fn new() -> Self {
let mut event: cuda_runtime::cudaEvent_t = std::ptr::null_mut();
let err = unsafe { cudaEventCreate(&mut event) };
assert_eq!(err, 0, "cudaEventCreate failed: {err}");
Self { event }
}
fn record(&self, stream: cuda_runtime::cudaStream_t) {
let err = unsafe { cudaEventRecord(self.event, stream) };
assert_eq!(err, 0, "cudaEventRecord failed: {err}");
}
fn synchronize(&self) {
let err = unsafe { cudaEventSynchronize(self.event) };
assert_eq!(err, 0, "cudaEventSynchronize failed: {err}");
}
fn elapsed_ms(&self, start: &CudaEvent) -> f32 {
let mut ms: f32 = 0.0;
let err = unsafe { cudaEventElapsedTime(&mut ms, start.event, self.event) };
assert_eq!(err, 0, "cudaEventElapsedTime failed: {err}");
ms
}
}
impl Drop for CudaEvent {
fn drop(&mut self) {
if !self.event.is_null() {
unsafe { cudaEventDestroy(self.event) };
}
}
}
// ---------------------------------------------------------------------------
// Transfer direction
// ---------------------------------------------------------------------------
#[derive(Clone, Copy, Debug)]
enum Direction {
H2D,
D2H,
D2D,
}
impl Direction {
fn label(&self) -> &'static str {
match self {
Direction::H2D => "h2d",
Direction::D2H => "d2h",
Direction::D2D => "d2d",
}
}
fn from_str(s: &str) -> Option<Self> {
match s {
"h2d" => Some(Direction::H2D),
"d2h" => Some(Direction::D2H),
"d2d" => Some(Direction::D2D),
_ => None,
}
}
fn all_labels() -> &'static str {
"h2d, d2h, d2d"
}
}
// ---------------------------------------------------------------------------
// Transfer pattern
// ---------------------------------------------------------------------------
#[derive(Clone, Copy, Debug)]
enum Pattern {
FcToFc,
LwToFc,
}
impl Pattern {
fn label(&self) -> &'static str {
match self {
Pattern::FcToFc => "fc_to_fc",
Pattern::LwToFc => "lw_to_fc",
}
}
fn from_str(s: &str) -> Option<Self> {
match s {
"fc_to_fc" | "fc" => Some(Pattern::FcToFc),
"lw_to_fc" | "lw" => Some(Pattern::LwToFc),
_ => None,
}
}
fn all_labels() -> &'static str {
"fc_to_fc (or fc), lw_to_fc (or lw)"
}
}
// ---------------------------------------------------------------------------
// Backend
// ---------------------------------------------------------------------------
#[derive(Clone, Copy, Debug)]
enum Backend {
Vectorized,
Batched,
MemcpyAsync,
}
impl Backend {
fn label(&self) -> &'static str {
match self {
Backend::Vectorized => "vectorized",
Backend::Batched => "batched",
Backend::MemcpyAsync => "memcpy_async",
}
}
fn from_str(s: &str) -> Option<Self> {
match s {
"vectorized" | "vec" => Some(Backend::Vectorized),
"batched" | "batch" => Some(Backend::Batched),
"memcpy_async" | "async" | "memcpy" => Some(Backend::MemcpyAsync),
_ => None,
}
}
fn all_labels() -> &'static str {
"vectorized (or vec), batched (or batch), memcpy_async (or async/memcpy)"
}
}
// ---------------------------------------------------------------------------
// Allocate src/dst memory pair for a given direction
// ---------------------------------------------------------------------------
struct MemoryPair {
src_bufs: SideBuffers,
dst_bufs: SideBuffers,
}
enum SideBuffers {
Pinned(Vec<PinnedBuffer>),
Device(Vec<DeviceBuffer>),
}
impl SideBuffers {
fn block_ptr(&self, block_idx: usize) -> *mut c_void {
match self {
SideBuffers::Pinned(bufs) => bufs[block_idx].as_ptr(),
SideBuffers::Device(bufs) => bufs[block_idx].as_ptr(),
}
}
}
fn allocate_memory(direction: Direction, num_blocks: usize, block_size: usize) -> MemoryPair {
match direction {
Direction::H2D => MemoryPair {
src_bufs: SideBuffers::Pinned(
(0..num_blocks)
.map(|_| PinnedBuffer::new(block_size))
.collect(),
),
dst_bufs: SideBuffers::Device(
(0..num_blocks)
.map(|_| DeviceBuffer::new(block_size))
.collect(),
),
},
Direction::D2H => MemoryPair {
src_bufs: SideBuffers::Device(
(0..num_blocks)
.map(|_| DeviceBuffer::new(block_size))
.collect(),
),
dst_bufs: SideBuffers::Pinned(
(0..num_blocks)
.map(|_| PinnedBuffer::new(block_size))
.collect(),
),
},
Direction::D2D => MemoryPair {
src_bufs: SideBuffers::Device(
(0..num_blocks)
.map(|_| DeviceBuffer::new(block_size))
.collect(),
),
dst_bufs: SideBuffers::Device(
(0..num_blocks)
.map(|_| DeviceBuffer::new(block_size))
.collect(),
),
},
}
}
// ---------------------------------------------------------------------------
// Build pointer lists
// ---------------------------------------------------------------------------
/// For FC<=>FC: one (src, dst) pair per block, each of full_block_size.
fn build_fc_ptrs(mem: &MemoryPair, num_blocks: usize) -> (Vec<*const c_void>, Vec<*mut c_void>) {
let mut src_ptrs = Vec::with_capacity(num_blocks);
let mut dst_ptrs = Vec::with_capacity(num_blocks);
for b in 0..num_blocks {
src_ptrs.push(mem.src_bufs.block_ptr(b) as *const c_void);
dst_ptrs.push(mem.dst_bufs.block_ptr(b));
}
(src_ptrs, dst_ptrs)
}
/// For LW<=>FC: loop over blocks, layers, outers.
/// Each entry is `inner` bytes at the appropriate offset into the contiguous block.
fn build_lw_ptrs(
mem: &MemoryPair,
num_blocks: usize,
inner: usize,
) -> (Vec<*const c_void>, Vec<*mut c_void>) {
let total = num_blocks * NUM_LAYERS * OUTER_DIM;
let mut src_ptrs = Vec::with_capacity(total);
let mut dst_ptrs = Vec::with_capacity(total);
for b in 0..num_blocks {
let src_base = mem.src_bufs.block_ptr(b) as *const u8;
let dst_base = mem.dst_bufs.block_ptr(b) as *mut u8;
for layer in 0..NUM_LAYERS {
for outer in 0..OUTER_DIM {
let offset = (layer * OUTER_DIM + outer) * inner;
unsafe {
src_ptrs.push(src_base.add(offset) as *const c_void);
dst_ptrs.push(dst_base.add(offset) as *mut c_void);
}
}
}
}
(src_ptrs, dst_ptrs)
}
// ---------------------------------------------------------------------------
// Run one benchmark configuration
// ---------------------------------------------------------------------------
fn run_benchmark(
stream_raw: cuda_runtime::cudaStream_t,
pattern: Pattern,
direction: Direction,
backend: Backend,
tokens_per_block: usize,
num_blocks: usize,
warmup_iters: usize,
timed_iters: usize,
) -> Option<(f64, f64)> {
let inner = tokens_per_block * NUM_KV_HEADS * HEAD_DIM * ELEM_SIZE;
let full_block_size = inner * OUTER_DIM * NUM_LAYERS;
let total_bytes = full_block_size * num_blocks;
// Allocate memory
let mem = allocate_memory(direction, num_blocks, full_block_size);
// Build pointer lists based on pattern
let (copy_size, num_copies) = match pattern {
Pattern::FcToFc => (full_block_size, num_blocks),
Pattern::LwToFc => (inner, num_blocks * NUM_LAYERS * OUTER_DIM),
};
let (src_ptrs, dst_ptrs) = match pattern {
Pattern::FcToFc => build_fc_ptrs(&mem, num_blocks),
Pattern::LwToFc => build_lw_ptrs(&mem, num_blocks, inner),
};
// For vectorized copy: allocate device pointer arrays
let (src_ptrs_dev, dst_ptrs_dev) = if matches!(backend, Backend::Vectorized) {
let ptr_array_bytes = num_copies * std::mem::size_of::<usize>();
(
Some(DeviceBuffer::new(ptr_array_bytes)),
Some(DeviceBuffer::new(ptr_array_bytes)),
)
} else {
(None, None)
};
let start_event = CudaEvent::new();
let end_event = CudaEvent::new();
let mut elapsed_samples = Vec::with_capacity(timed_iters);
for iter in 0..(warmup_iters + timed_iters) {
let is_timed = iter >= warmup_iters;
if is_timed {
start_event.record(stream_raw);
}
match backend {
Backend::Vectorized => {
let src_dev = src_ptrs_dev.as_ref().unwrap();
let dst_dev = dst_ptrs_dev.as_ref().unwrap();
let ptr_bytes = num_copies * std::mem::size_of::<usize>();
// H2D copy of pointer arrays (included in timing)
unsafe {
let err = cudaMemcpyAsync(
src_dev.as_ptr(),
src_ptrs.as_ptr() as *const c_void,
ptr_bytes,
CUDA_MEMCPY_HOST_TO_DEVICE,
stream_raw,
);
assert_eq!(err, 0, "H2D ptr copy failed: {err}");
let err = cudaMemcpyAsync(
dst_dev.as_ptr(),
dst_ptrs.as_ptr() as *const c_void,
ptr_bytes,
CUDA_MEMCPY_HOST_TO_DEVICE,
stream_raw,
);
assert_eq!(err, 0, "H2D ptr copy failed: {err}");
}
// Launch vectorized copy kernel
let status = unsafe {
vectorized_copy(
src_dev.as_ptr() as *mut *mut c_void,
dst_dev.as_ptr() as *mut *mut c_void,
copy_size,
num_copies as i32,
stream_raw,
)
};
assert_eq!(
status,
cuda_runtime::cudaError::cudaSuccess,
"vectorized_copy failed: {status:?}"
);
}
Backend::Batched => {
let status = unsafe {
memcpy_batch(
src_ptrs.as_ptr() as *const *const c_void,
dst_ptrs.as_ptr() as *const *mut c_void,
copy_size,
num_copies,
MemcpyBatchMode::BatchedWithFallback,
stream_raw,
)
};
assert_eq!(
status,
cuda_runtime::cudaError::cudaSuccess,
"memcpy_batch (Batched) failed: {status:?}"
);
}
Backend::MemcpyAsync => {
let status = unsafe {
memcpy_batch(
src_ptrs.as_ptr() as *const *const c_void,
dst_ptrs.as_ptr() as *const *mut c_void,
copy_size,
num_copies,
MemcpyBatchMode::FallbackOnly,
stream_raw,
)
};
assert_eq!(
status,
cuda_runtime::cudaError::cudaSuccess,
"memcpy_batch (FallbackOnly) failed: {status:?}"
);
}
}
if is_timed {
end_event.record(stream_raw);
end_event.synchronize();
let ms = end_event.elapsed_ms(&start_event);
elapsed_samples.push(ms);
}
}
// Sync before dropping memory
unsafe { cudaStreamSynchronize(stream_raw) };
// Compute median
elapsed_samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median_ms = elapsed_samples[elapsed_samples.len() / 2] as f64;
let bandwidth_gbps = (total_bytes as f64) / (median_ms / 1000.0) / 1e9;
Some((median_ms, bandwidth_gbps))
}
// ---------------------------------------------------------------------------
// Parse CLI values into typed enums
// ---------------------------------------------------------------------------
fn parse_directions(raw: &[String]) -> Vec<Direction> {
raw.iter()
.map(|s| {
Direction::from_str(s).unwrap_or_else(|| {
panic!(
"unknown direction '{}', expected: {}",
s,
Direction::all_labels()
)
})
})
.collect()
}
fn parse_patterns(raw: &[String]) -> Vec<Pattern> {
raw.iter()
.map(|s| {
Pattern::from_str(s).unwrap_or_else(|| {
panic!(
"unknown pattern '{}', expected: {}",
s,
Pattern::all_labels()
)
})
})
.collect()
}
fn parse_backends(raw: &[String]) -> Vec<Backend> {
raw.iter()
.map(|s| {
Backend::from_str(s).unwrap_or_else(|| {
panic!(
"unknown backend '{}', expected: {}",
s,
Backend::all_labels()
)
})
})
.collect()
}
// ---------------------------------------------------------------------------
// Main
// ---------------------------------------------------------------------------
fn main() {
let cli = Cli::parse();
let directions = parse_directions(&cli.direction);
let patterns = parse_patterns(&cli.pattern);
let backends = parse_backends(&cli.backend);
let tpb_options = &cli.tokens_per_block;
let num_blocks_options = &cli.num_blocks;
let warmup_iters = cli.warmup;
let timed_iters = cli.iters;
let total_tests = tpb_options.len()
* num_blocks_options.len()
* directions.len()
* patterns.len()
* backends.len();
// Initialize CUDA context
let count = CudaContext::device_count().expect("Failed to query CUDA devices");
assert!(count > 0, "No CUDA devices found");
let ctx = CudaContext::new(0).expect("Failed to create CUDA context");
let stream = ctx.new_stream().expect("Failed to create CUDA stream");
let stream_raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
// Print config to stderr
eprintln!("KV Cache Transfer Benchmark");
eprintln!(" Model: Llama 3.1 70B (bf16)");
eprintln!(
" Layers: {NUM_LAYERS}, KV heads: {NUM_KV_HEADS}, Head dim: {HEAD_DIM}, Outer dim: {OUTER_DIM}"
);
eprintln!(" Warmup: {warmup_iters}, Timed: {timed_iters}");
eprintln!(
" Batch API available: {}",
kvbm_kernels::is_memcpy_batch_available()
);
eprintln!(" tokens_per_block: {:?}", tpb_options);
eprintln!(" num_blocks: {:?}", num_blocks_options);
eprintln!(
" directions: [{}]",
directions
.iter()
.map(|d| d.label())
.collect::<Vec<_>>()
.join(", ")
);
eprintln!(
" patterns: [{}]",
patterns
.iter()
.map(|p| p.label())
.collect::<Vec<_>>()
.join(", ")
);
eprintln!(
" backends: [{}]",
backends
.iter()
.map(|b| b.label())
.collect::<Vec<_>>()
.join(", ")
);
eprintln!(" Total tests: {total_tests}");
eprintln!();
// CSV header
println!(
"tokens_per_block,num_blocks,pattern,direction,backend,total_bytes,inner_bytes,copy_size,num_copies,median_ms,bandwidth_gbps"
);
let mut test_num = 0;
for &tpb in tpb_options {
let inner = tpb * NUM_KV_HEADS * HEAD_DIM * ELEM_SIZE;
let full_block_size = inner * OUTER_DIM * NUM_LAYERS;
eprintln!(
"--- tokens_per_block={tpb}, inner={inner} bytes ({} KB), block={full_block_size} bytes ({:.1} MB) ---",
inner / 1024,
full_block_size as f64 / (1024.0 * 1024.0)
);
for &num_blocks in num_blocks_options {
let total_bytes = full_block_size * num_blocks;
for &direction in &directions {
for &pattern in &patterns {
let (copy_size, num_copies) = match pattern {
Pattern::FcToFc => (full_block_size, num_blocks),
Pattern::LwToFc => (inner, num_blocks * NUM_LAYERS * OUTER_DIM),
};
for &backend in &backends {
test_num += 1;
eprint!(
" [{test_num}/{total_tests}] tpb={tpb} N={num_blocks:>3} {:<8} {:<6} {:<12} ... ",
pattern.label(),
direction.label(),
backend.label(),
);
match run_benchmark(
stream_raw,
pattern,
direction,
backend,
tpb,
num_blocks,
warmup_iters,
timed_iters,
) {
Some((median_ms, bw)) => {
println!(
"{tpb},{num_blocks},{},{},{},{total_bytes},{inner},{copy_size},{num_copies},{median_ms:.4},{bw:.2}",
pattern.label(),
direction.label(),
backend.label(),
);
eprintln!("{bw:.2} GB/s ({median_ms:.4} ms)");
}
None => {
eprintln!("SKIPPED");
}
}
}
}
}
}
}
eprintln!("\nDone.");
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod tensor_kernels;
// Always available - core transfer functionality
pub use tensor_kernels::{
MemcpyBatchMode, is_memcpy_batch_available, is_using_stubs, memcpy_batch, vectorized_copy,
};
// Permute kernels - data layout transformation (requires permute_kernels feature)
#[cfg(feature = "permute_kernels")]
pub use tensor_kernels::{BlockLayout, TensorDataType, block_from_universal, universal_from_block};
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Safe-ish wrappers around the CUDA block/universal packing kernels.
//!
//! The core ideas:
//! * A “block” represents the stack of `nl * no` tensors arranged either as NHD
//! (inner axes `[nt, nh, hd]`) or HND (inner axes `[nh, nt, hd]`).
//! * A “universal” tensor is `[nh, nl, no, nt, hd]` stored contiguously.
//! * An “operational” tensor is `[nl, no, inner]` with `inner = nt * nh * hd`.
//!
//! All pointer-list parameters (e.g. `universal_ptrs`, `src_ptrs`) must be
//! device-accessible: allocated via `cudaMalloc` (device memory) or
//! `cudaMallocHost` / `cuMemHostRegister` (pinned/registered/page-locked host memory).
//!
//! Host code calls these helpers with flattened pointer tables so a single
//! launch can move many logical blocks in one go.
#![allow(clippy::missing_safety_doc)]
use std::ffi::c_void;
use cudarc::runtime::sys::{cudaError_t, cudaStream_t};
/// Numeric tags passed across the FFI boundary to select the CUDA template.
#[cfg(feature = "permute_kernels")]
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TensorDataType {
F16 = 0,
BF16 = 1,
F32 = 2,
F64 = 3,
}
/// Identifies how each `[nt, nh, hd]` chunk is laid out in device memory.
#[cfg(feature = "permute_kernels")]
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BlockLayout {
NHD = 0,
HND = 1,
}
#[cfg(feature = "permute_kernels")]
#[allow(dead_code)]
unsafe extern "C" {
fn kvbm_kernels_launch_universal_from_block(
universal_ptrs: *const *mut c_void,
block_ptrs: *const *const c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: i32,
layout: i32,
stream: cudaStream_t,
) -> cudaError_t;
fn kvbm_kernels_launch_block_from_universal(
universal_ptrs: *const *const c_void,
block_ptrs: *const *mut c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: i32,
layout: i32,
stream: cudaStream_t,
) -> cudaError_t;
}
/// Controls how `memcpy_batch` dispatches copies.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MemcpyBatchMode {
/// Try cudaMemcpyBatchAsync, fall back to individual cudaMemcpyAsync on failure.
BatchedWithFallback = 0,
/// Only use individual cudaMemcpyAsync loop (never attempt batch API).
FallbackOnly = 1,
/// Try cudaMemcpyBatchAsync, return error on failure (no fallback).
BatchWithoutFallback = 2,
}
#[allow(dead_code)]
unsafe extern "C" {
fn kvbm_kernels_launch_vectorized_copy(
src_ptrs: *mut *mut c_void,
dst_ptrs: *mut *mut c_void,
copy_size_bytes: usize,
num_pairs: i32,
stream: cudaStream_t,
) -> cudaError_t;
fn kvbm_kernels_memcpy_batch(
src_ptrs: *const *const c_void,
dst_ptrs: *const *mut c_void,
size_per_copy: usize,
num_copies: usize,
mode: i32,
stream: cudaStream_t,
) -> cudaError_t;
fn kvbm_kernels_has_memcpy_batch_async() -> bool;
fn kvbm_kernels_is_stub_build() -> bool;
}
/// Check if cudaMemcpyBatchAsync is available.
///
/// Returns true if the library was compiled with CUDA 12.9+ which provides
/// the `cudaMemcpyBatchAsync` API for efficient batched memory transfers.
pub fn is_memcpy_batch_available() -> bool {
unsafe { kvbm_kernels_has_memcpy_batch_async() }
}
/// Check if this library was built with stub kernels (no real CUDA).
///
/// Returns `true` if the library is using stubs that will abort on actual CUDA calls.
/// Returns `false` if real CUDA kernels are available.
///
/// Downstream crates should use this to skip CUDA tests at runtime:
/// ```ignore
/// #[test]
/// fn my_cuda_test() {
/// if kvbm_kernels::is_using_stubs() {
/// eprintln!("Skipping CUDA test: stub kernels in use");
/// return;
/// }
/// // ... actual CUDA test code ...
/// }
/// ```
pub fn is_using_stubs() -> bool {
unsafe { kvbm_kernels_is_stub_build() }
}
/// Batched memcpy using cudaMemcpyBatchAsync (CUDA 12.9+) and/or individual cudaMemcpyAsync.
///
/// Takes HOST arrays of src/dst pointers - no device allocation needed.
/// Direction is auto-determined by CUDA from pointer types using cudaMemcpyDefault.
///
/// The `mode` parameter controls dispatch:
/// - [`MemcpyBatchMode::BatchedWithFallback`]: try batch API, fall back to individual copies on error
/// - [`MemcpyBatchMode::FallbackOnly`]: always use individual cudaMemcpyAsync loop
/// - [`MemcpyBatchMode::BatchWithoutFallback`]: try batch API, return error if unavailable
///
/// # Safety
/// - `src_ptrs` must point to a valid array of `num_copies` source pointers
/// - `dst_ptrs` must point to a valid array of `num_copies` destination pointers
/// - Each source/destination pointer pair must have at least `size_per_copy` bytes accessible
/// - `stream` must be a valid CUDA stream handle
pub unsafe fn memcpy_batch(
src_ptrs: *const *const c_void,
dst_ptrs: *const *mut c_void,
size_per_copy: usize,
num_copies: usize,
mode: MemcpyBatchMode,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
kvbm_kernels_memcpy_batch(
src_ptrs,
dst_ptrs,
size_per_copy,
num_copies,
mode as i32,
stream,
)
}
}
/// Copy `num_blocks` stacks of NHD/HND tensors into universal form.
///
/// * `universal_ptrs` – device-accessible pointer to `num_blocks` universal bases.
/// * `block_ptrs` – device-accessible pointer to a flattened `[num_blocks][nl*no]`
/// table of chunk pointers.
/// * `nh, nl, no, nt, hd` – logical dimensions of each universal tensor.
/// * `stream` – CUDA stream used for the launch.
#[cfg(feature = "permute_kernels")]
#[allow(clippy::too_many_arguments)]
pub unsafe fn universal_from_block(
universal_ptrs: *const *mut c_void,
block_ptrs: *const *const c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: TensorDataType,
layout: BlockLayout,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
kvbm_kernels_launch_universal_from_block(
universal_ptrs,
block_ptrs,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype as i32,
layout as i32,
stream,
)
}
}
/// Copy `num_blocks` universal tensors back into their block stacks.
#[cfg(feature = "permute_kernels")]
#[allow(clippy::too_many_arguments)]
pub unsafe fn block_from_universal(
universal_ptrs: *const *const c_void,
block_ptrs: *const *mut c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: TensorDataType,
layout: BlockLayout,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
kvbm_kernels_launch_block_from_universal(
universal_ptrs,
block_ptrs,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype as i32,
layout as i32,
stream,
)
}
}
/// Launch vectorized copy between arbitrary device-visible pointer pairs.
///
/// This kernel automatically selects optimal vectorization (4/8/16 bytes) based on
/// pointer alignment. It is useful for copying between non-contiguous memory regions
/// where each pair has the same copy size.
///
/// Both source and destination pointers may refer to any device-visible memory,
/// including device allocations (`cudaMalloc`) and pinned host memory
/// (`cudaMallocHost` / `cudaHostAlloc`). CUDA unified addressing resolves the
/// actual location at runtime.
///
/// # Arguments
/// * `src_ptrs` - Device-accessible pointer to array of source pointers (each pointing to device-visible memory)
/// * `dst_ptrs` - Device-accessible pointer to array of destination pointers (each pointing to device-visible memory)
/// * `copy_size_bytes` - Size of each copy in bytes (same for all pairs)
/// * `num_pairs` - Number of pointer pairs to copy
/// * `stream` - CUDA stream for async execution
///
/// # Safety
/// - All pointers in the src/dst arrays must be valid device-visible pointers
/// (device memory or pinned host memory)
/// - Each pointer must have at least `copy_size_bytes` bytes accessible
/// - The pointer arrays themselves must be in device memory with at least `num_pairs` entries
/// - `stream` must be a valid CUDA stream handle
pub unsafe fn vectorized_copy(
src_ptrs: *mut *mut c_void,
dst_ptrs: *mut *mut c_void,
copy_size_bytes: usize,
num_pairs: i32,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
kvbm_kernels_launch_vectorized_copy(src_ptrs, dst_ptrs, copy_size_bytes, num_pairs, stream)
}
}
// Tests are gated to only run when:
// 1. testing-cuda feature is enabled
// 2. permute_kernels feature is enabled (tests use universal kernels)
// 3. NOT using stub kernels (stub_kernels cfg is set by build.rs when no nvcc)
#[cfg(all(
test,
feature = "testing-cuda",
feature = "permute_kernels",
not(stub_kernels)
))]
mod tests {
use super::*;
use cudarc::driver::result::memset_d8_async;
use cudarc::driver::{CudaContext, CudaSlice, DevicePtr, DevicePtrMut, DriverError};
use cudarc::runtime::sys as cuda_runtime;
#[test]
fn universal_roundtrip() -> Result<(), DriverError> {
let device_count = match CudaContext::device_count() {
Ok(count) => count,
Err(_) => return Ok(()),
};
if device_count <= 0 {
return Ok(());
}
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let stream_raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
let nh = 2usize;
let nl = 2usize;
let no = 2usize;
let nt = 3usize;
let hd = 4usize;
let inner = nt * nh * hd;
let chunk_count = nl * no;
let block_volume = nh * nl * no * nt * hd;
let num_blocks = 2usize;
let dtype = TensorDataType::F32;
let layout = BlockLayout::NHD;
let mut host_block_chunks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(num_blocks);
let mut block_slices: Vec<Vec<CudaSlice<f32>>> = Vec::with_capacity(num_blocks);
let mut block_ptr_values: Vec<usize> = Vec::with_capacity(num_blocks * chunk_count);
for block_idx in 0..num_blocks {
let mut host_chunks_for_block = Vec::with_capacity(chunk_count);
let mut slices_for_block = Vec::with_capacity(chunk_count);
for chunk_idx in 0..chunk_count {
let global_idx = block_idx * chunk_count + chunk_idx;
let mut host_chunk = Vec::with_capacity(inner);
for offset in 0..inner {
host_chunk.push((global_idx * inner + offset) as f32 + 0.25f32);
}
let slice = stream.clone_htod(&host_chunk)?;
{
let (ptr_raw, _guard) = slice.device_ptr(&stream);
block_ptr_values.push(ptr_raw as usize);
}
slices_for_block.push(slice);
host_chunks_for_block.push(host_chunk);
}
block_slices.push(slices_for_block);
host_block_chunks.push(host_chunks_for_block);
}
let block_ptrs = stream.clone_htod(block_ptr_values.as_slice())?;
let mut universal_slices = Vec::with_capacity(num_blocks);
let mut universal_ptr_values = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
let mut slice = unsafe { stream.alloc::<f32>(block_volume)? };
{
let (ptr_raw, _guard) = slice.device_ptr_mut(&stream);
universal_ptr_values.push(ptr_raw as usize);
unsafe {
memset_d8_async(
ptr_raw,
0xDE,
block_volume * std::mem::size_of::<f32>(),
stream.cu_stream(),
)?;
}
}
universal_slices.push(slice);
}
let universal_ptrs = stream.clone_htod(universal_ptr_values.as_slice())?;
// Block -> Universal
{
let (block_ptrs_raw, _block_guard) = block_ptrs.device_ptr(&stream);
let block_ptrs_ptr = block_ptrs_raw as usize as *const *const c_void;
let (universal_ptrs_raw, _univ_guard) = universal_ptrs.device_ptr(&stream);
let universal_ptrs_ptr = universal_ptrs_raw as usize as *const *mut c_void;
let status = unsafe {
super::universal_from_block(
universal_ptrs_ptr,
block_ptrs_ptr,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
let inner_offset = |nt_idx: usize, nh_idx: usize, hd_idx: usize| match layout {
BlockLayout::NHD => ((nt_idx * nh) + nh_idx) * hd + hd_idx,
BlockLayout::HND => ((nh_idx * nt) + nt_idx) * hd + hd_idx,
};
for (block_idx, universal_slice) in universal_slices.iter().enumerate().take(num_blocks) {
let host_universal = stream.clone_dtoh(universal_slice)?;
for nh_idx in 0..nh {
for nl_idx in 0..nl {
for no_idx in 0..no {
for nt_idx in 0..nt {
for hd_idx in 0..hd {
let universal_index =
((((nh_idx * nl + nl_idx) * no + no_idx) * nt + nt_idx) * hd)
+ hd_idx;
let chunk_idx = nl_idx * no + no_idx;
let offset = inner_offset(nt_idx, nh_idx, hd_idx);
let expected = ((block_idx * chunk_count + chunk_idx) * inner
+ offset) as f32
+ 0.25f32;
let value = host_universal[universal_index];
assert!(
(value - expected).abs() < 1e-5,
"universal mismatch block {} [{} {} {} {} {}]: {} vs {}",
block_idx,
nh_idx,
nl_idx,
no_idx,
nt_idx,
hd_idx,
value,
expected
);
}
}
}
}
}
}
// Universal -> Block (poison-fill destination before reverse pass)
for block in &mut block_slices {
for slice in block {
let (dptr, _guard) = slice.device_ptr_mut(&stream);
unsafe {
memset_d8_async(
dptr,
0xDE,
inner * std::mem::size_of::<f32>(),
stream.cu_stream(),
)?;
}
}
}
stream.synchronize()?;
{
let (block_ptrs_raw, _block_guard) = block_ptrs.device_ptr(&stream);
let block_ptrs_mut = block_ptrs_raw as usize as *const *mut c_void;
let (universal_ptrs_raw, _univ_guard) = universal_ptrs.device_ptr(&stream);
let universal_ptrs_const = universal_ptrs_raw as usize as *const *const c_void;
let status = unsafe {
super::block_from_universal(
universal_ptrs_const,
block_ptrs_mut,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
for block_idx in 0..num_blocks {
for chunk_idx in 0..chunk_count {
let host_chunk = stream.clone_dtoh(&block_slices[block_idx][chunk_idx])?;
for (inner_idx, value) in host_chunk.iter().enumerate() {
let expected = host_block_chunks[block_idx][chunk_idx][inner_idx];
assert!(
(value - expected).abs() < 1e-5,
"block mismatch block {} chunk {} offset {}: {} vs {}",
block_idx,
chunk_idx,
inner_idx,
value,
expected
);
}
}
}
Ok(())
}
/// Test the vectorized copy kernel directly with aligned data.
#[test]
fn test_vectorized_copy_aligned() -> Result<(), DriverError> {
let device_count = match CudaContext::device_count() {
Ok(count) => count,
Err(_) => return Ok(()),
};
if device_count <= 0 {
return Ok(());
}
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let stream_raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
// Create test data - 8-byte aligned for vectorized copy
let num_pairs = 4;
let copy_size = 256usize; // 256 bytes, divisible by 16 for int4 vectorization
// Source data
let mut src_slices = Vec::with_capacity(num_pairs);
let mut src_ptr_values = Vec::with_capacity(num_pairs);
let mut expected_data = Vec::with_capacity(num_pairs);
for i in 0..num_pairs {
let data: Vec<u8> = (0..copy_size)
.map(|j| ((i * copy_size + j) % 256) as u8)
.collect();
expected_data.push(data.clone());
let slice = stream.clone_htod(&data)?;
{
let (ptr, _guard) = slice.device_ptr(&stream);
src_ptr_values.push(ptr as usize);
}
src_slices.push(slice);
}
// Destination buffers
let mut dst_slices = Vec::with_capacity(num_pairs);
let mut dst_ptr_values = Vec::with_capacity(num_pairs);
for _ in 0..num_pairs {
let mut slice = unsafe { stream.alloc::<u8>(copy_size)? };
{
let (ptr, _guard) = slice.device_ptr_mut(&stream);
dst_ptr_values.push(ptr as usize);
}
dst_slices.push(slice);
}
// Upload pointer arrays to device
let src_ptrs = stream.clone_htod(&src_ptr_values)?;
let dst_ptrs = stream.clone_htod(&dst_ptr_values)?;
// Launch vectorized copy
{
let (src_ptrs_raw, _src_guard) = src_ptrs.device_ptr(&stream);
let (dst_ptrs_raw, _dst_guard) = dst_ptrs.device_ptr(&stream);
let status = unsafe {
super::vectorized_copy(
src_ptrs_raw as usize as *mut *mut c_void,
dst_ptrs_raw as usize as *mut *mut c_void,
copy_size,
num_pairs as i32,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
// Verify results
for i in 0..num_pairs {
let result = stream.clone_dtoh(&dst_slices[i])?;
assert_eq!(result, expected_data[i], "Mismatch at pair {}", i);
}
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for CUDA tensor packing kernel roundtrips.
//!
//! Mirrors the Python tests in `lib/bindings/kvbm/tests/test_tensor_kernels.py`
//! using ndarray for reference permutations and cudarc for GPU memory management.
#![cfg(all(
feature = "testing-cuda",
feature = "permute_kernels",
not(stub_kernels)
))]
use std::ffi::c_void;
use std::fmt::Debug;
use std::sync::Arc;
use cudarc::driver::result::memset_d8_async;
use cudarc::driver::{
CudaContext, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError,
ValidAsZeroBits,
};
use cudarc::runtime::sys as cuda_runtime;
use half::{bf16, f16};
use kvbm_kernels::{BlockLayout, TensorDataType, block_from_universal, universal_from_block};
use ndarray::{Array5, s};
use rand::Rng;
// ---------------------------------------------------------------------------
// TestDtype trait — bridges Rust types to kernel enums + tolerances
// ---------------------------------------------------------------------------
trait TestDtype: Clone + Debug + DeviceRepr + ValidAsZeroBits + 'static {
const DTYPE: TensorDataType;
const ATOL: f64;
const RTOL: f64;
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
}
impl TestDtype for f16 {
const DTYPE: TensorDataType = TensorDataType::F16;
const ATOL: f64 = 1e-2;
const RTOL: f64 = 1e-2;
fn from_f64(v: f64) -> Self {
f16::from_f64(v)
}
fn to_f64(self) -> f64 {
f16::to_f64(self)
}
}
impl TestDtype for bf16 {
const DTYPE: TensorDataType = TensorDataType::BF16;
const ATOL: f64 = 1e-2;
const RTOL: f64 = 1e-2;
fn from_f64(v: f64) -> Self {
bf16::from_f64(v)
}
fn to_f64(self) -> f64 {
bf16::to_f64(self)
}
}
impl TestDtype for f32 {
const DTYPE: TensorDataType = TensorDataType::F32;
const ATOL: f64 = 1e-5;
const RTOL: f64 = 1e-5;
fn from_f64(v: f64) -> Self {
v as f32
}
fn to_f64(self) -> f64 {
self as f64
}
}
impl TestDtype for f64 {
const DTYPE: TensorDataType = TensorDataType::F64;
const ATOL: f64 = 1e-12;
const RTOL: f64 = 1e-12;
fn from_f64(v: f64) -> Self {
v
}
fn to_f64(self) -> f64 {
self
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Reference permutation using ndarray, mirrors the Python `_make_blocks()`.
///
/// Takes a `[nh, nl, no, nt, hd]` universal tensor and produces `nl * no` flat
/// block chunks, each with layout-dependent axis ordering.
fn make_blocks<T: TestDtype>(universal: &Array5<T>, layout: BlockLayout) -> Vec<Vec<T>> {
let (_nh, nl, no, _nt, _hd) = universal.dim();
let mut blocks = Vec::with_capacity(nl * no);
for l in 0..nl {
for o in 0..no {
// Slice out [nh, nt, hd] for this (layer, outer) pair.
let chunk = universal.slice(s![.., l, o, .., ..]);
let flat = match layout {
BlockLayout::NHD => {
// [nh, nt, hd] -> [nt, nh, hd]
let permuted = chunk.permuted_axes([1, 0, 2]);
permuted.as_standard_layout().as_slice().unwrap().to_vec()
}
BlockLayout::HND => {
// [nh, nt, hd] — identity permutation
chunk.as_standard_layout().as_slice().unwrap().to_vec()
}
};
blocks.push(flat);
}
}
blocks
}
/// Element-wise comparison with dtype-aware tolerance (mirrors `torch.allclose`).
fn assert_close<T: TestDtype>(actual: &[T], expected: &[T], context: &str) {
assert_eq!(
actual.len(),
expected.len(),
"{context}: length mismatch ({} vs {})",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
let a_f64 = a.clone().to_f64();
let e_f64 = e.clone().to_f64();
let diff = (a_f64 - e_f64).abs();
let tol = T::ATOL + T::RTOL * e_f64.abs();
assert!(
diff <= tol,
"{context}[{i}]: {a_f64} vs {e_f64} (diff={diff}, tol={tol})"
);
}
}
/// Set up a CUDA context and stream. Returns `None` if no GPU is available.
fn cuda_setup() -> Option<(Arc<CudaStream>, cuda_runtime::cudaStream_t)> {
let count = CudaContext::device_count().ok()?;
if count == 0 {
return None;
}
let ctx = CudaContext::new(0).ok()?;
let stream = ctx.default_stream();
let raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
Some((stream, raw))
}
// ---------------------------------------------------------------------------
// GPU allocation helpers
// ---------------------------------------------------------------------------
/// Upload block chunks to GPU, returning the slices (kept alive) and a device
/// pointer table suitable for the kernel FFI.
fn upload_blocks<T: TestDtype>(
stream: &Arc<CudaStream>,
ref_blocks: &[Vec<Vec<T>>],
) -> Result<(Vec<Vec<CudaSlice<T>>>, CudaSlice<usize>), DriverError> {
let nb = ref_blocks.len();
let chunks_per_batch = ref_blocks.first().map_or(0, |b| b.len());
let mut all_slices: Vec<Vec<CudaSlice<T>>> = Vec::with_capacity(nb);
let mut ptr_values: Vec<usize> = Vec::with_capacity(nb * chunks_per_batch);
for batch in ref_blocks {
let mut slices = Vec::with_capacity(batch.len());
for chunk in batch {
let slice = stream.clone_htod(chunk)?;
{
let (ptr, _guard) = slice.device_ptr(stream);
ptr_values.push(ptr as usize);
}
slices.push(slice);
}
all_slices.push(slices);
}
let ptrs_device = stream.clone_htod(ptr_values.as_slice())?;
Ok((all_slices, ptrs_device))
}
/// Allocate `count` poison-filled (0xDE) device buffers of `volume` elements each.
/// Returns the slices and a device pointer table.
fn alloc_buffers<T: TestDtype>(
stream: &Arc<CudaStream>,
count: usize,
volume: usize,
) -> Result<(Vec<CudaSlice<T>>, CudaSlice<usize>), DriverError> {
let mut slices: Vec<CudaSlice<T>> = Vec::with_capacity(count);
let mut ptr_values: Vec<usize> = Vec::with_capacity(count);
let byte_count = volume * std::mem::size_of::<T>();
for _ in 0..count {
let mut slice = unsafe { stream.alloc::<T>(volume)? };
{
let (ptr, _guard) = slice.device_ptr_mut(stream);
ptr_values.push(ptr as usize);
unsafe {
memset_d8_async(ptr, 0xDE, byte_count, stream.cu_stream())?;
}
}
slices.push(slice);
}
let ptrs_device = stream.clone_htod(ptr_values.as_slice())?;
Ok((slices, ptrs_device))
}
/// Poison-fill (0xDE) all block chunk slices. `chunk_volume` is the element count per chunk.
fn poison_fill_blocks<T: TestDtype>(
stream: &Arc<CudaStream>,
block_slices: &mut [Vec<CudaSlice<T>>],
chunk_volume: usize,
) -> Result<(), DriverError> {
let byte_count = chunk_volume * std::mem::size_of::<T>();
for batch in block_slices.iter_mut() {
for slice in batch.iter_mut() {
let (dptr, _guard) = slice.device_ptr_mut(stream);
unsafe {
memset_d8_async(dptr, 0xDE, byte_count, stream.cu_stream())?;
}
}
}
Ok(())
}
// ---------------------------------------------------------------------------
// block <-> universal roundtrip
// ---------------------------------------------------------------------------
fn block_universal_roundtrip_inner<T: TestDtype>(layout: BlockLayout) -> Result<(), DriverError> {
let (stream, stream_raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
// Dimensions matching the Python test.
let nh = 3usize;
let nl = 2usize;
let no = 2usize;
let nt = 4usize;
let hd = 5usize;
let nb = 3usize;
let universal_volume = nh * nl * no * nt * hd;
// Generate random universal tensors and compute reference blocks.
let mut rng = rand::rng();
let universals: Vec<Array5<T>> = (0..nb)
.map(|_| {
Array5::from_shape_fn((nh, nl, no, nt, hd), |_| {
T::from_f64(rng.random::<f64>() * 2.0 - 1.0)
})
})
.collect();
let ref_blocks: Vec<Vec<Vec<T>>> = universals.iter().map(|u| make_blocks(u, layout)).collect();
// Upload reference blocks to GPU.
let (mut block_slices, block_ptrs) = upload_blocks(&stream, &ref_blocks)?;
// Allocate universal output buffers on GPU.
let (universal_slices, universal_ptrs) = alloc_buffers::<T>(&stream, nb, universal_volume)?;
// --- Forward: blocks -> universal ---
{
let (bp, _g1) = block_ptrs.device_ptr(&stream);
let (up, _g2) = universal_ptrs.device_ptr(&stream);
let status = unsafe {
universal_from_block(
up as usize as *const *mut c_void,
bp as usize as *const *const c_void,
nb,
nh,
nl,
no,
nt,
hd,
T::DTYPE,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
// Verify each universal buffer matches the original tensor.
for (i, (slice, expected)) in universal_slices.iter().zip(universals.iter()).enumerate() {
let host = stream.clone_dtoh(slice)?;
let expected_flat: Vec<T> = expected.as_standard_layout().as_slice().unwrap().to_vec();
assert_close::<T>(&host, &expected_flat, &format!("universal batch {i}"));
}
// --- Reverse: poison-fill blocks, then universal -> blocks ---
poison_fill_blocks(&stream, &mut block_slices, nh * nt * hd)?;
stream.synchronize()?;
{
let (bp, _g1) = block_ptrs.device_ptr(&stream);
let (up, _g2) = universal_ptrs.device_ptr(&stream);
let status = unsafe {
block_from_universal(
up as usize as *const *const c_void,
bp as usize as *const *mut c_void,
nb,
nh,
nl,
no,
nt,
hd,
T::DTYPE,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
for (bi, (batch, ref_batch)) in block_slices.iter().zip(ref_blocks.iter()).enumerate() {
for (ci, (slice, expected)) in batch.iter().zip(ref_batch.iter()).enumerate() {
let host = stream.clone_dtoh(slice)?;
assert_close::<T>(&host, expected, &format!("block batch {bi} chunk {ci}"));
}
}
Ok(())
}
macro_rules! block_universal_test {
($name:ident, $ty:ty, $layout:expr) => {
#[test]
fn $name() -> Result<(), DriverError> {
block_universal_roundtrip_inner::<$ty>($layout)
}
};
}
block_universal_test!(block_universal_roundtrip_nhd_f16, f16, BlockLayout::NHD);
block_universal_test!(block_universal_roundtrip_nhd_bf16, bf16, BlockLayout::NHD);
block_universal_test!(block_universal_roundtrip_nhd_f32, f32, BlockLayout::NHD);
block_universal_test!(block_universal_roundtrip_nhd_f64, f64, BlockLayout::NHD);
block_universal_test!(block_universal_roundtrip_hnd_f16, f16, BlockLayout::HND);
block_universal_test!(block_universal_roundtrip_hnd_bf16, bf16, BlockLayout::HND);
block_universal_test!(block_universal_roundtrip_hnd_f32, f32, BlockLayout::HND);
block_universal_test!(block_universal_roundtrip_hnd_f64, f64, BlockLayout::HND);
// ---------------------------------------------------------------------------
// Edge cases
// ---------------------------------------------------------------------------
/// All kernel functions with num_blocks=0 should be a noop returning cudaSuccess.
#[test]
fn empty_batch_noop() -> Result<(), DriverError> {
let (_stream, stream_raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let null_mut = std::ptr::null() as *const *mut c_void;
let null_const = std::ptr::null() as *const *const c_void;
// universal_from_block
let status = unsafe {
universal_from_block(
null_mut,
null_const,
0,
1,
1,
1,
1,
1,
TensorDataType::F32,
BlockLayout::NHD,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
// block_from_universal
let status = unsafe {
block_from_universal(
null_const,
null_mut,
0,
1,
1,
1,
1,
1,
TensorDataType::F32,
BlockLayout::NHD,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
Ok(())
}
// ---------------------------------------------------------------------------
// CPU-only validation of make_blocks reference implementation
// ---------------------------------------------------------------------------
/// Verify `make_blocks` for NHD layout against first-principles index arithmetic.
/// Uses deterministic position-encoded values so each element maps to a unique expected value.
#[test]
fn make_blocks_reference_nhd() {
let nh = 3usize;
let nl = 2usize;
let no = 2usize;
let nt = 4usize;
let hd = 5usize;
let universal =
Array5::from_shape_fn((nh, nl, no, nt, hd), |(nh_i, nl_i, no_i, nt_i, hd_i)| {
((((nh_i * nl + nl_i) * no + no_i) * nt + nt_i) * hd + hd_i) as f32
});
let blocks = make_blocks(&universal, BlockLayout::NHD);
assert_eq!(blocks.len(), nl * no);
for nl_i in 0..nl {
for no_i in 0..no {
let block = &blocks[nl_i * no + no_i];
assert_eq!(block.len(), nt * nh * hd);
for nt_i in 0..nt {
for nh_i in 0..nh {
for hd_i in 0..hd {
// NHD block offset: [nt, nh, hd]
let offset = (nt_i * nh + nh_i) * hd + hd_i;
let expected =
((((nh_i * nl + nl_i) * no + no_i) * nt + nt_i) * hd + hd_i) as f32;
assert_eq!(
block[offset], expected,
"NHD mismatch at nl={nl_i} no={no_i} nt={nt_i} nh={nh_i} hd={hd_i}"
);
}
}
}
}
}
}
/// Verify `make_blocks` for HND layout against first-principles index arithmetic.
#[test]
fn make_blocks_reference_hnd() {
let nh = 3usize;
let nl = 2usize;
let no = 2usize;
let nt = 4usize;
let hd = 5usize;
let universal =
Array5::from_shape_fn((nh, nl, no, nt, hd), |(nh_i, nl_i, no_i, nt_i, hd_i)| {
((((nh_i * nl + nl_i) * no + no_i) * nt + nt_i) * hd + hd_i) as f32
});
let blocks = make_blocks(&universal, BlockLayout::HND);
assert_eq!(blocks.len(), nl * no);
for nl_i in 0..nl {
for no_i in 0..no {
let block = &blocks[nl_i * no + no_i];
assert_eq!(block.len(), nh * nt * hd);
for nh_i in 0..nh {
for nt_i in 0..nt {
for hd_i in 0..hd {
// HND block offset: [nh, nt, hd]
let offset = (nh_i * nt + nt_i) * hd + hd_i;
let expected =
((((nh_i * nl + nl_i) * no + no_i) * nt + nt_i) * hd + hd_i) as f32;
assert_eq!(
block[offset], expected,
"HND mismatch at nl={nl_i} no={no_i} nh={nh_i} nt={nt_i} hd={hd_i}"
);
}
}
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for `memcpy_batch` and the always-available query helpers
//! (`is_memcpy_batch_available`, `is_using_stubs`).
//!
//! These don't require `permute_kernels` — the functions are unconditionally
//! linked regardless of feature flags.
//!
//! Functional tests use pinned-host -> device -> pinned-host roundtrips (H2D + D2H)
//! to match the transfer patterns that `cudaMemcpyBatchAsync` is designed for.
#![cfg(all(feature = "testing-cuda", not(stub_kernels)))]
use std::ffi::c_void;
use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaSlice, CudaStream, DevicePtr, DriverError};
use cudarc::runtime::sys as cuda_runtime;
use kvbm_kernels::{MemcpyBatchMode, is_memcpy_batch_available, is_using_stubs, memcpy_batch};
// Direct FFI for cudaMallocHost / cudaFreeHost.
// We bypass cudarc's runtime::sys because cudarc eagerly resolves ALL runtime
// symbols on first use, and CUDA 13.x removed `cudaGetDeviceProperties_v2`
// which causes a panic. Our test binary links against libcudart directly
// (through kvbm-kernels' build.rs), so these symbols are always available.
unsafe extern "C" {
fn cudaMallocHost(ptr: *mut *mut c_void, size: usize) -> u32;
fn cudaFreeHost(ptr: *mut c_void) -> u32;
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn cuda_setup() -> Option<(Arc<CudaStream>, cuda_runtime::cudaStream_t)> {
let count = CudaContext::device_count().ok()?;
if count == 0 {
return None;
}
let ctx = CudaContext::new(0).ok()?;
// Use a non-default stream — cudaMemcpyBatchAsync does not accept the
// NULL (default) stream. A real CUstream from the driver API works fine.
let stream = ctx.new_stream().ok()?;
let raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
Some((stream, raw))
}
/// Allocate `len` zero bytes on device, return slice + raw device address.
fn alloc_device_zeroed(
stream: &Arc<CudaStream>,
len: usize,
) -> Result<(CudaSlice<u8>, usize), DriverError> {
let slice = stream.alloc_zeros::<u8>(len)?;
let addr = {
let (ptr, _guard) = slice.device_ptr(stream);
ptr as usize
};
Ok((slice, addr))
}
/// RAII wrapper around pinned host memory allocated with `cudaMallocHost`.
struct PinnedBuffer {
ptr: *mut c_void,
len: usize,
}
impl PinnedBuffer {
/// Allocate `len` bytes of pinned host memory, zeroed.
fn new_zeroed(len: usize) -> Self {
let mut ptr: *mut c_void = std::ptr::null_mut();
let err = unsafe { cudaMallocHost(&mut ptr, len) };
assert_eq!(err, 0, "cudaMallocHost failed with error {err}");
// Zero the buffer
unsafe { std::ptr::write_bytes(ptr as *mut u8, 0, len) };
Self { ptr, len }
}
/// Allocate pinned host memory and fill it from `data`.
fn from_data(data: &[u8]) -> Self {
let buf = Self::new_zeroed(data.len());
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), buf.ptr as *mut u8, data.len());
}
buf
}
fn as_ptr(&self) -> *mut c_void {
self.ptr
}
fn as_const_ptr(&self) -> *const c_void {
self.ptr as *const c_void
}
/// Read contents back as a `Vec<u8>`.
fn to_vec(&self) -> Vec<u8> {
let mut v = vec![0u8; self.len];
unsafe {
std::ptr::copy_nonoverlapping(self.ptr as *const u8, v.as_mut_ptr(), self.len);
}
v
}
}
impl Drop for PinnedBuffer {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
cudaFreeHost(self.ptr);
}
}
}
}
/// Run a pinned-host -> device -> pinned-host roundtrip via two `memcpy_batch` calls.
///
/// 1. Batch H2D: copy from `src_pinned` buffers to `device` buffers
/// 2. Batch D2H: copy from `device` buffers to `dst_pinned` (zeroed) buffers
/// 3. Assert `dst_pinned` contents match original data
///
/// Returns both batch statuses for the caller to inspect.
fn h2d_d2h_roundtrip(
stream: &Arc<CudaStream>,
raw: cuda_runtime::cudaStream_t,
data_sets: &[Vec<u8>],
copy_size: usize,
mode: MemcpyBatchMode,
) -> Result<(cuda_runtime::cudaError, cuda_runtime::cudaError), DriverError> {
let num_pairs = data_sets.len();
// Source: pinned host buffers filled with known data
let src_pinned: Vec<PinnedBuffer> = data_sets
.iter()
.map(|d| PinnedBuffer::from_data(d))
.collect();
// Device buffers (zeroed)
let mut dev_slices = Vec::with_capacity(num_pairs);
let mut dev_addrs = Vec::with_capacity(num_pairs);
for _ in 0..num_pairs {
let (s, a) = alloc_device_zeroed(stream, copy_size)?;
dev_slices.push(s);
dev_addrs.push(a);
}
// Destination: zeroed pinned host buffers
let dst_pinned: Vec<PinnedBuffer> = (0..num_pairs)
.map(|_| PinnedBuffer::new_zeroed(copy_size))
.collect();
// Build pointer arrays for H2D: src = pinned host, dst = device
let h2d_src_ptrs: Vec<*const c_void> = src_pinned.iter().map(|b| b.as_const_ptr()).collect();
let h2d_dst_ptrs: Vec<*mut c_void> = dev_addrs.iter().map(|&a| a as *mut c_void).collect();
let h2d_status = unsafe {
memcpy_batch(
h2d_src_ptrs.as_ptr() as *const *const c_void,
h2d_dst_ptrs.as_ptr() as *const *mut c_void,
copy_size,
num_pairs,
mode,
raw,
)
};
if h2d_status != cuda_runtime::cudaError::cudaSuccess {
return Ok((h2d_status, cuda_runtime::cudaError::cudaSuccess));
}
// Build pointer arrays for D2H: src = device, dst = pinned host
let d2h_src_ptrs: Vec<*const c_void> = dev_addrs.iter().map(|&a| a as *const c_void).collect();
let d2h_dst_ptrs: Vec<*mut c_void> = dst_pinned.iter().map(|b| b.as_ptr()).collect();
let d2h_status = unsafe {
memcpy_batch(
d2h_src_ptrs.as_ptr() as *const *const c_void,
d2h_dst_ptrs.as_ptr() as *const *mut c_void,
copy_size,
num_pairs,
mode,
raw,
)
};
if d2h_status != cuda_runtime::cudaError::cudaSuccess {
return Ok((h2d_status, d2h_status));
}
// Synchronize before reading back
stream.synchronize()?;
// Verify roundtrip: dst_pinned should match original data
for (i, (dst, expected)) in dst_pinned.iter().zip(data_sets.iter()).enumerate() {
let result = dst.to_vec();
assert_eq!(
result,
*expected,
"roundtrip mismatch at pair {i}: first differing byte at position {}",
result
.iter()
.zip(expected.iter())
.position(|(a, b)| a != b)
.unwrap_or(result.len())
);
}
Ok((h2d_status, d2h_status))
}
// ---------------------------------------------------------------------------
// Query function tests
// ---------------------------------------------------------------------------
#[test]
fn stubs_not_active() {
// Since the file is gated on not(stub_kernels), this must be false.
assert!(!is_using_stubs());
}
#[test]
fn availability_is_consistent() {
// Just ensure it doesn't crash and returns a stable value.
let a = is_memcpy_batch_available();
let b = is_memcpy_batch_available();
assert_eq!(a, b);
eprintln!(
"cudaMemcpyBatchAsync available: {} (CUDA {}12.9)",
a,
if a { ">=" } else { "<" }
);
}
// ---------------------------------------------------------------------------
// memcpy_batch edge cases (work regardless of CUDA version)
// ---------------------------------------------------------------------------
#[test]
fn memcpy_batch_zero_copies_noop() -> Result<(), DriverError> {
let (_stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
// All modes should treat zero copies as a no-op
for mode in [
MemcpyBatchMode::BatchedWithFallback,
MemcpyBatchMode::FallbackOnly,
MemcpyBatchMode::BatchWithoutFallback,
] {
let status = unsafe {
memcpy_batch(
std::ptr::null(),
std::ptr::null(),
128,
0, // num_copies = 0
mode,
raw,
)
};
assert_eq!(
status,
cuda_runtime::cudaError::cudaSuccess,
"mode={mode:?}"
);
}
Ok(())
}
#[test]
fn memcpy_batch_zero_size_noop() -> Result<(), DriverError> {
let (_stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
// All modes should treat zero size as a no-op
for mode in [
MemcpyBatchMode::BatchedWithFallback,
MemcpyBatchMode::FallbackOnly,
MemcpyBatchMode::BatchWithoutFallback,
] {
let status = unsafe {
memcpy_batch(
std::ptr::null(),
std::ptr::null(),
0, // size_per_copy = 0
5,
mode,
raw,
)
};
assert_eq!(
status,
cuda_runtime::cudaError::cudaSuccess,
"mode={mode:?}"
);
}
Ok(())
}
// ---------------------------------------------------------------------------
// memcpy_batch functional tests — H2D + D2H roundtrip with pinned memory
//
// Each test runs across all three MemcpyBatchMode variants:
// - BatchedWithFallback: always works (batch or fallback)
// - FallbackOnly: always works (individual cudaMemcpyAsync)
// - BatchWithoutFallback: only works when batch API is available (CUDA 12.9+)
// ---------------------------------------------------------------------------
/// Run a roundtrip test across all modes, handling BatchWithoutFallback gracefully
/// when the batch API is not available.
fn run_all_modes(
stream: &Arc<CudaStream>,
raw: cuda_runtime::cudaStream_t,
data_sets: &[Vec<u8>],
copy_size: usize,
) -> Result<(), DriverError> {
let batch_available = is_memcpy_batch_available();
for mode in [
MemcpyBatchMode::BatchedWithFallback,
MemcpyBatchMode::FallbackOnly,
MemcpyBatchMode::BatchWithoutFallback,
] {
let (h2d, d2h) = h2d_d2h_roundtrip(stream, raw, data_sets, copy_size, mode)?;
if mode == MemcpyBatchMode::BatchWithoutFallback && !batch_available {
// Expected to fail when batch API is not available
eprintln!(" {mode:?}: batch API unavailable, got h2d={h2d:?} (expected non-success)");
continue;
}
assert_eq!(
h2d,
cuda_runtime::cudaError::cudaSuccess,
"H2D failed with mode={mode:?}"
);
assert_eq!(
d2h,
cuda_runtime::cudaError::cudaSuccess,
"D2H failed with mode={mode:?}"
);
eprintln!(" {mode:?}: OK");
}
Ok(())
}
/// Single H2D + D2H roundtrip via `memcpy_batch` (all modes).
#[test]
fn memcpy_batch_single_copy() -> Result<(), DriverError> {
let (stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let copy_size = 256;
let data: Vec<u8> = (0..copy_size as u16).map(|i| (i % 256) as u8).collect();
run_all_modes(&stream, raw, &[data], copy_size)
}
/// Multiple independent H2D + D2H roundtrips in one batch call (all modes).
#[test]
fn memcpy_batch_multiple_copies() -> Result<(), DriverError> {
let (stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let num_pairs = 8;
let copy_size = 512;
let data_sets: Vec<Vec<u8>> = (0..num_pairs)
.map(|i| {
(0..copy_size)
.map(|j| ((i * 31 + j * 7) % 256) as u8)
.collect()
})
.collect();
run_all_modes(&stream, raw, &data_sets, copy_size)
}
/// Large copy (1 MiB per pair) to exercise alignment paths (all modes).
#[test]
fn memcpy_batch_large_copy() -> Result<(), DriverError> {
let (stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let copy_size = 1 << 20; // 1 MiB
let num_pairs = 3;
let data_sets: Vec<Vec<u8>> = (0..num_pairs)
.map(|i| (0..copy_size).map(|j| ((i + j) % 251) as u8).collect())
.collect();
run_all_modes(&stream, raw, &data_sets, copy_size)
}
/// Non-power-of-two copy size (regression guard for alignment assumptions, all modes).
#[test]
fn memcpy_batch_odd_size() -> Result<(), DriverError> {
let (stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let copy_size = 999; // not aligned to anything useful
let num_pairs = 4;
let data_sets: Vec<Vec<u8>> = (0..num_pairs)
.map(|i| (0..copy_size).map(|j| ((i * 13 + j) % 256) as u8).collect())
.collect();
run_all_modes(&stream, raw, &data_sets, copy_size)
}
/// Many small pairs to stress the batch dispatch path (all modes).
#[test]
fn memcpy_batch_many_pairs() -> Result<(), DriverError> {
let (stream, raw) = match cuda_setup() {
Some(s) => s,
None => return Ok(()),
};
let num_pairs = 256;
let copy_size = 64;
let data_sets: Vec<Vec<u8>> = (0..num_pairs)
.map(|i| (0..copy_size).map(|j| ((i + j) % 256) as u8).collect())
.collect();
run_all_modes(&stream, raw, &data_sets, copy_size)
}
// ---------------------------------------------------------------------------
// Diagnostic: mirrors NVIDIA benchmark calling pattern exactly
// ---------------------------------------------------------------------------
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests that only compile when stub kernels are in use (no CUDA available).
//!
//! Complementary to the `memcpy_batch::stubs_not_active` test which asserts
//! the opposite under `not(stub_kernels)`.
#![cfg(stub_kernels)]
use kvbm_kernels::{is_memcpy_batch_available, is_using_stubs};
#[test]
fn stubs_active() {
assert!(
is_using_stubs(),
"expected is_using_stubs() == true under stub build"
);
}
#[test]
fn memcpy_batch_unavailable_under_stubs() {
assert!(
!is_memcpy_batch_available(),
"expected is_memcpy_batch_available() == false under stub build"
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment