Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
4564a387
Unverified
Commit
4564a387
authored
May 09, 2025
by
Ryan Olson
Committed by
GitHub
May 09, 2025
Browse files
feat: kv block manager (#965)
parent
24e2cbf5
Changes
38
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4612 additions
and
127 deletions
+4612
-127
.github/workflows/pre-merge-python.yml
.github/workflows/pre-merge-python.yml
+5
-2
.github/workflows/pre-merge-rust.yml
.github/workflows/pre-merge-rust.yml
+0
-2
Cargo.lock
Cargo.lock
+45
-2
Cargo.toml
Cargo.toml
+2
-0
container/Dockerfile.vllm
container/Dockerfile.vllm
+12
-2
lib/bindings/python/Cargo.lock
lib/bindings/python/Cargo.lock
+7
-0
lib/bindings/python/pyproject.toml
lib/bindings/python/pyproject.toml
+2
-1
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+8
-3
lib/llm/build.rs
lib/llm/build.rs
+122
-115
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+297
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+1683
-0
lib/llm/src/block_manager/block/registry.rs
lib/llm/src/block_manager/block/registry.rs
+390
-0
lib/llm/src/block_manager/block/state.rs
lib/llm/src/block_manager/block/state.rs
+179
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+634
-0
lib/llm/src/block_manager/block/transfer/cuda.rs
lib/llm/src/block_manager/block/transfer/cuda.rs
+292
-0
lib/llm/src/block_manager/block/transfer/memcpy.rs
lib/llm/src/block_manager/block/transfer/memcpy.rs
+80
-0
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+161
-0
lib/llm/src/block_manager/block/transfer/strategy.rs
lib/llm/src/block_manager/block/transfer/strategy.rs
+296
-0
lib/llm/src/block_manager/block/view.rs
lib/llm/src/block_manager/block/view.rs
+224
-0
lib/llm/src/block_manager/config.rs
lib/llm/src/block_manager/config.rs
+173
-0
No files found.
.github/workflows/pre-merge-python.yml
View file @
4564a387
...
@@ -47,15 +47,18 @@ jobs:
...
@@ -47,15 +47,18 @@ jobs:
GITHUB_TOKEN
:
${{ secrets.CI_TOKEN }}
GITHUB_TOKEN
:
${{ secrets.CI_TOKEN }}
run
:
|
run
:
|
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
-
name
:
Run Rust checks (llm/block-manager)
run
:
|
docker run -w /workspace/lib/llm --name ${{ env.CONTAINER_ID }}_rust_checks ${{ steps.define_image_tag.outputs.image_tag }} bash -ec 'rustup component add rustfmt clippy && cargo fmt -- --check && cargo clippy --features block-manager --no-deps --all-targets -- -D warnings && cargo test --locked --all-targets --features=block-manager'
-
name
:
Run pytest
-
name
:
Run pytest
env
:
env
:
PYTEST_MARKS
:
"
pre_merge
or
mypy"
PYTEST_MARKS
:
"
pre_merge
or
mypy"
run
:
|
run
:
|
docker run -w /workspace --name ${{ env.CONTAINER_ID }} ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
docker run -w /workspace --name ${{ env.CONTAINER_ID }}
_pytest
${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
-
name
:
Copy test report from test Container
-
name
:
Copy test report from test Container
if
:
always()
if
:
always()
run
:
|
run
:
|
docker cp ${{ env.CONTAINER_ID }}:/workspace/${{ env.PYTEST_XML_FILE }} .
docker cp ${{ env.CONTAINER_ID }}
_pytest
:/workspace/${{ env.PYTEST_XML_FILE }} .
-
name
:
Archive test report
-
name
:
Archive test report
uses
:
actions/upload-artifact@v4
uses
:
actions/upload-artifact@v4
if
:
always()
if
:
always()
...
...
.github/workflows/pre-merge-rust.yml
View file @
4564a387
...
@@ -23,8 +23,6 @@ on:
...
@@ -23,8 +23,6 @@ on:
# Run this workflow on pull requests targeting main but only if files in runtime/rust change.
# Run this workflow on pull requests targeting main but only if files in runtime/rust change.
pull_request
:
pull_request
:
branches
:
-
main
paths
:
paths
:
-
.github/workflows/pre-merge-rust.yml
-
.github/workflows/pre-merge-rust.yml
-
'
lib/runtime/**'
-
'
lib/runtime/**'
...
...
Cargo.lock
View file @
4564a387
...
@@ -513,6 +513,26 @@ dependencies = [
...
@@ -513,6 +513,26 @@ dependencies = [
"which",
"which",
]
]
[[package]]
name = "bindgen"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.1",
"shlex",
"syn 2.0.100",
]
[[package]]
[[package]]
name = "bindgen_cuda"
name = "bindgen_cuda"
version = "0.1.5"
version = "0.1.5"
...
@@ -1606,6 +1626,8 @@ dependencies = [
...
@@ -1606,6 +1626,8 @@ dependencies = [
"minijinja",
"minijinja",
"minijinja-contrib",
"minijinja-contrib",
"ndarray",
"ndarray",
"nixl-sys",
"oneshot",
"prometheus",
"prometheus",
"proptest",
"proptest",
"rand 0.9.1",
"rand 0.9.1",
...
@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
dependencies = [
"cfg-if 1.0.0",
"cfg-if 1.0.0",
"windows-targets 0.
52.6
",
"windows-targets 0.
48.5
",
]
]
[[package]]
[[package]]
...
@@ -3441,7 +3463,7 @@ version = "0.1.103"
...
@@ -3441,7 +3463,7 @@ version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [
dependencies = [
"bindgen",
"bindgen
0.69.5
",
"cc",
"cc",
"cmake",
"cmake",
"find_cuda_helper",
"find_cuda_helper",
...
@@ -4057,6 +4079,21 @@ dependencies = [
...
@@ -4057,6 +4079,21 @@ dependencies = [
"libc",
"libc",
]
]
[[package]]
name = "nixl-sys"
version = "0.2.1-rc.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4698155ec791ae888482b0c1c5d19792a1b457167fa8ec0ffa487ffef8c8e5a"
dependencies = [
"bindgen 0.71.1",
"cc",
"libc",
"pkg-config",
"serde",
"thiserror 2.0.12",
"tracing",
]
[[package]]
[[package]]
name = "nkeys"
name = "nkeys"
version = "0.4.4"
version = "0.4.4"
...
@@ -4282,6 +4319,12 @@ version = "1.21.3"
...
@@ -4282,6 +4319,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
[[package]]
name = "onig"
name = "onig"
version = "6.4.0"
version = "6.4.0"
...
...
Cargo.toml
View file @
4564a387
...
@@ -60,6 +60,7 @@ futures = { version = "0.3" }
...
@@ -60,6 +60,7 @@ futures = { version = "0.3" }
hf-hub
=
{
version
=
"0.4.2"
,
default-features
=
false
,
features
=
[
"tokio"
,
"rustls-tls"
]
}
hf-hub
=
{
version
=
"0.4.2"
,
default-features
=
false
,
features
=
[
"tokio"
,
"rustls-tls"
]
}
humantime
=
{
version
=
"2.2.0"
}
humantime
=
{
version
=
"2.2.0"
}
libc
=
{
version
=
"0.2"
}
libc
=
{
version
=
"0.2"
}
oneshot
=
{
version
=
"0.1.11"
,
features
=
[
"std"
,
"async"
]
}
prometheus
=
{
version
=
"0.14"
}
prometheus
=
{
version
=
"0.14"
}
rand
=
{
version
=
"0.9.0"
}
rand
=
{
version
=
"0.9.0"
}
serde
=
{
version
=
"1"
,
features
=
["derive"]
}
serde
=
{
version
=
"1"
,
features
=
["derive"]
}
...
@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
...
@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
url
=
{
version
=
"2.5"
,
features
=
["serde"]
}
url
=
{
version
=
"2.5"
,
features
=
["serde"]
}
xxhash-rust
=
{
version
=
"0.8"
,
features
=
[
"xxh3"
,
"const_xxh3"
]
}
xxhash-rust
=
{
version
=
"0.8"
,
features
=
[
"xxh3"
,
"const_xxh3"
]
}
[profile.dev.package]
[profile.dev.package]
insta.opt-level
=
3
insta.opt-level
=
3
...
...
container/Dockerfile.vllm
View file @
4564a387
...
@@ -108,6 +108,12 @@ WORKDIR /workspace
...
@@ -108,6 +108,12 @@ WORKDIR /workspace
# Copy nixl source, and use commit hash as cache hint
# Copy nixl source, and use commit hash as cache hint
COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install
### NATS & ETCD SETUP ###
### NATS & ETCD SETUP ###
# nats
# nats
...
@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
...
@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
WORKDIR /workspace
WORKDIR /workspace
RUN yum update -y \
RUN yum update -y \
&& yum install -y llvm-toolset \
&& yum install -y python3.12-devel \
&& yum install -y python3.12-devel \
&& yum install -y protobuf-compiler \
&& yum install -y protobuf-compiler \
&& yum clean all \
&& yum clean all \
...
@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
...
@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base /usr/local/nixl /opt/nvidia/nvda_nixl
COPY --from=base /workspace /workspace
COPY --from=base /workspace /workspace
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
...
@@ -342,7 +350,7 @@ COPY launch /workspace/launch
...
@@ -342,7 +350,7 @@ COPY launch /workspace/launch
COPY deploy/sdk /workspace/deploy/sdk
COPY deploy/sdk /workspace/deploy/sdk
# Build Rust crate binaries packaged with the wheel
# Build Rust crate binaries packaged with the wheel
RUN cargo build --release --locked --features mistralrs,python \
RUN cargo build --release --locked --features mistralrs,python
,dynamo-llm/block-manager
\
-p dynamo-run \
-p dynamo-run \
-p llmctl \
-p llmctl \
# Multiple http named crates are present in dependencies, need to specify the path
# Multiple http named crates are present in dependencies, need to specify the path
...
@@ -370,6 +378,7 @@ WORKDIR /workspace
...
@@ -370,6 +378,7 @@ WORKDIR /workspace
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /opt/nvidia/nvda_nixl /opt/nvidia/nvda_nixl
# Copy Cargo cache to avoid re-downloading dependencies
# Copy Cargo cache to avoid re-downloading dependencies
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
...
@@ -377,7 +386,7 @@ COPY . /workspace
...
@@ -377,7 +386,7 @@ COPY . /workspace
# Build rest of the crates
# Build rest of the crates
# Need to figure out rust caching to avoid rebuilding and remove exclude flags
# Need to figure out rust caching to avoid rebuilding and remove exclude flags
RUN cargo build --release --locked --workspace \
RUN cargo build --release --locked
--features block-manager
--workspace \
--exclude dynamo-run \
--exclude dynamo-run \
--exclude llmctl \
--exclude llmctl \
--exclude file://$PWD/components/http \
--exclude file://$PWD/components/http \
...
@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
...
@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/
##########################################
##########################################
########## Perf Analyzer Image ###########
########## Perf Analyzer Image ###########
...
...
lib/bindings/python/Cargo.lock
View file @
4564a387
...
@@ -1058,6 +1058,7 @@ dependencies = [
...
@@ -1058,6 +1058,7 @@ dependencies = [
"memmap2",
"memmap2",
"minijinja",
"minijinja",
"minijinja-contrib",
"minijinja-contrib",
"oneshot",
"prometheus",
"prometheus",
"rand 0.9.1",
"rand 0.9.1",
"rayon",
"rayon",
...
@@ -2859,6 +2860,12 @@ version = "1.21.3"
...
@@ -2859,6 +2860,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
[[package]]
name = "onig"
name = "onig"
version = "6.4.0"
version = "6.4.0"
...
...
lib/bindings/python/pyproject.toml
View file @
4564a387
...
@@ -51,10 +51,11 @@ module-name = "dynamo._core"
...
@@ -51,10 +51,11 @@ module-name = "dynamo._core"
manifest-path
=
"Cargo.toml"
manifest-path
=
"Cargo.toml"
python-packages
=
["dynamo"]
python-packages
=
["dynamo"]
python-source
=
"src"
python-source
=
"src"
features
=
["dynamo-llm/block-manager"]
[build-system]
[build-system]
requires
=
[
"maturin>=1.0,<2.0"
,
"patchelf"
]
requires
=
[
"maturin>=1.0,<2.0"
,
"patchelf"
]
build-backend
=
"maturin"
build-backend
=
"maturin"
[tool.uv]
[tool.uv]
config-settings
=
{
build-args
=
'--auditwheel repair --manylinux'
}
config-settings
=
{
build-args
=
'--auditwheel repair --manylinux'
}
\ No newline at end of file
lib/llm/Cargo.toml
View file @
4564a387
...
@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
...
@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
[features]
[features]
default
=
[]
default
=
[]
cuda_kv
=
[
"dep:cudarc"
,
"dep:ndarray"
]
testing-full
=
[
"testing-cuda"
,
"testing-nixl"
]
testing-cuda
=
["dep:cudarc"]
testing-nixl
=
["dep:nixl-sys"]
block-manager
=
[
"dep:nixl-sys"
,
"dep:cudarc"
,
"dep:ndarray"
]
sentencepiece
=
["dep:sentencepiece"]
sentencepiece
=
["dep:sentencepiece"]
[dependencies]
[dependencies]
...
@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
...
@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
futures
=
{
workspace
=
true
}
futures
=
{
workspace
=
true
}
hf-hub
=
{
workspace
=
true
}
hf-hub
=
{
workspace
=
true
}
rand
=
{
workspace
=
true
}
rand
=
{
workspace
=
true
}
oneshot
=
{
workspace
=
true
}
prometheus
=
{
workspace
=
true
}
prometheus
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
...
@@ -72,8 +76,9 @@ derive-getters = "0.5"
...
@@ -72,8 +76,9 @@ derive-getters = "0.5"
regex
=
"1"
regex
=
"1"
rayon
=
"1"
rayon
=
"1"
# kv_cuda
# block_manager
cudarc
=
{
version
=
"0.16.2"
,
features
=
["cuda-12040"]
,
optional
=
true
}
nixl-sys
=
{
version
=
"0.2.1-rc.1"
,
optional
=
true
}
cudarc
=
{
version
=
"0.16.2"
,
features
=
["cuda-12020"]
,
optional
=
true
}
ndarray
=
{
version
=
"0.16"
,
optional
=
true
}
ndarray
=
{
version
=
"0.16"
,
optional
=
true
}
# protocols
# protocols
...
...
lib/llm/build.rs
View file @
4564a387
...
@@ -13,121 +13,128 @@
...
@@ -13,121 +13,128 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#[cfg(not(feature
=
"cuda_kv"
))]
fn
main
()
{}
#[cfg(feature
=
"cuda_kv"
)]
fn
main
()
{
fn
main
()
{
use
std
::{
path
::
PathBuf
,
process
::
Command
};
println!
(
"cargo:warning=Building with CUDA KV off"
);
println!
(
"cargo:rerun-if-changed=src/kernels/block_copy.cu"
);
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let
nvcc
=
Command
::
new
(
"which"
)
.arg
(
"nvcc"
)
.output
()
.unwrap
();
let
cuda_lib
=
if
nvcc
.status
.success
()
{
println!
(
"cargo:info=nvcc found in path"
);
// Extract the path from nvcc location by removing "bin/nvcc"
let
nvcc_path
=
String
::
from_utf8_lossy
(
&
nvcc
.stdout
)
.trim
()
.to_string
();
let
path
=
PathBuf
::
from
(
nvcc_path
);
if
let
Some
(
parent
)
=
path
.parent
()
{
// Remove "nvcc"
if
let
Some
(
cuda_root
)
=
parent
.parent
()
{
// Remove "bin"
cuda_root
.to_string_lossy
()
.to_string
()
}
else
{
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default
()
}
}
else
{
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default
()
}
}
else
{
println!
(
"cargo:warning=nvcc not found in path"
);
get_cuda_root_or_default
()
};
println!
(
"cargo:info=Using CUDA installation at: {}"
,
cuda_lib
);
let
cuda_lib_path
=
PathBuf
::
from
(
&
cuda_lib
)
.join
(
"lib64"
);
println!
(
"cargo:info=Using CUDA libs: {}"
,
cuda_lib_path
.display
());
println!
(
"cargo:rustc-link-search=native={}"
,
cuda_lib_path
.display
());
// Link against multiple CUDA libraries
println!
(
"cargo:rustc-link-lib=dylib=cudart"
);
println!
(
"cargo:rustc-link-lib=dylib=cuda"
);
println!
(
"cargo:rustc-link-lib=dylib=cudadevrt"
);
// Make sure the CUDA libraries are found before other system libraries
println!
(
"cargo:rustc-link-arg=-Wl,-rpath,{}"
,
cuda_lib_path
.display
()
);
// Create kernels directory for output if it doesn't exist
std
::
fs
::
create_dir_all
(
"src/kernels"
)
.unwrap_or_else
(|
_
|
{
println!
(
"Kernels directory already exists"
);
});
// Compile CUDA code
let
output
=
Command
::
new
(
"nvcc"
)
.arg
(
"src/kernels/block_copy.cu"
)
.arg
(
"-O3"
)
.arg
(
"--compiler-options"
)
.arg
(
"-fPIC"
)
.arg
(
"-o"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.arg
(
"-c"
)
.output
()
.expect
(
"Failed to compile CUDA code"
);
if
!
output
.status
.success
()
{
panic!
(
"Failed to compile CUDA kernel: {}"
,
String
::
from_utf8_lossy
(
&
output
.stderr
)
);
}
// Create static library
#[cfg(target_os
=
"windows"
)]
{
Command
::
new
(
"lib"
)
.arg
(
"/OUT:src/kernels/block_copy.lib"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.output
()
.expect
(
"Failed to create static library"
);
println!
(
"cargo:rustc-link-search=native=src/kernels"
);
println!
(
"cargo:rustc-link-lib=static=block_copy"
);
}
#[cfg(not(target_os
=
"windows"
))]
{
Command
::
new
(
"ar"
)
.arg
(
"rcs"
)
.arg
(
"src/kernels/libblock_copy.a"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.output
()
.expect
(
"Failed to create static library"
);
println!
(
"cargo:rustc-link-search=native=src/kernels"
);
println!
(
"cargo:rustc-link-lib=static=block_copy"
);
println!
(
"cargo:rustc-link-lib=dylib=cudart"
);
println!
(
"cargo:rustc-link-lib=dylib=cuda"
);
println!
(
"cargo:rustc-link-lib=dylib=cudadevrt"
);
}
}
}
#[cfg(feature
=
"cuda_kv"
)]
// NOTE: Preserving this build.rs for reference. We may want to re-enable
fn
get_cuda_root_or_default
()
->
String
{
// custom kernel compilation in the future.
match
std
::
env
::
var
(
"CUDA_ROOT"
)
{
Ok
(
path
)
=>
path
,
// #[cfg(not(feature = "cuda_kv"))]
Err
(
_
)
=>
{
// fn main() {}
// Default locations based on OS
if
cfg!
(
target_os
=
"windows"
)
{
// #[cfg(feature = "cuda_kv")]
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8"
.to_string
()
// fn main() {
}
else
{
// use std::{path::PathBuf, process::Command};
"/usr/local/cuda"
.to_string
()
}
// println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
}
}
// // first do a which nvcc, if it is in the path
}
// // if so, we don't need to set the cuda_lib
// let nvcc = Command::new("which").arg("nvcc").output().unwrap();
// let cuda_lib = if nvcc.status.success() {
// println!("cargo:info=nvcc found in path");
// // Extract the path from nvcc location by removing "bin/nvcc"
// let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
// let path = PathBuf::from(nvcc_path);
// if let Some(parent) = path.parent() {
// // Remove "nvcc"
// if let Some(cuda_root) = parent.parent() {
// // Remove "bin"
// cuda_root.to_string_lossy().to_string()
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// println!("cargo:warning=nvcc not found in path");
// get_cuda_root_or_default()
// };
// println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
// let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
// println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
// println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// // Link against multiple CUDA libraries
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// // Make sure the CUDA libraries are found before other system libraries
// println!(
// "cargo:rustc-link-arg=-Wl,-rpath,{}",
// cuda_lib_path.display()
// );
// // Create kernels directory for output if it doesn't exist
// std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
// println!("Kernels directory already exists");
// });
// // Compile CUDA code
// let output = Command::new("nvcc")
// .arg("src/kernels/block_copy.cu")
// .arg("-O3")
// .arg("--compiler-options")
// .arg("-fPIC")
// .arg("-o")
// .arg("src/kernels/libblock_copy.o")
// .arg("-c")
// .output()
// .expect("Failed to compile CUDA code");
// if !output.status.success() {
// panic!(
// "Failed to compile CUDA kernel: {}",
// String::from_utf8_lossy(&output.stderr)
// );
// }
// // Create static library
// #[cfg(target_os = "windows")]
// {
// Command::new("lib")
// .arg("/OUT:src/kernels/block_copy.lib")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// }
// #[cfg(not(target_os = "windows"))]
// {
// Command::new("ar")
// .arg("rcs")
// .arg("src/kernels/libblock_copy.a")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// }
// }
// #[cfg(feature = "cuda_kv")]
// fn get_cuda_root_or_default() -> String {
// match std::env::var("CUDA_ROOT") {
// Ok(path) => path,
// Err(_) => {
// // Default locations based on OS
// if cfg!(target_os = "windows") {
// "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
// } else {
// "/usr/local/cuda".to_string()
// }
// }
// }
// }
lib/llm/src/block_manager.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block Manager for LLM KV Cache
//!
//! This module provides functionality for managing KV blocks in LLM attention
//! mechanisms. It handles storage allocation, block management, and safe access
//! patterns for both system memory and remote (NIXL) storage.
mod
config
;
mod
state
;
pub
mod
block
;
pub
mod
events
;
pub
mod
layout
;
pub
mod
pool
;
pub
mod
storage
;
pub
use
crate
::
common
::
dtype
::
DType
;
pub
use
block
::{
nixl
::{
AsBlockDescriptorSet
,
BlockDescriptorList
,
IsImmutable
,
IsMutable
,
MutabilityKind
,
RemoteBlock
,
},
transfer
::{
BlockTransferEngineV1
,
TransferRequestPut
},
BasicMetadata
,
BlockMetadata
,
Blocks
,
};
pub
use
config
::
*
;
pub
use
layout
::{
nixl
::
NixlLayout
,
LayoutConfig
,
LayoutConfigBuilder
,
LayoutError
,
LayoutType
};
pub
use
pool
::
BlockPool
;
pub
use
storage
::{
nixl
::
NixlRegisterableStorage
,
DeviceStorage
,
PinnedStorage
,
Storage
,
StorageAllocator
,
};
pub
use
tokio_util
::
sync
::
CancellationToken
;
use
anyhow
::{
Context
,
Result
};
use
block
::
nixl
::{
BlockMutability
,
NixlBlockSet
,
RemoteBlocks
,
SerializedNixlBlockSet
};
use
derive_builder
::
Builder
;
use
nixl_sys
::
Agent
as
NixlAgent
;
use
std
::{
collections
::
HashMap
,
sync
::{
Arc
,
RwLock
},
};
use
storage
::
nixl
::
MemType
;
use
validator
::
Validate
;
pub
type
WorkerID
=
u64
;
pub
type
ReferenceBlockManager
=
KvBlockManager
<
BasicMetadata
>
;
/// Represents the different cache levels for KV blocks
pub
enum
CacheLevel
{
/// Represents KV blocks in GPU memory
G1
,
/// Represents KV blocks in CPU memory
G2
,
/// Represents KV blocks in Local NVMe storage
G3
,
/// Represents KV blocks in Remote NVMe storage
G4
,
}
// When we construct the pool:
// 1. instantiate the runtime,
// 2. build layout::LayoutConfigs for each of the requested storage types
// 3. register the layouts with the NIXL agent if enabled
// 4. construct a Blocks object for each layout providing a unique block_set_idx
// for each layout type.
// 5. initialize the pools for each set of blocks
pub
struct
KvBlockManager
<
Metadata
:
BlockMetadata
>
{
state
:
Arc
<
state
::
KvBlockManagerState
<
Metadata
>>
,
cancellation_token
:
CancellationToken
,
}
impl
<
Metadata
:
BlockMetadata
>
KvBlockManager
<
Metadata
>
{
/// Create a new [KvBlockManager]
///
/// The returned object is a frontend to the [KvBlockManager] which owns the cancellation
/// tokens. When this object gets drop, the cancellation token will be cancelled and begin
/// the gracefully shutdown of the block managers internal state.
pub
fn
new
(
config
:
KvBlockManagerConfig
)
->
Result
<
Self
>
{
let
mut
config
=
config
;
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let
cancellation_token
=
config
.runtime.cancellation_token
.clone
();
// The internal state will use a child token of the original token
config
.runtime.cancellation_token
=
cancellation_token
.child_token
();
// Create the internal state
let
state
=
state
::
KvBlockManagerState
::
new
(
config
)
?
;
Ok
(
Self
{
state
,
cancellation_token
,
})
}
/// Exports the local blockset configuration as a serialized object.
pub
fn
export_local_blockset
(
&
self
)
->
Result
<
SerializedNixlBlockSet
>
{
self
.state
.export_local_blockset
()
}
/// Imports a remote blockset configuration from a serialized object.
pub
fn
import_remote_blockset
(
&
self
,
serialized_blockset
:
SerializedNixlBlockSet
,
)
->
Result
<
()
>
{
self
.state
.import_remote_blockset
(
serialized_blockset
)
}
/// Get a [`Vec<RemoteBlock<IsImmutable>>`] from a [`BlockDescriptorList`]
pub
fn
get_remote_blocks_immutable
(
&
self
,
bds
:
&
BlockDescriptorList
,
)
->
Result
<
Vec
<
RemoteBlock
<
IsImmutable
>>>
{
self
.state
.get_remote_blocks_immutable
(
bds
)
}
/// Get a [`Vec<RemoteBlock<IsMutable>>`] from a [`BlockDescriptorList`]
pub
fn
get_remote_blocks_mutable
(
&
self
,
bds
:
&
BlockDescriptorList
,
)
->
Result
<
Vec
<
RemoteBlock
<
IsMutable
>>>
{
self
.state
.get_remote_blocks_mutable
(
bds
)
}
/// Get a reference to the host block pool
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
self
.state
.host
()
}
/// Get a reference to the device block pool
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
self
.state
.device
()
}
/// Get the worker ID
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
self
.state
.worker_id
()
}
}
impl
<
Metadata
:
BlockMetadata
>
Drop
for
KvBlockManager
<
Metadata
>
{
fn
drop
(
&
mut
self
)
{
self
.cancellation_token
.cancel
();
}
}
#[cfg(all(test,
feature
=
"testing-full"
))]
mod
tests
{
use
super
::
*
;
use
std
::
sync
::
atomic
::{
AtomicU64
,
Ordering
};
// Atomic Counter for Worker ID
static
WORKER_ID
:
AtomicU64
=
AtomicU64
::
new
(
1337
);
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
config
=
KvBlockManagerConfig
::
builder
()
.runtime
(
KvManagerRuntimeConfig
::
builder
()
.worker_id
(
worker_id
)
.build
()
.unwrap
(),
)
.model
(
KvManagerModelConfig
::
builder
()
.num_layers
(
3
)
.page_size
(
4
)
.inner_dim
(
16
)
.build
()
.unwrap
(),
)
.host_layout
(
KvManagerLayoutConfig
::
builder
()
.num_blocks
(
16
)
.allocator
(
storage
::
PinnedAllocator
::
default
())
.build
()
.unwrap
(),
)
.device_layout
(
KvManagerLayoutConfig
::
builder
()
.num_blocks
(
8
)
.allocator
(
storage
::
DeviceAllocator
::
new
(
0
)
.unwrap
())
.build
()
.unwrap
(),
)
.build
()
.unwrap
();
ReferenceBlockManager
::
new
(
config
)
.unwrap
()
}
#[tokio::test]
async
fn
test_reference_block_manager_inherited_async_runtime
()
{
dynamo_runtime
::
logging
::
init
();
let
_
block_manager
=
create_reference_block_manager
();
}
#[test]
fn
test_reference_block_manager_blocking
()
{
dynamo_runtime
::
logging
::
init
();
let
_
block_manager
=
create_reference_block_manager
();
}
// This tests mimics the behavior of two unique kvbm workers exchanging blocksets
// Each KvBlockManager is a unique worker in this test, each has its resources including
// it's own worker_ids, nixl_agent, and block pools.
//
// This test is meant to mimic the behavior of the basic nixl integration test found here:
// https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs
#[tokio::test]
async
fn
test_reference_block_managers
()
{
dynamo_runtime
::
logging
::
init
();
// create two block managers - mimics two unique dynamo workers
let
kvbm_0
=
create_reference_block_manager
();
let
kvbm_1
=
create_reference_block_manager
();
assert_ne!
(
kvbm_0
.worker_id
(),
kvbm_1
.worker_id
());
// in dynamo, we would exchange the blocksets via the discovery plane
let
blockset_0
=
kvbm_0
.export_local_blockset
()
.unwrap
();
let
blockset_1
=
kvbm_1
.export_local_blockset
()
.unwrap
();
// in dynamo, we would be watching the discovery plane for remote blocksets
kvbm_0
.import_remote_blockset
(
blockset_1
)
.unwrap
();
kvbm_1
.import_remote_blockset
(
blockset_0
)
.unwrap
();
// Worker 0
// Allocate 4 mutable blocks on the host
let
blocks_0
=
kvbm_0
.host
()
.unwrap
()
.allocate_blocks
(
4
)
.await
.unwrap
();
// Create a BlockDescriptorList for the mutable blocks
// let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap();
let
blockset_0
=
blocks_0
.as_block_descriptor_set
()
.unwrap
();
// Worker 1
// Create a RemoteBlock list from blockset_0
let
_
blocks_1
=
kvbm_1
.host
()
.unwrap
()
.allocate_blocks
(
4
)
.await
.unwrap
();
let
mut
_
remote_blocks_0
=
kvbm_1
.get_remote_blocks_mutable
(
&
blockset_0
)
.unwrap
();
// TODO(#967) - Enable with TransferEngine
// // Create a TransferRequestPut for the mutable blocks
// let transfer_request = TransferRequestPut::new(&blocks_0, &mut remote_blocks_0).unwrap();
// // Validate blocks - this could be an expensive operation
// // TODO: Create an ENV trigger debug flag which will call this on every transfer request
// // In this case, we expect an error because we have overlapping blocks as we are sending to/from the same blocks
// // because we are using the wrong target (artifact of the test setup allowing variable to cross what woudl be
// // worker boundaries)
// assert!(transfer_request.validate_blocks().is_err());
// // This is proper request - PUT from worker 1 (local) to worker 0 (remote)
// let transfer_request = TransferRequestPut::new(&blocks_1, &mut remote_blocks_0).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
// let mut put_request = PutRequestBuilder::<_, _>::builder();
// put_request.from(&blocks_1).to(&mut remote_blocks_0);
// // Create a Put request direct between two local blocks
// // split the blocks into two vecs each with 2 blocks
// let mut blocks_1 = blocks_1;
// let slice_0 = blocks_1.split_off(2);
// let mut slice_1 = blocks_1;
// let transfer_request = TransferRequestPut::new(&slice_0, &mut slice_1).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
}
}
lib/llm/src/block_manager/block.rs
0 → 100644
View file @
4564a387
This diff is collapsed.
Click to expand it.
lib/llm/src/block_manager/block/registry.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::{
collections
::
HashMap
,
sync
::{
Arc
,
Weak
},
};
use
super
::
super
::
events
::{
EventManager
,
EventReleaseManager
,
PublishHandle
};
use
super
::
state
::
BlockState
;
use
crate
::
tokens
::{
BlockHash
,
SequenceHash
,
TokenBlock
};
use
derive_getters
::
Getters
;
#[derive(Debug,
thiserror::Error)]
pub
enum
BlockRegistationError
{
#[error(
"Block already registered"
)]
BlockAlreadyRegistered
(
SequenceHash
),
#[error(
"Invalid state: {0}"
)]
InvalidState
(
String
),
}
/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug,
thiserror::Error)]
#[error(
"Failed to unregister block: {0}"
)]
pub
struct
UnregisterFailure
(
SequenceHash
);
#[derive()]
pub
struct
BlockRegistry
{
blocks
:
HashMap
<
SequenceHash
,
Weak
<
RegistrationHandle
>>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
}
impl
BlockRegistry
{
pub
fn
new
(
event_manager
:
Arc
<
dyn
EventManager
>
)
->
Self
{
Self
{
blocks
:
HashMap
::
new
(),
event_manager
,
}
}
pub
fn
is_registered
(
&
self
,
sequence_hash
:
SequenceHash
)
->
bool
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
true
;
}
}
false
}
pub
fn
register_block
(
&
mut
self
,
block_state
:
&
mut
BlockState
,
)
->
Result
<
PublishHandle
,
BlockRegistationError
>
{
match
block_state
{
BlockState
::
Reset
=>
Err
(
BlockRegistationError
::
InvalidState
(
"Block is in Reset state"
.to_string
(),
)),
BlockState
::
Partial
(
_
partial
)
=>
Err
(
BlockRegistationError
::
InvalidState
(
"Block is in Partial state"
.to_string
(),
)),
BlockState
::
Complete
(
state
)
=>
{
let
sequence_hash
=
state
.token_block
()
.sequence_hash
();
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
sequence_hash
));
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let
publish_handle
=
Self
::
create_publish_handle
(
state
.token_block
(),
self
.event_manager
.clone
());
let
reg_handle
=
publish_handle
.remove_handle
();
// Insert the [RegistrationHandle] into the registry
self
.blocks
.insert
(
sequence_hash
,
Arc
::
downgrade
(
&
reg_handle
));
// Update the [BlockState] to [BlockState::Registered]
let
_
=
std
::
mem
::
replace
(
block_state
,
BlockState
::
Registered
(
reg_handle
));
Ok
(
publish_handle
)
}
BlockState
::
Registered
(
registered
)
=>
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
registered
.sequence_hash
()),
),
}
}
pub
fn
unregister_block
(
&
mut
self
,
sequence_hash
:
SequenceHash
,
)
->
Result
<
(),
UnregisterFailure
>
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
handle
.upgrade
()
.is_none
()
{
self
.blocks
.remove
(
&
sequence_hash
);
return
Ok
(());
}
else
{
return
Err
(
UnregisterFailure
(
sequence_hash
));
}
}
Ok
(())
}
fn
create_publish_handle
(
token_block
:
&
TokenBlock
,
event_manager
:
Arc
<
dyn
EventManager
>
,
)
->
PublishHandle
{
let
reg_handle
=
RegistrationHandle
::
from_token_block
(
token_block
,
event_manager
.clone
());
PublishHandle
::
new
(
reg_handle
,
event_manager
)
}
}
#[derive(Getters)]
pub
struct
RegistrationHandle
{
#[getter(copy)]
block_hash
:
BlockHash
,
#[getter(copy)]
sequence_hash
:
SequenceHash
,
#[getter(copy)]
parent_sequence_hash
:
Option
<
SequenceHash
>
,
#[getter(skip)]
release_manager
:
Arc
<
dyn
EventReleaseManager
>
,
}
impl
RegistrationHandle
{
fn
from_token_block
(
token_block
:
&
TokenBlock
,
release_manager
:
Arc
<
dyn
EventReleaseManager
>
,
)
->
Self
{
Self
{
block_hash
:
token_block
.block_hash
(),
sequence_hash
:
token_block
.sequence_hash
(),
parent_sequence_hash
:
token_block
.parent_sequence_hash
(),
release_manager
,
}
}
}
impl
std
::
fmt
::
Debug
for
RegistrationHandle
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}"
,
self
.sequence_hash
,
self
.block_hash
,
self
.parent_sequence_hash
)
}
}
impl
Drop
for
RegistrationHandle
{
fn
drop
(
&
mut
self
)
{
self
.release_manager
.block_release
(
self
);
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
block_manager
::
events
::
tests
::{
EventType
,
MockEventManager
};
use
crate
::
tokens
::{
TokenBlockSequence
,
Tokens
};
fn
create_sequence
()
->
TokenBlockSequence
{
let
tokens
=
Tokens
::
from
(
vec!
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]);
// NOTE: 1337 was the original seed, so we are temporarily using that here to prove the logic has not changed
let
sequence
=
TokenBlockSequence
::
new
(
tokens
,
4
,
Some
(
1337_u64
));
assert_eq!
(
sequence
.blocks
()
.len
(),
2
);
assert_eq!
(
sequence
.current_block
()
.len
(),
2
);
assert_eq!
(
sequence
.blocks
()[
0
]
.tokens
(),
&
vec!
[
1
,
2
,
3
,
4
]);
assert_eq!
(
sequence
.blocks
()[
0
]
.sequence_hash
(),
14643705804678351452
);
assert_eq!
(
sequence
.blocks
()[
1
]
.tokens
(),
&
vec!
[
5
,
6
,
7
,
8
]);
assert_eq!
(
sequence
.blocks
()[
1
]
.sequence_hash
(),
4945711292740353085
);
assert_eq!
(
sequence
.current_block
()
.tokens
(),
&
vec!
[
9
,
10
]);
sequence
}
#[test]
fn
test_mock_event_manager_with_single_publish_handle
()
{
let
sequence
=
create_sequence
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publish_handle
=
BlockRegistry
::
create_publish_handle
(
&
sequence
.blocks
()[
0
],
event_manager
.clone
());
// no event should have been triggered
assert
!
(
rx
.try_recv
()
.is_err
());
// we shoudl get two events when this is dropped, since we never took ownership of the RegistrationHandle
drop
(
publish_handle
);
// the first event should be a Register event
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Register
(
sequence
.blocks
()[
0
]
.sequence_hash
())
);
// the second event should be a Remove event
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Remove
(
sequence
.blocks
()[
0
]
.sequence_hash
())
);
// there should be no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_mock_event_manager_single_publish_handle_removed
()
{
let
sequence
=
create_sequence
();
let
block_to_test
=
&
sequence
.blocks
()[
0
];
let
expected_sequence_hash
=
block_to_test
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publish_handle
=
BlockRegistry
::
create_publish_handle
(
block_to_test
,
event_manager
.clone
());
// Remove the registration handle before dropping the publish handle
let
reg_handle
=
publish_handle
.remove_handle
();
// no event should have been triggered yet
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop the publish handle - it SHOULD trigger a Register event now because remove_handle doesn't disarm
drop
(
publish_handle
);
let
register_events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
register_events
.len
(),
1
,
"Register event should be triggered on PublishHandle drop"
);
assert_eq!
(
register_events
[
0
],
EventType
::
Register
(
expected_sequence_hash
),
"Expected Register event"
);
// Drop the registration handle - this SHOULD trigger the Remove event
drop
(
reg_handle
);
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Remove
(
expected_sequence_hash
),
"Only Remove event should be triggered"
);
// there should be no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_mock_event_manager_publisher_multiple_handles_removed
()
{
let
sequence
=
create_sequence
();
let
block1
=
&
sequence
.blocks
()[
0
];
let
block2
=
&
sequence
.blocks
()[
1
];
let
hash1
=
block1
.sequence_hash
();
let
hash2
=
block2
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
mut
publisher
=
event_manager
.publisher
();
let
publish_handle1
=
BlockRegistry
::
create_publish_handle
(
block1
,
event_manager
.clone
());
let
publish_handle2
=
BlockRegistry
::
create_publish_handle
(
block2
,
event_manager
.clone
());
// Remove handles before adding to publisher
let
reg_handle1
=
publish_handle1
.remove_handle
();
let
reg_handle2
=
publish_handle2
.remove_handle
();
// Add disarmed handles to publisher
publisher
.take_handle
(
publish_handle1
);
publisher
.take_handle
(
publish_handle2
);
// no events yet
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop the publisher - should trigger a single Publish event with both Register events
drop
(
publisher
);
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
2
,
"Should receive two Register events in one batch"
);
// Order isn't guaranteed, so check for both
assert
!
(
events
.contains
(
&
EventType
::
Register
(
hash1
)));
assert
!
(
events
.contains
(
&
EventType
::
Register
(
hash2
)));
// no more events immediately after publish
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop registration handles individually - should trigger Remove events
drop
(
reg_handle1
);
let
events1
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events1
.len
(),
1
);
assert_eq!
(
events1
[
0
],
EventType
::
Remove
(
hash1
));
drop
(
reg_handle2
);
let
events2
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events2
.len
(),
1
);
assert_eq!
(
events2
[
0
],
EventType
::
Remove
(
hash2
));
// no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_publisher_empty_drop
()
{
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publisher
=
event_manager
.publisher
();
drop
(
publisher
);
// No events should be sent
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_publisher_publish_multiple_times
()
{
let
sequence
=
create_sequence
();
let
block1
=
&
sequence
.blocks
()[
0
];
let
hash1
=
block1
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
mut
publisher
=
event_manager
.publisher
();
let
publish_handle1
=
BlockRegistry
::
create_publish_handle
(
block1
,
event_manager
.clone
());
publisher
.take_handle
(
publish_handle1
);
// First publish call
publisher
.publish
();
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Register
(
hash1
));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately.
let
remove_events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
remove_events
.len
(),
1
,
"Remove event should be triggered after publish consumes the handle"
);
assert_eq!
(
remove_events
[
0
],
EventType
::
Remove
(
hash1
),
"Expected Remove event"
);
// Second publish call (should do nothing as handles were taken)
publisher
.publish
();
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop publisher (should also do nothing)
drop
(
publisher
);
assert
!
(
rx
.try_recv
()
.is_err
());
}
}
lib/llm/src/block_manager/block/state.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
sync
::
Arc
;
use
derive_getters
::
Getters
;
use
super
::
registry
::
RegistrationHandle
;
use
super
::
Result
;
use
crate
::
tokens
::{
PartialTokenBlock
,
SaltHash
,
Token
,
TokenBlock
,
Tokens
};
#[derive(Debug,
thiserror::Error)]
#[error(
"Block state is invalid: {0}"
)]
pub
struct
BlockStateInvalid
(
pub
String
);
#[derive(Debug)]
pub
enum
BlockState
{
Reset
,
Partial
(
PartialState
),
Complete
(
CompleteState
),
Registered
(
Arc
<
RegistrationHandle
>
),
}
impl
BlockState
{
pub
fn
initialize_sequence
(
&
mut
self
,
page_size
:
usize
,
salt_hash
:
SaltHash
,
)
->
Result
<
(),
BlockStateInvalid
>
{
if
!
matches!
(
self
,
BlockState
::
Reset
)
{
return
Err
(
BlockStateInvalid
(
"Block is not reset"
.to_string
()));
}
let
block
=
PartialTokenBlock
::
create_sequence_root
(
page_size
,
salt_hash
);
*
self
=
BlockState
::
Partial
(
PartialState
::
new
(
block
));
Ok
(())
}
pub
fn
add_token
(
&
mut
self
,
token
:
Token
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
Ok
(
state
.block
.push_token
(
token
)
?
),
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
add_tokens
(
&
mut
self
,
tokens
:
Tokens
)
->
Result
<
Tokens
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
Ok
(
state
.block
.push_tokens
(
tokens
)),
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
pop_token
(
&
mut
self
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
state
.block
.pop_token
()
?
;
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
pop_tokens
(
&
mut
self
,
count
:
usize
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
state
.block
.pop_tokens
(
count
)
?
;
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
commit
(
&
mut
self
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
let
token_block
=
state
.block
.commit
()
?
;
*
self
=
BlockState
::
Complete
(
CompleteState
::
new
(
token_block
));
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
apply_token_block
(
&
mut
self
,
token_block
:
TokenBlock
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Reset
=>
{
*
self
=
BlockState
::
Complete
(
CompleteState
::
new
(
token_block
));
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not reset"
.to_string
()))
?
,
}
}
/// Returns the number of tokens currently in the block.
pub
fn
len
(
&
self
)
->
Option
<
usize
>
{
match
self
{
BlockState
::
Reset
=>
Some
(
0
),
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.len
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()
.len
()),
BlockState
::
Registered
(
_
)
=>
None
,
}
}
/// Returns the number of additional tokens that can be added.
pub
fn
remaining
(
&
self
)
->
usize
{
match
self
{
BlockState
::
Partial
(
state
)
=>
state
.block
.remaining
(),
_
=>
0
,
// Reset, Complete, Registered have 0 remaining capacity
}
}
/// Returns true if the block contains no tokens.
pub
fn
is_empty
(
&
self
)
->
bool
{
match
self
{
BlockState
::
Reset
=>
true
,
BlockState
::
Partial
(
state
)
=>
state
.block
.is_empty
(),
BlockState
::
Complete
(
_
)
=>
false
,
// Always full
BlockState
::
Registered
(
_
)
=>
false
,
// Always full
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
{
match
self
{
BlockState
::
Reset
|
BlockState
::
Registered
(
_
)
=>
None
,
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.tokens
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()),
}
}
/// Returns true if the block is empty
pub
fn
is_reset
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Reset
)
}
/// Returns true if the block is in the complete or registered state
pub
fn
is_complete
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Complete
(
_
)
|
BlockState
::
Registered
(
_
))
}
/// Returns true if the block is in the registered state
pub
fn
is_registered
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Registered
(
_
state
))
}
}
#[derive(Debug)]
pub
struct
PartialState
{
block
:
PartialTokenBlock
,
}
impl
PartialState
{
pub
fn
new
(
block
:
PartialTokenBlock
)
->
Self
{
Self
{
block
}
}
}
#[derive(Debug,
Getters)]
pub
struct
CompleteState
{
token_block
:
TokenBlock
,
}
impl
CompleteState
{
pub
fn
new
(
token_block
:
TokenBlock
)
->
Self
{
Self
{
token_block
}
}
}
lib/llm/src/block_manager/block/transfer.rs
0 → 100644
View file @
4564a387
This diff is collapsed.
Click to expand it.
lib/llm/src/block_manager/block/transfer/cuda.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
super
::
TransferError
;
use
crate
::
block_manager
::
storage
::{
DeviceStorage
,
PinnedStorage
};
use
anyhow
::
Result
;
use
cudarc
::
driver
::
result
as
cuda_result
;
use
std
::
ops
::
Range
;
type
CudaMemcpyFnPtr
=
unsafe
fn
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
;
fn
cuda_memcpy_fn_ptr
(
strategy
:
&
TransferStrategy
)
->
Result
<
CudaMemcpyFnPtr
,
TransferError
>
{
match
strategy
{
TransferStrategy
::
CudaAsyncH2D
=>
Ok
(
cuda_memcpy_h2d
),
TransferStrategy
::
CudaAsyncD2H
=>
Ok
(
cuda_memcpy_d2h
),
TransferStrategy
::
CudaAsyncD2D
=>
Ok
(
cuda_memcpy_d2d
),
_
=>
Err
(
TransferError
::
ExecutionError
(
"Unsupported copy strategy for CUDA memcpy async"
.into
(),
)),
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
copy_block
<
'a
,
Source
,
Destination
>
(
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
stream
:
&
CudaStream
,
strategy
:
TransferStrategy
,
)
->
Result
<
(),
TransferError
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
let
memcpy_fn
=
cuda_memcpy_fn_ptr
(
&
strategy
)
?
;
#[cfg(debug_assertions)]
{
let
expected_strategy
=
expected_strategy
::
<
Source
::
StorageType
,
Destination
::
StorageType
>
();
assert_eq!
(
strategy
,
expected_strategy
);
}
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
src_view
=
src_data
.block_view
()
?
;
let
mut
dst_view
=
dst_data
.block_view_mut
()
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
}
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
copy_layers
(
0
..
src_data
.num_layers
(),
sources
,
destinations
,
stream
,
strategy
,
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub
fn
copy_layers
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
stream
:
&
CudaStream
,
strategy
:
TransferStrategy
,
)
->
Result
<
(),
TransferError
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
let
memcpy_fn
=
cuda_memcpy_fn_ptr
(
&
strategy
)
?
;
#[cfg(debug_assertions)]
{
let
expected_strategy
=
expected_strategy
::
<
Source
::
StorageType
,
Destination
::
StorageType
>
();
assert_eq!
(
strategy
,
expected_strategy
);
}
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
}
}
Ok
(())
}
/// Helper function to perform the appropriate CUDA memcpy based on storage types
// Allow dead code because it's used in debug assertions
#[allow(dead_code)]
fn
expected_strategy
<
Source
:
Storage
,
Dest
:
Storage
>
()
->
TransferStrategy
{
match
(
std
::
any
::
TypeId
::
of
::
<
Source
>
(),
std
::
any
::
TypeId
::
of
::
<
Dest
>
(),
)
{
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
PinnedStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncH2D
}
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
PinnedStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncD2H
}
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncD2D
}
_
=>
TransferStrategy
::
Invalid
,
}
}
/// H2D Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_h2d
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source host pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination device pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
let
src_slice
=
std
::
slice
::
from_raw_parts
(
src_ptr
,
size
);
cuda_result
::
memcpy_htod_async
(
dst_ptr
as
u64
,
src_slice
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA H2D memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
/// D2H Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_d2h
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source device pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination host pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
let
dst_slice
=
std
::
slice
::
from_raw_parts_mut
(
dst_ptr
,
size
);
cuda_result
::
memcpy_dtoh_async
(
dst_slice
,
src_ptr
as
u64
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA D2H memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
/// D2D Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_d2d
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source device pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination device pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
cuda_result
::
memcpy_dtod_async
(
dst_ptr
as
u64
,
src_ptr
as
u64
,
size
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA D2D memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
#[cfg(all(test,
feature
=
"testing-cuda"
))]
mod
tests
{
use
super
::
*
;
use
crate
::
block_manager
::
storage
::{
DeviceAllocator
,
PinnedAllocator
,
StorageAllocator
,
StorageMemset
,
};
#[test]
fn
test_memset_and_transfer
()
{
// Create allocators
let
device_allocator
=
DeviceAllocator
::
default
();
let
pinned_allocator
=
PinnedAllocator
::
default
();
let
ctx
=
device_allocator
.ctx
()
.clone
();
// Create CUDA stream
let
stream
=
ctx
.new_stream
()
.unwrap
();
// Allocate host and device memory
let
mut
host
=
pinned_allocator
.allocate
(
1024
)
.unwrap
();
let
mut
device
=
device_allocator
.allocate
(
1024
)
.unwrap
();
// Set a pattern in host memory
StorageMemset
::
memset
(
&
mut
host
,
42
,
0
,
1024
)
.unwrap
();
// Verify host memory was set correctly
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
42
));
}
// Copy host to device
unsafe
{
cuda_memcpy_h2d
(
host
.as_ptr
(),
device
.as_mut_ptr
(),
1024
,
stream
.as_ref
())
.unwrap
();
}
// Synchronize to ensure H2D copy is complete
stream
.synchronize
()
.unwrap
();
// Clear host memory
StorageMemset
::
memset
(
&
mut
host
,
0
,
0
,
1024
)
.unwrap
();
// Verify host memory was cleared
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
0
));
}
// Copy back from device to host
unsafe
{
cuda_memcpy_d2h
(
device
.as_ptr
(),
host
.as_mut_ptr
(),
1024
,
stream
.as_ref
())
.unwrap
();
}
// Synchronize to ensure D2H copy is complete before verifying
stream
.synchronize
()
.unwrap
();
// Verify the original pattern was restored
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
42
));
}
}
}
lib/llm/src/block_manager/block/transfer/memcpy.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
/// Copy a block from a source to a destination using memcpy
pub
fn
copy_block
<
'a
,
Source
,
Destination
>
(
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
)
->
Result
<
(),
TransferError
>
where
Source
:
ReadableBlock
,
Destination
:
WritableBlock
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
src_view
=
src_data
.block_view
()
?
;
let
mut
dst_view
=
dst_data
.block_view_mut
()
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
copy_layers
(
0
..
src_data
.num_layers
(),
sources
,
destinations
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using memcpy
pub
fn
copy_layers
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
)
->
Result
<
(),
TransferError
>
where
Source
:
ReadableBlock
,
// <Source as ReadableBlock>::StorageType: SystemAccessible + Local,
Destination
:
WritableBlock
,
// <Destination as WritableBlock>::StorageType: SystemAccessible + Local,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
Ok
(())
}
#[inline(always)]
unsafe
fn
memcpy
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
)
{
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination memory regions must not overlap for copy_nonoverlapping"
);
std
::
ptr
::
copy_nonoverlapping
(
src_ptr
,
dst_ptr
,
size
);
}
lib/llm/src/block_manager/block/transfer/nixl.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
std
::
ops
::
Range
;
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
write_block_to
<
'a
,
Source
,
Destination
>
(
src
:
&
'a
Source
,
dst
:
&
'a
mut
Destination
,
ctx
:
&
TransferContext
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
nixl_agent
=
ctx
.nixl_agent
()
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
None
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
mut
status
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
tracing
::
span!
(
tracing
::
Level
::
DEBUG
,
"Waiting for transfer to complete"
)
.in_scope
(||
{
while
status
{
status
=
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
();
}
});
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
write_layers_to
(
0
..
src_data
.num_layers
(),
src
,
dst
,
ctx
,
notify
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub
fn
write_layers_to
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
src
:
&
'a
Source
,
dst
:
&
'a
mut
Destination
,
ctx
:
&
TransferContext
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
let
nixl_agent
=
ctx
.nixl_agent
()
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
}
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
Some
(
&
xfer_args
),
)
?
;
let
mut
status
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
tracing
::
span!
(
tracing
::
Level
::
DEBUG
,
"Waiting for transfer to complete"
)
.in_scope
(||
{
while
status
{
status
=
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
();
}
});
Ok
(())
}
lib/llm/src/block_manager/block/transfer/strategy.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This module implements the `WriteToStrategy` and `ReadFromStrategy` traits
//! for the common storage types.
use
super
::
*
;
impl
WriteToStrategy
<
SystemStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaBlockingH2D
}
}
impl
WriteToStrategy
<
SystemStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncH2D
}
}
impl
WriteToStrategy
<
SystemStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaBlockingD2H
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncD2H
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncD2D
}
}
impl
<
S
:
Storage
+
Local
>
WriteToStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
SystemStorage
where
S
:
WriteToStrategy
<
SystemStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
PinnedStorage
where
S
:
WriteToStrategy
<
PinnedStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
DeviceStorage
where
S
:
WriteToStrategy
<
DeviceStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
:
Storage
+
Local
>
ReadFromStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlRead
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
write_to_strategy
()
{
// System to ...
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingH2D
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Pinned to ...
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncH2D
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Device to ...
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2D
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
#[test]
fn
read_from_strategy
()
{
// System to ...
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaBlockingD2H
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Pinned to ...
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncD2H
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Device to ...
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaBlockingH2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncH2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncD2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
}
}
lib/llm/src/block_manager/block/view.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block storage management.
//!
//! This module provides the implementation for managing collections of blocks
//! and their storage. It handles the relationship between storage, layout,
//! and individual blocks.
use
super
::{
BlockData
,
BlockError
,
Storage
};
pub
trait
Kind
:
std
::
marker
::
Sized
+
std
::
fmt
::
Debug
+
Clone
+
Copy
+
Send
+
Sync
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
BlockKind
;
impl
Kind
for
BlockKind
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
LayerKind
;
impl
Kind
for
LayerKind
{}
pub
type
BlockView
<
'a
,
S
>
=
MemoryView
<
'a
,
S
,
BlockKind
>
;
pub
type
BlockViewMut
<
'a
,
S
>
=
MemoryViewMut
<
'a
,
S
,
BlockKind
>
;
pub
type
LayerView
<
'a
,
S
>
=
MemoryView
<
'a
,
S
,
LayerKind
>
;
pub
type
LayerViewMut
<
'a
,
S
>
=
MemoryViewMut
<
'a
,
S
,
LayerKind
>
;
/// Storage view that provides safe access to a region of storage
#[derive(Debug)]
pub
struct
MemoryView
<
'a
,
S
:
Storage
,
K
:
Kind
>
{
_
block_data
:
&
'a
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
kind
:
std
::
marker
::
PhantomData
<
K
>
,
}
impl
<
'a
,
S
,
K
>
MemoryView
<
'a
,
S
,
K
>
where
S
:
Storage
,
K
:
Kind
,
{
/// Create a new storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
pub
(
crate
)
unsafe
fn
new
(
_
block_data
:
&
'a
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
)
->
Result
<
Self
,
BlockError
>
{
Ok
(
Self
{
_
block_data
,
addr
,
size
,
kind
:
std
::
marker
::
PhantomData
,
})
}
/// Get a raw pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - Access patterns respect the storage's thread safety model
pub
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
/// Size of the view in bytes
pub
fn
size
(
&
self
)
->
usize
{
self
.size
}
}
/// Mutable storage view that provides exclusive access to a region of storage
#[derive(Debug)]
pub
struct
MemoryViewMut
<
'a
,
S
:
Storage
,
K
:
Kind
>
{
_
block_data
:
&
'a
mut
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
kind
:
std
::
marker
::
PhantomData
<
K
>
,
}
impl
<
'a
,
S
:
Storage
,
K
:
Kind
>
MemoryViewMut
<
'a
,
S
,
K
>
{
/// Create a new mutable storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
/// - No other views exist for this region
pub
(
crate
)
unsafe
fn
new
(
_
block_data
:
&
'a
mut
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
)
->
Result
<
Self
,
BlockError
>
{
Ok
(
Self
{
_
block_data
,
addr
,
size
,
kind
:
std
::
marker
::
PhantomData
,
})
}
/// Get a raw mutable pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - No other references exist while the pointer is in use
/// - Access patterns respect the storage's thread safety model
pub
unsafe
fn
as_mut_ptr
(
&
mut
self
)
->
*
mut
u8
{
self
.addr
as
*
mut
u8
}
/// Size of the view in bytes
pub
fn
size
(
&
self
)
->
usize
{
self
.size
}
}
mod
nixl
{
use
super
::
*
;
use
super
::
super
::
nixl
::
*
;
pub
use
nixl_sys
::{
MemType
,
MemoryRegion
,
NixlDescriptor
};
impl
<
S
:
Storage
,
K
:
Kind
>
MemoryRegion
for
MemoryView
<
'_
,
S
,
K
>
{
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
fn
size
(
&
self
)
->
usize
{
self
.size
()
}
}
impl
<
S
,
K
>
NixlDescriptor
for
MemoryView
<
'_
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
fn
mem_type
(
&
self
)
->
MemType
{
self
._block_data.layout
.storage_type
()
.nixl_mem_type
()
}
fn
device_id
(
&
self
)
->
u64
{
self
._block_data.layout
.storage_type
()
.nixl_device_id
()
}
}
impl
<
S
:
Storage
,
K
:
Kind
>
MemoryRegion
for
MemoryViewMut
<
'_
,
S
,
K
>
{
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
fn
size
(
&
self
)
->
usize
{
self
.size
()
}
}
impl
<
S
:
Storage
,
K
:
Kind
>
NixlDescriptor
for
MemoryViewMut
<
'_
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
fn
mem_type
(
&
self
)
->
MemType
{
self
._block_data.layout
.storage_type
()
.nixl_mem_type
()
}
fn
device_id
(
&
self
)
->
u64
{
self
._block_data.layout
.storage_type
()
.nixl_device_id
()
}
}
impl
<
'a
,
S
,
K
>
MemoryView
<
'a
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
// Ensure the underlying storage is a NixlDescriptor
K
:
Kind
,
{
/// Creates an immutable NIXL memory descriptor from this view.
pub
fn
as_nixl_descriptor
(
&
self
)
->
NixlMemoryDescriptor
<
'a
,
K
,
IsImmutable
>
{
NixlMemoryDescriptor
::
new
(
self
.addr
as
u64
,
// Address from the view
self
.size
(),
// Size from the view
NixlDescriptor
::
mem_type
(
self
),
// Delegate to self's NixlDescriptor impl
NixlDescriptor
::
device_id
(
self
),
// Delegate to self's NixlDescriptor impl
)
}
}
impl
<
'a
,
S
,
K
>
MemoryViewMut
<
'a
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
/// Creates a mutable NIXL memory descriptor from this view.
// Note: We return a mutable descriptor even from an immutable borrow (&self)
// because the underlying memory region *can* be mutated.
pub
fn
as_nixl_descriptor_mut
(
&
mut
self
)
->
NixlMemoryDescriptor
<
'a
,
K
,
IsMutable
>
{
NixlMemoryDescriptor
::
new
(
self
.addr
as
u64
,
self
.size
(),
NixlDescriptor
::
mem_type
(
self
),
// Delegate to self's NixlDescriptor impl
NixlDescriptor
::
device_id
(
self
),
// Delegate to self's NixlDescriptor impl
)
}
}
}
lib/llm/src/block_manager/config.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
#[derive(Debug,
Clone)]
pub
enum
NixlOptions
{
/// Enable NIXL and create a new NIXL agent
Enabled
,
/// Enable NIXL and use the provided NIXL agent
EnabledWithAgent
(
NixlAgent
),
/// Disable NIXL
Disabled
,
}
#[derive(Debug,
Clone,
Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvManagerRuntimeConfig
{
pub
worker_id
:
u64
,
#[builder(default)]
pub
cancellation_token
:
CancellationToken
,
#[builder(default
=
"NixlOptions::Enabled"
)]
pub
nixl
:
NixlOptions
,
}
impl
KvManagerRuntimeConfig
{
pub
fn
builder
()
->
KvManagerRuntimeConfigBuilder
{
KvManagerRuntimeConfigBuilder
::
default
()
}
}
impl
KvManagerRuntimeConfigBuilder
{
pub
fn
enable_nixl
(
mut
self
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
Enabled
);
self
}
pub
fn
use_nixl_agent
(
mut
self
,
agent
:
NixlAgent
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
EnabledWithAgent
(
agent
));
self
}
pub
fn
disable_nixl
(
mut
self
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
Disabled
);
self
}
}
#[derive(Debug,
Clone,
Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvManagerModelConfig
{
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
#[validate(range(min
=
1
))]
pub
inner_dim
:
usize
,
#[builder(default
=
"DType::FP16"
)]
pub
dtype
:
DType
,
}
impl
KvManagerModelConfig
{
pub
fn
builder
()
->
KvManagerModelConfigBuilder
{
KvManagerModelConfigBuilder
::
default
()
}
}
#[derive(Builder,
Validate)]
#[builder(pattern
=
"owned"
,
build_fn(validate
=
"Self::validate"
))]
pub
struct
KvManagerLayoutConfig
<
S
:
Storage
+
NixlRegisterableStorage
>
{
/// The number of blocks to allocate
#[validate(range(min
=
1
))]
pub
num_blocks
:
usize
,
/// The type of layout to use
#[builder(default
=
"LayoutType::FullyContiguous"
)]
pub
layout_type
:
LayoutType
,
/// Storage for the blocks
/// If provided, the blocks will be allocated from the provided storage
/// Otherwise, the blocks will be allocated from
#[builder(default)]
pub
storage
:
Option
<
Vec
<
S
>>
,
/// If provided, the blocks will be allocated from the provided allocator
/// This option is mutually exclusive with the `storage` option
#[builder(default,
setter(custom))]
pub
allocator
:
Option
<
Arc
<
dyn
StorageAllocator
<
S
>>>
,
}
impl
<
S
:
Storage
+
NixlRegisterableStorage
>
KvManagerLayoutConfig
<
S
>
{
/// Create a new builder for the KvManagerLayoutConfig
pub
fn
builder
()
->
KvManagerLayoutConfigBuilder
<
S
>
{
KvManagerLayoutConfigBuilder
::
default
()
}
}
// Implement the validation and build functions on the generated builder type
// Note: derive_builder generates KvManagerBlockConfigBuilder<S>
impl
<
S
:
Storage
+
NixlRegisterableStorage
>
KvManagerLayoutConfigBuilder
<
S
>
{
/// Custom setter for the `allocator` field
pub
fn
allocator
(
mut
self
,
allocator
:
impl
StorageAllocator
<
S
>
+
'static
)
->
Self
{
self
.allocator
=
Some
(
Some
(
Arc
::
new
(
allocator
)));
self
}
// Validation function
fn
validate
(
&
self
)
->
Result
<
(),
String
>
{
match
(
self
.storage
.is_some
(),
self
.allocator
.is_some
())
{
(
true
,
false
)
|
(
false
,
true
)
=>
Ok
(()),
// XOR condition met
(
true
,
true
)
=>
Err
(
"Cannot provide both `storage` and `allocator`."
.to_string
()),
(
false
,
false
)
=>
Err
(
"Must provide either `storage` or `allocator`."
.to_string
()),
}
}
}
/// Configuration for the KvBlockManager
#[derive(Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvBlockManagerConfig
{
/// Runtime configuration
///
/// This provides core runtime configuration for the KvBlockManager.
pub
runtime
:
KvManagerRuntimeConfig
,
/// Model configuration
///
/// This provides model-specific configuration for the KvBlockManager, specifically,
/// the number of layers and the size of the inner dimension which is directly related
/// to the type of attention used by the model.
///
/// Included in this configuration is also the page_size, i.e. the number of tokens that will
/// be represented in each "paged" KV block.
pub
model
:
KvManagerModelConfig
,
/// Specific configuration for the device layout
///
/// This includes the number of blocks and the layout of the data into the device memory/storage.
#[builder(default,
setter(strip_option))]
pub
device_layout
:
Option
<
KvManagerLayoutConfig
<
DeviceStorage
>>
,
/// Specific configuration for the host layout
///
/// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default,
setter(strip_option))]
pub
host_layout
:
Option
<
KvManagerLayoutConfig
<
PinnedStorage
>>
,
}
impl
KvBlockManagerConfig
{
/// Create a new builder for the KvBlockManagerConfig
pub
fn
builder
()
->
KvBlockManagerConfigBuilder
{
KvBlockManagerConfigBuilder
::
default
()
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment