Unverified Commit 09bb1c68 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm kernels (#4356)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent adbe133c
......@@ -42,7 +42,7 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.CI_TOKEN }}
run: |
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target dev --framework none
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target dev --framework none --enable-kvbm
- name: Start services with docker-compose
working-directory: ./deploy
run: |
......
......@@ -92,7 +92,7 @@ jobs:
runs-on:
group: Fastchecker
strategy:
matrix: { dir: ['.', 'lib/bindings/python', 'lib/runtime/examples', 'launch/dynamo-run'] }
matrix: { dir: ['.', 'lib/bindings/python', 'lib/runtime/examples', 'launch/dynamo-run', 'lib/kvbm-kernels'] }
permissions:
contents: read
steps:
......@@ -117,7 +117,10 @@ jobs:
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: ${{ runner.os }}-cargo-
- name: Set up Rust Toolchain Components
run: rustup component add clippy
run: rustup component add clippy rustfmt
- name: Verify Code Formatting
working-directory: ${{ matrix.dir }}
run: cargo fmt -- --check
- name: Run Clippy Checks
working-directory: ${{ matrix.dir }}
run: cargo clippy --no-deps --all-targets -- -D warnings
......
......@@ -2063,7 +2063,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -2208,6 +2208,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "dynamo-kvbm-kernels"
version = "0.7.0"
dependencies = [
"cc",
"cudarc 0.17.8",
"dynamo-config",
"md5",
"once_cell",
]
[[package]]
name = "dynamo-llm"
version = "0.7.0"
......@@ -2630,7 +2641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4184,7 +4195,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4265,7 +4276,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4899,6 +4910,12 @@ dependencies = [
"rayon",
]
[[package]]
name = "md5"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0"
[[package]]
name = "memchr"
version = "2.7.6"
......@@ -5621,7 +5638,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -6750,7 +6767,7 @@ dependencies = [
"once_cell",
"socket2 0.6.1",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
......@@ -7431,7 +7448,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -8700,7 +8717,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -10253,7 +10270,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.61.2",
]
[[package]]
......
......@@ -15,6 +15,7 @@ members = [
"lib/bindings/python/codegen",
"lib/engines/*",
"lib/config",
"lib/kvbm-kernels",
]
# Exclude certain packages that are slow to build and we don't ship as flagship
# features from default build, but keep them in workspace for convenience.
......
......@@ -1145,15 +1145,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "cudarc"
version = "0.16.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17200eb07e7d85a243aa1bf4569a7aa998385ba98d14833973a817a63cc86e92"
dependencies = [
"libloading",
]
[[package]]
name = "cudarc"
version = "0.17.8"
......@@ -1473,7 +1464,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -1580,6 +1571,24 @@ dependencies = [
"uuid",
]
[[package]]
name = "dynamo-config"
version = "0.7.0"
dependencies = [
"anyhow",
]
[[package]]
name = "dynamo-kvbm-kernels"
version = "0.7.0"
dependencies = [
"cc",
"cudarc",
"dynamo-config",
"md5",
"once_cell",
]
[[package]]
name = "dynamo-llm"
version = "0.7.0"
......@@ -1604,7 +1613,7 @@ dependencies = [
"bytes",
"candle-core",
"chrono",
"cudarc 0.17.8",
"cudarc",
"dashmap 5.5.3",
"derive-getters",
"derive_builder",
......@@ -1890,7 +1899,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -2825,7 +2834,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.5.10",
"socket2 0.6.1",
"system-configuration",
"tokio",
"tower-service",
......@@ -3215,7 +3224,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -3488,9 +3497,10 @@ dependencies = [
"anyhow",
"async-stream",
"async-trait",
"cudarc 0.16.6",
"cudarc",
"derive-getters",
"dlpark",
"dynamo-kvbm-kernels",
"dynamo-llm",
"dynamo-runtime",
"either",
......@@ -3748,6 +3758,12 @@ dependencies = [
"rayon",
]
[[package]]
name = "md5"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0"
[[package]]
name = "memchr"
version = "2.7.6"
......@@ -4170,7 +4186,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -5322,7 +5338,7 @@ dependencies = [
"quinn-udp",
"rustc-hash 2.1.1",
"rustls",
"socket2 0.5.10",
"socket2 0.6.1",
"thiserror 2.0.17",
"tokio",
"tracing",
......@@ -5359,9 +5375,9 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.5.10",
"socket2 0.6.1",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.60.2",
]
[[package]]
......@@ -5864,7 +5880,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -6660,7 +6676,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -7840,7 +7856,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.61.2",
]
[[package]]
......
......@@ -21,11 +21,12 @@ crate-type = ["cdylib", "rlib"]
[features]
default = ["block-manager"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc", "dep:kvbm_kernels"]
[dependencies]
dynamo-llm = { path = "../../llm" }
dynamo-runtime = { path = "../../runtime" }
kvbm_kernels = { path = "../../kvbm-kernels", package = "dynamo-kvbm-kernels", optional = true }
anyhow = { version = "1" }
async-stream = { version = "0.3" }
......@@ -68,7 +69,7 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features =
pythonize = "0.23"
dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
cudarc = { version = "0.17.1", features = ["cuda-12020"], optional = true }
prometheus = "0.14.0"
[dev-dependencies]
......
......@@ -26,7 +26,8 @@ license = { text = "Apache-2.0" }
license-files = ["LICENSE"]
requires-python = ">=3.10"
dependencies = [
"nixl==0.7.0"
"nixl==0.7.0",
"pydantic>=2.0",
]
classifiers = [
"Development Status :: 4 - Beta",
......@@ -43,6 +44,18 @@ classifiers = [
]
keywords = ["llm", "genai", "inference", "nvidia", "kvcache", "dynamo"]
[project.optional-dependencies]
test = [
"pytest>=8.3.4",
"pytest-mypy",
"pytest-asyncio",
]
dev = [
"kvbm[test]",
"maturin>=1.0,<2.0",
"patchelf",
]
[tool.maturin]
module-name = "kvbm._core"
manifest-path = "Cargo.toml"
......
......@@ -6,3 +6,4 @@
from kvbm._core import BlockManager as BlockManager
from kvbm._core import KvbmLeader as KvbmLeader
from kvbm._core import KvbmWorker as KvbmWorker
from kvbm._core import kernels as kernels
This diff is collapsed.
......@@ -13,7 +13,10 @@ use dynamo_runtime::{self as rs, RuntimeConfig, logging, traits::DistributedRunt
use dynamo_llm::{self as llm_rs};
#[cfg(feature = "block-manager")]
mod block_manager;
#[cfg(feature = "block-manager")]
mod kernels;
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
......@@ -39,6 +42,13 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
#[cfg(feature = "block-manager")]
block_manager::add_to_module(m)?;
#[cfg(feature = "block-manager")]
{
let kernels = PyModule::new(m.py(), "kernels")?;
kernels::add_to_module(&kernels)?;
m.add_submodule(&kernels)?;
}
Ok(())
}
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
PyTorch-based regression tests for the CUDA tensor packing kernels.
The goal is to mirror how an ML engineer would use the library, so the tests
act as both verification and documentation.
"""
from typing import List
import pytest
import torch
from kvbm import kernels as ctk
def _tolerances(dtype: torch.dtype) -> tuple[float, float]:
"""
Relax tolerances for low-precision dtypes.
fp16/bf16 round differently from fp32/fp64. Using dtype-aware tolerances
avoids spurious failures while still guarding against layout mistakes.
"""
if dtype in (torch.float16, torch.bfloat16):
return 1e-2, 1e-2
return 1e-5, 1e-5
def _make_blocks(universal: torch.Tensor, layout: str) -> List[torch.Tensor]:
"""
Reference implementation for turning a universal tensor into its block stack.
`layout` controls the per-chunk permutation:
- "NHD": expect `[nh, nl, no, nt, hd] -> [nt, nh, hd]`
- "HND": expect `[nh, nl, no, nt, hd] -> [nh, nt, hd]`
"""
nh, nl, no, nt, hd = universal.shape
blocks = []
for layer in range(nl):
for outer in range(no):
slice_ = universal[:, layer, outer, :, :].contiguous()
if layout.upper() == "NHD":
block = slice_.permute(1, 0, 2).contiguous()
elif layout.upper() == "HND":
block = slice_.contiguous()
else:
raise ValueError(f"Unsupported layout {layout}")
blocks.append(block.clone())
return blocks
def _call_with_backend(func, backend: str, *args):
"""
Helper to invoke a binding with a backend override, translating
unsupported backends into pytest skips instead of hard failures.
"""
try:
if backend is None:
func(*args)
else:
func(*args, backend=backend)
except RuntimeError as err:
if "cudaErrorNotSupported" in str(err):
pytest.skip(f"{backend} backend not supported on this runtime")
raise
@pytest.mark.parametrize("layout", ["NHD", "HND"])
@pytest.mark.parametrize(
"dtype",
[torch.float16, torch.bfloat16, torch.float32, torch.float64],
)
def test_block_universal_roundtrip(layout: str, dtype: torch.dtype) -> None:
"""
Launch `nb` block stacks through block⇄universal kernels and compare
against pure-PyTorch permutations.
Shapes:
- universals: `[nb][nh, nl, no, nt, hd]`
- blocks: `[nb][nl * no][nt, nh, hd]` (or `[nh, nt, hd]` for HND)
"""
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests")
device = torch.device("cuda:0")
torch.manual_seed(0)
nh, nl, no, nt, hd = 3, 2, 2, 4, 5
nb = 3
universals = [
torch.randn(nh, nl, no, nt, hd, device=device, dtype=dtype) for _ in range(nb)
]
# Prepare block stacks by permuting each universal tensor with PyTorch ops.
blocks = [_make_blocks(t, layout) for t in universals]
outputs = [torch.empty_like(t) for t in universals]
# Convert block stacks -> universal using the CUDA kernels.
ctk.block_to_universal(blocks, outputs, layout)
torch.cuda.synchronize()
atol, rtol = _tolerances(dtype)
for produced, expected in zip(outputs, universals):
assert torch.allclose(produced, expected, atol=atol, rtol=rtol)
# Zero the inputs and run the reverse direction.
for block_set in blocks:
for block in block_set:
block.zero_()
ctk.universal_to_block(universals, blocks, layout)
torch.cuda.synchronize()
expected_blocks = [_make_blocks(t, layout) for t in universals]
for produced_set, expected_set in zip(blocks, expected_blocks):
for produced, expected in zip(produced_set, expected_set):
assert torch.allclose(produced, expected, atol=atol, rtol=rtol)
@pytest.mark.parametrize(
"dtype",
[torch.float16, torch.bfloat16, torch.float32, torch.float64],
)
def test_operational_roundtrip(dtype: torch.dtype) -> None:
"""
Validate the block⇄operational fusion path.
Operational layout flattens `[nt, nh, hd]` into a single `inner` dimension.
This is useful when `nh` does not need to vary between participants.
"""
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests")
device = torch.device("cuda:0")
torch.manual_seed(1)
nh, nl, no, nt, hd = 2, 3, 2, 4, 3
nb = 2
universals = [
torch.randn(nh, nl, no, nt, hd, device=device, dtype=dtype) for _ in range(nb)
]
reference_blocks = [_make_blocks(t, "NHD") for t in universals]
blocks = [[b.clone() for b in block_set] for block_set in reference_blocks]
inner = nt * nh * hd
operationals = [
torch.empty(nl, no, inner, device=device, dtype=dtype) for _ in range(nb)
]
# Pack block stacks -> operational.
ctk.block_to_operational(blocks, operationals)
torch.cuda.synchronize()
atol, rtol = _tolerances(dtype)
for operational, ref_blocks in zip(operationals, reference_blocks):
expected_operational = torch.stack(
[b.reshape(-1) for b in ref_blocks], dim=0
).view(nl, no, inner)
assert torch.allclose(operational, expected_operational, atol=atol, rtol=rtol)
# Zero and unpack back into block stacks.
for block_set in blocks:
for block in block_set:
block.zero_()
ctk.operational_to_block(operationals, blocks)
torch.cuda.synchronize()
for produced_set, expected_set in zip(blocks, reference_blocks):
for produced, expected in zip(produced_set, expected_set):
assert torch.allclose(produced, expected, atol=atol, rtol=rtol)
@pytest.mark.parametrize("backend", [None, "auto", "kernel", "async", "batch"])
def test_operational_backends(backend):
"""
Exercise every backend override. When a backend is unavailable (e.g. batch
on older runtimes) we skip instead of failing.
"""
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests")
device = torch.device("cuda:0")
nh, nl, no, nt, hd = 2, 1, 2, 3, 4
nb = 1
dtype = torch.float32
universals = [
torch.randn(nh, nl, no, nt, hd, device=device, dtype=dtype) for _ in range(nb)
]
blocks = [_make_blocks(t, "NHD") for t in universals]
operationals = [
torch.empty(nl, no, nt * nh * hd, device=device, dtype=dtype) for _ in range(nb)
]
_call_with_backend(ctk.block_to_operational, backend, blocks, operationals)
torch.cuda.synchronize()
for block in blocks[0]:
block.zero_()
_call_with_backend(ctk.operational_to_block, backend, operationals, blocks)
torch.cuda.synchronize()
reference = _make_blocks(universals[0], "NHD")
assert torch.allclose(blocks[0][0], reference[0], atol=1e-5, rtol=1e-5)
def test_universal_shape_mismatch():
"""
Blocks with the wrong inner shape should trigger a ValueError.
"""
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests")
device = torch.device("cuda:0")
dtype = torch.float32
universal = torch.randn(2, 2, 1, 2, 4, device=device, dtype=dtype)
bad_block = torch.randn(2, 3, 4, device=device, dtype=dtype) # wrong nt
with pytest.raises(ValueError):
ctk.block_to_universal([[bad_block]], [torch.empty_like(universal)], "NHD")
def test_dtype_mismatch_error():
"""
Mixed dtypes in a batch should raise rather than silently convert.
"""
if not torch.cuda.is_available():
pytest.skip("CUDA required for these tests")
device = torch.device("cuda:0")
universal_f16 = torch.randn(1, 1, 1, 2, 4, device=device, dtype=torch.float16)
universal_f32 = torch.randn(1, 1, 1, 2, 4, device=device, dtype=torch.float32)
blocks = [_make_blocks(universal_f16, "NHD"), _make_blocks(universal_f32, "NHD")]
with pytest.raises(TypeError):
ctk.block_to_universal(blocks, [universal_f16, universal_f32], "NHD")
def test_non_cuda_tensor_error():
"""
CPU tensors should be rejected up-front with a helpful message.
"""
device = torch.device("cpu")
universal = torch.randn(1, 1, 1, 2, 4, device=device)
blocks = _make_blocks(universal.cuda(), "NHD")
with pytest.raises(ValueError):
ctk.block_to_universal([blocks], [universal], "NHD")
def test_empty_batch_noop():
"""
An empty batch should succeed without touching CUDA.
"""
assert ctk.block_to_universal([], [], "NHD") is None
assert ctk.universal_to_block([], [], "NHD") is None
assert ctk.block_to_operational([], [], None) is None
assert ctk.operational_to_block([], [], None) is None
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-kvbm-kernels"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
build = "build.rs"
[lib]
name = "kvbm_kernels"
crate-type = ["rlib", "cdylib"]
[features]
default = ["testing-cuda"]
testing-cuda = []
prebuilt-kernels = []
# python-bindings = ["pyo3"]
[dependencies]
cudarc = { workspace = true }
once_cell = "1.19"
# pyo3 = { version = "0.26", optional = true, features = ["extension-module"] }
[build-dependencies]
dynamo-config = { workspace = true }
cc = "1.0"
md5 = "0.8.0"
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Use CUDA 12.9.0 development image as base
FROM nvidia/cuda:12.9.0-devel-ubuntu22.04
# Set environment variables
ENV RUSTUP_HOME=/usr/local/rustup \
CARGO_HOME=/usr/local/cargo \
PATH=/usr/local/cargo/bin:$PATH \
CUDA_PATH=/usr/local/cuda
# Install system dependencies
RUN apt-get update && apt-get install -y \
curl \
build-essential \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
# Set working directory
WORKDIR /app
# Copy project files
COPY Cargo.toml ./
COPY src ./src
COPY build.rs ./build.rs
COPY cuda ./cuda
# Build the project
RUN cargo build --release
# Set the entrypoint to run the binary
CMD ["cargo", "test"]
## Dynamo KV Block Manager Kernels
This workspace houses CUDA + Rust + Python tooling for shuttling attention
blocks between three commonly used layouts:
1. **Stacked NHD / HND blocks**`nl * no` tensors per block, each shaped
`[nt, nh, hd]` (NHD) or `[nh, nt, hd]` (HND).
- primarily used by vLLM
2. **Operational blocks** – flattened buffers shaped `[nl, no, inner]`,
where `inner = nt * nh * hd`.
- primarily used by TensorRT LLM
- used by Dynamo's KVBM for non-device storage when no adjustments to
the layout is need to translate to/from different TP world sizes
3. **Universal blocks** – contiguous buffers shaped `[nh, nl, no, nt, hd]`.
- move the head dimension to the front
- excellent format for storage blocks that can be used by different tp
world sizes by scattering/gathering on slices of the leading dimension
allowing for large contiguous transfers.
All kernels are batch aware: a single launch can process `nb` blocks by
walking flattened pointer tables that the host code prepares ahead of time.
Bindings are provided for both Rust and PyTorch so you can slot the kernels
into existing pipelines without living in CUDA all day.
---
### Layout Cheat Sheet
| Term | Logical Shape | Stored As | Notes |
|---------------------|----------------------------|------------------------------------|-------------------------------|
| NHD block stack | `[nl][no][nt, nh, hd]` | list of `nl * no` pointers | Inner layout = NHD |
| HND block stack | `[nl][no][nh, nt, hd]` | list of `nl * no` pointers | Inner layout = HND |
| Operational block | `[nl, no, inner]` | contiguous buffer per block | `inner = nt * nh * hd` |
| Universal block | `[nh, nl, no, nt, hd]` | contiguous buffer per block | Ideal when all dims are fixed |
> **Pointer prep**
> For each logical block you provide:
> - one universal pointer,
> - `nl * no` pointers for either NHD or HND chunks, and
> - one operational pointer (when needed).
---
### Repository Structure
```
.
├── Cargo.toml # Rust lib/bin targets
├── build.rs # NVCC build script (sm80+sm90 by default)
├── cuda/
│ └── tensor_kernels.cu # Batched CUDA kernels + memcpy fallback
├── src/
│ ├── lib.rs # Rust facade for the kernels
│ ├── main.rs # Legacy cudaMemcpyBatchAsync demo (bin)
│ └── tensor_kernels.rs # FFI wrappers + integration tests
└── run.sh / Dockerfile # Optional CUDA 12.9 container harness
```
> **Note:** Python bindings (`python.rs`) and tests have been moved to
> `lib/bindings/kvbm/` as part of the integrated `kvbm` wheel.
---
### Building the CUDA Library
The CUDA code is compiled via `nvcc` in `build.rs`. Supported architectures
default to `sm_80` (Ampere) and `sm_90` (Hopper). Override with `CUDA_ARCHS`
for broader compatibility:
```bash
# Default build (sm_80, sm_90)
cargo build
# Broader compatibility across GPU generations
CUDA_ARCHS="80,86,89,90,100" cargo build
# Common architectures:
# 80 = Ampere (A100)
# 86 = Ampere (RTX 30xx)
# 89 = Ada Lovelace (RTX 40xx, L4, L40)
# 90 = Hopper (H100, H200)
# 100 = Blackwell (B100, B200, GB200)
```
> **Prerequisites**
> - CUDA 12.1+ toolkit on PATH
> - `nvcc` and compatible driver
> - Rust stable (1.70+) with `cargo`
For rapid iteration without the Python bindings:
```bash
cargo check
cargo test fused_copy_roundtrip -- --nocapture
```
The unit test synthesizes two blocks on-device, exercises every conversion
path (block ⇄ universal ⇄ operational), and asserts lossless round-trips.
---
### Python Bindings & Tests
> **Note:** The Python bindings and tests have been migrated to the `kvbm` wheel
> at `lib/bindings/kvbm/`. Install and test using that package instead.
#### Install locally
```bash
cd lib/bindings/kvbm
uv pip install -e ".[dev]"
```
This installs the `kvbm` package with all development dependencies including
the CUDA tensor kernels, pytest, and build tools.
#### Validate against PyTorch baselines
```bash
cd lib/bindings/kvbm
pytest tests/
```
Each test synthesizes random CUDA tensors, permutes them using native PyTorch
ops, then compares the kernel output with tolerances tuned per dtype.
#### Python API Sketch
```python
import torch
from kvbm import kernels
blocks = [...] # list[list[torch.Tensor]] sized nb x (nl*no)
universals = [...] # list[torch.Tensor] sized nb
operationals = [...] # list[torch.Tensor] sized nb
kernels.block_to_universal(blocks, universals, layout="NHD")
kernels.universal_to_block(universals, blocks, layout="NHD")
kernels.block_to_operational(blocks, operationals, backend="batch") # or "async" / "kernel" / "auto"
kernels.operational_to_block(operationals, blocks, backend="auto")
```
All tensors must be CUDA accessible by the specificed device and match the expected
shapes and be contiguous in those shapes. The bindings validate shapes/dtypes, stage
pointer tables on-device, and launch the appropriate CUDA kernel.
---
### Docker Workflow (Optional)
Need a reproducible environment? The repo includes a CUDA 12.9 container that
installs Rust and builds the project.
```bash
# Build and run the demo binary inside the container
./run.sh
# Or build manually
# Or build manually
docker build -t kvbm-kernels .
docker run --rm --gpus all kvbm-kernels
```
To develop interactively with Python, extend the Dockerfile with your preferred
Python distribution and PyTorch wheel.
---
### Troubleshooting
| Symptom | Likely Cause / Fix |
|---------------------------------------|--------------------------------------------------------------------|
| `cudaErrorInvalidValue` on launch | Pointer counts mismatch (`nb`, `nl`, `no`) or non-contiguous input |
| Wrong values when using HND layout | Inner tensors not permuted to `[nh, nt, hd]` before passing in |
| Python bindings complain about dtype | Mixed precision in a batch; convert tensors to a common dtype |
| Kernels take unexpected time | Verify that `CUDA_ARCHS` matches your GPU to avoid JIT at runtime |
- `backend="auto"` defaults to the fused kernel, then `cudaMemcpyBatchAsync`, then `cudaMemcpyAsync`. Override if you want to benchmark a specific path.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::env;
use std::fs;
use std::io::Read;
use std::path::Path;
use std::process::Command;
fn main() {
println!("cargo:rerun-if-changed=cuda/tensor_kernels.cu");
println!("cargo:rerun-if-env-changed=DYNAMO_USE_PREBUILT_KERNELS");
println!("cargo:rerun-if-env-changed=CUDA_ARCHS");
let use_prebuilt = determine_build_mode();
if use_prebuilt {
build_with_prebuilt_kernels();
} else {
build_from_source();
// Only link against CUDA runtime when building from source
// Add CUDA library search paths
if let Ok(cuda_path) = env::var("CUDA_PATH") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
println!("cargo:rustc-link-search=native={}/lib", cuda_path);
} else if let Ok(cuda_home) = env::var("CUDA_HOME") {
println!("cargo:rustc-link-search=native={}/lib64", cuda_home);
println!("cargo:rustc-link-search=native={}/lib", cuda_home);
} else {
// Try standard paths
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib");
}
println!("cargo:rustc-link-lib=cudart");
}
}
/// Determine whether to use prebuilt kernels based on:
/// 1. Feature flag (highest precedence)
/// 2. Environment variable
/// 3. Auto-detection of nvcc
fn determine_build_mode() -> bool {
// Check feature flag first
#[cfg(feature = "prebuilt-kernels")]
{
println!("cargo:warning=Using prebuilt kernels (feature flag enabled)");
return true;
}
// Check environment variable
if dynamo_config::env_is_truthy("DYNAMO_USE_PREBUILT_KERNELS") {
println!("cargo:warning=Using prebuilt kernels (DYNAMO_USE_PREBUILT_KERNELS set)");
return true;
}
// Auto-detect nvcc
if !is_nvcc_available() {
println!("cargo:warning=nvcc not found, using prebuilt kernels");
return true;
}
println!("cargo:warning=Building CUDA kernels from source");
false
}
fn is_nvcc_available() -> bool {
Command::new("nvcc").arg("--version").output().is_ok()
}
fn build_with_prebuilt_kernels() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let cu_path = Path::new(&manifest_dir).join("cuda/tensor_kernels.cu");
let md5_path = Path::new(&manifest_dir).join("cuda/prebuilt/tensor_kernels.md5");
let fatbin_path = Path::new(&manifest_dir).join("cuda/prebuilt/tensor_kernels.fatbin");
// Validate that prebuilt files exist
if !md5_path.exists() {
panic!(
"Prebuilt mode requires cuda/prebuilt/tensor_kernels.md5 but it does not exist. \
Please build with nvcc available first to generate the prebuilt artifacts."
);
}
if !fatbin_path.exists() {
panic!(
"Prebuilt mode requires cuda/prebuilt/tensor_kernels.fatbin but it does not exist. \
Please build with nvcc available first to generate the prebuilt artifacts."
);
}
// Read stored hashes (three lines: build.rs, .cu, .fatbin)
let stored_hashes_content =
fs::read_to_string(&md5_path).expect("Failed to read cuda/prebuilt/tensor_kernels.md5");
let stored_hashes: Vec<&str> = stored_hashes_content.lines().collect();
if stored_hashes.len() != 3 {
panic!(
"Invalid .md5 file format. Expected 3 lines (build.rs, .cu, .fatbin hashes), found {}.\n\
Please rebuild with nvcc available to regenerate the prebuilt artifacts.",
stored_hashes.len()
);
}
let stored_build_rs_hash = stored_hashes[0];
let stored_cu_hash = stored_hashes[1];
let stored_fatbin_hash = stored_hashes[2];
// Compute current hashes
let build_rs_path = Path::new(&manifest_dir).join("build.rs");
let current_build_rs_hash = compute_file_hash(&build_rs_path);
let current_cu_hash = compute_file_hash(&cu_path);
let current_fatbin_hash = compute_file_hash(&fatbin_path);
// Validate all three hashes
let mut mismatches = Vec::new();
if current_build_rs_hash != stored_build_rs_hash {
mismatches.push(format!(
" build.rs: current={}, stored={}",
current_build_rs_hash, stored_build_rs_hash
));
}
if current_cu_hash != stored_cu_hash {
mismatches.push(format!(
" .cu source: current={}, stored={}",
current_cu_hash, stored_cu_hash
));
}
if current_fatbin_hash != stored_fatbin_hash {
mismatches.push(format!(
" .fatbin: current={}, stored={}",
current_fatbin_hash, stored_fatbin_hash
));
}
if !mismatches.is_empty() {
panic!(
"Hash mismatch! The prebuilt .fatbin is out of sync:\n{}\n\
Please rebuild with nvcc available to regenerate the prebuilt artifacts.",
mismatches.join("\n")
);
}
println!("cargo:warning=Hash validation passed:");
println!("cargo:warning= build.rs: {}", current_build_rs_hash);
println!("cargo:warning= .cu source: {}", current_cu_hash);
println!("cargo:warning= .fatbin: {}", current_fatbin_hash);
// Link the prebuilt fatbin
// Note: We need to inform the linker about the fatbin file.
// The typical approach is to use cc to link it as an object file or
// use CUDA's fatbinary tool. For simplicity, we'll use cc to link it.
let out_dir = env::var("OUT_DIR").unwrap();
let fatbin_copy = Path::new(&out_dir).join("tensor_kernels.fatbin");
fs::copy(&fatbin_path, &fatbin_copy).expect("Failed to copy .fatbin to OUT_DIR");
// Link the fatbin as a dependency
println!("cargo:rustc-link-search=native={}", out_dir);
// Create a stub object file that references the fatbin
// This is a workaround since we can't directly link .fatbin files
// In a real scenario, you'd use cuModuleLoadFatBinary at runtime
println!(
"cargo:warning=Prebuilt kernel loaded from {}",
fatbin_path.display()
);
}
fn build_from_source() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let cu_path = Path::new(&manifest_dir).join("cuda/tensor_kernels.cu");
let out_dir = env::var("OUT_DIR").unwrap();
// Build with cc crate
let mut build = cc::Build::new();
build
.cuda(true)
.file(&cu_path)
.flag("-std=c++17")
.flag("-O3")
.flag("-Xcompiler")
.flag("-fPIC");
// Configure CUDA architectures
let arch_flags = get_cuda_arch_flags();
for flag in &arch_flags {
build.flag(flag);
}
build.compile("tensor_kernels");
// Generate .fatbin and .md5 for future prebuilt use
generate_prebuilt_artifacts(&cu_path, &arch_flags, &out_dir);
}
fn get_cuda_arch_flags() -> Vec<String> {
let mut flags = Vec::new();
if let Ok(arch_list) = env::var("CUDA_ARCHS") {
for arch in arch_list.split(',') {
let arch = arch.trim();
if arch.is_empty() {
continue;
}
flags.push(format!("-gencode=arch=compute_{},code=sm_{}", arch, arch));
}
} else {
// Default to Ampere (SM 80) and Hopper (SM 90) support.
flags.push("-gencode=arch=compute_80,code=sm_80".to_string());
flags.push("-gencode=arch=compute_90,code=sm_90".to_string());
}
flags
}
fn generate_prebuilt_artifacts(cu_path: &Path, arch_flags: &[String], out_dir: &str) {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let prebuilt_dir = Path::new(&manifest_dir).join("cuda/prebuilt");
let fatbin_path = prebuilt_dir.join("tensor_kernels.fatbin");
let md5_path = prebuilt_dir.join("tensor_kernels.md5");
// Ensure prebuilt directory exists
fs::create_dir_all(&prebuilt_dir).expect("Failed to create cuda/prebuilt directory");
// Generate .fatbin using nvcc
let temp_fatbin = Path::new(out_dir).join("tensor_kernels.fatbin");
let mut nvcc_cmd = Command::new("nvcc");
nvcc_cmd
.arg("-fatbin")
.arg("-std=c++17")
.arg("-O3")
.arg(cu_path)
.arg("-o")
.arg(&temp_fatbin);
for flag in arch_flags {
nvcc_cmd.arg(flag);
}
println!("cargo:warning=Generating .fatbin with nvcc...");
let status = nvcc_cmd
.status()
.expect("Failed to execute nvcc for .fatbin generation");
if !status.success() {
panic!("nvcc failed to generate .fatbin");
}
// Copy .fatbin to prebuilt directory
fs::copy(&temp_fatbin, &fatbin_path).expect("Failed to copy .fatbin to cuda/prebuilt/");
// Generate MD5 hashes of all three files for consistency validation
let build_rs_path = Path::new(&manifest_dir).join("build.rs");
let build_rs_hash = compute_file_hash(&build_rs_path);
let cu_hash = compute_file_hash(cu_path);
let fatbin_hash = compute_file_hash(&fatbin_path);
// Write all three hashes (one per line)
let hashes = format!("{}\n{}\n{}\n", build_rs_hash, cu_hash, fatbin_hash);
fs::write(&md5_path, hashes).expect("Failed to write .md5 file");
println!(
"cargo:warning=Generated prebuilt artifacts:\n {}\n {}",
fatbin_path.display(),
md5_path.display()
);
println!("cargo:warning=build.rs hash: {}", build_rs_hash);
println!("cargo:warning=.cu source hash: {}", cu_hash);
println!("cargo:warning=.fatbin hash: {}", fatbin_hash);
}
fn compute_file_hash(path: &Path) -> String {
let mut file = fs::File::open(path)
.unwrap_or_else(|e| panic!("Failed to open {} for hashing: {}", path.display(), e));
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)
.unwrap_or_else(|e| panic!("Failed to read {} for hashing: {}", path.display(), e));
format!("{:x}", md5::compute(&buffer))
}
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
7aac008ed704fe198ef5056a4e502069
a7d1649c148ee0366de6d19896a80116
4234cfd2ef4b283592a37a9ceed94666
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <algorithm>
#include <cstdint>
#include <type_traits>
#include <vector>
#ifndef CUDA_CALLABLE_MEMBER
#define CUDA_CALLABLE_MEMBER __host__ __device__
#endif
namespace {
/**
* There are three logical tensor views involved in these kernels:
*
* 1. Universal blocks: contiguous buffers whose logical shape is
* [nh, nl, no, nt, hd]. Every “block” is a separate pointer.
* 2. NHD/HND block stacks: `nl * no` pointers per block, each pointing
* to a chunk shaped either [nt, nh, hd] (NHD) or [nh, nt, hd] (HND).
* Stacks are arranged as `[layer][outer]`.
* 3. Operational blocks: contiguous buffers whose logical shape is
* [nl, no, inner], where inner = nt * nh * hd. These are used when
* the consumer does not care about the split between nh/nt/hd.
*
* Each kernel batch-processes `num_blocks` block pairs. All pointer
* tables are flattened on the host:
* • universal_ptrs_device : [num_blocks]
* • block_ptrs_device : [num_blocks * nl * no]
* • operational_ptrs_device: [num_blocks]
*
* This lets us launch a single grid per direction, keeps the per-block
* math regular, and avoids any per-kernel pointer chasing on the CPU.
*/
enum class TensorDataType : int {
F16 = 0,
BF16 = 1,
F32 = 2,
F64 = 3,
};
enum class BlockLayout : int {
NHD = 0,
HND = 1,
};
enum class OperationalCopyDirection : int {
BlockToOperational = 0,
OperationalToBlock = 1,
};
template <TensorDataType>
struct DTypeTraits;
template <>
struct DTypeTraits<TensorDataType::F16> {
using type = __half;
};
template <>
struct DTypeTraits<TensorDataType::BF16> {
using type = __nv_bfloat16;
};
template <>
struct DTypeTraits<TensorDataType::F32> {
using type = float;
};
template <>
struct DTypeTraits<TensorDataType::F64> {
using type = double;
};
template <typename T>
CUDA_CALLABLE_MEMBER inline T*
ptr_offset(T* base, size_t index)
{
return base + index;
}
template <typename T>
CUDA_CALLABLE_MEMBER inline const T*
ptr_offset(const T* base, size_t index)
{
return base + index;
}
template <BlockLayout Layout>
CUDA_CALLABLE_MEMBER inline size_t
block_inner_offset(size_t nt_idx, size_t nh_idx, size_t hd_idx, size_t nt, size_t nh, size_t hd)
{
if constexpr (Layout == BlockLayout::NHD) {
return ((nt_idx * nh) + nh_idx) * hd + hd_idx;
} else {
return ((nh_idx * nt) + nt_idx) * hd + hd_idx;
}
}
// Choose a conservative grid size so every thread handles a roughly equal
// share of the work even when the total element count spans many blocks.
inline int
compute_grid_dim(size_t total_elements, int block_dim)
{
if (total_elements == 0) {
return 0;
}
size_t blocks = (total_elements + static_cast<size_t>(block_dim) - 1) / static_cast<size_t>(block_dim);
if (blocks == 0) {
blocks = 1;
}
blocks = std::min<size_t>(blocks, 65535);
return static_cast<int>(blocks);
}
// Flatten the [nh, nl, no, nt, hd] coordinates into a linear index so a single
// launch can cover many independent blocks in one pass.
template <typename T, BlockLayout Layout>
__global__ void
block_to_universal_kernel(
const T* const* block_chunks, T* const* universal_blocks, size_t block_stride, size_t total_per_block,
size_t num_blocks, size_t nh, size_t nl, size_t no, size_t nt, size_t hd)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t tmp = residual;
size_t hd_idx = tmp % hd;
tmp /= hd;
size_t nt_idx = tmp % nt;
tmp /= nt;
size_t no_idx = tmp % no;
tmp /= no;
size_t nl_idx = tmp % nl;
tmp /= nl;
size_t nh_idx = tmp;
const T* const* block_base = block_chunks + block_idx * block_stride;
const T* chunk_base = block_base[nl_idx * no + no_idx];
size_t chunk_offset = block_inner_offset<Layout>(nt_idx, nh_idx, hd_idx, nt, nh, hd);
T* universal_base = universal_blocks[block_idx];
universal_base[residual] = chunk_base[chunk_offset];
thread_id += stride;
}
}
// The inverse of block_to_universal_kernel: peel apart the same linear index
// and scatter back into the layer/outer stacks.
template <typename T, BlockLayout Layout>
__global__ void
universal_to_block_kernel(
const T* const* universal_blocks, T* const* block_chunks, size_t block_stride, size_t total_per_block,
size_t num_blocks, size_t nh, size_t nl, size_t no, size_t nt, size_t hd)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t tmp = residual;
size_t hd_idx = tmp % hd;
tmp /= hd;
size_t nt_idx = tmp % nt;
tmp /= nt;
size_t no_idx = tmp % no;
tmp /= no;
size_t nl_idx = tmp % nl;
tmp /= nl;
size_t nh_idx = tmp;
T* const* block_base = const_cast<T* const*>(block_chunks + block_idx * block_stride);
T* chunk_base = block_base[nl_idx * no + no_idx];
size_t chunk_offset = block_inner_offset<Layout>(nt_idx, nh_idx, hd_idx, nt, nh, hd);
const T* universal_base = universal_blocks[block_idx];
chunk_base[chunk_offset] = universal_base[residual];
thread_id += stride;
}
}
// Pack or unpack the operational layout by striding across the flattened
// (nl * no) chunk table. chunk_elements == inner.
template <typename T>
__global__ void
operational_pack_kernel(
const T* const* block_chunks, T* const* operational_blocks, size_t block_stride, size_t chunk_elements,
size_t total_per_block, size_t num_blocks)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t chunk_idx = residual / chunk_elements;
size_t inner_idx = residual % chunk_elements;
const T* const* block_base = block_chunks + block_idx * block_stride;
const T* chunk_ptr = block_base[chunk_idx];
T* operational_base = operational_blocks[block_idx];
operational_base[residual] = chunk_ptr[inner_idx];
thread_id += stride;
}
}
template <typename T>
__global__ void
operational_unpack_kernel(
const T* const* operational_blocks, T* const* block_chunks, size_t block_stride, size_t chunk_elements,
size_t total_per_block, size_t num_blocks)
{
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
size_t total = total_per_block * num_blocks;
while (thread_id < total) {
size_t block_idx = thread_id / total_per_block;
size_t residual = thread_id % total_per_block;
size_t chunk_idx = residual / chunk_elements;
size_t inner_idx = residual % chunk_elements;
T* const* block_base = block_chunks + block_idx * block_stride;
T* chunk_ptr = block_base[chunk_idx];
const T* operational_base = operational_blocks[block_idx];
chunk_ptr[inner_idx] = operational_base[residual];
thread_id += stride;
}
}
template <typename T>
cudaError_t
launch_block_to_universal_impl(
void* const* universal_ptrs_device, const void* const* block_ptrs_device, size_t num_blocks, size_t nh, size_t nl,
size_t no, size_t nt, size_t hd, BlockLayout layout, cudaStream_t stream)
{
size_t block_stride = nl * no;
size_t total_per_block = nh * nl * no * nt * hd;
size_t total = total_per_block * num_blocks;
if (total == 0) {
return cudaSuccess;
}
if (!block_ptrs_device || !universal_ptrs_device) {
return cudaErrorInvalidValue;
}
constexpr int kBlockDim = 256;
int grid_dim = compute_grid_dim(total, kBlockDim);
if (grid_dim == 0) {
return cudaSuccess;
}
const T* const* chunks = reinterpret_cast<const T* const*>(block_ptrs_device);
T* const* universal_blocks = reinterpret_cast<T* const*>(const_cast<void* const*>(universal_ptrs_device));
if (layout == BlockLayout::NHD) {
block_to_universal_kernel<T, BlockLayout::NHD><<<grid_dim, kBlockDim, 0, stream>>>(
chunks, universal_blocks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
} else {
block_to_universal_kernel<T, BlockLayout::HND><<<grid_dim, kBlockDim, 0, stream>>>(
chunks, universal_blocks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
}
return cudaGetLastError();
}
template <typename T>
cudaError_t
launch_block_from_universal_impl(
const void* const* universal_ptrs_device, void* const* block_ptrs_device, size_t num_blocks, size_t nh, size_t nl,
size_t no, size_t nt, size_t hd, BlockLayout layout, cudaStream_t stream)
{
size_t block_stride = nl * no;
size_t total_per_block = nh * nl * no * nt * hd;
size_t total = total_per_block * num_blocks;
if (total == 0) {
return cudaSuccess;
}
if (!block_ptrs_device || !universal_ptrs_device) {
return cudaErrorInvalidValue;
}
constexpr int kBlockDim = 256;
int grid_dim = compute_grid_dim(total, kBlockDim);
if (grid_dim == 0) {
return cudaSuccess;
}
const T* const* universal_blocks = reinterpret_cast<const T* const*>(universal_ptrs_device);
T* const* chunks = reinterpret_cast<T* const*>(const_cast<void* const*>(block_ptrs_device));
if (layout == BlockLayout::NHD) {
universal_to_block_kernel<T, BlockLayout::NHD><<<grid_dim, kBlockDim, 0, stream>>>(
universal_blocks, chunks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
} else {
universal_to_block_kernel<T, BlockLayout::HND><<<grid_dim, kBlockDim, 0, stream>>>(
universal_blocks, chunks, block_stride, total_per_block, num_blocks, nh, nl, no, nt, hd);
}
return cudaGetLastError();
}
template <typename T>
cudaError_t
launch_operational_copy_impl(
void* const* operational_ptrs_device, void* const* block_ptrs_device, size_t num_blocks, size_t nl, size_t no,
size_t inner, OperationalCopyDirection direction, cudaStream_t stream)
{
size_t chunk_count = nl * no;
if (chunk_count == 0 || inner == 0 || num_blocks == 0) {
return cudaSuccess;
}
if (!operational_ptrs_device || !block_ptrs_device) {
return cudaErrorInvalidValue;
}
constexpr int kBlockDim = 256;
size_t chunk_elements = inner;
size_t total_per_block = chunk_elements * chunk_count;
size_t total = total_per_block * num_blocks;
int grid_dim = compute_grid_dim(total, kBlockDim);
if (grid_dim == 0) {
return cudaSuccess;
}
T* const* operational_blocks = reinterpret_cast<T* const*>(const_cast<void* const*>(operational_ptrs_device));
if (direction == OperationalCopyDirection::BlockToOperational) {
const T* const* block_chunks = reinterpret_cast<const T* const*>(block_ptrs_device);
operational_pack_kernel<T><<<grid_dim, kBlockDim, 0, stream>>>(
block_chunks, operational_blocks, chunk_count, chunk_elements, total_per_block, num_blocks);
} else {
T* const* block_chunks = reinterpret_cast<T* const*>(block_ptrs_device);
operational_unpack_kernel<T><<<grid_dim, kBlockDim, 0, stream>>>(
reinterpret_cast<const T* const*>(operational_ptrs_device), block_chunks, chunk_count, chunk_elements,
total_per_block, num_blocks);
}
return cudaGetLastError();
}
} // namespace
extern "C" cudaError_t
launch_universal_from_block(
void* const* universal_ptrs_device, const void* const* block_ptrs_device, size_t num_blocks, size_t nh, size_t nl,
size_t no, size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
auto dtype = static_cast<TensorDataType>(dtype_value);
auto layout = static_cast<BlockLayout>(layout_value);
switch (dtype) {
case TensorDataType::F16:
return launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F16>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::BF16:
return launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::BF16>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F32:
return launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F32>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F64:
return launch_block_to_universal_impl<typename DTypeTraits<TensorDataType::F64>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
default:
return cudaErrorInvalidValue;
}
}
extern "C" cudaError_t
launch_block_from_universal(
const void* const* universal_ptrs_device, void* const* block_ptrs_device, size_t num_blocks, size_t nh, size_t nl,
size_t no, size_t nt, size_t hd, int dtype_value, int layout_value, cudaStream_t stream)
{
auto dtype = static_cast<TensorDataType>(dtype_value);
auto layout = static_cast<BlockLayout>(layout_value);
switch (dtype) {
case TensorDataType::F16:
return launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F16>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::BF16:
return launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::BF16>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F32:
return launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F32>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
case TensorDataType::F64:
return launch_block_from_universal_impl<typename DTypeTraits<TensorDataType::F64>::type>(
universal_ptrs_device, block_ptrs_device, num_blocks, nh, nl, no, nt, hd, layout, stream);
default:
return cudaErrorInvalidValue;
}
}
enum class OperationalCopyBackend : int {
Auto = 0,
KernelOnly = 1,
MemcpyAsync = 2,
MemcpyBatch = 3,
};
extern "C" cudaError_t
launch_operational_copy(
const void* const* block_ptrs_host, const void* const* block_ptrs_device, void* const* operational_ptrs_host,
void* const* operational_ptrs_device, size_t num_blocks, size_t nl, size_t no, size_t inner, size_t elem_size,
int dtype_value, int direction_value, int backend_value, cudaStream_t stream)
{
auto direction = static_cast<OperationalCopyDirection>(direction_value);
auto dtype = static_cast<TensorDataType>(dtype_value);
auto backend = static_cast<OperationalCopyBackend>(backend_value);
size_t chunk_count = nl * no;
size_t chunk_bytes = inner * elem_size;
size_t total_chunks = num_blocks * chunk_count;
if (chunk_count == 0 || chunk_bytes == 0 || num_blocks == 0) {
return cudaSuccess;
}
if (!block_ptrs_host || !operational_ptrs_host || !operational_ptrs_device) {
return cudaErrorInvalidValue;
}
std::vector<void*> dst_ptrs(total_chunks);
std::vector<const void*> src_ptrs(total_chunks);
std::vector<size_t> sizes(total_chunks, chunk_bytes);
for (size_t block = 0; block < num_blocks; ++block) {
auto operational_base = static_cast<std::uint8_t*>(const_cast<void*>(operational_ptrs_host[block]));
for (size_t chunk = 0; chunk < chunk_count; ++chunk) {
size_t idx = block * chunk_count + chunk;
auto operational_ptr = operational_base + chunk * chunk_bytes;
if (direction == OperationalCopyDirection::BlockToOperational) {
dst_ptrs[idx] = operational_ptr;
src_ptrs[idx] = block_ptrs_host[idx];
} else {
dst_ptrs[idx] = const_cast<void*>(block_ptrs_host[idx]);
src_ptrs[idx] = operational_ptr;
}
}
}
auto launch_kernel = [&]() -> cudaError_t {
if (!block_ptrs_device) {
return cudaSuccess;
}
switch (dtype) {
case TensorDataType::F16:
return launch_operational_copy_impl<typename DTypeTraits<TensorDataType::F16>::type>(
operational_ptrs_device, const_cast<void* const*>(block_ptrs_device), num_blocks, nl, no, inner, direction,
stream);
case TensorDataType::BF16:
return launch_operational_copy_impl<typename DTypeTraits<TensorDataType::BF16>::type>(
operational_ptrs_device, const_cast<void* const*>(block_ptrs_device), num_blocks, nl, no, inner, direction,
stream);
case TensorDataType::F32:
return launch_operational_copy_impl<typename DTypeTraits<TensorDataType::F32>::type>(
operational_ptrs_device, const_cast<void* const*>(block_ptrs_device), num_blocks, nl, no, inner, direction,
stream);
case TensorDataType::F64:
return launch_operational_copy_impl<typename DTypeTraits<TensorDataType::F64>::type>(
operational_ptrs_device, const_cast<void* const*>(block_ptrs_device), num_blocks, nl, no, inner, direction,
stream);
default:
return cudaErrorInvalidValue;
}
};
auto launch_memcpy_async = [&]() -> cudaError_t {
for (size_t idx = 0; idx < total_chunks; ++idx) {
cudaError_t err = cudaMemcpyAsync(dst_ptrs[idx], src_ptrs[idx], sizes[idx], cudaMemcpyDeviceToDevice, stream);
if (err != cudaSuccess) {
return err;
}
}
return cudaSuccess;
};
auto launch_memcpy_batch = [&]() -> cudaError_t {
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12090
std::vector<void*> src_mut(total_chunks);
for (size_t i = 0; i < total_chunks; ++i) {
src_mut[i] = const_cast<void*>(src_ptrs[i]);
}
size_t fail_idx = 0;
return cudaMemcpyBatchAsync(
const_cast<void**>(dst_ptrs.data()), src_mut.data(), const_cast<size_t*>(sizes.data()), total_chunks, nullptr,
nullptr, 0, &fail_idx, stream);
#else
return cudaErrorNotSupported;
#endif
};
cudaError_t status = cudaErrorInvalidValue;
switch (backend) {
case OperationalCopyBackend::KernelOnly:
status = launch_kernel();
break;
case OperationalCopyBackend::MemcpyAsync:
status = launch_memcpy_async();
break;
case OperationalCopyBackend::MemcpyBatch:
status = launch_memcpy_batch();
break;
case OperationalCopyBackend::Auto:
default:
status = launch_kernel();
if (status == cudaErrorNotSupported || status == cudaErrorInvalidValue) {
status = launch_memcpy_batch();
}
if (status == cudaErrorNotSupported || status == cudaErrorInvalidValue) {
status = launch_memcpy_async();
}
break;
}
return status;
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod tensor_kernels;
pub use tensor_kernels::{
BlockLayout, OperationalCopyBackend, OperationalCopyDirection, TensorDataType,
block_from_universal, operational_copy, universal_from_block,
};
// #[cfg(feature = "python-bindings")]
// mod python;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Safe-ish wrappers around the CUDA block/universal packing kernels.
//!
//! The core ideas:
//! * A “block” represents the stack of `nl * no` tensors arranged either as NHD
//! (inner axes `[nt, nh, hd]`) or HND (inner axes `[nh, nt, hd]`).
//! * A “universal” tensor is `[nh, nl, no, nt, hd]` stored contiguously.
//! * An “operational” tensor is `[nl, no, inner]` with `inner = nt * nh * hd`.
//!
//! Host code calls these helpers with flattened pointer tables so a single
//! launch can move many logical blocks in one go.
#![allow(dead_code)]
#![allow(clippy::missing_safety_doc)]
use std::ffi::c_void;
use cudarc::runtime::sys::{cudaError_t, cudaStream_t};
/// Numeric tags passed across the FFI boundary to select the CUDA template.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TensorDataType {
F16 = 0,
BF16 = 1,
F32 = 2,
F64 = 3,
}
/// Identifies how each `[nt, nh, hd]` chunk is laid out in device memory.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BlockLayout {
NHD = 0,
HND = 1,
}
/// Direction flag for copying between block stacks and operational buffers.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OperationalCopyDirection {
BlockToOperational = 0,
OperationalToBlock = 1,
}
/// Selects how the operational copy should move data.
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OperationalCopyBackend {
/// Auto-select the best backend based on the available CUDA toolkit version.
/// Priortizes kernel, over batch copy, then memcpy async.
Auto = 0,
/// Force the custom CUDA kernel path.
KernelOnly = 1,
/// Issue one cudaMemcpyAsync per chunk.
MemcpyAsync = 2,
/// Invoke cudaMemcpyBatchAsync directly.
MemcpyBatch = 3,
}
unsafe extern "C" {
fn launch_universal_from_block(
universal_ptrs_device: *const *mut c_void,
block_ptrs_device: *const *const c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: i32,
layout: i32,
stream: cudaStream_t,
) -> cudaError_t;
fn launch_block_from_universal(
universal_ptrs_device: *const *const c_void,
block_ptrs_device: *const *mut c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: i32,
layout: i32,
stream: cudaStream_t,
) -> cudaError_t;
fn launch_operational_copy(
block_ptrs_host: *const *const c_void,
block_ptrs_device: *const *const c_void,
operational_ptrs_host: *const *mut c_void,
operational_ptrs_device: *const *const c_void,
num_blocks: usize,
nl: usize,
no: usize,
inner: usize,
elem_size: usize,
dtype: i32,
direction: i32,
backend: i32,
stream: cudaStream_t,
) -> cudaError_t;
}
/// Copy `num_blocks` stacks of NHD/HND tensors into universal form.
///
/// * `universal_device_ptrs` – device pointer to `num_blocks` universal bases.
/// * `block_device_ptrs` – device pointer to a flattened `[num_blocks][nl*no]`
/// table of chunk pointers.
/// * `nh, nl, no, nt, hd` – logical dimensions of each universal tensor.
/// * `stream` – CUDA stream used for the launch.
#[allow(clippy::too_many_arguments)]
pub unsafe fn universal_from_block(
universal_device_ptrs: *const *mut c_void,
block_device_ptrs: *const *const c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: TensorDataType,
layout: BlockLayout,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
launch_universal_from_block(
universal_device_ptrs,
block_device_ptrs,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype as i32,
layout as i32,
stream,
)
}
}
/// Copy `num_blocks` universal tensors back into their block stacks.
#[allow(clippy::too_many_arguments)]
pub unsafe fn block_from_universal(
universal_device_ptrs: *const *const c_void,
block_device_ptrs: *const *mut c_void,
num_blocks: usize,
nh: usize,
nl: usize,
no: usize,
nt: usize,
hd: usize,
dtype: TensorDataType,
layout: BlockLayout,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
launch_block_from_universal(
universal_device_ptrs,
block_device_ptrs,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype as i32,
layout as i32,
stream,
)
}
}
/// Copy between block stacks and operational buffers for `num_blocks`.
///
/// The CUDA ≥12.9 path uses `cudaMemcpyBatchAsync`; older toolkits fall back to
/// an explicit kernel (`launch_operational_copy_impl`). `backend` lets callers
/// force a specific path (`Auto`, `KernelOnly`, `MemcpyAsync`, `MemcpyBatch`).
/// In `Auto` mode we try the fused kernel first, then batch copy, then plain
/// `cudaMemcpyAsync`.
#[allow(clippy::too_many_arguments)]
pub unsafe fn operational_copy(
block_ptrs_host: *const *const c_void,
block_ptrs_device: *const *const c_void,
operational_ptrs_host: *const *mut c_void,
operational_ptrs_device: *const *const c_void,
num_blocks: usize,
nl: usize,
no: usize,
inner: usize,
elem_size: usize,
dtype: TensorDataType,
direction: OperationalCopyDirection,
backend: OperationalCopyBackend,
stream: cudaStream_t,
) -> cudaError_t {
unsafe {
launch_operational_copy(
block_ptrs_host,
block_ptrs_device,
operational_ptrs_host,
operational_ptrs_device,
num_blocks,
nl,
no,
inner,
elem_size,
dtype as i32,
direction as i32,
backend as i32,
stream,
)
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
use cudarc::driver::{CudaContext, CudaSlice, DevicePtr, DevicePtrMut, DriverError};
use cudarc::runtime::sys as cuda_runtime;
#[test]
fn fused_copy_roundtrip() -> Result<(), DriverError> {
let device_count = match CudaContext::device_count() {
Ok(count) => count,
Err(_) => return Ok(()),
};
if device_count <= 0 {
return Ok(());
}
let ctx = CudaContext::new(0)?;
let stream = ctx.default_stream();
let stream_raw = stream.cu_stream() as cuda_runtime::cudaStream_t;
let nh = 2usize;
let nl = 2usize;
let no = 2usize;
let nt = 3usize;
let hd = 4usize;
let inner = nt * nh * hd;
let chunk_count = nl * no;
let block_volume = nh * nl * no * nt * hd;
let operational_volume = chunk_count * inner;
let num_blocks = 2usize;
let dtype = TensorDataType::F32;
let layout = BlockLayout::NHD;
let mut host_block_chunks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(num_blocks);
let mut block_slices: Vec<Vec<CudaSlice<f32>>> = Vec::with_capacity(num_blocks);
let mut block_ptrs_host: Vec<*const c_void> = Vec::with_capacity(num_blocks * chunk_count);
let mut block_ptr_values: Vec<usize> = Vec::with_capacity(num_blocks * chunk_count);
for block_idx in 0..num_blocks {
let mut host_chunks_for_block = Vec::with_capacity(chunk_count);
let mut slices_for_block = Vec::with_capacity(chunk_count);
for chunk_idx in 0..chunk_count {
let global_idx = block_idx * chunk_count + chunk_idx;
let mut host_chunk = Vec::with_capacity(inner);
for offset in 0..inner {
host_chunk.push((global_idx * inner + offset) as f32 + 0.25f32);
}
let slice = stream.memcpy_stod(&host_chunk)?;
{
let (ptr_raw, _guard) = slice.device_ptr(&stream);
block_ptrs_host.push(ptr_raw as usize as *const c_void);
block_ptr_values.push(ptr_raw as usize);
}
slices_for_block.push(slice);
host_chunks_for_block.push(host_chunk);
}
block_slices.push(slices_for_block);
host_block_chunks.push(host_chunks_for_block);
}
let block_ptrs_device = stream.memcpy_stod(block_ptr_values.as_slice())?;
let mut universal_slices = Vec::with_capacity(num_blocks);
let mut universal_ptr_values = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
let mut slice = unsafe { stream.alloc::<f32>(block_volume)? };
{
let (ptr_raw, _guard) = slice.device_ptr_mut(&stream);
universal_ptr_values.push(ptr_raw as usize);
}
universal_slices.push(slice);
}
let universal_ptrs_device = stream.memcpy_stod(universal_ptr_values.as_slice())?;
let mut operational_slices = Vec::with_capacity(num_blocks);
let mut operational_ptrs_host = Vec::with_capacity(num_blocks);
let mut operational_ptr_values = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
let mut slice = unsafe { stream.alloc::<f32>(operational_volume)? };
{
let (ptr_raw, _guard) = slice.device_ptr_mut(&stream);
operational_ptrs_host.push(ptr_raw as usize as *mut c_void);
operational_ptr_values.push(ptr_raw as usize);
}
operational_slices.push(slice);
}
let operational_ptrs_device = stream.memcpy_stod(operational_ptr_values.as_slice())?;
// Block -> Universal
{
let (block_ptrs_device_raw, _block_guard) = block_ptrs_device.device_ptr(&stream);
let block_ptrs_device_ptr = block_ptrs_device_raw as usize as *const *const c_void;
let (universal_ptrs_device_raw, _univ_guard) =
universal_ptrs_device.device_ptr(&stream);
let universal_ptrs_device_ptr =
universal_ptrs_device_raw as usize as *const *mut c_void;
let status = unsafe {
super::universal_from_block(
universal_ptrs_device_ptr,
block_ptrs_device_ptr,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
let inner_offset = |nt_idx: usize, nh_idx: usize, hd_idx: usize| match layout {
BlockLayout::NHD => ((nt_idx * nh) + nh_idx) * hd + hd_idx,
BlockLayout::HND => ((nh_idx * nt) + nt_idx) * hd + hd_idx,
};
for (block_idx, universal_slice) in universal_slices.iter().enumerate().take(num_blocks) {
let host_universal = stream.memcpy_dtov(universal_slice)?;
for nh_idx in 0..nh {
for nl_idx in 0..nl {
for no_idx in 0..no {
for nt_idx in 0..nt {
for hd_idx in 0..hd {
let universal_index =
((((nh_idx * nl + nl_idx) * no + no_idx) * nt + nt_idx) * hd)
+ hd_idx;
let chunk_idx = nl_idx * no + no_idx;
let offset = inner_offset(nt_idx, nh_idx, hd_idx);
let expected = ((block_idx * chunk_count + chunk_idx) * inner
+ offset) as f32
+ 0.25f32;
let value = host_universal[universal_index];
assert!(
(value - expected).abs() < 1e-5,
"universal mismatch block {} [{} {} {} {} {}]: {} vs {}",
block_idx,
nh_idx,
nl_idx,
no_idx,
nt_idx,
hd_idx,
value,
expected
);
}
}
}
}
}
}
// Universal -> Block
for block in &mut block_slices {
for slice in block {
stream.memset_zeros(slice)?;
}
}
stream.synchronize()?;
{
let (block_ptrs_device_raw, _block_guard) = block_ptrs_device.device_ptr(&stream);
let block_ptrs_device_mut = block_ptrs_device_raw as usize as *const *mut c_void;
let (universal_ptrs_device_raw, _univ_guard) =
universal_ptrs_device.device_ptr(&stream);
let universal_ptrs_device_const =
universal_ptrs_device_raw as usize as *const *const c_void;
let status = unsafe {
super::block_from_universal(
universal_ptrs_device_const,
block_ptrs_device_mut,
num_blocks,
nh,
nl,
no,
nt,
hd,
dtype,
layout,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
for block_idx in 0..num_blocks {
for chunk_idx in 0..chunk_count {
let host_chunk = stream.memcpy_dtov(&block_slices[block_idx][chunk_idx])?;
for (inner_idx, value) in host_chunk.iter().enumerate() {
let expected = host_block_chunks[block_idx][chunk_idx][inner_idx];
assert!(
(value - expected).abs() < 1e-5,
"block mismatch block {} chunk {} offset {}: {} vs {}",
block_idx,
chunk_idx,
inner_idx,
value,
expected
);
}
}
}
// Block -> Operational
{
let (block_ptrs_device_raw, _block_guard) = block_ptrs_device.device_ptr(&stream);
let block_ptrs_device_ptr = block_ptrs_device_raw as usize as *const *const c_void;
let (operational_ptrs_device_raw, _op_guard) =
operational_ptrs_device.device_ptr(&stream);
let operational_ptrs_device_ptr =
operational_ptrs_device_raw as usize as *const *const c_void;
let status = unsafe {
super::operational_copy(
block_ptrs_host.as_ptr(),
block_ptrs_device_ptr,
operational_ptrs_host.as_ptr(),
operational_ptrs_device_ptr,
num_blocks,
nl,
no,
inner,
std::mem::size_of::<f32>(),
dtype,
OperationalCopyDirection::BlockToOperational,
OperationalCopyBackend::Auto,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
for block_idx in 0..num_blocks {
let host_operational = stream.memcpy_dtov(&operational_slices[block_idx])?;
for chunk_idx in 0..chunk_count {
for inner_idx in 0..inner {
let expected = host_block_chunks[block_idx][chunk_idx][inner_idx];
let value = host_operational[chunk_idx * inner + inner_idx];
assert!(
(value - expected).abs() < 1e-5,
"operational pack mismatch block {} chunk {} offset {}: {} vs {}",
block_idx,
chunk_idx,
inner_idx,
value,
expected
);
}
}
}
// Operational -> Block
for block in &mut block_slices {
for slice in block {
stream.memset_zeros(slice)?;
}
}
stream.synchronize()?;
{
let (block_ptrs_device_raw, _block_guard) = block_ptrs_device.device_ptr(&stream);
let (operational_ptrs_device_raw, _op_guard) =
operational_ptrs_device.device_ptr(&stream);
let operational_ptrs_device_const =
operational_ptrs_device_raw as usize as *const *const c_void;
let status = unsafe {
super::operational_copy(
block_ptrs_host.as_ptr(),
block_ptrs_device_raw as usize as *const *const c_void,
operational_ptrs_host.as_ptr(),
operational_ptrs_device_const,
num_blocks,
nl,
no,
inner,
std::mem::size_of::<f32>(),
dtype,
OperationalCopyDirection::OperationalToBlock,
OperationalCopyBackend::Auto,
stream_raw,
)
};
assert_eq!(status, cuda_runtime::cudaError::cudaSuccess);
}
stream.synchronize()?;
for block_idx in 0..num_blocks {
for chunk_idx in 0..chunk_count {
let host_chunk = stream.memcpy_dtov(&block_slices[block_idx][chunk_idx])?;
for (inner_idx, value) in host_chunk.iter().enumerate() {
let expected = host_block_chunks[block_idx][chunk_idx][inner_idx];
assert!(
(value - expected).abs() < 1e-5,
"operational unpack mismatch block {} chunk {} offset {}: {} vs {}",
block_idx,
chunk_idx,
inner_idx,
value,
expected
);
}
}
}
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment