Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
4564a387
Unverified
Commit
4564a387
authored
May 09, 2025
by
Ryan Olson
Committed by
GitHub
May 09, 2025
Browse files
feat: kv block manager (#965)
parent
24e2cbf5
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4612 additions
and
127 deletions
+4612
-127
.github/workflows/pre-merge-python.yml
.github/workflows/pre-merge-python.yml
+5
-2
.github/workflows/pre-merge-rust.yml
.github/workflows/pre-merge-rust.yml
+0
-2
Cargo.lock
Cargo.lock
+45
-2
Cargo.toml
Cargo.toml
+2
-0
container/Dockerfile.vllm
container/Dockerfile.vllm
+12
-2
lib/bindings/python/Cargo.lock
lib/bindings/python/Cargo.lock
+7
-0
lib/bindings/python/pyproject.toml
lib/bindings/python/pyproject.toml
+2
-1
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+8
-3
lib/llm/build.rs
lib/llm/build.rs
+122
-115
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+297
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+1683
-0
lib/llm/src/block_manager/block/registry.rs
lib/llm/src/block_manager/block/registry.rs
+390
-0
lib/llm/src/block_manager/block/state.rs
lib/llm/src/block_manager/block/state.rs
+179
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+634
-0
lib/llm/src/block_manager/block/transfer/cuda.rs
lib/llm/src/block_manager/block/transfer/cuda.rs
+292
-0
lib/llm/src/block_manager/block/transfer/memcpy.rs
lib/llm/src/block_manager/block/transfer/memcpy.rs
+80
-0
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+161
-0
lib/llm/src/block_manager/block/transfer/strategy.rs
lib/llm/src/block_manager/block/transfer/strategy.rs
+296
-0
lib/llm/src/block_manager/block/view.rs
lib/llm/src/block_manager/block/view.rs
+224
-0
lib/llm/src/block_manager/config.rs
lib/llm/src/block_manager/config.rs
+173
-0
No files found.
.github/workflows/pre-merge-python.yml
View file @
4564a387
...
@@ -47,15 +47,18 @@ jobs:
...
@@ -47,15 +47,18 @@ jobs:
GITHUB_TOKEN
:
${{ secrets.CI_TOKEN }}
GITHUB_TOKEN
:
${{ secrets.CI_TOKEN }}
run
:
|
run
:
|
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target ci_minimum --framework ${{ matrix.framework }}
-
name
:
Run Rust checks (llm/block-manager)
run
:
|
docker run -w /workspace/lib/llm --name ${{ env.CONTAINER_ID }}_rust_checks ${{ steps.define_image_tag.outputs.image_tag }} bash -ec 'rustup component add rustfmt clippy && cargo fmt -- --check && cargo clippy --features block-manager --no-deps --all-targets -- -D warnings && cargo test --locked --all-targets --features=block-manager'
-
name
:
Run pytest
-
name
:
Run pytest
env
:
env
:
PYTEST_MARKS
:
"
pre_merge
or
mypy"
PYTEST_MARKS
:
"
pre_merge
or
mypy"
run
:
|
run
:
|
docker run -w /workspace --name ${{ env.CONTAINER_ID }} ${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
docker run -w /workspace --name ${{ env.CONTAINER_ID }}
_pytest
${{ steps.define_image_tag.outputs.image_tag }} pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m "${{ env.PYTEST_MARKS }}"
-
name
:
Copy test report from test Container
-
name
:
Copy test report from test Container
if
:
always()
if
:
always()
run
:
|
run
:
|
docker cp ${{ env.CONTAINER_ID }}:/workspace/${{ env.PYTEST_XML_FILE }} .
docker cp ${{ env.CONTAINER_ID }}
_pytest
:/workspace/${{ env.PYTEST_XML_FILE }} .
-
name
:
Archive test report
-
name
:
Archive test report
uses
:
actions/upload-artifact@v4
uses
:
actions/upload-artifact@v4
if
:
always()
if
:
always()
...
...
.github/workflows/pre-merge-rust.yml
View file @
4564a387
...
@@ -23,8 +23,6 @@ on:
...
@@ -23,8 +23,6 @@ on:
# Run this workflow on pull requests targeting main but only if files in runtime/rust change.
# Run this workflow on pull requests targeting main but only if files in runtime/rust change.
pull_request
:
pull_request
:
branches
:
-
main
paths
:
paths
:
-
.github/workflows/pre-merge-rust.yml
-
.github/workflows/pre-merge-rust.yml
-
'
lib/runtime/**'
-
'
lib/runtime/**'
...
...
Cargo.lock
View file @
4564a387
...
@@ -513,6 +513,26 @@ dependencies = [
...
@@ -513,6 +513,26 @@ dependencies = [
"which",
"which",
]
]
[[package]]
name = "bindgen"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 2.1.1",
"shlex",
"syn 2.0.100",
]
[[package]]
[[package]]
name = "bindgen_cuda"
name = "bindgen_cuda"
version = "0.1.5"
version = "0.1.5"
...
@@ -1606,6 +1626,8 @@ dependencies = [
...
@@ -1606,6 +1626,8 @@ dependencies = [
"minijinja",
"minijinja",
"minijinja-contrib",
"minijinja-contrib",
"ndarray",
"ndarray",
"nixl-sys",
"oneshot",
"prometheus",
"prometheus",
"proptest",
"proptest",
"rand 0.9.1",
"rand 0.9.1",
...
@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -3379,7 +3401,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
dependencies = [
"cfg-if 1.0.0",
"cfg-if 1.0.0",
"windows-targets 0.
52.6
",
"windows-targets 0.
48.5
",
]
]
[[package]]
[[package]]
...
@@ -3441,7 +3463,7 @@ version = "0.1.103"
...
@@ -3441,7 +3463,7 @@ version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [
dependencies = [
"bindgen",
"bindgen
0.69.5
",
"cc",
"cc",
"cmake",
"cmake",
"find_cuda_helper",
"find_cuda_helper",
...
@@ -4057,6 +4079,21 @@ dependencies = [
...
@@ -4057,6 +4079,21 @@ dependencies = [
"libc",
"libc",
]
]
[[package]]
name = "nixl-sys"
version = "0.2.1-rc.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4698155ec791ae888482b0c1c5d19792a1b457167fa8ec0ffa487ffef8c8e5a"
dependencies = [
"bindgen 0.71.1",
"cc",
"libc",
"pkg-config",
"serde",
"thiserror 2.0.12",
"tracing",
]
[[package]]
[[package]]
name = "nkeys"
name = "nkeys"
version = "0.4.4"
version = "0.4.4"
...
@@ -4282,6 +4319,12 @@ version = "1.21.3"
...
@@ -4282,6 +4319,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
[[package]]
name = "onig"
name = "onig"
version = "6.4.0"
version = "6.4.0"
...
...
Cargo.toml
View file @
4564a387
...
@@ -60,6 +60,7 @@ futures = { version = "0.3" }
...
@@ -60,6 +60,7 @@ futures = { version = "0.3" }
hf-hub
=
{
version
=
"0.4.2"
,
default-features
=
false
,
features
=
[
"tokio"
,
"rustls-tls"
]
}
hf-hub
=
{
version
=
"0.4.2"
,
default-features
=
false
,
features
=
[
"tokio"
,
"rustls-tls"
]
}
humantime
=
{
version
=
"2.2.0"
}
humantime
=
{
version
=
"2.2.0"
}
libc
=
{
version
=
"0.2"
}
libc
=
{
version
=
"0.2"
}
oneshot
=
{
version
=
"0.1.11"
,
features
=
[
"std"
,
"async"
]
}
prometheus
=
{
version
=
"0.14"
}
prometheus
=
{
version
=
"0.14"
}
rand
=
{
version
=
"0.9.0"
}
rand
=
{
version
=
"0.9.0"
}
serde
=
{
version
=
"1"
,
features
=
["derive"]
}
serde
=
{
version
=
"1"
,
features
=
["derive"]
}
...
@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
...
@@ -77,6 +78,7 @@ uuid = { version = "1", features = ["v4", "serde"] }
url
=
{
version
=
"2.5"
,
features
=
["serde"]
}
url
=
{
version
=
"2.5"
,
features
=
["serde"]
}
xxhash-rust
=
{
version
=
"0.8"
,
features
=
[
"xxh3"
,
"const_xxh3"
]
}
xxhash-rust
=
{
version
=
"0.8"
,
features
=
[
"xxh3"
,
"const_xxh3"
]
}
[profile.dev.package]
[profile.dev.package]
insta.opt-level
=
3
insta.opt-level
=
3
...
...
container/Dockerfile.vllm
View file @
4564a387
...
@@ -108,6 +108,12 @@ WORKDIR /workspace
...
@@ -108,6 +108,12 @@ WORKDIR /workspace
# Copy nixl source, and use commit hash as cache hint
# Copy nixl source, and use commit hash as cache hint
COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl /opt/nixl
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install
### NATS & ETCD SETUP ###
### NATS & ETCD SETUP ###
# nats
# nats
...
@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
...
@@ -310,6 +316,7 @@ ARG RELEASE_BUILD
WORKDIR /workspace
WORKDIR /workspace
RUN yum update -y \
RUN yum update -y \
&& yum install -y llvm-toolset \
&& yum install -y python3.12-devel \
&& yum install -y python3.12-devel \
&& yum install -y protobuf-compiler \
&& yum install -y protobuf-compiler \
&& yum clean all \
&& yum clean all \
...
@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
...
@@ -322,6 +329,7 @@ ENV RUSTUP_HOME=/usr/local/rustup \
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $RUSTUP_HOME $RUSTUP_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base $CARGO_HOME $CARGO_HOME
COPY --from=base /usr/local/nixl /opt/nvidia/nvda_nixl
COPY --from=base /workspace /workspace
COPY --from=base /workspace /workspace
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH
...
@@ -342,7 +350,7 @@ COPY launch /workspace/launch
...
@@ -342,7 +350,7 @@ COPY launch /workspace/launch
COPY deploy/sdk /workspace/deploy/sdk
COPY deploy/sdk /workspace/deploy/sdk
# Build Rust crate binaries packaged with the wheel
# Build Rust crate binaries packaged with the wheel
RUN cargo build --release --locked --features mistralrs,python \
RUN cargo build --release --locked --features mistralrs,python
,dynamo-llm/block-manager
\
-p dynamo-run \
-p dynamo-run \
-p llmctl \
-p llmctl \
# Multiple http named crates are present in dependencies, need to specify the path
# Multiple http named crates are present in dependencies, need to specify the path
...
@@ -370,6 +378,7 @@ WORKDIR /workspace
...
@@ -370,6 +378,7 @@ WORKDIR /workspace
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/dist/ /workspace/dist/
COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /workspace/target/ /workspace/target/
COPY --from=wheel_builder /opt/nvidia/nvda_nixl /opt/nvidia/nvda_nixl
# Copy Cargo cache to avoid re-downloading dependencies
# Copy Cargo cache to avoid re-downloading dependencies
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME
...
@@ -377,7 +386,7 @@ COPY . /workspace
...
@@ -377,7 +386,7 @@ COPY . /workspace
# Build rest of the crates
# Build rest of the crates
# Need to figure out rust caching to avoid rebuilding and remove exclude flags
# Need to figure out rust caching to avoid rebuilding and remove exclude flags
RUN cargo build --release --locked --workspace \
RUN cargo build --release --locked
--features block-manager
--workspace \
--exclude dynamo-run \
--exclude dynamo-run \
--exclude llmctl \
--exclude llmctl \
--exclude file://$PWD/components/http \
--exclude file://$PWD/components/http \
...
@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
...
@@ -405,6 +414,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing
# Tell vllm to use the Dynamo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/
##########################################
##########################################
########## Perf Analyzer Image ###########
########## Perf Analyzer Image ###########
...
...
lib/bindings/python/Cargo.lock
View file @
4564a387
...
@@ -1058,6 +1058,7 @@ dependencies = [
...
@@ -1058,6 +1058,7 @@ dependencies = [
"memmap2",
"memmap2",
"minijinja",
"minijinja",
"minijinja-contrib",
"minijinja-contrib",
"oneshot",
"prometheus",
"prometheus",
"rand 0.9.1",
"rand 0.9.1",
"rayon",
"rayon",
...
@@ -2859,6 +2860,12 @@ version = "1.21.3"
...
@@ -2859,6 +2860,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oneshot"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea"
[[package]]
[[package]]
name = "onig"
name = "onig"
version = "6.4.0"
version = "6.4.0"
...
...
lib/bindings/python/pyproject.toml
View file @
4564a387
...
@@ -51,6 +51,7 @@ module-name = "dynamo._core"
...
@@ -51,6 +51,7 @@ module-name = "dynamo._core"
manifest-path
=
"Cargo.toml"
manifest-path
=
"Cargo.toml"
python-packages
=
["dynamo"]
python-packages
=
["dynamo"]
python-source
=
"src"
python-source
=
"src"
features
=
["dynamo-llm/block-manager"]
[build-system]
[build-system]
requires
=
[
"maturin>=1.0,<2.0"
,
"patchelf"
]
requires
=
[
"maturin>=1.0,<2.0"
,
"patchelf"
]
...
...
lib/llm/Cargo.toml
View file @
4564a387
...
@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
...
@@ -27,7 +27,10 @@ description = "Dynamo LLM Library"
[features]
[features]
default
=
[]
default
=
[]
cuda_kv
=
[
"dep:cudarc"
,
"dep:ndarray"
]
testing-full
=
[
"testing-cuda"
,
"testing-nixl"
]
testing-cuda
=
["dep:cudarc"]
testing-nixl
=
["dep:nixl-sys"]
block-manager
=
[
"dep:nixl-sys"
,
"dep:cudarc"
,
"dep:ndarray"
]
sentencepiece
=
["dep:sentencepiece"]
sentencepiece
=
["dep:sentencepiece"]
[dependencies]
[dependencies]
...
@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
...
@@ -48,6 +51,7 @@ etcd-client = { workspace = true }
futures
=
{
workspace
=
true
}
futures
=
{
workspace
=
true
}
hf-hub
=
{
workspace
=
true
}
hf-hub
=
{
workspace
=
true
}
rand
=
{
workspace
=
true
}
rand
=
{
workspace
=
true
}
oneshot
=
{
workspace
=
true
}
prometheus
=
{
workspace
=
true
}
prometheus
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
...
@@ -72,8 +76,9 @@ derive-getters = "0.5"
...
@@ -72,8 +76,9 @@ derive-getters = "0.5"
regex
=
"1"
regex
=
"1"
rayon
=
"1"
rayon
=
"1"
# kv_cuda
# block_manager
cudarc
=
{
version
=
"0.16.2"
,
features
=
["cuda-12040"]
,
optional
=
true
}
nixl-sys
=
{
version
=
"0.2.1-rc.1"
,
optional
=
true
}
cudarc
=
{
version
=
"0.16.2"
,
features
=
["cuda-12020"]
,
optional
=
true
}
ndarray
=
{
version
=
"0.16"
,
optional
=
true
}
ndarray
=
{
version
=
"0.16"
,
optional
=
true
}
# protocols
# protocols
...
...
lib/llm/build.rs
View file @
4564a387
...
@@ -13,121 +13,128 @@
...
@@ -13,121 +13,128 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#[cfg(not(feature
=
"cuda_kv"
))]
fn
main
()
{}
#[cfg(feature
=
"cuda_kv"
)]
fn
main
()
{
fn
main
()
{
use
std
::{
path
::
PathBuf
,
process
::
Command
};
println!
(
"cargo:warning=Building with CUDA KV off"
);
println!
(
"cargo:rerun-if-changed=src/kernels/block_copy.cu"
);
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let
nvcc
=
Command
::
new
(
"which"
)
.arg
(
"nvcc"
)
.output
()
.unwrap
();
let
cuda_lib
=
if
nvcc
.status
.success
()
{
println!
(
"cargo:info=nvcc found in path"
);
// Extract the path from nvcc location by removing "bin/nvcc"
let
nvcc_path
=
String
::
from_utf8_lossy
(
&
nvcc
.stdout
)
.trim
()
.to_string
();
let
path
=
PathBuf
::
from
(
nvcc_path
);
if
let
Some
(
parent
)
=
path
.parent
()
{
// Remove "nvcc"
if
let
Some
(
cuda_root
)
=
parent
.parent
()
{
// Remove "bin"
cuda_root
.to_string_lossy
()
.to_string
()
}
else
{
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default
()
}
}
else
{
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default
()
}
}
else
{
println!
(
"cargo:warning=nvcc not found in path"
);
get_cuda_root_or_default
()
};
println!
(
"cargo:info=Using CUDA installation at: {}"
,
cuda_lib
);
let
cuda_lib_path
=
PathBuf
::
from
(
&
cuda_lib
)
.join
(
"lib64"
);
println!
(
"cargo:info=Using CUDA libs: {}"
,
cuda_lib_path
.display
());
println!
(
"cargo:rustc-link-search=native={}"
,
cuda_lib_path
.display
());
// Link against multiple CUDA libraries
println!
(
"cargo:rustc-link-lib=dylib=cudart"
);
println!
(
"cargo:rustc-link-lib=dylib=cuda"
);
println!
(
"cargo:rustc-link-lib=dylib=cudadevrt"
);
// Make sure the CUDA libraries are found before other system libraries
println!
(
"cargo:rustc-link-arg=-Wl,-rpath,{}"
,
cuda_lib_path
.display
()
);
// Create kernels directory for output if it doesn't exist
std
::
fs
::
create_dir_all
(
"src/kernels"
)
.unwrap_or_else
(|
_
|
{
println!
(
"Kernels directory already exists"
);
});
// Compile CUDA code
let
output
=
Command
::
new
(
"nvcc"
)
.arg
(
"src/kernels/block_copy.cu"
)
.arg
(
"-O3"
)
.arg
(
"--compiler-options"
)
.arg
(
"-fPIC"
)
.arg
(
"-o"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.arg
(
"-c"
)
.output
()
.expect
(
"Failed to compile CUDA code"
);
if
!
output
.status
.success
()
{
panic!
(
"Failed to compile CUDA kernel: {}"
,
String
::
from_utf8_lossy
(
&
output
.stderr
)
);
}
// Create static library
#[cfg(target_os
=
"windows"
)]
{
Command
::
new
(
"lib"
)
.arg
(
"/OUT:src/kernels/block_copy.lib"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.output
()
.expect
(
"Failed to create static library"
);
println!
(
"cargo:rustc-link-search=native=src/kernels"
);
println!
(
"cargo:rustc-link-lib=static=block_copy"
);
}
#[cfg(not(target_os
=
"windows"
))]
{
Command
::
new
(
"ar"
)
.arg
(
"rcs"
)
.arg
(
"src/kernels/libblock_copy.a"
)
.arg
(
"src/kernels/libblock_copy.o"
)
.output
()
.expect
(
"Failed to create static library"
);
println!
(
"cargo:rustc-link-search=native=src/kernels"
);
println!
(
"cargo:rustc-link-lib=static=block_copy"
);
println!
(
"cargo:rustc-link-lib=dylib=cudart"
);
println!
(
"cargo:rustc-link-lib=dylib=cuda"
);
println!
(
"cargo:rustc-link-lib=dylib=cudadevrt"
);
}
}
}
#[cfg(feature
=
"cuda_kv"
)]
// NOTE: Preserving this build.rs for reference. We may want to re-enable
fn
get_cuda_root_or_default
()
->
String
{
// custom kernel compilation in the future.
match
std
::
env
::
var
(
"CUDA_ROOT"
)
{
Ok
(
path
)
=>
path
,
// #[cfg(not(feature = "cuda_kv"))]
Err
(
_
)
=>
{
// fn main() {}
// Default locations based on OS
if
cfg!
(
target_os
=
"windows"
)
{
// #[cfg(feature = "cuda_kv")]
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8"
.to_string
()
// fn main() {
}
else
{
// use std::{path::PathBuf, process::Command};
"/usr/local/cuda"
.to_string
()
}
// println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
}
}
// // first do a which nvcc, if it is in the path
}
// // if so, we don't need to set the cuda_lib
// let nvcc = Command::new("which").arg("nvcc").output().unwrap();
// let cuda_lib = if nvcc.status.success() {
// println!("cargo:info=nvcc found in path");
// // Extract the path from nvcc location by removing "bin/nvcc"
// let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
// let path = PathBuf::from(nvcc_path);
// if let Some(parent) = path.parent() {
// // Remove "nvcc"
// if let Some(cuda_root) = parent.parent() {
// // Remove "bin"
// cuda_root.to_string_lossy().to_string()
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// // Fallback to CUDA_ROOT or default if path extraction fails
// get_cuda_root_or_default()
// }
// } else {
// println!("cargo:warning=nvcc not found in path");
// get_cuda_root_or_default()
// };
// println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
// let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
// println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
// println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// // Link against multiple CUDA libraries
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// // Make sure the CUDA libraries are found before other system libraries
// println!(
// "cargo:rustc-link-arg=-Wl,-rpath,{}",
// cuda_lib_path.display()
// );
// // Create kernels directory for output if it doesn't exist
// std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
// println!("Kernels directory already exists");
// });
// // Compile CUDA code
// let output = Command::new("nvcc")
// .arg("src/kernels/block_copy.cu")
// .arg("-O3")
// .arg("--compiler-options")
// .arg("-fPIC")
// .arg("-o")
// .arg("src/kernels/libblock_copy.o")
// .arg("-c")
// .output()
// .expect("Failed to compile CUDA code");
// if !output.status.success() {
// panic!(
// "Failed to compile CUDA kernel: {}",
// String::from_utf8_lossy(&output.stderr)
// );
// }
// // Create static library
// #[cfg(target_os = "windows")]
// {
// Command::new("lib")
// .arg("/OUT:src/kernels/block_copy.lib")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// }
// #[cfg(not(target_os = "windows"))]
// {
// Command::new("ar")
// .arg("rcs")
// .arg("src/kernels/libblock_copy.a")
// .arg("src/kernels/libblock_copy.o")
// .output()
// .expect("Failed to create static library");
// println!("cargo:rustc-link-search=native=src/kernels");
// println!("cargo:rustc-link-lib=static=block_copy");
// println!("cargo:rustc-link-lib=dylib=cudart");
// println!("cargo:rustc-link-lib=dylib=cuda");
// println!("cargo:rustc-link-lib=dylib=cudadevrt");
// }
// }
// #[cfg(feature = "cuda_kv")]
// fn get_cuda_root_or_default() -> String {
// match std::env::var("CUDA_ROOT") {
// Ok(path) => path,
// Err(_) => {
// // Default locations based on OS
// if cfg!(target_os = "windows") {
// "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
// } else {
// "/usr/local/cuda".to_string()
// }
// }
// }
// }
lib/llm/src/block_manager.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block Manager for LLM KV Cache
//!
//! This module provides functionality for managing KV blocks in LLM attention
//! mechanisms. It handles storage allocation, block management, and safe access
//! patterns for both system memory and remote (NIXL) storage.
mod
config
;
mod
state
;
pub
mod
block
;
pub
mod
events
;
pub
mod
layout
;
pub
mod
pool
;
pub
mod
storage
;
pub
use
crate
::
common
::
dtype
::
DType
;
pub
use
block
::{
nixl
::{
AsBlockDescriptorSet
,
BlockDescriptorList
,
IsImmutable
,
IsMutable
,
MutabilityKind
,
RemoteBlock
,
},
transfer
::{
BlockTransferEngineV1
,
TransferRequestPut
},
BasicMetadata
,
BlockMetadata
,
Blocks
,
};
pub
use
config
::
*
;
pub
use
layout
::{
nixl
::
NixlLayout
,
LayoutConfig
,
LayoutConfigBuilder
,
LayoutError
,
LayoutType
};
pub
use
pool
::
BlockPool
;
pub
use
storage
::{
nixl
::
NixlRegisterableStorage
,
DeviceStorage
,
PinnedStorage
,
Storage
,
StorageAllocator
,
};
pub
use
tokio_util
::
sync
::
CancellationToken
;
use
anyhow
::{
Context
,
Result
};
use
block
::
nixl
::{
BlockMutability
,
NixlBlockSet
,
RemoteBlocks
,
SerializedNixlBlockSet
};
use
derive_builder
::
Builder
;
use
nixl_sys
::
Agent
as
NixlAgent
;
use
std
::{
collections
::
HashMap
,
sync
::{
Arc
,
RwLock
},
};
use
storage
::
nixl
::
MemType
;
use
validator
::
Validate
;
pub
type
WorkerID
=
u64
;
pub
type
ReferenceBlockManager
=
KvBlockManager
<
BasicMetadata
>
;
/// Represents the different cache levels for KV blocks
pub
enum
CacheLevel
{
/// Represents KV blocks in GPU memory
G1
,
/// Represents KV blocks in CPU memory
G2
,
/// Represents KV blocks in Local NVMe storage
G3
,
/// Represents KV blocks in Remote NVMe storage
G4
,
}
// When we construct the pool:
// 1. instantiate the runtime,
// 2. build layout::LayoutConfigs for each of the requested storage types
// 3. register the layouts with the NIXL agent if enabled
// 4. construct a Blocks object for each layout providing a unique block_set_idx
// for each layout type.
// 5. initialize the pools for each set of blocks
pub
struct
KvBlockManager
<
Metadata
:
BlockMetadata
>
{
state
:
Arc
<
state
::
KvBlockManagerState
<
Metadata
>>
,
cancellation_token
:
CancellationToken
,
}
impl
<
Metadata
:
BlockMetadata
>
KvBlockManager
<
Metadata
>
{
/// Create a new [KvBlockManager]
///
/// The returned object is a frontend to the [KvBlockManager] which owns the cancellation
/// tokens. When this object gets drop, the cancellation token will be cancelled and begin
/// the gracefully shutdown of the block managers internal state.
pub
fn
new
(
config
:
KvBlockManagerConfig
)
->
Result
<
Self
>
{
let
mut
config
=
config
;
// The frontend of the KvBlockManager will take ownership of the cancellation token
// and will be responsible for cancelling the task when the KvBlockManager is dropped
let
cancellation_token
=
config
.runtime.cancellation_token
.clone
();
// The internal state will use a child token of the original token
config
.runtime.cancellation_token
=
cancellation_token
.child_token
();
// Create the internal state
let
state
=
state
::
KvBlockManagerState
::
new
(
config
)
?
;
Ok
(
Self
{
state
,
cancellation_token
,
})
}
/// Exports the local blockset configuration as a serialized object.
pub
fn
export_local_blockset
(
&
self
)
->
Result
<
SerializedNixlBlockSet
>
{
self
.state
.export_local_blockset
()
}
/// Imports a remote blockset configuration from a serialized object.
pub
fn
import_remote_blockset
(
&
self
,
serialized_blockset
:
SerializedNixlBlockSet
,
)
->
Result
<
()
>
{
self
.state
.import_remote_blockset
(
serialized_blockset
)
}
/// Get a [`Vec<RemoteBlock<IsImmutable>>`] from a [`BlockDescriptorList`]
pub
fn
get_remote_blocks_immutable
(
&
self
,
bds
:
&
BlockDescriptorList
,
)
->
Result
<
Vec
<
RemoteBlock
<
IsImmutable
>>>
{
self
.state
.get_remote_blocks_immutable
(
bds
)
}
/// Get a [`Vec<RemoteBlock<IsMutable>>`] from a [`BlockDescriptorList`]
pub
fn
get_remote_blocks_mutable
(
&
self
,
bds
:
&
BlockDescriptorList
,
)
->
Result
<
Vec
<
RemoteBlock
<
IsMutable
>>>
{
self
.state
.get_remote_blocks_mutable
(
bds
)
}
/// Get a reference to the host block pool
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
self
.state
.host
()
}
/// Get a reference to the device block pool
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
self
.state
.device
()
}
/// Get the worker ID
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
self
.state
.worker_id
()
}
}
impl
<
Metadata
:
BlockMetadata
>
Drop
for
KvBlockManager
<
Metadata
>
{
fn
drop
(
&
mut
self
)
{
self
.cancellation_token
.cancel
();
}
}
#[cfg(all(test,
feature
=
"testing-full"
))]
mod
tests
{
use
super
::
*
;
use
std
::
sync
::
atomic
::{
AtomicU64
,
Ordering
};
// Atomic Counter for Worker ID
static
WORKER_ID
:
AtomicU64
=
AtomicU64
::
new
(
1337
);
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
config
=
KvBlockManagerConfig
::
builder
()
.runtime
(
KvManagerRuntimeConfig
::
builder
()
.worker_id
(
worker_id
)
.build
()
.unwrap
(),
)
.model
(
KvManagerModelConfig
::
builder
()
.num_layers
(
3
)
.page_size
(
4
)
.inner_dim
(
16
)
.build
()
.unwrap
(),
)
.host_layout
(
KvManagerLayoutConfig
::
builder
()
.num_blocks
(
16
)
.allocator
(
storage
::
PinnedAllocator
::
default
())
.build
()
.unwrap
(),
)
.device_layout
(
KvManagerLayoutConfig
::
builder
()
.num_blocks
(
8
)
.allocator
(
storage
::
DeviceAllocator
::
new
(
0
)
.unwrap
())
.build
()
.unwrap
(),
)
.build
()
.unwrap
();
ReferenceBlockManager
::
new
(
config
)
.unwrap
()
}
#[tokio::test]
async
fn
test_reference_block_manager_inherited_async_runtime
()
{
dynamo_runtime
::
logging
::
init
();
let
_
block_manager
=
create_reference_block_manager
();
}
#[test]
fn
test_reference_block_manager_blocking
()
{
dynamo_runtime
::
logging
::
init
();
let
_
block_manager
=
create_reference_block_manager
();
}
// This tests mimics the behavior of two unique kvbm workers exchanging blocksets
// Each KvBlockManager is a unique worker in this test, each has its resources including
// it's own worker_ids, nixl_agent, and block pools.
//
// This test is meant to mimic the behavior of the basic nixl integration test found here:
// https://github.com/ai-dynamo/nixl/blob/main/src/bindings/rust/src/tests.rs
#[tokio::test]
async
fn
test_reference_block_managers
()
{
dynamo_runtime
::
logging
::
init
();
// create two block managers - mimics two unique dynamo workers
let
kvbm_0
=
create_reference_block_manager
();
let
kvbm_1
=
create_reference_block_manager
();
assert_ne!
(
kvbm_0
.worker_id
(),
kvbm_1
.worker_id
());
// in dynamo, we would exchange the blocksets via the discovery plane
let
blockset_0
=
kvbm_0
.export_local_blockset
()
.unwrap
();
let
blockset_1
=
kvbm_1
.export_local_blockset
()
.unwrap
();
// in dynamo, we would be watching the discovery plane for remote blocksets
kvbm_0
.import_remote_blockset
(
blockset_1
)
.unwrap
();
kvbm_1
.import_remote_blockset
(
blockset_0
)
.unwrap
();
// Worker 0
// Allocate 4 mutable blocks on the host
let
blocks_0
=
kvbm_0
.host
()
.unwrap
()
.allocate_blocks
(
4
)
.await
.unwrap
();
// Create a BlockDescriptorList for the mutable blocks
// let blockset_0 = BlockDescriptorList::from_mutable_blocks(&blocks_0).unwrap();
let
blockset_0
=
blocks_0
.as_block_descriptor_set
()
.unwrap
();
// Worker 1
// Create a RemoteBlock list from blockset_0
let
_
blocks_1
=
kvbm_1
.host
()
.unwrap
()
.allocate_blocks
(
4
)
.await
.unwrap
();
let
mut
_
remote_blocks_0
=
kvbm_1
.get_remote_blocks_mutable
(
&
blockset_0
)
.unwrap
();
// TODO(#967) - Enable with TransferEngine
// // Create a TransferRequestPut for the mutable blocks
// let transfer_request = TransferRequestPut::new(&blocks_0, &mut remote_blocks_0).unwrap();
// // Validate blocks - this could be an expensive operation
// // TODO: Create an ENV trigger debug flag which will call this on every transfer request
// // In this case, we expect an error because we have overlapping blocks as we are sending to/from the same blocks
// // because we are using the wrong target (artifact of the test setup allowing variable to cross what woudl be
// // worker boundaries)
// assert!(transfer_request.validate_blocks().is_err());
// // This is proper request - PUT from worker 1 (local) to worker 0 (remote)
// let transfer_request = TransferRequestPut::new(&blocks_1, &mut remote_blocks_0).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
// let mut put_request = PutRequestBuilder::<_, _>::builder();
// put_request.from(&blocks_1).to(&mut remote_blocks_0);
// // Create a Put request direct between two local blocks
// // split the blocks into two vecs each with 2 blocks
// let mut blocks_1 = blocks_1;
// let slice_0 = blocks_1.split_off(2);
// let mut slice_1 = blocks_1;
// let transfer_request = TransferRequestPut::new(&slice_0, &mut slice_1).unwrap();
// assert!(transfer_request.validate_blocks().is_ok());
// // Execute the transfer request
// transfer_request.execute().unwrap();
}
}
lib/llm/src/block_manager/block.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub
mod
registry
;
pub
mod
state
;
pub
mod
transfer
;
pub
mod
view
;
pub
use
crate
::
tokens
::
TokenBlockError
;
pub
use
anyhow
::
Result
;
use
nixl_sys
::
NixlDescriptor
;
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
use
crate
::
block_manager
::{
state
::{
KvBlockManagerState
as
BlockManager
,
TransferContext
},
storage
::{
Local
,
Remote
,
Storage
},
};
use
crate
::
tokens
::{
SaltHash
,
SequenceHash
,
Token
,
TokenBlock
,
Tokens
};
use
transfer
::{
Immutable
,
Mutable
,
Readable
,
Writable
};
use
super
::{
events
::
PublishHandle
,
layout
::{
BlockLayout
,
LayoutError
,
LayoutType
},
storage
::
StorageType
,
WorkerID
,
};
use
std
::{
fmt
::
Debug
,
ops
::{
Deref
,
DerefMut
},
sync
::
Arc
,
};
use
thiserror
::
Error
;
mod
private
{
pub
struct
PrivateToken
;
}
/// A unique identifier for a block
pub
type
BlockId
=
usize
;
/// A unique identifier for a block set
pub
type
BlockSetId
=
usize
;
/// Result type for Block operations
pub
type
BlockResult
<
T
>
=
std
::
result
::
Result
<
T
,
BlockError
>
;
/// Errors specific to block storage operations
#[derive(Debug,
Error)]
pub
enum
BlockError
{
#[error(transparent)]
Layout
(
#[from]
LayoutError
),
#[error(
"Invalid state: {0}"
)]
InvalidState
(
String
),
#[error(transparent)]
Other
(
#[from]
anyhow
::
Error
),
}
pub
trait
BlockMetadata
:
Default
+
std
::
fmt
::
Debug
+
Clone
+
Ord
+
Send
+
Sync
+
'static
{
/// Called when the block is acquired from the pool
fn
on_acquired
(
&
mut
self
,
tick
:
u64
);
/// Called when the block is returned to the pool
fn
on_returned
(
&
mut
self
,
tick
:
u64
);
/// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true
fn
reset_metadata
(
&
mut
self
);
}
/// Marker trait for types that are mutable blocks
pub
trait
WritableBlock
:
BlockDataProviderMut
{
type
StorageType
:
Storage
+
NixlDescriptor
;
fn
storage_type_id
(
&
self
)
->
std
::
any
::
TypeId
{
std
::
any
::
TypeId
::
of
::
<<
Self
as
WritableBlock
>
::
StorageType
>
()
}
}
/// Marker trait for types that are immutable blocks
pub
trait
ReadableBlock
:
BlockDataProvider
{
type
StorageType
:
Storage
+
NixlDescriptor
;
fn
storage_type_id
(
&
self
)
->
std
::
any
::
TypeId
{
std
::
any
::
TypeId
::
of
::
<<
Self
as
ReadableBlock
>
::
StorageType
>
()
}
fn
transfer_context
(
&
self
)
->
&
TransferContext
{
unimplemented!
()
}
}
pub
trait
ReadableBlocks
{}
impl
<
T
:
ReadableBlock
>
ReadableBlocks
for
Vec
<
T
>
{}
impl
<
T
:
ReadableBlock
>
ReadableBlocks
for
[
T
]
{}
impl
<
T
:
ReadableBlock
>
ReadableBlocks
for
&
[
T
]
{}
pub
trait
WritableBlocks
{}
impl
<
T
:
WritableBlock
>
WritableBlocks
for
Vec
<
T
>
{}
impl
<
T
:
WritableBlock
>
WritableBlocks
for
[
T
]
{}
impl
<
T
:
WritableBlock
>
WritableBlocks
for
&
[
T
]
{}
/// Blanket trait for anything that can be viewed as a slice of blocks
pub
trait
AsBlockSlice
<
'a
,
B
:
'a
>
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
B
];
}
/// Blanket trait for anything that can be viewed as a mutable slice of blocks
pub
trait
AsBlockMutSlice
<
'a
,
B
:
'a
>
{
fn
as_block_mut_slice
(
&
'a
mut
self
)
->
&
'a
mut
[
B
];
}
/// Blanket trait for anything that can be converted into a mutable block
pub
trait
IntoWritableBlocks
<
M
:
BlockMetadata
>
{
type
Output
:
WritableBlocks
;
fn
into_writable_blocks
(
self
,
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
;
}
impl
<
T
:
WritableBlocks
,
M
:
BlockMetadata
>
IntoWritableBlocks
<
M
>
for
T
{
type
Output
=
T
;
fn
into_writable_blocks
(
self
,
_
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
self
)
}
}
pub
trait
IntoReadableBlocks
<
M
:
BlockMetadata
>
{
type
Output
:
ReadableBlocks
;
fn
into_readable_blocks
(
self
,
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
;
}
impl
<
T
:
ReadableBlocks
,
M
:
BlockMetadata
>
IntoReadableBlocks
<
M
>
for
T
{
type
Output
=
T
;
fn
into_readable_blocks
(
self
,
_
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
self
)
}
}
/// A block with storage and associated metadata/state
#[derive(Debug)]
pub
struct
Block
<
S
:
Storage
,
M
:
BlockMetadata
>
{
data
:
BlockData
<
S
>
,
metadata
:
M
,
state
:
BlockState
,
manager
:
Option
<
Arc
<
BlockManager
<
M
>>>
,
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Block
<
S
,
M
>
{
/// Create a new block with default metadata/state
pub
fn
new
(
data
:
BlockData
<
S
>
,
metadata
:
M
)
->
BlockResult
<
Self
>
{
Ok
(
Self
{
data
,
metadata
,
state
:
BlockState
::
Reset
,
manager
:
None
,
})
}
pub
fn
sequence_hash
(
&
self
)
->
Result
<
SequenceHash
,
BlockError
>
{
match
self
.state
()
{
BlockState
::
Complete
(
state
)
=>
Ok
(
state
.token_block
()
.sequence_hash
()),
BlockState
::
Registered
(
state
)
=>
Ok
(
state
.sequence_hash
()),
_
=>
Err
(
BlockError
::
InvalidState
(
"Block is not complete"
.to_string
(),
)),
}
}
pub
(
crate
)
fn
reset
(
&
mut
self
)
{
self
.state
=
BlockState
::
Reset
;
self
.metadata
.reset_metadata
();
}
pub
(
crate
)
fn
set_manager
(
&
mut
self
,
manager
:
Arc
<
BlockManager
<
M
>>
)
{
self
.manager
=
Some
(
manager
);
}
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub
(
crate
)
fn
manager
(
&
self
)
->
Option
<&
Arc
<
BlockManager
<
M
>>>
{
self
.manager
.as_ref
()
}
/// Get the metadata of the block
pub
fn
metadata
(
&
self
)
->
&
M
{
&
self
.metadata
}
/// Update the metadata of the block
pub
fn
update_metadata
(
&
mut
self
,
metadata
:
M
)
{
self
.metadata
=
metadata
;
}
/// Update the state of the block
#[allow(dead_code)]
pub
(
crate
)
fn
update_state
(
&
mut
self
,
state
:
BlockState
)
{
self
.state
=
state
;
}
/// Get a reference to the state of the block
pub
fn
state
(
&
self
)
->
&
BlockState
{
&
self
.state
}
/// Get the number of blocks in the block
pub
fn
num_blocks
(
&
self
)
->
usize
{
self
.data.layout
.num_blocks
()
}
/// Get the number of layers in the block
pub
fn
num_layers
(
&
self
)
->
usize
{
self
.data.layout
.num_layers
()
}
/// Get the size of each block in the block
pub
fn
page_size
(
&
self
)
->
usize
{
self
.data.layout
.page_size
()
}
/// Get the inner dimension of the block
pub
fn
inner_dim
(
&
self
)
->
usize
{
self
.data.layout
.inner_dim
()
}
pub
(
crate
)
fn
metadata_on_acquired
(
&
mut
self
,
tick
:
u64
)
{
self
.metadata
.on_acquired
(
tick
);
}
pub
(
crate
)
fn
metadata_on_returned
(
&
mut
self
,
tick
:
u64
)
{
self
.metadata
.on_returned
(
tick
);
}
}
pub
(
crate
)
trait
PrivateBlockExt
{
fn
register
(
&
mut
self
,
registry
:
&
mut
registry
::
BlockRegistry
,
)
->
Result
<
PublishHandle
,
registry
::
BlockRegistationError
>
;
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
PrivateBlockExt
for
Block
<
S
,
M
>
{
fn
register
(
&
mut
self
,
registry
:
&
mut
registry
::
BlockRegistry
,
)
->
Result
<
PublishHandle
,
registry
::
BlockRegistationError
>
{
registry
.register_block
(
&
mut
self
.state
)
}
}
pub
trait
BlockExt
{
/// Reset the state of the block
fn
reset
(
&
mut
self
);
/// Initialize a sequence on the block using a [SaltHash]
///
/// The block must be in the [BlockState::Reset] state.
///
/// After initialization, the block will be in the [BlockState::Partial] state.
fn
init_sequence
(
&
mut
self
,
salt_hash
:
SaltHash
)
->
Result
<
()
>
;
/// Appends a single token to the block if it is in the Partial state and not full.
/// Returns `Err` if the block is not Partial or already full.
fn
add_token
(
&
mut
self
,
token
:
Token
)
->
Result
<
()
>
;
/// Appends multiple tokens to the block if it is in the Partial state
/// and has enough remaining capacity for *all* provided tokens.
/// The block must be in the [BlockState::Partial] state.
/// Returns `Err` if the block is not Partial or if there isn't enough space.
fn
add_tokens
(
&
mut
self
,
tokens
:
Tokens
)
->
Result
<
Tokens
>
;
/// Removes the last token from the block.
/// Requires the block to be in the Partial state and not empty.
/// Returns `Err` otherwise.
fn
pop_token
(
&
mut
self
)
->
Result
<
()
>
;
/// Removes the last `count` tokens from the block.
/// Requires the block to be in the Partial state and have at least `count` tokens.
/// Returns `Err` otherwise.
fn
pop_tokens
(
&
mut
self
,
count
:
usize
)
->
Result
<
()
>
;
/// Commit the block
/// Requires the block to be in the [BlockState::Partial] state and completely full.
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
fn
commit
(
&
mut
self
)
->
Result
<
()
>
;
/// Apply a [TokenBlock] to the block
/// Requires the block to be in the [BlockState::Reset] state.
///
/// Additionally, the [TokenBlock] must match the [BlockLayout::page_size()]
/// Transitions the state to [BlockState::Complete]. Returns `Err` otherwise.
fn
apply_token_block
(
&
mut
self
,
token_block
:
TokenBlock
)
->
Result
<
()
>
;
/// Returns the number of tokens currently in the block.
fn
len
(
&
self
)
->
usize
;
/// Returns the number of additional tokens that can be added (only valid for Partial state).
fn
remaining
(
&
self
)
->
usize
;
/// Returns true if the block contains no tokens (only true for Reset or empty Partial state).
fn
is_empty
(
&
self
)
->
bool
;
/// Returns true if the block is full.
fn
is_full
(
&
self
)
->
bool
;
/// Returns a list of tokens in the block.
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
;
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
BlockExt
for
Block
<
S
,
M
>
{
fn
reset
(
&
mut
self
)
{
Block
::
reset
(
self
);
}
fn
init_sequence
(
&
mut
self
,
salt_hash
:
SaltHash
)
->
Result
<
()
>
{
Ok
(
self
.state
.initialize_sequence
(
self
.page_size
(),
salt_hash
)
?
)
}
fn
add_token
(
&
mut
self
,
token
:
Token
)
->
Result
<
()
>
{
self
.state
.add_token
(
token
)
}
fn
add_tokens
(
&
mut
self
,
tokens
:
Tokens
)
->
Result
<
Tokens
>
{
self
.state
.add_tokens
(
tokens
)
}
fn
pop_token
(
&
mut
self
)
->
Result
<
()
>
{
self
.state
.pop_token
()
}
fn
pop_tokens
(
&
mut
self
,
count
:
usize
)
->
Result
<
()
>
{
self
.state
.pop_tokens
(
count
)
}
fn
commit
(
&
mut
self
)
->
Result
<
()
>
{
self
.state
.commit
()
}
fn
apply_token_block
(
&
mut
self
,
token_block
:
TokenBlock
)
->
Result
<
()
>
{
if
self
.page_size
()
!=
token_block
.tokens
()
.len
()
{
return
Err
(
BlockStateInvalid
(
format!
(
"TokenBlock size ({}) does not match Block page size ({})"
,
token_block
.tokens
()
.len
(),
self
.page_size
()
))
.into
());
}
self
.state
.apply_token_block
(
token_block
)
}
fn
len
(
&
self
)
->
usize
{
match
self
.state
.len
()
{
Some
(
len
)
=>
len
,
None
=>
self
.page_size
(),
}
}
fn
remaining
(
&
self
)
->
usize
{
self
.state
.remaining
()
}
fn
is_empty
(
&
self
)
->
bool
{
self
.state
.is_empty
()
}
fn
is_full
(
&
self
)
->
bool
{
self
.len
()
==
self
.page_size
()
}
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
{
self
.state
.tokens
()
}
}
pub
trait
BlockDataExt
<
S
:
Storage
+
NixlDescriptor
>
{
/// Returns true if the block data is fully contiguous
fn
is_fully_contiguous
(
&
self
)
->
bool
;
/// Returns the number of layers in the block
fn
num_layers
(
&
self
)
->
usize
;
/// Get a read-only view of this block's storage for a layer
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
;
/// Get a mutable view of this block's storage for a layer
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
;
/// Get a read-only view of this block's storage
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
;
/// Get a mutable view of this block's storage
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
S
>>
;
}
/// Individual block storage - cannot be cloned to ensure uniqueness
#[derive(Debug)]
pub
struct
BlockData
<
S
:
Storage
>
{
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
S
>>
,
block_idx
:
usize
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
}
impl
<
S
>
BlockData
<
S
>
where
S
:
Storage
,
{
/// Create a new block storage
pub
(
crate
)
fn
new
(
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
S
>>
,
block_idx
:
usize
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
)
->
Self
{
Self
{
layout
,
block_idx
,
block_set_idx
,
worker_id
,
}
}
pub
fn
storage_type
(
&
self
)
->
StorageType
{
self
.layout
.storage_type
()
}
}
impl
<
S
:
Storage
+
NixlDescriptor
>
BlockDataExt
<
S
>
for
BlockData
<
S
>
where
S
:
Storage
+
NixlDescriptor
,
{
fn
is_fully_contiguous
(
&
self
)
->
bool
{
self
.layout
.layout_type
()
==
LayoutType
::
FullyContiguous
}
fn
num_layers
(
&
self
)
->
usize
{
self
.layout
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
S
>>
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
unsafe
{
view
::
LayerView
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerViewMut
<
S
>>
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
layer_idx
)
?
;
unsafe
{
view
::
LayerViewMut
::
new
(
self
,
offset
as
usize
,
self
.layout
.memory_region_size
())
}
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
S
>>
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
unsafe
{
view
::
BlockView
::
new
(
self
,
offset
as
usize
,
size
)
}
}
else
{
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
))
}
}
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
S
>>
{
if
self
.is_fully_contiguous
()
{
let
offset
=
self
.layout
.memory_region_addr
(
self
.block_idx
,
0
)
?
;
let
size
=
self
.layout
.memory_region_size
()
*
self
.layout
.num_layers
();
unsafe
{
view
::
BlockViewMut
::
new
(
self
,
offset
as
usize
,
size
)
}
}
else
{
Err
(
BlockError
::
InvalidState
(
"Block is not fully contiguous"
.to_string
(),
))
}
}
}
pub
trait
BlockDataProvider
{
type
StorageType
:
Storage
+
NixlDescriptor
;
fn
block_data
(
&
self
,
_
:
private
::
PrivateToken
)
->
&
BlockData
<
Self
::
StorageType
>
;
}
pub
trait
BlockDataProviderMut
:
BlockDataProvider
{
fn
block_data_mut
(
&
mut
self
,
_
:
private
::
PrivateToken
)
->
&
mut
BlockData
<
Self
::
StorageType
>
;
}
#[derive(Clone,
Debug,
Default,
Eq,
PartialEq,
Ord,
PartialOrd)]
pub
struct
BasicMetadata
{
priority
:
u32
,
returned_tick
:
u64
,
acquired_tick
:
u64
,
}
impl
BlockMetadata
for
BasicMetadata
{
fn
on_acquired
(
&
mut
self
,
tick
:
u64
)
{
self
.acquired_tick
=
tick
;
}
fn
on_returned
(
&
mut
self
,
tick
:
u64
)
{
self
.returned_tick
=
tick
;
}
fn
reset_metadata
(
&
mut
self
)
{
self
.priority
=
0
;
}
}
/// Collection that holds shared storage and layout
#[derive(Debug)]
pub
struct
Blocks
<
L
:
BlockLayout
,
M
:
BlockMetadata
>
{
layout
:
Box
<
L
>
,
metadata
:
std
::
marker
::
PhantomData
<
M
>
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
}
impl
<
L
:
BlockLayout
+
'static
,
M
:
BlockMetadata
>
Blocks
<
L
,
M
>
{
/// Create a new block storage collection
pub
fn
new
(
layout
:
L
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
)
->
BlockResult
<
Self
>
{
let
layout
=
Box
::
new
(
layout
);
Ok
(
Self
{
layout
,
metadata
:
std
::
marker
::
PhantomData
,
block_set_idx
,
worker_id
,
})
}
/// Convert collection into Vec<Block> with default metadata/state
pub
fn
into_blocks
(
self
)
->
BlockResult
<
Vec
<
Block
<
L
::
StorageType
,
M
>>>
{
// convert box to arc
let
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
L
::
StorageType
>>
=
Arc
::
new
(
*
self
.layout
);
layout_to_blocks
(
layout
,
self
.block_set_idx
,
self
.worker_id
)
}
}
pub
(
crate
)
fn
layout_to_blocks
<
S
:
Storage
,
M
:
BlockMetadata
>
(
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
S
>>
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
)
->
BlockResult
<
Vec
<
Block
<
S
,
M
>>>
{
(
0
..
layout
.num_blocks
())
.map
(|
idx
|
{
let
data
=
BlockData
::
new
(
layout
.clone
(),
idx
,
block_set_idx
,
worker_id
);
Block
::
new
(
data
,
M
::
default
())
})
.collect
()
}
pub
struct
MutableBlock
<
S
:
Storage
,
M
:
BlockMetadata
>
{
block
:
Option
<
Block
<
S
,
M
>>
,
return_tx
:
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Block
<
S
,
M
>>
,
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
WritableBlock
for
MutableBlock
<
S
,
M
>
{
type
StorageType
=
S
;
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
ReadableBlock
for
MutableBlock
<
S
,
M
>
{
type
StorageType
=
S
;
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Writable
for
MutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Readable
for
MutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Mutable
for
MutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Local
for
MutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
MutableBlock
<
S
,
M
>
{
pub
(
crate
)
fn
new
(
block
:
Block
<
S
,
M
>
,
return_tx
:
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Block
<
S
,
M
>>
,
)
->
Self
{
Self
{
block
:
Some
(
block
),
return_tx
,
}
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
std
::
fmt
::
Debug
for
MutableBlock
<
S
,
M
>
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
"MutableBlock {{ block: {:?} }}"
,
self
.block
)
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Drop
for
MutableBlock
<
S
,
M
>
{
fn
drop
(
&
mut
self
)
{
if
let
Some
(
block
)
=
self
.block
.take
()
{
if
self
.return_tx
.send
(
block
)
.is_err
()
{
tracing
::
warn!
(
"block pool shutdown before block was returned"
);
}
}
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Deref
for
MutableBlock
<
S
,
M
>
{
type
Target
=
Block
<
S
,
M
>
;
fn
deref
(
&
self
)
->
&
Self
::
Target
{
self
.block
.as_ref
()
.expect
(
"block was dropped"
)
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
DerefMut
for
MutableBlock
<
S
,
M
>
{
fn
deref_mut
(
&
mut
self
)
->
&
mut
Self
::
Target
{
self
.block
.as_mut
()
.expect
(
"block was dropped"
)
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
BlockDataProvider
for
MutableBlock
<
S
,
M
>
{
type
StorageType
=
S
;
fn
block_data
(
&
self
,
_
:
private
::
PrivateToken
)
->
&
BlockData
<
S
>
{
&
self
.block
.as_ref
()
.expect
(
"block was dropped"
)
.data
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
BlockDataProviderMut
for
MutableBlock
<
S
,
M
>
{
fn
block_data_mut
(
&
mut
self
,
_
:
private
::
PrivateToken
)
->
&
mut
BlockData
<
S
>
{
&
mut
self
.block
.as_mut
()
.expect
(
"block was dropped"
)
.data
}
}
impl
<
'a
,
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
AsBlockSlice
<
'a
,
MutableBlock
<
S
,
M
>>
for
[
MutableBlock
<
S
,
M
>
]
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
MutableBlock
<
S
,
M
>
]
{
self
}
}
impl
<
'a
,
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
AsBlockSlice
<
'a
,
MutableBlock
<
S
,
M
>>
for
Vec
<
MutableBlock
<
S
,
M
>>
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
MutableBlock
<
S
,
M
>
]
{
self
.as_slice
()
}
}
impl
<
'a
,
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
AsBlockMutSlice
<
'a
,
MutableBlock
<
S
,
M
>>
for
[
MutableBlock
<
S
,
M
>
]
{
fn
as_block_mut_slice
(
&
'a
mut
self
)
->
&
'a
mut
[
MutableBlock
<
S
,
M
>
]
{
self
}
}
impl
<
'a
,
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
AsBlockMutSlice
<
'a
,
MutableBlock
<
S
,
M
>>
for
Vec
<
MutableBlock
<
S
,
M
>>
{
fn
as_block_mut_slice
(
&
'a
mut
self
)
->
&
'a
mut
[
MutableBlock
<
S
,
M
>
]
{
self
.as_mut_slice
()
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
IntoWritableBlocks
<
M
>
for
MutableBlock
<
S
,
M
>
{
type
Output
=
Vec
<
MutableBlock
<
S
,
M
>>
;
fn
into_writable_blocks
(
self
,
_
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
vec!
[
self
])
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
IntoReadableBlocks
<
M
>
for
MutableBlock
<
S
,
M
>
{
type
Output
=
Vec
<
MutableBlock
<
S
,
M
>>
;
fn
into_readable_blocks
(
self
,
_
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
vec!
[
self
])
}
}
#[derive(Debug)]
pub
struct
ImmutableBlock
<
S
:
Storage
,
M
:
BlockMetadata
>
{
block
:
Arc
<
MutableBlock
<
S
,
M
>>
,
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
ImmutableBlock
<
S
,
M
>
{
pub
(
crate
)
fn
new
(
block
:
Arc
<
MutableBlock
<
S
,
M
>>
)
->
Self
{
Self
{
block
}
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
ReadableBlock
for
ImmutableBlock
<
S
,
M
>
{
type
StorageType
=
S
;
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Readable
for
ImmutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Immutable
for
ImmutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
Local
for
ImmutableBlock
<
S
,
M
>
{}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Deref
for
ImmutableBlock
<
S
,
M
>
{
type
Target
=
Block
<
S
,
M
>
;
fn
deref
(
&
self
)
->
&
Self
::
Target
{
self
.block
.as_ref
()
.block
.as_ref
()
.expect
(
"block was dropped"
)
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
BlockDataProvider
for
ImmutableBlock
<
S
,
M
>
{
type
StorageType
=
S
;
fn
block_data
(
&
self
,
_
:
private
::
PrivateToken
)
->
&
BlockData
<
S
>
{
&
self
.block
.as_ref
()
.block
.as_ref
()
.expect
(
"block was dropped"
)
.data
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
IntoReadableBlocks
<
M
>
for
ImmutableBlock
<
S
,
M
>
{
type
Output
=
Vec
<
ImmutableBlock
<
S
,
M
>>
;
fn
into_readable_blocks
(
self
,
_
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
vec!
[
self
])
}
}
impl
<
'a
,
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
AsBlockSlice
<
'a
,
ImmutableBlock
<
S
,
M
>>
for
[
ImmutableBlock
<
S
,
M
>
]
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
ImmutableBlock
<
S
,
M
>
]
{
self
}
}
impl
<
'a
,
S
:
Storage
,
M
:
BlockMetadata
>
AsBlockSlice
<
'a
,
ImmutableBlock
<
S
,
M
>>
for
Vec
<
ImmutableBlock
<
S
,
M
>>
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
ImmutableBlock
<
S
,
M
>
]
{
self
.as_slice
()
}
}
pub
mod
nixl
{
use
super
::
*
;
use
super
::
view
::{
BlockKind
,
Kind
,
LayerKind
};
use
super
::
super
::{
layout
::
nixl
::{
NixlLayout
,
SerializedNixlBlockLayout
},
storage
::
nixl
::{
MemType
,
NixlRegisterableStorage
,
NixlStorage
},
WorkerID
,
};
use
derive_getters
::{
Dissolve
,
Getters
};
use
nixl_sys
::{
Agent
as
NixlAgent
,
MemoryRegion
,
NixlDescriptor
,
OptArgs
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
collections
::
HashMap
;
// --- Mutability Marker ---
pub
trait
MutabilityKind
:
Debug
+
Clone
+
Copy
+
Send
+
Sync
+
'static
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
IsMutable
;
impl
MutabilityKind
for
IsMutable
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
IsImmutable
;
impl
MutabilityKind
for
IsImmutable
{}
impl
<
L
:
NixlLayout
,
M
:
BlockMetadata
>
Blocks
<
L
,
M
>
where
L
::
StorageType
:
NixlRegisterableStorage
,
{
/// Register the blocks with an NIXL agent
pub
fn
nixl_register
(
&
mut
self
,
agent
:
&
NixlAgent
,
opt_args
:
Option
<&
OptArgs
>
,
)
->
anyhow
::
Result
<
()
>
{
self
.layout
.nixl_register
(
agent
,
opt_args
)
}
}
/// A unified, lifetime-bound descriptor containing information needed for NIXL operations.
/// Typed by Kind (Block/Layer) and Mutability (IsMutable/IsImmutable).
#[derive(Copy,
Clone)]
// Can be Copy/Clone as it holds basic data + markers
pub
struct
NixlMemoryDescriptor
<
'a
,
K
:
Kind
,
M
:
MutabilityKind
>
{
addr
:
u64
,
size
:
usize
,
mem_type
:
MemType
,
device_id
:
u64
,
_
lifetime
:
std
::
marker
::
PhantomData
<&
'a
()
>
,
// Binds the descriptor's lifetime to 'a
_
kind
:
std
::
marker
::
PhantomData
<
K
>
,
// Stores the Kind marker type
_
mutability
:
std
::
marker
::
PhantomData
<
M
>
,
// Stores the Mutability marker type
}
// Helper function to get the short type name
pub
(
crate
)
fn
short_type_name
<
T
>
()
->
&
'static
str
{
let
name
=
core
::
any
::
type_name
::
<
T
>
();
name
.split
(
"::"
)
.last
()
.unwrap_or
(
name
)
}
// Implement Debug manually to avoid bounds on K/M
impl
<
K
:
Kind
,
M
:
MutabilityKind
>
Debug
for
NixlMemoryDescriptor
<
'_
,
K
,
M
>
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"NixlMemoryDescriptor"
)
.field
(
"addr"
,
&
self
.addr
)
.field
(
"size"
,
&
self
.size
)
.field
(
"mem_type"
,
&
self
.mem_type
)
.field
(
"device_id"
,
&
self
.device_id
)
.field
(
"kind"
,
&
short_type_name
::
<
K
>
())
// Show marker types
.field
(
"mutability"
,
&
short_type_name
::
<
M
>
())
.finish
()
}
}
impl
<
K
:
Kind
,
M
:
MutabilityKind
>
NixlMemoryDescriptor
<
'_
,
K
,
M
>
{
/// Creates a new NixlMemoryDescriptor. Typically called via conversion methods.
#[inline]
pub
(
crate
)
fn
new
(
addr
:
u64
,
size
:
usize
,
mem_type
:
MemType
,
device_id
:
u64
)
->
Self
{
Self
{
addr
,
size
,
mem_type
,
device_id
,
_
lifetime
:
std
::
marker
::
PhantomData
,
_
kind
:
std
::
marker
::
PhantomData
,
_
mutability
:
std
::
marker
::
PhantomData
,
}
}
}
impl
<
K
:
Kind
,
M
:
MutabilityKind
>
MemoryRegion
for
NixlMemoryDescriptor
<
'_
,
K
,
M
>
{
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
fn
size
(
&
self
)
->
usize
{
self
.size
}
}
impl
<
K
:
Kind
,
M
:
MutabilityKind
>
NixlDescriptor
for
NixlMemoryDescriptor
<
'_
,
K
,
M
>
{
fn
mem_type
(
&
self
)
->
MemType
{
self
.mem_type
}
fn
device_id
(
&
self
)
->
u64
{
self
.device_id
}
}
pub
trait
NixlBlockDataImmutable
<
S
:
Storage
+
NixlDescriptor
>
:
BlockDataExt
<
S
>
{
/// Get the NIXL memory descriptor for the entire block
fn
as_block_descriptor
(
&
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsImmutable
>>
;
/// Get the NIXL memory descriptor for a specific layer
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
;
}
pub
trait
NixlBlockDataMutable
<
S
:
Storage
+
NixlDescriptor
>
:
BlockDataExt
<
S
>
+
NixlBlockDataImmutable
<
S
>
{
/// Get the NIXL memory descriptor for the entire block
fn
as_block_descriptor_mut
(
&
mut
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsMutable
>>
;
/// Get the NIXL memory descriptor for a specific layer
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
;
}
impl
<
S
:
Storage
+
NixlDescriptor
>
NixlBlockDataImmutable
<
S
>
for
BlockData
<
S
>
{
fn
as_block_descriptor
(
&
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsImmutable
>>
{
Ok
(
self
.block_view
()
?
.as_nixl_descriptor
())
}
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
Ok
(
self
.layer_view
(
layer_idx
)
?
.as_nixl_descriptor
())
}
}
impl
<
S
:
Storage
+
NixlDescriptor
>
NixlBlockDataMutable
<
S
>
for
BlockData
<
S
>
{
fn
as_block_descriptor_mut
(
&
mut
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsMutable
>>
{
Ok
(
self
.block_view_mut
()
?
.as_nixl_descriptor_mut
())
}
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
Ok
(
self
.layer_view_mut
(
layer_idx
)
?
.as_nixl_descriptor_mut
())
}
}
/// Error type for NixlBlockSet serialization/deserialization failures.
#[derive(Debug,
Error)]
pub
enum
NixlSerializationError
{
#[error(
"Serialization failed: {0}"
)]
Serialize
(
#[from]
serde_json
::
Error
),
}
/// A strongly-typed wrapper for serialized NixlBlockSet data.
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
SerializedNixlBlockSet
(
Vec
<
u8
>
);
impl
TryFrom
<&
NixlBlockSet
>
for
SerializedNixlBlockSet
{
type
Error
=
NixlSerializationError
;
/// Serializes a NixlBlockSet into SerializedNixlBlockSet.
fn
try_from
(
value
:
&
NixlBlockSet
)
->
Result
<
Self
,
Self
::
Error
>
{
let
bytes
=
serde_json
::
to_vec
(
value
)
?
;
Ok
(
SerializedNixlBlockSet
(
bytes
))
}
}
impl
TryFrom
<
NixlBlockSet
>
for
SerializedNixlBlockSet
{
type
Error
=
NixlSerializationError
;
/// Serializes a NixlBlockSet into SerializedNixlBlockSet, consuming the original.
fn
try_from
(
value
:
NixlBlockSet
)
->
Result
<
Self
,
Self
::
Error
>
{
let
bytes
=
serde_json
::
to_vec
(
&
value
)
?
;
Ok
(
SerializedNixlBlockSet
(
bytes
))
}
}
impl
TryFrom
<&
SerializedNixlBlockSet
>
for
NixlBlockSet
{
type
Error
=
NixlSerializationError
;
/// Deserializes SerializedNixlBlockSet into a NixlBlockSet.
fn
try_from
(
value
:
&
SerializedNixlBlockSet
)
->
Result
<
Self
,
Self
::
Error
>
{
let
block_set
=
serde_json
::
from_slice
(
&
value
.0
)
?
;
Ok
(
block_set
)
}
}
impl
TryFrom
<
SerializedNixlBlockSet
>
for
NixlBlockSet
{
type
Error
=
NixlSerializationError
;
/// Deserializes SerializedNixlBlockSet into a NixlBlockSet, consuming the original.
fn
try_from
(
value
:
SerializedNixlBlockSet
)
->
Result
<
Self
,
Self
::
Error
>
{
let
block_set
=
serde_json
::
from_slice
(
&
value
.0
)
?
;
Ok
(
block_set
)
}
}
#[derive(Clone,
serde::Serialize,
serde::Deserialize,
Dissolve)]
pub
struct
NixlBlockSet
{
/// The block set index
block_sets
:
HashMap
<
usize
,
SerializedNixlBlockLayout
>
,
/// Captures the NIXL metadata from [nixl_sys::Agent::get_local_md]
nixl_metadata
:
Vec
<
u8
>
,
/// Worker ID
worker_id
:
u64
,
}
impl
NixlBlockSet
{
pub
fn
new
(
worker_id
:
u64
)
->
Self
{
Self
{
block_sets
:
HashMap
::
new
(),
nixl_metadata
:
Vec
::
new
(),
worker_id
,
}
}
pub
fn
worker_id
(
&
self
)
->
u64
{
self
.worker_id
}
/// Get the block set for a given block set index
pub
fn
block_sets
(
&
self
)
->
&
HashMap
<
usize
,
SerializedNixlBlockLayout
>
{
&
self
.block_sets
}
/// Add a block set to the block set
pub
fn
add_block_set
(
&
mut
self
,
block_set_idx
:
usize
,
serialized_layout
:
SerializedNixlBlockLayout
,
)
{
self
.block_sets
.insert
(
block_set_idx
,
serialized_layout
);
}
/// Get the NIXL metadata
pub
fn
get_nixl_metadata
(
&
self
)
->
&
Vec
<
u8
>
{
&
self
.nixl_metadata
}
/// Set the NIXL metadata
pub
fn
set_nixl_metadata
(
&
mut
self
,
nixl_metadata
:
Vec
<
u8
>
)
{
self
.nixl_metadata
=
nixl_metadata
;
}
}
#[derive(Debug,
Clone)]
pub
struct
RemoteBlocks
{
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
NixlStorage
>>
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
}
impl
RemoteBlocks
{
pub
fn
new
(
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
NixlStorage
>>
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
)
->
Self
{
Self
{
layout
,
block_set_idx
,
worker_id
,
}
}
pub
fn
from_serialized
(
serialized
:
SerializedNixlBlockLayout
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
)
->
BlockResult
<
Self
>
{
let
layout
=
serialized
.deserialize
()
?
;
Ok
(
Self
::
new
(
layout
,
block_set_idx
,
worker_id
))
}
pub
fn
block
<
M
:
MutabilityKind
>
(
&
self
,
block_idx
:
usize
)
->
BlockResult
<
RemoteBlock
<
M
>>
{
if
block_idx
>=
self
.layout
.num_blocks
()
{
return
Err
(
BlockError
::
InvalidState
(
format!
(
"block index out of bounds: {} >= {}"
,
block_idx
,
self
.layout
.num_blocks
()
)));
}
Ok
(
RemoteBlock
::
new
(
self
.layout
.clone
(),
block_idx
,
self
.block_set_idx
,
self
.worker_id
,
))
}
/// Get the layout of the remote blocks
pub
fn
layout
(
&
self
)
->
&
dyn
BlockLayout
<
StorageType
=
NixlStorage
>
{
self
.layout
.as_ref
()
}
}
pub
type
ImmutableRemoteBlock
=
RemoteBlock
<
IsImmutable
>
;
pub
type
MutableRemoteBlock
=
RemoteBlock
<
IsMutable
>
;
pub
struct
RemoteBlock
<
M
:
MutabilityKind
>
{
data
:
BlockData
<
NixlStorage
>
,
_
mutability
:
std
::
marker
::
PhantomData
<
M
>
,
}
impl
<
M
:
MutabilityKind
>
Remote
for
RemoteBlock
<
M
>
{}
impl
<
M
:
MutabilityKind
>
ReadableBlock
for
RemoteBlock
<
M
>
{
type
StorageType
=
NixlStorage
;
}
impl
WritableBlock
for
RemoteBlock
<
IsMutable
>
{
type
StorageType
=
NixlStorage
;
}
impl
<
M
:
MutabilityKind
>
RemoteBlock
<
M
>
{
pub
fn
new
(
layout
:
Arc
<
dyn
BlockLayout
<
StorageType
=
NixlStorage
>>
,
block_idx
:
usize
,
block_set_idx
:
usize
,
worker_id
:
WorkerID
,
)
->
Self
{
let
data
=
BlockData
::
new
(
layout
,
block_idx
,
block_set_idx
,
worker_id
);
Self
{
data
,
_
mutability
:
std
::
marker
::
PhantomData
,
}
}
}
impl
<
M
:
MutabilityKind
>
BlockDataExt
<
NixlStorage
>
for
RemoteBlock
<
M
>
{
fn
is_fully_contiguous
(
&
self
)
->
bool
{
self
.data
.is_fully_contiguous
()
}
fn
num_layers
(
&
self
)
->
usize
{
self
.data
.num_layers
()
}
fn
layer_view
(
&
self
,
layer_idx
:
usize
)
->
BlockResult
<
view
::
LayerView
<
NixlStorage
>>
{
self
.data
.layer_view
(
layer_idx
)
}
fn
layer_view_mut
(
&
mut
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
view
::
LayerViewMut
<
NixlStorage
>>
{
self
.data
.layer_view_mut
(
layer_idx
)
}
fn
block_view
(
&
self
)
->
BlockResult
<
view
::
BlockView
<
NixlStorage
>>
{
self
.data
.block_view
()
}
fn
block_view_mut
(
&
mut
self
)
->
BlockResult
<
view
::
BlockViewMut
<
NixlStorage
>>
{
self
.data
.block_view_mut
()
}
}
impl
<
M
:
MutabilityKind
>
BlockDataProvider
for
RemoteBlock
<
M
>
{
type
StorageType
=
NixlStorage
;
fn
block_data
(
&
self
,
_
:
private
::
PrivateToken
)
->
&
BlockData
<
NixlStorage
>
{
&
self
.data
}
}
impl
<
M
:
MutabilityKind
>
NixlBlockDataImmutable
<
NixlStorage
>
for
RemoteBlock
<
M
>
{
fn
as_block_descriptor
(
&
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsImmutable
>>
{
self
.data
.as_block_descriptor
()
}
fn
as_layer_descriptor
(
&
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsImmutable
>>
{
self
.data
.as_layer_descriptor
(
layer_idx
)
}
}
impl
BlockDataProviderMut
for
RemoteBlock
<
IsMutable
>
{
fn
block_data_mut
(
&
mut
self
,
_
:
private
::
PrivateToken
)
->
&
mut
BlockData
<
NixlStorage
>
{
&
mut
self
.data
}
}
impl
NixlBlockDataMutable
<
NixlStorage
>
for
RemoteBlock
<
IsMutable
>
{
fn
as_block_descriptor_mut
(
&
mut
self
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
BlockKind
,
IsMutable
>>
{
self
.data
.as_block_descriptor_mut
()
}
fn
as_layer_descriptor_mut
(
&
mut
self
,
layer_idx
:
usize
,
)
->
BlockResult
<
NixlMemoryDescriptor
<
'_
,
LayerKind
,
IsMutable
>>
{
self
.data
.as_layer_descriptor_mut
(
layer_idx
)
}
}
impl
<
'a
,
M
:
MutabilityKind
>
AsBlockSlice
<
'a
,
RemoteBlock
<
M
>>
for
[
RemoteBlock
<
M
>
]
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
RemoteBlock
<
M
>
]
{
self
}
}
impl
<
'a
,
M
:
MutabilityKind
>
AsBlockSlice
<
'a
,
RemoteBlock
<
M
>>
for
Vec
<
RemoteBlock
<
M
>>
{
fn
as_block_slice
(
&
'a
self
)
->
&
'a
[
RemoteBlock
<
M
>
]
{
self
.as_slice
()
}
}
impl
<
'a
>
AsBlockMutSlice
<
'a
,
RemoteBlock
<
IsMutable
>>
for
[
RemoteBlock
<
IsMutable
>
]
{
fn
as_block_mut_slice
(
&
'a
mut
self
)
->
&
'a
mut
[
RemoteBlock
<
IsMutable
>
]
{
self
}
}
impl
<
'a
>
AsBlockMutSlice
<
'a
,
RemoteBlock
<
IsMutable
>>
for
Vec
<
RemoteBlock
<
IsMutable
>>
{
fn
as_block_mut_slice
(
&
'a
mut
self
)
->
&
'a
mut
[
RemoteBlock
<
IsMutable
>
]
{
self
.as_mut_slice
()
}
}
/// Defines the intended access pattern for a block represented by a descriptor.
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Serialize,
Deserialize)]
pub
enum
BlockMutability
{
Immutable
,
Mutable
,
}
/// Describes a single block for identification and potential remote access setup.
#[derive(Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize)]
pub
struct
BlockDescriptor
{
pub
worker_id
:
WorkerID
,
pub
block_set_idx
:
usize
,
pub
block_idx
:
usize
,
pub
mutability
:
BlockMutability
,
}
// Placeholder Trait: Real pool handles must provide this info.
// This trait allows BlockDescriptorList constructors to be generic.
pub
trait
BlockHandleInfo
{
fn
worker_id
(
&
self
)
->
WorkerID
;
// Needs access to the parent KvBlockManager's ID
fn
block_set_idx
(
&
self
)
->
usize
;
fn
block_idx
(
&
self
)
->
usize
;
}
impl
<
S
:
Storage
>
BlockHandleInfo
for
BlockData
<
S
>
{
fn
worker_id
(
&
self
)
->
WorkerID
{
self
.worker_id
}
fn
block_set_idx
(
&
self
)
->
usize
{
self
.block_set_idx
}
fn
block_idx
(
&
self
)
->
usize
{
self
.block_idx
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
BlockHandleInfo
for
Block
<
S
,
M
>
{
fn
worker_id
(
&
self
)
->
WorkerID
{
self
.data.worker_id
}
fn
block_set_idx
(
&
self
)
->
usize
{
self
.data.block_set_idx
}
fn
block_idx
(
&
self
)
->
usize
{
self
.data.block_idx
}
}
/// A validated, homogeneous, and serializable collection of BlockDescriptors.
/// Primarily used to describe sets of remote blocks for transfer operations.
#[derive(Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Getters)]
pub
struct
BlockDescriptorList
{
#[getter(copy)]
worker_id
:
WorkerID
,
#[getter(copy)]
block_set_idx
:
usize
,
#[getter(copy)]
mutability
:
BlockMutability
,
block_indices
:
Vec
<
usize
>
,
// TODO: Consider storing MemType explicitly if it cannot be reliably
// derived from block_set_idx via the NixlBlockSet on the receiving side.
}
impl
<
M
:
BlockMetadata
>
IntoWritableBlocks
<
M
>
for
BlockDescriptorList
{
type
Output
=
Vec
<
RemoteBlock
<
IsMutable
>>
;
fn
into_writable_blocks
(
self
,
manager
:
&
BlockManager
<
M
>
)
->
BlockResult
<
Self
::
Output
>
{
Ok
(
manager
.get_remote_blocks_mutable
(
&
self
)
?
)
}
}
#[derive(Debug,
Error)]
pub
enum
BlockDescriptorSetError
{
#[error(
"Input block list cannot be empty"
)]
EmptyInput
,
#[error(
"Blocks in the input list are not homogeneous (worker_id, block_set_idx mismatch)"
)]
NotHomogeneous
,
#[error(
"Serialization failed: {0}"
)]
SerializationError
(
#[from]
serde_json
::
Error
),
#[error(
"An invalid block handle was encountered (block may have been dropped prematurely)"
)]
InvalidBlockHandle
,
}
impl
BlockDescriptorList
{
/// Creates a new validated BlockDescriptorList from a slice of block handles.
/// Ensures all handles belong to the same worker and block set.
fn
new
<
S
:
Storage
>
(
blocks
:
&
[
&
BlockData
<
S
>
],
// Use the generic trait bound
mutability
:
BlockMutability
,
)
->
Result
<
Self
,
BlockDescriptorSetError
>
{
if
blocks
.is_empty
()
{
return
Err
(
BlockDescriptorSetError
::
EmptyInput
);
}
let
first
=
blocks
[
0
];
let
worker_id
=
first
.worker_id
();
let
block_set_idx
=
first
.block_set_idx
();
let
mut
block_indices
=
Vec
::
with_capacity
(
blocks
.len
());
block_indices
.push
(
first
.block_idx
());
for
block
in
blocks
.iter
()
.skip
(
1
)
{
// Validate homogeneity
if
block
.worker_id
()
!=
worker_id
||
block
.block_set_idx
()
!=
block_set_idx
{
return
Err
(
BlockDescriptorSetError
::
NotHomogeneous
);
}
block_indices
.push
(
block
.block_idx
());
}
// TODO: Potentially validate MemType derived from block_set_idx here if possible
Ok
(
Self
{
worker_id
,
block_set_idx
,
mutability
,
block_indices
,
})
}
/// Creates a BlockDescriptorList representing immutable blocks.
pub
fn
from_immutable_blocks
<
S
:
Storage
,
M
:
BlockMetadata
>
(
blocks
:
&
[
ImmutableBlock
<
S
,
M
>
],
)
->
Result
<
Self
,
BlockDescriptorSetError
>
{
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let
data
:
Vec
<&
BlockData
<
S
>>
=
blocks
.iter
()
.map
(|
b
|
b
.block.block
.as_ref
()
.map
(|
inner_b
|
&
inner_b
.data
))
.map
(|
opt
|
opt
.ok_or
(
BlockDescriptorSetError
::
InvalidBlockHandle
))
.collect
::
<
Result
<
Vec
<&
BlockData
<
S
>>
,
_
>>
()
?
;
Self
::
new
(
&
data
,
BlockMutability
::
Immutable
)
}
/// Creates a BlockDescriptorList representing mutable blocks.
pub
fn
from_mutable_blocks
<
S
:
Storage
,
M
:
BlockMetadata
>
(
blocks
:
&
[
MutableBlock
<
S
,
M
>
],
)
->
Result
<
Self
,
BlockDescriptorSetError
>
{
// Map each block handle to Option<&BlockData>,
// then convert Option to Result (treating None as an error),
// finally collect into Result<Vec<&BlockData>, Error>.
let
data
:
Vec
<&
BlockData
<
S
>>
=
blocks
.iter
()
.map
(|
b
|
b
.block
.as_ref
()
.map
(|
inner_b
|
&
inner_b
.data
))
.map
(|
opt
|
opt
.ok_or
(
BlockDescriptorSetError
::
InvalidBlockHandle
))
.collect
::
<
Result
<
Vec
<&
BlockData
<
S
>>
,
_
>>
()
?
;
Self
::
new
(
&
data
,
BlockMutability
::
Mutable
)
}
// /// Serializes the BlockDescriptorList into a byte vector.
// pub fn serialize(&self) -> Result<Vec<u8>, BlockDescriptorSetError> {
// Ok(serde_json::to_vec(self)?)
// }
// /// Deserializes a BlockDescriptorList from a byte slice.
// pub fn deserialize(data: &[u8]) -> Result<Self, BlockDescriptorSetError> {
// Ok(serde_json::from_slice(data)?)
// }
}
pub
trait
AsBlockDescriptorSet
{
type
Block
;
fn
as_block_descriptor_set
(
&
self
)
->
Result
<
BlockDescriptorList
,
BlockDescriptorSetError
>
;
}
impl
<
S
,
M
>
AsBlockDescriptorSet
for
[
ImmutableBlock
<
S
,
M
>
]
where
S
:
Storage
,
M
:
BlockMetadata
,
{
type
Block
=
ImmutableBlock
<
S
,
M
>
;
fn
as_block_descriptor_set
(
&
self
)
->
Result
<
BlockDescriptorList
,
BlockDescriptorSetError
>
{
BlockDescriptorList
::
from_immutable_blocks
(
self
)
}
}
impl
<
S
,
M
>
AsBlockDescriptorSet
for
[
MutableBlock
<
S
,
M
>
]
where
S
:
Storage
,
M
:
BlockMetadata
,
{
type
Block
=
MutableBlock
<
S
,
M
>
;
fn
as_block_descriptor_set
(
&
self
)
->
Result
<
BlockDescriptorList
,
BlockDescriptorSetError
>
{
BlockDescriptorList
::
from_mutable_blocks
(
self
)
}
}
impl
<
T
>
AsBlockDescriptorSet
for
Vec
<
T
>
where
[
T
]:
AsBlockDescriptorSet
<
Block
=
T
>
,
{
type
Block
=
T
;
fn
as_block_descriptor_set
(
&
self
)
->
Result
<
BlockDescriptorList
,
BlockDescriptorSetError
>
{
self
.as_slice
()
.as_block_descriptor_set
()
}
}
impl
<
T
,
const
N
:
usize
>
AsBlockDescriptorSet
for
[
T
;
N
]
where
[
T
]:
AsBlockDescriptorSet
<
Block
=
T
>
,
{
type
Block
=
T
;
fn
as_block_descriptor_set
(
&
self
)
->
Result
<
BlockDescriptorList
,
BlockDescriptorSetError
>
{
self
.as_slice
()
.as_block_descriptor_set
()
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
super
::
nixl
::
*
;
use
super
::
super
::
layout
::{
nixl
::{
NixlLayout
,
SerializedNixlBlockLayout
,
ToSerializedNixlBlockLayout
},
tests
::
setup_layout
,
FullyContiguous
,
LayoutConfig
,
};
use
crate
::
block_manager
::
storage
::
SystemAllocator
;
use
crate
::
tokens
::
TokenBlockSequence
;
use
dynamo_runtime
::
logging
::
init
as
init_logging
;
use
nixl_sys
::
Agent
as
NixlAgent
;
const
BLOCK_SIZE
:
usize
=
4
;
const
SALT_HASH
:
SaltHash
=
12345
;
// Helper to create a default reset block
fn
create_reset_block
()
->
Block
<
impl
Storage
,
BasicMetadata
>
{
let
layout
=
setup_layout
(
None
)
.unwrap
();
let
data
=
BlockData
::
new
(
Arc
::
new
(
layout
),
0
,
42
,
0
);
Block
::
new
(
data
,
BasicMetadata
::
default
())
.unwrap
()
}
// Helper to create a complete TokenBlock for testing apply_token_block
fn
create_full_token_block
()
->
TokenBlock
{
let
tokens
=
Tokens
::
from
(
vec!
[
1
,
2
,
3
,
4
]);
let
salt_hash
=
SALT_HASH
;
let
block_size
=
BLOCK_SIZE
;
let
(
mut
blocks
,
_
)
=
TokenBlockSequence
::
split_tokens
(
tokens
.as_ref
(),
block_size
,
salt_hash
);
blocks
.pop
()
.unwrap
()
}
#[test]
fn
test_block_state_transitions_and_ops
()
{
let
mut
block
=
create_reset_block
();
assert
!
(
matches!
(
block
.state
(),
BlockState
::
Reset
));
// --- Reset State --- //
assert
!
(
block
.add_token
(
1
)
.is_err
(),
"Append on Reset should fail"
);
assert
!
(
block
.add_tokens
(
Tokens
::
from
(
vec!
[
1
]))
.is_err
(),
"Extend on Reset should fail"
);
assert
!
(
block
.commit
()
.is_err
(),
"Commit on Reset should fail"
);
assert
!
(
block
.pop_token
()
.is_err
(),
"Pop on Reset should fail"
);
assert
!
(
block
.pop_tokens
(
1
)
.is_err
(),
"Pop tokens on Reset should fail"
);
// --- Reset -> Partial (via init_sequence) --- //
assert
!
(
block
.init_sequence
(
SALT_HASH
)
.is_ok
());
assert
!
(
matches!
(
block
.state
(),
BlockState
::
Partial
(
_
)));
// --- Partial State --- //
let
invalid_block
=
create_full_token_block
();
assert
!
(
block
.apply_token_block
(
invalid_block
)
.is_err
(),
"Apply block on Partial should fail"
);
// Append tokens
assert
!
(
block
.add_token
(
1
)
.is_ok
());
// 1
assert
!
(
block
.add_token
(
2
)
.is_ok
());
// 1, 2
assert
!
(
block
.add_tokens
(
Tokens
::
from
(
vec!
[
3
]))
.is_ok
());
// 1, 2, 3
assert_eq!
(
block
.len
(),
3
);
// Extend beyond capacity (should fail)
let
new_tokens
=
Tokens
::
from
(
vec!
[
4
,
5
]);
assert_eq!
(
block
.add_tokens
(
new_tokens
.clone
())
.unwrap
()
.as_ref
(),
&
[
5
]);
// Extend to fill capacity
assert
!
(
block
.add_tokens
(
Tokens
::
from
(
vec!
[
4
]))
.is_ok
());
// 1, 2, 3, 4
assert_eq!
(
block
.len
(),
BLOCK_SIZE
);
// Append when full (should fail)
assert
!
(
block
.add_token
(
5
)
.is_err
(),
"Append on full Partial block"
);
// Pop tokens
assert
!
(
block
.pop_token
()
.is_ok
());
// After pop: 1, 2, 3
assert_eq!
(
block
.len
(),
3
);
// Pop multiple tokens
assert
!
(
block
.pop_tokens
(
2
)
.is_ok
());
// After pop: [1]
assert_eq!
(
block
.len
(),
1
);
// Pop too many tokens (should fail)
assert
!
(
block
.pop_tokens
(
2
)
.is_err
(),
"Pop too many tokens"
);
assert_eq!
(
block
.len
(),
1
);
// Pop last token
assert
!
(
block
.pop_token
()
.is_ok
());
// empty
assert_eq!
(
block
.len
(),
0
);
assert
!
(
block
.is_empty
());
// Fill block again for commit
assert
!
(
block
.add_tokens
(
Tokens
::
from
(
vec!
[
1
,
2
,
3
,
4
]))
.is_ok
());
assert_eq!
(
block
.len
(),
BLOCK_SIZE
);
// --- Partial -> Complete (via commit) --- //
assert
!
(
block
.commit
()
.is_ok
());
assert
!
(
matches!
(
block
.state
(),
BlockState
::
Complete
(
_
)));
assert_eq!
(
block
.tokens
()
.unwrap
()
.as_ref
(),
&
[
1
,
2
,
3
,
4
]);
// --- Complete State --- //
assert
!
(
block
.init_sequence
(
SALT_HASH
)
.is_err
(),
"Init sequence on Complete should fail"
);
assert
!
(
block
.add_token
(
5
)
.is_err
(),
"Append on Complete should fail"
);
assert
!
(
block
.add_tokens
(
Tokens
::
from
(
vec!
[
5
]))
.is_err
(),
"Extend on Complete should fail"
);
assert
!
(
block
.commit
()
.is_err
(),
"Commit on Complete should fail"
);
assert
!
(
block
.pop_token
()
.is_err
(),
"Pop on Complete should fail"
);
assert
!
(
block
.pop_tokens
(
1
)
.is_err
(),
"Pop tokens on Complete should fail"
);
let
invalid_block
=
create_full_token_block
();
assert
!
(
block
.apply_token_block
(
invalid_block
)
.is_err
(),
"Apply block on Complete should fail"
);
// --- Complete -> Reset (via reset) --- //
block
.reset
();
assert
!
(
matches!
(
block
.state
(),
BlockState
::
Reset
));
// --- Reset -> Complete (via apply_token_block) --- //
let
full_block
=
create_full_token_block
();
assert
!
(
block
.apply_token_block
(
full_block
.clone
())
.is_ok
());
assert
!
(
matches!
(
block
.state
(),
BlockState
::
Complete
(
_
)));
let
applied_tokens
=
block
.tokens
()
.unwrap
();
assert_eq!
(
applied_tokens
,
full_block
.tokens
());
// Testing applying to a non-reset state:
let
mut
non_reset_block
=
create_reset_block
();
non_reset_block
.init_sequence
(
SALT_HASH
)
.unwrap
();
// Put in Partial state
assert
!
(
non_reset_block
.apply_token_block
(
full_block
)
.is_err
(),
"Apply block to non-reset state"
);
}
#[test]
fn
test_block_state_incomplete_commit
()
{
// Commit incomplete block (should fail)
let
mut
partial_block
=
create_reset_block
();
partial_block
.init_sequence
(
SALT_HASH
)
.unwrap
();
partial_block
.add_token
(
1
)
.unwrap
();
partial_block
.add_tokens
(
Tokens
::
from
(
vec!
[
2
,
3
]))
.unwrap
();
assert_eq!
(
partial_block
.len
(),
3
);
assert
!
(
partial_block
.commit
()
.is_err
(),
"Commit on incomplete Partial block"
);
}
#[test]
fn
test_error_types
()
{
let
mut
block
=
create_reset_block
();
block
.init_sequence
(
SALT_HASH
)
.unwrap
();
// Fill the block
block
.add_tokens
(
Tokens
::
from
(
vec!
[
1
,
2
,
3
,
4
]))
.unwrap
();
// Append when full
let
append_err
=
block
.add_token
(
5
)
.unwrap_err
();
assert
!
(
append_err
.is
::
<
TokenBlockError
>
());
assert_eq!
(
*
append_err
.downcast_ref
::
<
TokenBlockError
>
()
.unwrap
(),
TokenBlockError
::
Full
);
// .add_tokens will try to fill the block and return the remaining tokens in the Tokens passed in
let
new_tokens
=
Tokens
::
from
(
vec!
[
5
]);
let
ret_tokens
=
block
.add_tokens
(
new_tokens
.clone
())
.unwrap
();
assert_eq!
(
new_tokens
,
ret_tokens
);
// Commit when full (should succeed)
block
.commit
()
.unwrap
();
// Commit when Complete
let
commit_err
=
block
.commit
()
.unwrap_err
();
assert
!
(
commit_err
.is
::
<
BlockStateInvalid
>
());
// Reset and test pop empty
block
.reset
();
block
.init_sequence
(
SALT_HASH
)
.unwrap
();
let
pop_err
=
block
.pop_token
()
.unwrap_err
();
assert
!
(
pop_err
.is
::
<
TokenBlockError
>
());
assert_eq!
(
*
pop_err
.downcast_ref
::
<
TokenBlockError
>
()
.unwrap
(),
TokenBlockError
::
Empty
);
let
pop_tokens_err
=
block
.pop_tokens
(
1
)
.unwrap_err
();
assert
!
(
pop_tokens_err
.is
::
<
TokenBlockError
>
());
assert_eq!
(
*
pop_tokens_err
.downcast_ref
::
<
TokenBlockError
>
()
.unwrap
(),
TokenBlockError
::
InsufficientTokens
);
// Test commit incomplete
block
.add_token
(
1
)
.unwrap
();
let
commit_incomplete_err
=
block
.commit
()
.unwrap_err
();
assert
!
(
commit_incomplete_err
.is
::
<
TokenBlockError
>
());
assert_eq!
(
*
commit_incomplete_err
.downcast_ref
::
<
TokenBlockError
>
()
.unwrap
(),
TokenBlockError
::
Incomplete
);
}
#[test]
fn
test_nixl_block_data_ext
()
{
init_logging
();
let
config
=
LayoutConfig
::
builder
()
.num_blocks
(
10
)
.num_layers
(
2
)
.page_size
(
4
)
.inner_dim
(
13
)
.build
()
.unwrap
();
let
mut
layout
=
FullyContiguous
::
allocate
(
config
,
&
SystemAllocator
)
.unwrap
();
let
agent
=
NixlAgent
::
new
(
"test"
)
.unwrap
();
tracing
::
info!
(
"Registering layout"
);
layout
.nixl_register
(
&
agent
,
None
)
.unwrap
();
tracing
::
info!
(
"Layout registered"
);
let
serialized
=
layout
.serialize
()
.unwrap
();
let
layout
=
Arc
::
new
(
layout
);
let
data
=
BlockData
::
new
(
layout
.clone
(),
0
,
42
,
0
);
assert_eq!
(
data
.block_idx
(),
0
);
assert_eq!
(
data
.block_set_idx
(),
42
);
let
block_desc
=
data
.as_block_descriptor
()
.unwrap
();
println!
(
"Block descriptor: {:?}"
,
block_desc
);
let
data
=
BlockData
::
new
(
layout
.clone
(),
1
,
42
,
0
);
assert_eq!
(
data
.block_idx
(),
1
);
assert_eq!
(
data
.block_set_idx
(),
42
);
let
block_desc
=
data
.as_block_descriptor
()
.unwrap
();
println!
(
"Block descriptor: {:?}"
,
block_desc
);
let
remote_layout
=
SerializedNixlBlockLayout
::
deserialize
(
&
serialized
)
.unwrap
();
println!
(
"Nixl layout: {:?}"
,
remote_layout
);
let
remote_block
=
RemoteBlock
::
<
IsMutable
>
::
new
(
remote_layout
.clone
(),
0
,
42
,
0
);
let
remote_desc
=
remote_block
.as_block_descriptor
()
.unwrap
();
println!
(
"Remote Descriptor: {:?}"
,
remote_desc
);
// drop(layout);
tracing
::
info!
(
"Layout dropped"
);
}
}
lib/llm/src/block_manager/block/registry.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::{
collections
::
HashMap
,
sync
::{
Arc
,
Weak
},
};
use
super
::
super
::
events
::{
EventManager
,
EventReleaseManager
,
PublishHandle
};
use
super
::
state
::
BlockState
;
use
crate
::
tokens
::{
BlockHash
,
SequenceHash
,
TokenBlock
};
use
derive_getters
::
Getters
;
#[derive(Debug,
thiserror::Error)]
pub
enum
BlockRegistationError
{
#[error(
"Block already registered"
)]
BlockAlreadyRegistered
(
SequenceHash
),
#[error(
"Invalid state: {0}"
)]
InvalidState
(
String
),
}
/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug,
thiserror::Error)]
#[error(
"Failed to unregister block: {0}"
)]
pub
struct
UnregisterFailure
(
SequenceHash
);
#[derive()]
pub
struct
BlockRegistry
{
blocks
:
HashMap
<
SequenceHash
,
Weak
<
RegistrationHandle
>>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
}
impl
BlockRegistry
{
pub
fn
new
(
event_manager
:
Arc
<
dyn
EventManager
>
)
->
Self
{
Self
{
blocks
:
HashMap
::
new
(),
event_manager
,
}
}
pub
fn
is_registered
(
&
self
,
sequence_hash
:
SequenceHash
)
->
bool
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
true
;
}
}
false
}
pub
fn
register_block
(
&
mut
self
,
block_state
:
&
mut
BlockState
,
)
->
Result
<
PublishHandle
,
BlockRegistationError
>
{
match
block_state
{
BlockState
::
Reset
=>
Err
(
BlockRegistationError
::
InvalidState
(
"Block is in Reset state"
.to_string
(),
)),
BlockState
::
Partial
(
_
partial
)
=>
Err
(
BlockRegistationError
::
InvalidState
(
"Block is in Partial state"
.to_string
(),
)),
BlockState
::
Complete
(
state
)
=>
{
let
sequence_hash
=
state
.token_block
()
.sequence_hash
();
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
sequence_hash
));
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let
publish_handle
=
Self
::
create_publish_handle
(
state
.token_block
(),
self
.event_manager
.clone
());
let
reg_handle
=
publish_handle
.remove_handle
();
// Insert the [RegistrationHandle] into the registry
self
.blocks
.insert
(
sequence_hash
,
Arc
::
downgrade
(
&
reg_handle
));
// Update the [BlockState] to [BlockState::Registered]
let
_
=
std
::
mem
::
replace
(
block_state
,
BlockState
::
Registered
(
reg_handle
));
Ok
(
publish_handle
)
}
BlockState
::
Registered
(
registered
)
=>
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
registered
.sequence_hash
()),
),
}
}
pub
fn
unregister_block
(
&
mut
self
,
sequence_hash
:
SequenceHash
,
)
->
Result
<
(),
UnregisterFailure
>
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
handle
.upgrade
()
.is_none
()
{
self
.blocks
.remove
(
&
sequence_hash
);
return
Ok
(());
}
else
{
return
Err
(
UnregisterFailure
(
sequence_hash
));
}
}
Ok
(())
}
fn
create_publish_handle
(
token_block
:
&
TokenBlock
,
event_manager
:
Arc
<
dyn
EventManager
>
,
)
->
PublishHandle
{
let
reg_handle
=
RegistrationHandle
::
from_token_block
(
token_block
,
event_manager
.clone
());
PublishHandle
::
new
(
reg_handle
,
event_manager
)
}
}
#[derive(Getters)]
pub
struct
RegistrationHandle
{
#[getter(copy)]
block_hash
:
BlockHash
,
#[getter(copy)]
sequence_hash
:
SequenceHash
,
#[getter(copy)]
parent_sequence_hash
:
Option
<
SequenceHash
>
,
#[getter(skip)]
release_manager
:
Arc
<
dyn
EventReleaseManager
>
,
}
impl
RegistrationHandle
{
fn
from_token_block
(
token_block
:
&
TokenBlock
,
release_manager
:
Arc
<
dyn
EventReleaseManager
>
,
)
->
Self
{
Self
{
block_hash
:
token_block
.block_hash
(),
sequence_hash
:
token_block
.sequence_hash
(),
parent_sequence_hash
:
token_block
.parent_sequence_hash
(),
release_manager
,
}
}
}
impl
std
::
fmt
::
Debug
for
RegistrationHandle
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
write!
(
f
,
"RegistrationHandle {{ sequence_hash: {}; block_hash: {}; parent_sequence_hash: {:?} }}"
,
self
.sequence_hash
,
self
.block_hash
,
self
.parent_sequence_hash
)
}
}
impl
Drop
for
RegistrationHandle
{
fn
drop
(
&
mut
self
)
{
self
.release_manager
.block_release
(
self
);
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
block_manager
::
events
::
tests
::{
EventType
,
MockEventManager
};
use
crate
::
tokens
::{
TokenBlockSequence
,
Tokens
};
fn
create_sequence
()
->
TokenBlockSequence
{
let
tokens
=
Tokens
::
from
(
vec!
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]);
// NOTE: 1337 was the original seed, so we are temporarily using that here to prove the logic has not changed
let
sequence
=
TokenBlockSequence
::
new
(
tokens
,
4
,
Some
(
1337_u64
));
assert_eq!
(
sequence
.blocks
()
.len
(),
2
);
assert_eq!
(
sequence
.current_block
()
.len
(),
2
);
assert_eq!
(
sequence
.blocks
()[
0
]
.tokens
(),
&
vec!
[
1
,
2
,
3
,
4
]);
assert_eq!
(
sequence
.blocks
()[
0
]
.sequence_hash
(),
14643705804678351452
);
assert_eq!
(
sequence
.blocks
()[
1
]
.tokens
(),
&
vec!
[
5
,
6
,
7
,
8
]);
assert_eq!
(
sequence
.blocks
()[
1
]
.sequence_hash
(),
4945711292740353085
);
assert_eq!
(
sequence
.current_block
()
.tokens
(),
&
vec!
[
9
,
10
]);
sequence
}
#[test]
fn
test_mock_event_manager_with_single_publish_handle
()
{
let
sequence
=
create_sequence
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publish_handle
=
BlockRegistry
::
create_publish_handle
(
&
sequence
.blocks
()[
0
],
event_manager
.clone
());
// no event should have been triggered
assert
!
(
rx
.try_recv
()
.is_err
());
// we shoudl get two events when this is dropped, since we never took ownership of the RegistrationHandle
drop
(
publish_handle
);
// the first event should be a Register event
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Register
(
sequence
.blocks
()[
0
]
.sequence_hash
())
);
// the second event should be a Remove event
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Remove
(
sequence
.blocks
()[
0
]
.sequence_hash
())
);
// there should be no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_mock_event_manager_single_publish_handle_removed
()
{
let
sequence
=
create_sequence
();
let
block_to_test
=
&
sequence
.blocks
()[
0
];
let
expected_sequence_hash
=
block_to_test
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publish_handle
=
BlockRegistry
::
create_publish_handle
(
block_to_test
,
event_manager
.clone
());
// Remove the registration handle before dropping the publish handle
let
reg_handle
=
publish_handle
.remove_handle
();
// no event should have been triggered yet
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop the publish handle - it SHOULD trigger a Register event now because remove_handle doesn't disarm
drop
(
publish_handle
);
let
register_events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
register_events
.len
(),
1
,
"Register event should be triggered on PublishHandle drop"
);
assert_eq!
(
register_events
[
0
],
EventType
::
Register
(
expected_sequence_hash
),
"Expected Register event"
);
// Drop the registration handle - this SHOULD trigger the Remove event
drop
(
reg_handle
);
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Remove
(
expected_sequence_hash
),
"Only Remove event should be triggered"
);
// there should be no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_mock_event_manager_publisher_multiple_handles_removed
()
{
let
sequence
=
create_sequence
();
let
block1
=
&
sequence
.blocks
()[
0
];
let
block2
=
&
sequence
.blocks
()[
1
];
let
hash1
=
block1
.sequence_hash
();
let
hash2
=
block2
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
mut
publisher
=
event_manager
.publisher
();
let
publish_handle1
=
BlockRegistry
::
create_publish_handle
(
block1
,
event_manager
.clone
());
let
publish_handle2
=
BlockRegistry
::
create_publish_handle
(
block2
,
event_manager
.clone
());
// Remove handles before adding to publisher
let
reg_handle1
=
publish_handle1
.remove_handle
();
let
reg_handle2
=
publish_handle2
.remove_handle
();
// Add disarmed handles to publisher
publisher
.take_handle
(
publish_handle1
);
publisher
.take_handle
(
publish_handle2
);
// no events yet
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop the publisher - should trigger a single Publish event with both Register events
drop
(
publisher
);
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
2
,
"Should receive two Register events in one batch"
);
// Order isn't guaranteed, so check for both
assert
!
(
events
.contains
(
&
EventType
::
Register
(
hash1
)));
assert
!
(
events
.contains
(
&
EventType
::
Register
(
hash2
)));
// no more events immediately after publish
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop registration handles individually - should trigger Remove events
drop
(
reg_handle1
);
let
events1
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events1
.len
(),
1
);
assert_eq!
(
events1
[
0
],
EventType
::
Remove
(
hash1
));
drop
(
reg_handle2
);
let
events2
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events2
.len
(),
1
);
assert_eq!
(
events2
[
0
],
EventType
::
Remove
(
hash2
));
// no more events
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_publisher_empty_drop
()
{
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
publisher
=
event_manager
.publisher
();
drop
(
publisher
);
// No events should be sent
assert
!
(
rx
.try_recv
()
.is_err
());
}
#[test]
fn
test_publisher_publish_multiple_times
()
{
let
sequence
=
create_sequence
();
let
block1
=
&
sequence
.blocks
()[
0
];
let
hash1
=
block1
.sequence_hash
();
let
(
event_manager
,
mut
rx
)
=
MockEventManager
::
new
();
let
mut
publisher
=
event_manager
.publisher
();
let
publish_handle1
=
BlockRegistry
::
create_publish_handle
(
block1
,
event_manager
.clone
());
publisher
.take_handle
(
publish_handle1
);
// First publish call
publisher
.publish
();
let
events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
],
EventType
::
Register
(
hash1
));
// The RegistrationHandle Arc was taken by the publisher and dropped after the publish call
// So, the Remove event should follow immediately.
let
remove_events
=
rx
.try_recv
()
.unwrap
();
assert_eq!
(
remove_events
.len
(),
1
,
"Remove event should be triggered after publish consumes the handle"
);
assert_eq!
(
remove_events
[
0
],
EventType
::
Remove
(
hash1
),
"Expected Remove event"
);
// Second publish call (should do nothing as handles were taken)
publisher
.publish
();
assert
!
(
rx
.try_recv
()
.is_err
());
// Drop publisher (should also do nothing)
drop
(
publisher
);
assert
!
(
rx
.try_recv
()
.is_err
());
}
}
lib/llm/src/block_manager/block/state.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
sync
::
Arc
;
use
derive_getters
::
Getters
;
use
super
::
registry
::
RegistrationHandle
;
use
super
::
Result
;
use
crate
::
tokens
::{
PartialTokenBlock
,
SaltHash
,
Token
,
TokenBlock
,
Tokens
};
#[derive(Debug,
thiserror::Error)]
#[error(
"Block state is invalid: {0}"
)]
pub
struct
BlockStateInvalid
(
pub
String
);
#[derive(Debug)]
pub
enum
BlockState
{
Reset
,
Partial
(
PartialState
),
Complete
(
CompleteState
),
Registered
(
Arc
<
RegistrationHandle
>
),
}
impl
BlockState
{
pub
fn
initialize_sequence
(
&
mut
self
,
page_size
:
usize
,
salt_hash
:
SaltHash
,
)
->
Result
<
(),
BlockStateInvalid
>
{
if
!
matches!
(
self
,
BlockState
::
Reset
)
{
return
Err
(
BlockStateInvalid
(
"Block is not reset"
.to_string
()));
}
let
block
=
PartialTokenBlock
::
create_sequence_root
(
page_size
,
salt_hash
);
*
self
=
BlockState
::
Partial
(
PartialState
::
new
(
block
));
Ok
(())
}
pub
fn
add_token
(
&
mut
self
,
token
:
Token
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
Ok
(
state
.block
.push_token
(
token
)
?
),
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
add_tokens
(
&
mut
self
,
tokens
:
Tokens
)
->
Result
<
Tokens
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
Ok
(
state
.block
.push_tokens
(
tokens
)),
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
pop_token
(
&
mut
self
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
state
.block
.pop_token
()
?
;
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
pop_tokens
(
&
mut
self
,
count
:
usize
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
state
.block
.pop_tokens
(
count
)
?
;
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
commit
(
&
mut
self
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Partial
(
state
)
=>
{
let
token_block
=
state
.block
.commit
()
?
;
*
self
=
BlockState
::
Complete
(
CompleteState
::
new
(
token_block
));
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not partial"
.to_string
()))
?
,
}
}
pub
fn
apply_token_block
(
&
mut
self
,
token_block
:
TokenBlock
)
->
Result
<
()
>
{
match
self
{
BlockState
::
Reset
=>
{
*
self
=
BlockState
::
Complete
(
CompleteState
::
new
(
token_block
));
Ok
(())
}
_
=>
Err
(
BlockStateInvalid
(
"Block is not reset"
.to_string
()))
?
,
}
}
/// Returns the number of tokens currently in the block.
pub
fn
len
(
&
self
)
->
Option
<
usize
>
{
match
self
{
BlockState
::
Reset
=>
Some
(
0
),
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.len
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()
.len
()),
BlockState
::
Registered
(
_
)
=>
None
,
}
}
/// Returns the number of additional tokens that can be added.
pub
fn
remaining
(
&
self
)
->
usize
{
match
self
{
BlockState
::
Partial
(
state
)
=>
state
.block
.remaining
(),
_
=>
0
,
// Reset, Complete, Registered have 0 remaining capacity
}
}
/// Returns true if the block contains no tokens.
pub
fn
is_empty
(
&
self
)
->
bool
{
match
self
{
BlockState
::
Reset
=>
true
,
BlockState
::
Partial
(
state
)
=>
state
.block
.is_empty
(),
BlockState
::
Complete
(
_
)
=>
false
,
// Always full
BlockState
::
Registered
(
_
)
=>
false
,
// Always full
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
{
match
self
{
BlockState
::
Reset
|
BlockState
::
Registered
(
_
)
=>
None
,
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.tokens
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()),
}
}
/// Returns true if the block is empty
pub
fn
is_reset
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Reset
)
}
/// Returns true if the block is in the complete or registered state
pub
fn
is_complete
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Complete
(
_
)
|
BlockState
::
Registered
(
_
))
}
/// Returns true if the block is in the registered state
pub
fn
is_registered
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Registered
(
_
state
))
}
}
#[derive(Debug)]
pub
struct
PartialState
{
block
:
PartialTokenBlock
,
}
impl
PartialState
{
pub
fn
new
(
block
:
PartialTokenBlock
)
->
Self
{
Self
{
block
}
}
}
#[derive(Debug,
Getters)]
pub
struct
CompleteState
{
token_block
:
TokenBlock
,
}
impl
CompleteState
{
pub
fn
new
(
token_block
:
TokenBlock
)
->
Self
{
Self
{
token_block
}
}
}
lib/llm/src/block_manager/block/transfer.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod
cuda
;
mod
memcpy
;
mod
nixl
;
mod
strategy
;
use
super
::
nixl
::{
IsMutable
,
NixlBlockDataImmutable
,
NixlBlockDataMutable
,
RemoteBlock
};
use
super
::
*
;
use
crate
::
block_manager
::
storage
::{
nixl
::{
NixlRegisterableStorage
,
NixlStorage
},
DeviceStorage
,
PinnedStorage
,
SystemStorage
,
};
use
cudarc
::
driver
::
CudaStream
;
use
std
::
ops
::
Range
;
pub
use
crate
::
block_manager
::
storage
::{
CudaAccessible
,
Local
,
Remote
};
pub
use
async_trait
::
async_trait
;
/// A block that can be the target of a write
pub
trait
Writable
{}
/// A block that can be the source of a read
pub
trait
Readable
{}
pub
trait
Mutable
:
Readable
+
Writable
{}
pub
trait
Immutable
:
Readable
{}
#[derive(Debug)]
pub
enum
BlockTarget
{
Source
,
Destination
,
}
#[derive(Debug,
thiserror::Error)]
pub
enum
TransferError
{
#[error(
"Builder configuration error: {0}"
)]
BuilderError
(
String
),
#[error(
"Transfer execution failed: {0}"
)]
ExecutionError
(
String
),
#[error(
"Incompatible block types provided: {0}"
)]
IncompatibleTypes
(
String
),
#[error(
"Mismatched source/destination counts: {0} sources, {1} destinations"
)]
CountMismatch
(
usize
,
usize
),
#[error(
"Block operation failed: {0}"
)]
BlockError
(
#[from]
BlockError
),
// TODO: Add NIXL specific errors
#[error(
"No blocks provided"
)]
NoBlocksProvided
,
#[error(
"Mismatched {0:?} block set index: {1} != {2}"
)]
MismatchedBlockSetIndex
(
BlockTarget
,
usize
,
usize
),
#[error(
"Mismatched {0:?} worker ID: {1} != {2}"
)]
MismatchedWorkerID
(
BlockTarget
,
usize
,
usize
),
#[error(transparent)]
Other
(
#[from]
anyhow
::
Error
),
}
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
TransferStrategy
{
Memcpy
,
CudaAsyncH2D
,
CudaAsyncD2H
,
CudaAsyncD2D
,
CudaBlockingH2D
,
CudaBlockingD2H
,
NixlWrite
,
// aka PUT
NixlRead
,
// aka GET
Invalid
,
}
/// Trait for determining the transfer strategy for writing from a local
/// source to a target destination which could be local or remote
pub
trait
WriteToStrategy
<
Target
>
{
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Invalid
}
}
/// Trait for determining the transfer strategy for reading from a
/// `Source` which could be local or remote into `Self` which must
/// be both local and writable.
pub
trait
ReadFromStrategy
<
Source
>
{
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Invalid
}
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteToStrategy
<
WB
>
for
RB
where
<
RB
as
ReadableBlock
>
::
StorageType
:
Local
+
WriteToStrategy
<<
WB
as
WritableBlock
>
::
StorageType
>
,
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
<<
RB
as
ReadableBlock
>
::
StorageType
as
WriteToStrategy
<
<
WB
as
WritableBlock
>
::
StorageType
,
>>
::
write_to_strategy
()
}
}
impl
<
WB
:
WritableBlock
,
RB
:
ReadableBlock
>
ReadFromStrategy
<
RB
>
for
WB
where
<
RB
as
ReadableBlock
>
::
StorageType
:
Remote
,
<
WB
as
WritableBlock
>
::
StorageType
:
NixlRegisterableStorage
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlRead
}
}
pub
trait
WriteTo
<
Target
>
{
fn
write_to
(
&
self
,
dst
:
&
mut
Target
,
notify
:
Option
<
String
>
)
->
Result
<
(),
TransferError
>
;
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
RB
where
RB
:
WriteToStrategy
<
WB
>
+
Local
,
{
fn
write_to
(
&
self
,
dst
:
&
mut
WB
,
notify
:
Option
<
String
>
)
->
Result
<
(),
TransferError
>
{
let
ctx
=
self
.transfer_context
();
match
Self
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
memcpy
::
copy_block
(
self
,
dst
),
TransferStrategy
::
CudaAsyncH2D
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2D
=>
{
cuda
::
copy_block
(
self
,
dst
,
ctx
.stream
()
.as_ref
(),
RB
::
write_to_strategy
())
}
TransferStrategy
::
NixlWrite
=>
Ok
(
nixl
::
write_block_to
(
self
,
dst
,
ctx
,
notify
)
?
),
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Unsupported copy strategy: {:?}"
,
RB
::
write_to_strategy
()
))),
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
}
#[derive(Default)]
pub
struct
GetXferRequestBuilder
<
'xfer
,
Source
:
BlockDataProvider
,
Target
:
BlockDataProviderMut
+
Local
,
>
{
_
src
:
Option
<&
'xfer
[
Source
]
>
,
_
dst
:
Option
<&
'xfer
[
Target
]
>
,
}
// impl<'xfer, Source: BlockDataProvider, Target: BlockDataProviderMut + Local>
// GetXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_or_remote_blocks: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote_blocks);
// self
// }
// pub fn to(&mut self, local_mutable_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_mutable_blocks);
// self
// }
// }
pub
struct
PutXferRequestBuilder
<
'xfer
,
Source
:
BlockDataProvider
+
Local
,
Target
:
BlockDataProviderMut
,
>
{
_
src
:
Option
<&
'xfer
[
Source
]
>
,
_
dst
:
Option
<&
'xfer
[
Target
]
>
,
}
// impl<'xfer, Source: BlockDataProvider + Local, Target: BlockDataProviderMut>
// PutXferRequestBuilder<'xfer, Source, Target>
// {
// fn new(state: Arc<BlockTransferEngineState>) -> Self {
// Self {
// src: None,
// dst: None,
// }
// }
// pub fn from(&mut self, local_blocks: &'xfer [Source]) -> &mut Self {
// self.src = Some(local_blocks);
// self
// }
// pub fn to(&mut self, local_or_remote: &'xfer [Target]) -> &mut Self {
// self.dst = Some(local_or_remote);
// self
// }
// }
// #[async_trait]
// impl<'xfer, Target: BlockDataProviderMut + Local>
// AsyncBlockTransferEngine<RemoteBlock<IsImmutable>, Target>
// for GetXferRequestBuilder<'xfer, RemoteBlock<IsImmutable>, Target>
// where
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// #[async_trait]
// impl<'xfer, Source, Target> AsyncBlockTransferEngine<Source, Target>
// for GetXferRequestBuilder<'xfer, Source, Target>
// where
// Source: BlockDataProvider + Local + Send + Sync,
// Target: BlockDataProviderMut + Local + Send + Sync,
// {
// async fn execute(self) -> Result<()> {
// unimplemented!()
// }
// }
// pub trait BlockCopyTo<Target:BlockDataProviderMut + Local>: BlockDataProvider + Local {
// fn copy_blocks
#[async_trait]
pub
trait
AsyncBlockTransferEngine
<
Source
:
BlockDataProvider
,
Target
:
BlockDataProviderMut
+
Local
>
{
async
fn
execute
(
self
)
->
anyhow
::
Result
<
()
>
;
}
pub
trait
BlockTransferEngineV1
<
Source
:
BlockDataProvider
,
Target
:
BlockDataProviderMut
>
{
fn
prepare
(
&
mut
self
)
->
Result
<
(),
TransferError
>
{
Ok
(())
}
fn
execute
(
self
)
->
Result
<
(),
TransferError
>
;
}
// memcpy transfer engine
// - System -> System
// - Pinned -> Pinned
// cuda memcpy transfer engine
// - Pinned -> Device
// - Device -> Pinned
// - Device -> Device
// nixl memcpy transfer engine
// - NixlRegisterableStorage -> Nixl
// - Nixl -> NixlRegisterableStorage
// where System, Pinned, Device are NixlRegisterableStorage
// Placeholder for the actual transfer plan
#[derive(Debug)]
pub
struct
TransferRequestPut
<
'a
,
Source
:
BlockDataProvider
+
Local
,
Destination
:
BlockDataProviderMut
,
>
{
sources
:
&
'a
[
Source
],
destinations
:
&
'a
mut
[
Destination
],
}
// --- NIXL PUT Transfer Implementation ---
impl
<
Source
>
BlockTransferEngineV1
<
Source
,
RemoteBlock
<
IsMutable
>>
for
TransferRequestPut
<
'_
,
Source
,
RemoteBlock
<
IsMutable
>>
where
Source
:
BlockDataProvider
+
Local
,
// + NixlBlockDataMutable<Source::StorageType>,
Source
::
StorageType
:
NixlRegisterableStorage
,
{
fn
execute
(
self
)
->
Result
<
(),
TransferError
>
{
self
.validate_counts
()
?
;
tracing
::
info!
(
"Executing NIXL PUT transfer request"
);
// TODO: Get NixlAgent handle
for
(
src_block
,
dst_block
)
in
self
.sources
.iter
()
.zip
(
self
.destinations
.iter_mut
())
{
let
src_data
=
src_block
.block_data
(
private
::
PrivateToken
);
let
src_nixl_desc
=
src_data
.as_block_descriptor
()
?
;
let
dst_data
=
dst_block
.block_data_mut
(
private
::
PrivateToken
);
let
dst_nixl_desc
=
dst_data
.as_block_descriptor_mut
()
?
;
// TODO: Perform NIXL PUT operation
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "NIXL PUT block");
tracing
::
trace!
(
src_desc
=
?
src_nixl_desc
,
dst_desc
=
?
dst_nixl_desc
,
"NIXL PUT block"
);
}
Ok
(())
}
}
impl
<
'a
,
Source
,
Destination
>
TransferRequestPut
<
'a
,
Source
,
Destination
>
where
Source
:
BlockDataProvider
+
Local
,
Destination
:
BlockDataProviderMut
,
{
pub
fn
new
(
sources
:
&
'a
[
Source
],
destinations
:
&
'a
mut
[
Destination
],
)
->
Result
<
Self
,
TransferError
>
{
let
transfer_request
=
Self
{
sources
,
destinations
,
};
transfer_request
.validate_counts
()
?
;
Ok
(
transfer_request
)
}
/// Validate blocks
///
/// For a put, we can have duplicate blocks on the source side, but all destinations must be unique
/// For all transfers, the source and destination block sets must be disjoint.
pub
fn
validate_blocks
(
&
self
)
->
Result
<
(),
TransferError
>
{
let
mut
src_set
=
std
::
collections
::
HashSet
::
new
();
let
mut
dst_set
=
std
::
collections
::
HashSet
::
new
();
for
(
src_block
,
dst_block
)
in
self
.sources
.iter
()
.zip
(
self
.destinations
.iter
())
{
let
src_data
=
src_block
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst_block
.block_data
(
private
::
PrivateToken
);
src_set
.insert
((
src_data
.block_set_idx
,
src_data
.block_idx
,
src_data
.worker_id
,
));
dst_set
.insert
((
dst_data
.block_set_idx
,
dst_data
.block_idx
,
dst_data
.worker_id
,
));
}
if
dst_set
.len
()
!=
self
.destinations
.len
()
{
return
Err
(
TransferError
::
BuilderError
(
"Duplicate destination blocks"
.to_string
(),
));
}
// the intersection of src_set and dst_set must be empty
if
!
src_set
.is_disjoint
(
&
dst_set
)
{
return
Err
(
TransferError
::
BuilderError
(
"Duplicate one or more duplicate entries in source and destination list"
.to_string
(),
));
}
Ok
(())
}
/// Common validation for all PUT requests.
fn
validate_counts
(
&
self
)
->
Result
<
(),
TransferError
>
{
if
self
.sources
.len
()
!=
self
.destinations
.len
()
{
Err
(
TransferError
::
CountMismatch
(
self
.sources
.len
(),
self
.destinations
.len
(),
))
}
else
if
self
.sources
.is_empty
()
{
Err
(
TransferError
::
BuilderError
(
"Sources cannot be empty"
.to_string
(),
))
}
else
if
self
.destinations
.is_empty
()
{
Err
(
TransferError
::
BuilderError
(
"Destinations cannot be empty"
.to_string
(),
))
}
else
{
Ok
(())
}
}
}
// // --- Local Transfer Implementations ---
// // Local Pinned -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2H or std::ptr::copy
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Pinned -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<PinnedStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Pinned -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy H2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Pinned
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<PinnedStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Pinned");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2H
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// // Local Device -> Device
// impl<'a, MSource: BlockMetadata, MDest: BlockMetadata>
// TransferRequestPut<
// 'a,
// ImmutableBlock<DeviceStorage, MSource>,
// MutableBlock<DeviceStorage, MDest>,
// >
// {
// pub fn execute(mut self) -> Result<(), TransferError> {
// self.validate_counts()?;
// tracing::info!("Executing local transfer: Device -> Device");
// for (src_block, dst_block) in self.sources.iter().zip(self.destinations.iter_mut()) {
// let src_data = src_block.block_data(private::PrivateToken);
// let dst_data = dst_block.block_data_mut(private::PrivateToken);
// // TODO: Implement layer-wise or block-wise CUDA memcpy D2D
// tracing::trace!(src = ?(src_data.worker_id, src_data.block_set_idx, src_data.block_idx), dst = ?(dst_data.worker_id, dst_data.block_set_idx, dst_data.block_idx), "Copying block");
// }
// Ok(())
// }
// }
// pub fn dispatch_copy_to<RB, WB>(
// src: &RB,
// dst: &mut WB,
// ctx: &TransferContext,
// ) -> Result<(), TransferError>
// where
// RB: ReadableBlock,
// WB: WritableBlock,
// // Ensure the necessary capability traits are implemented for the storage types
// // Note: These bounds aren't strictly *required* for the TypeId check,
// // but help ensure the backend calls will compile if a match occurs.
// // RB::Storage: SystemAccessible + CudaAccessible, // Might be too restrictive, apply within match arms
// // WB::Storage: SystemAccessible + CudaAccessible,
// {
// let src_type = src.storage_type_id();
// let dst_type = dst.storage_type_id();
// match (src_type, dst_type) {
// // === Memcpy Cases ===
// (s, d)
// if (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<SystemStorage>())
// || (s == TypeId::of::<SystemStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<PinnedStorage>()) =>
// {
// memcpy::memcpy_block(src, dst)
// }
// // === CUDA Cases ===
// (s, d)
// if (s == TypeId::of::<PinnedStorage>() && d == TypeId::of::<DeviceStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<PinnedStorage>())
// || (s == TypeId::of::<DeviceStorage>() && d == TypeId::of::<DeviceStorage>()) =>
// {
// cuda::cuda_memcpy_block(src, dst, ctx.stream().as_ref())
// // let stream = stream.ok_or_else(|| {
// // TransferError::BuilderError("CUDA stream required for this transfer".into())
// // })?;
// // if is_cuda_compatible::<RB, WB>() {
// // tracing::debug!("Dispatching copy using CUDA");
// // cuda::cuda_memcpy_block(src_provider, dst_provider, stream) // Assumes cuda_memcpy_block is generic
// // } else {
// // Err(TransferError::IncompatibleTypes(
// // "CUDA copy requires CudaAccessible storage".into(),
// // ))
// // }
// }
// // === NIXL Cases ===
// (s, d)
// if d == TypeId::of::<NixlStorage>()
// && (s == TypeId::of::<SystemStorage>()
// || s == TypeId::of::<PinnedStorage>()
// || s == TypeId::of::<DeviceStorage>()) =>
// {
// unimplemented!()
// // tracing::debug!("Dispatching copy using NIXL PUT");
// // // TODO: Implement NIXL PUT logic
// // // You might need a specific NIXL transfer function here.
// // // Example: nixl::nixl_put_block(src_provider, dst_provider)
// // Err(TransferError::ExecutionError(
// // "NIXL PUT not yet implemented".into(),
// // ))
// }
// // TODO: Add NIXL GET cases (Nixl -> System/Pinned/Device)
// // === Error Case ===
// _ => Err(TransferError::IncompatibleTypes(format!(
// "Unsupported storage combination for copy: {:?} -> {:?}",
// std::any::type_name::<<RB as ReadableBlock>::StorageType>(), // Requires nightly or use debug print
// std::any::type_name::<<WB as WritableBlock>::StorageType>()
// ))),
// }
// }
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
write_to_strategy
()
{
// System to ...
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingH2D
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Pinned to ...
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncH2D
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Device to ...
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2D
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
}
lib/llm/src/block_manager/block/transfer/cuda.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
super
::
TransferError
;
use
crate
::
block_manager
::
storage
::{
DeviceStorage
,
PinnedStorage
};
use
anyhow
::
Result
;
use
cudarc
::
driver
::
result
as
cuda_result
;
use
std
::
ops
::
Range
;
type
CudaMemcpyFnPtr
=
unsafe
fn
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
;
fn
cuda_memcpy_fn_ptr
(
strategy
:
&
TransferStrategy
)
->
Result
<
CudaMemcpyFnPtr
,
TransferError
>
{
match
strategy
{
TransferStrategy
::
CudaAsyncH2D
=>
Ok
(
cuda_memcpy_h2d
),
TransferStrategy
::
CudaAsyncD2H
=>
Ok
(
cuda_memcpy_d2h
),
TransferStrategy
::
CudaAsyncD2D
=>
Ok
(
cuda_memcpy_d2d
),
_
=>
Err
(
TransferError
::
ExecutionError
(
"Unsupported copy strategy for CUDA memcpy async"
.into
(),
)),
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
copy_block
<
'a
,
Source
,
Destination
>
(
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
stream
:
&
CudaStream
,
strategy
:
TransferStrategy
,
)
->
Result
<
(),
TransferError
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
let
memcpy_fn
=
cuda_memcpy_fn_ptr
(
&
strategy
)
?
;
#[cfg(debug_assertions)]
{
let
expected_strategy
=
expected_strategy
::
<
Source
::
StorageType
,
Destination
::
StorageType
>
();
assert_eq!
(
strategy
,
expected_strategy
);
}
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
src_view
=
src_data
.block_view
()
?
;
let
mut
dst_view
=
dst_data
.block_view_mut
()
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
}
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
copy_layers
(
0
..
src_data
.num_layers
(),
sources
,
destinations
,
stream
,
strategy
,
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub
fn
copy_layers
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
stream
:
&
CudaStream
,
strategy
:
TransferStrategy
,
)
->
Result
<
(),
TransferError
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
let
memcpy_fn
=
cuda_memcpy_fn_ptr
(
&
strategy
)
?
;
#[cfg(debug_assertions)]
{
let
expected_strategy
=
expected_strategy
::
<
Source
::
StorageType
,
Destination
::
StorageType
>
();
assert_eq!
(
strategy
,
expected_strategy
);
}
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy_fn
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
(),
stream
,
)
?
;
}
}
Ok
(())
}
/// Helper function to perform the appropriate CUDA memcpy based on storage types
// Allow dead code because it's used in debug assertions
#[allow(dead_code)]
fn
expected_strategy
<
Source
:
Storage
,
Dest
:
Storage
>
()
->
TransferStrategy
{
match
(
std
::
any
::
TypeId
::
of
::
<
Source
>
(),
std
::
any
::
TypeId
::
of
::
<
Dest
>
(),
)
{
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
PinnedStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncH2D
}
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
PinnedStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncD2H
}
(
src
,
dst
)
if
src
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
&&
dst
==
std
::
any
::
TypeId
::
of
::
<
DeviceStorage
>
()
=>
{
TransferStrategy
::
CudaAsyncD2D
}
_
=>
TransferStrategy
::
Invalid
,
}
}
/// H2D Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_h2d
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source host pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination device pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
let
src_slice
=
std
::
slice
::
from_raw_parts
(
src_ptr
,
size
);
cuda_result
::
memcpy_htod_async
(
dst_ptr
as
u64
,
src_slice
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA H2D memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
/// D2H Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_d2h
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source device pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination host pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
let
dst_slice
=
std
::
slice
::
from_raw_parts_mut
(
dst_ptr
,
size
);
cuda_result
::
memcpy_dtoh_async
(
dst_slice
,
src_ptr
as
u64
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA D2H memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
/// D2D Implementation
#[inline(always)]
unsafe
fn
cuda_memcpy_d2d
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
,
stream
:
&
CudaStream
,
)
->
Result
<
(),
TransferError
>
{
debug_assert!
(
!
src_ptr
.is_null
(),
"Source device pointer is null"
);
debug_assert!
(
!
dst_ptr
.is_null
(),
"Destination device pointer is null"
);
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination device memory regions must not overlap for D2D copy"
);
cuda_result
::
memcpy_dtod_async
(
dst_ptr
as
u64
,
src_ptr
as
u64
,
size
,
stream
.cu_stream
())
.map_err
(|
e
|
TransferError
::
ExecutionError
(
format!
(
"CUDA D2D memcpy failed: {}"
,
e
)))
?
;
Ok
(())
}
#[cfg(all(test,
feature
=
"testing-cuda"
))]
mod
tests
{
use
super
::
*
;
use
crate
::
block_manager
::
storage
::{
DeviceAllocator
,
PinnedAllocator
,
StorageAllocator
,
StorageMemset
,
};
#[test]
fn
test_memset_and_transfer
()
{
// Create allocators
let
device_allocator
=
DeviceAllocator
::
default
();
let
pinned_allocator
=
PinnedAllocator
::
default
();
let
ctx
=
device_allocator
.ctx
()
.clone
();
// Create CUDA stream
let
stream
=
ctx
.new_stream
()
.unwrap
();
// Allocate host and device memory
let
mut
host
=
pinned_allocator
.allocate
(
1024
)
.unwrap
();
let
mut
device
=
device_allocator
.allocate
(
1024
)
.unwrap
();
// Set a pattern in host memory
StorageMemset
::
memset
(
&
mut
host
,
42
,
0
,
1024
)
.unwrap
();
// Verify host memory was set correctly
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
42
));
}
// Copy host to device
unsafe
{
cuda_memcpy_h2d
(
host
.as_ptr
(),
device
.as_mut_ptr
(),
1024
,
stream
.as_ref
())
.unwrap
();
}
// Synchronize to ensure H2D copy is complete
stream
.synchronize
()
.unwrap
();
// Clear host memory
StorageMemset
::
memset
(
&
mut
host
,
0
,
0
,
1024
)
.unwrap
();
// Verify host memory was cleared
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
0
));
}
// Copy back from device to host
unsafe
{
cuda_memcpy_d2h
(
device
.as_ptr
(),
host
.as_mut_ptr
(),
1024
,
stream
.as_ref
())
.unwrap
();
}
// Synchronize to ensure D2H copy is complete before verifying
stream
.synchronize
()
.unwrap
();
// Verify the original pattern was restored
unsafe
{
let
ptr
=
host
.as_ptr
();
let
slice
=
std
::
slice
::
from_raw_parts
(
ptr
,
1024
);
assert
!
(
slice
.iter
()
.all
(|
&
x
|
x
==
42
));
}
}
}
lib/llm/src/block_manager/block/transfer/memcpy.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
/// Copy a block from a source to a destination using memcpy
pub
fn
copy_block
<
'a
,
Source
,
Destination
>
(
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
)
->
Result
<
(),
TransferError
>
where
Source
:
ReadableBlock
,
Destination
:
WritableBlock
,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
src_view
=
src_data
.block_view
()
?
;
let
mut
dst_view
=
dst_data
.block_view_mut
()
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
copy_layers
(
0
..
src_data
.num_layers
(),
sources
,
destinations
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using memcpy
pub
fn
copy_layers
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
sources
:
&
'a
Source
,
destinations
:
&
'a
mut
Destination
,
)
->
Result
<
(),
TransferError
>
where
Source
:
ReadableBlock
,
// <Source as ReadableBlock>::StorageType: SystemAccessible + Local,
Destination
:
WritableBlock
,
// <Destination as WritableBlock>::StorageType: SystemAccessible + Local,
{
let
src_data
=
sources
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
destinations
.block_data_mut
(
private
::
PrivateToken
);
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
unsafe
{
memcpy
(
src_view
.as_ptr
(),
dst_view
.as_mut_ptr
(),
src_view
.size
());
}
}
Ok
(())
}
#[inline(always)]
unsafe
fn
memcpy
(
src_ptr
:
*
const
u8
,
dst_ptr
:
*
mut
u8
,
size
:
usize
)
{
debug_assert!
(
(
src_ptr
as
usize
+
size
<=
dst_ptr
as
usize
)
||
(
dst_ptr
as
usize
+
size
<=
src_ptr
as
usize
),
"Source and destination memory regions must not overlap for copy_nonoverlapping"
);
std
::
ptr
::
copy_nonoverlapping
(
src_ptr
,
dst_ptr
,
size
);
}
lib/llm/src/block_manager/block/transfer/nixl.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
std
::
ops
::
Range
;
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
write_block_to
<
'a
,
Source
,
Destination
>
(
src
:
&
'a
Source
,
dst
:
&
'a
mut
Destination
,
ctx
:
&
TransferContext
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
let
nixl_agent
=
ctx
.nixl_agent
()
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
None
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
mut
status
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
tracing
::
span!
(
tracing
::
Level
::
DEBUG
,
"Waiting for transfer to complete"
)
.in_scope
(||
{
while
status
{
status
=
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
();
}
});
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
write_layers_to
(
0
..
src_data
.num_layers
(),
src
,
dst
,
ctx
,
notify
)
?
;
}
Ok
(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub
fn
write_layers_to
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
src
:
&
'a
Source
,
dst
:
&
'a
mut
Destination
,
ctx
:
&
TransferContext
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
let
nixl_agent
=
ctx
.nixl_agent
()
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for
layer_idx
in
layer_range
{
let
src_view
=
src_data
.layer_view
(
layer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
}
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
Some
(
&
xfer_args
),
)
?
;
let
mut
status
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
tracing
::
span!
(
tracing
::
Level
::
DEBUG
,
"Waiting for transfer to complete"
)
.in_scope
(||
{
while
status
{
status
=
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
();
}
});
Ok
(())
}
lib/llm/src/block_manager/block/transfer/strategy.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This module implements the `WriteToStrategy` and `ReadFromStrategy` traits
//! for the common storage types.
use
super
::
*
;
impl
WriteToStrategy
<
SystemStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaBlockingH2D
}
}
impl
WriteToStrategy
<
SystemStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Memcpy
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncH2D
}
}
impl
WriteToStrategy
<
SystemStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaBlockingD2H
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncD2H
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
CudaAsyncD2D
}
}
impl
<
S
:
Storage
+
Local
>
WriteToStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
SystemStorage
where
S
:
WriteToStrategy
<
SystemStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
PinnedStorage
where
S
:
WriteToStrategy
<
PinnedStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
>
ReadFromStrategy
<
S
>
for
DeviceStorage
where
S
:
WriteToStrategy
<
DeviceStorage
>
+
Storage
+
Local
,
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
S
::
write_to_strategy
()
}
}
impl
<
S
:
Storage
+
Local
>
ReadFromStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlRead
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
write_to_strategy
()
{
// System to ...
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingH2D
);
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Pinned to ...
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncH2D
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Device to ...
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
SystemStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaBlockingD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
PinnedStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2H
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
DeviceStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
CudaAsyncD2D
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as WriteToStrategy<SystemStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<PinnedStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<DeviceStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
// assert_eq!(
// <NixlStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
// TransferStrategy::Invalid
// );
}
#[test]
fn
read_from_strategy
()
{
// System to ...
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaBlockingD2H
);
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Pinned to ...
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
Memcpy
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncD2H
);
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Device to ...
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
SystemStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaBlockingH2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
PinnedStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncH2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
DeviceStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
CudaAsyncD2D
);
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
);
// Nixl to ... should fail to compile
// assert_eq!(
// <NixlStorage as ReadFromStrategy<SystemStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<PinnedStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<DeviceStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
//
// assert_eq!(
// <NixlStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
// TransferStrategy::Invalid
// );
}
}
lib/llm/src/block_manager/block/view.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Block storage management.
//!
//! This module provides the implementation for managing collections of blocks
//! and their storage. It handles the relationship between storage, layout,
//! and individual blocks.
use
super
::{
BlockData
,
BlockError
,
Storage
};
pub
trait
Kind
:
std
::
marker
::
Sized
+
std
::
fmt
::
Debug
+
Clone
+
Copy
+
Send
+
Sync
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
BlockKind
;
impl
Kind
for
BlockKind
{}
#[derive(Debug,
Clone,
Copy)]
pub
struct
LayerKind
;
impl
Kind
for
LayerKind
{}
pub
type
BlockView
<
'a
,
S
>
=
MemoryView
<
'a
,
S
,
BlockKind
>
;
pub
type
BlockViewMut
<
'a
,
S
>
=
MemoryViewMut
<
'a
,
S
,
BlockKind
>
;
pub
type
LayerView
<
'a
,
S
>
=
MemoryView
<
'a
,
S
,
LayerKind
>
;
pub
type
LayerViewMut
<
'a
,
S
>
=
MemoryViewMut
<
'a
,
S
,
LayerKind
>
;
/// Storage view that provides safe access to a region of storage
#[derive(Debug)]
pub
struct
MemoryView
<
'a
,
S
:
Storage
,
K
:
Kind
>
{
_
block_data
:
&
'a
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
kind
:
std
::
marker
::
PhantomData
<
K
>
,
}
impl
<
'a
,
S
,
K
>
MemoryView
<
'a
,
S
,
K
>
where
S
:
Storage
,
K
:
Kind
,
{
/// Create a new storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
pub
(
crate
)
unsafe
fn
new
(
_
block_data
:
&
'a
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
)
->
Result
<
Self
,
BlockError
>
{
Ok
(
Self
{
_
block_data
,
addr
,
size
,
kind
:
std
::
marker
::
PhantomData
,
})
}
/// Get a raw pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - Access patterns respect the storage's thread safety model
pub
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
/// Size of the view in bytes
pub
fn
size
(
&
self
)
->
usize
{
self
.size
}
}
/// Mutable storage view that provides exclusive access to a region of storage
#[derive(Debug)]
pub
struct
MemoryViewMut
<
'a
,
S
:
Storage
,
K
:
Kind
>
{
_
block_data
:
&
'a
mut
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
kind
:
std
::
marker
::
PhantomData
<
K
>
,
}
impl
<
'a
,
S
:
Storage
,
K
:
Kind
>
MemoryViewMut
<
'a
,
S
,
K
>
{
/// Create a new mutable storage view
///
/// # Safety
/// The caller must ensure:
/// - addr + size <= storage.size()
/// - The view does not outlive the storage
/// - No other views exist for this region
pub
(
crate
)
unsafe
fn
new
(
_
block_data
:
&
'a
mut
BlockData
<
S
>
,
addr
:
usize
,
size
:
usize
,
)
->
Result
<
Self
,
BlockError
>
{
Ok
(
Self
{
_
block_data
,
addr
,
size
,
kind
:
std
::
marker
::
PhantomData
,
})
}
/// Get a raw mutable pointer to the view's data
///
/// # Safety
/// The caller must ensure:
/// - The pointer is not used after the view is dropped
/// - No other references exist while the pointer is in use
/// - Access patterns respect the storage's thread safety model
pub
unsafe
fn
as_mut_ptr
(
&
mut
self
)
->
*
mut
u8
{
self
.addr
as
*
mut
u8
}
/// Size of the view in bytes
pub
fn
size
(
&
self
)
->
usize
{
self
.size
}
}
mod
nixl
{
use
super
::
*
;
use
super
::
super
::
nixl
::
*
;
pub
use
nixl_sys
::{
MemType
,
MemoryRegion
,
NixlDescriptor
};
impl
<
S
:
Storage
,
K
:
Kind
>
MemoryRegion
for
MemoryView
<
'_
,
S
,
K
>
{
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
fn
size
(
&
self
)
->
usize
{
self
.size
()
}
}
impl
<
S
,
K
>
NixlDescriptor
for
MemoryView
<
'_
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
fn
mem_type
(
&
self
)
->
MemType
{
self
._block_data.layout
.storage_type
()
.nixl_mem_type
()
}
fn
device_id
(
&
self
)
->
u64
{
self
._block_data.layout
.storage_type
()
.nixl_device_id
()
}
}
impl
<
S
:
Storage
,
K
:
Kind
>
MemoryRegion
for
MemoryViewMut
<
'_
,
S
,
K
>
{
unsafe
fn
as_ptr
(
&
self
)
->
*
const
u8
{
self
.addr
as
*
const
u8
}
fn
size
(
&
self
)
->
usize
{
self
.size
()
}
}
impl
<
S
:
Storage
,
K
:
Kind
>
NixlDescriptor
for
MemoryViewMut
<
'_
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
fn
mem_type
(
&
self
)
->
MemType
{
self
._block_data.layout
.storage_type
()
.nixl_mem_type
()
}
fn
device_id
(
&
self
)
->
u64
{
self
._block_data.layout
.storage_type
()
.nixl_device_id
()
}
}
impl
<
'a
,
S
,
K
>
MemoryView
<
'a
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
// Ensure the underlying storage is a NixlDescriptor
K
:
Kind
,
{
/// Creates an immutable NIXL memory descriptor from this view.
pub
fn
as_nixl_descriptor
(
&
self
)
->
NixlMemoryDescriptor
<
'a
,
K
,
IsImmutable
>
{
NixlMemoryDescriptor
::
new
(
self
.addr
as
u64
,
// Address from the view
self
.size
(),
// Size from the view
NixlDescriptor
::
mem_type
(
self
),
// Delegate to self's NixlDescriptor impl
NixlDescriptor
::
device_id
(
self
),
// Delegate to self's NixlDescriptor impl
)
}
}
impl
<
'a
,
S
,
K
>
MemoryViewMut
<
'a
,
S
,
K
>
where
S
:
Storage
+
NixlDescriptor
,
K
:
Kind
,
{
/// Creates a mutable NIXL memory descriptor from this view.
// Note: We return a mutable descriptor even from an immutable borrow (&self)
// because the underlying memory region *can* be mutated.
pub
fn
as_nixl_descriptor_mut
(
&
mut
self
)
->
NixlMemoryDescriptor
<
'a
,
K
,
IsMutable
>
{
NixlMemoryDescriptor
::
new
(
self
.addr
as
u64
,
self
.size
(),
NixlDescriptor
::
mem_type
(
self
),
// Delegate to self's NixlDescriptor impl
NixlDescriptor
::
device_id
(
self
),
// Delegate to self's NixlDescriptor impl
)
}
}
}
lib/llm/src/block_manager/config.rs
0 → 100644
View file @
4564a387
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
#[derive(Debug,
Clone)]
pub
enum
NixlOptions
{
/// Enable NIXL and create a new NIXL agent
Enabled
,
/// Enable NIXL and use the provided NIXL agent
EnabledWithAgent
(
NixlAgent
),
/// Disable NIXL
Disabled
,
}
#[derive(Debug,
Clone,
Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvManagerRuntimeConfig
{
pub
worker_id
:
u64
,
#[builder(default)]
pub
cancellation_token
:
CancellationToken
,
#[builder(default
=
"NixlOptions::Enabled"
)]
pub
nixl
:
NixlOptions
,
}
impl
KvManagerRuntimeConfig
{
pub
fn
builder
()
->
KvManagerRuntimeConfigBuilder
{
KvManagerRuntimeConfigBuilder
::
default
()
}
}
impl
KvManagerRuntimeConfigBuilder
{
pub
fn
enable_nixl
(
mut
self
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
Enabled
);
self
}
pub
fn
use_nixl_agent
(
mut
self
,
agent
:
NixlAgent
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
EnabledWithAgent
(
agent
));
self
}
pub
fn
disable_nixl
(
mut
self
)
->
Self
{
self
.nixl
=
Some
(
NixlOptions
::
Disabled
);
self
}
}
#[derive(Debug,
Clone,
Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvManagerModelConfig
{
#[validate(range(min
=
1
))]
pub
num_layers
:
usize
,
#[validate(range(min
=
1
))]
pub
page_size
:
usize
,
#[validate(range(min
=
1
))]
pub
inner_dim
:
usize
,
#[builder(default
=
"DType::FP16"
)]
pub
dtype
:
DType
,
}
impl
KvManagerModelConfig
{
pub
fn
builder
()
->
KvManagerModelConfigBuilder
{
KvManagerModelConfigBuilder
::
default
()
}
}
#[derive(Builder,
Validate)]
#[builder(pattern
=
"owned"
,
build_fn(validate
=
"Self::validate"
))]
pub
struct
KvManagerLayoutConfig
<
S
:
Storage
+
NixlRegisterableStorage
>
{
/// The number of blocks to allocate
#[validate(range(min
=
1
))]
pub
num_blocks
:
usize
,
/// The type of layout to use
#[builder(default
=
"LayoutType::FullyContiguous"
)]
pub
layout_type
:
LayoutType
,
/// Storage for the blocks
/// If provided, the blocks will be allocated from the provided storage
/// Otherwise, the blocks will be allocated from
#[builder(default)]
pub
storage
:
Option
<
Vec
<
S
>>
,
/// If provided, the blocks will be allocated from the provided allocator
/// This option is mutually exclusive with the `storage` option
#[builder(default,
setter(custom))]
pub
allocator
:
Option
<
Arc
<
dyn
StorageAllocator
<
S
>>>
,
}
impl
<
S
:
Storage
+
NixlRegisterableStorage
>
KvManagerLayoutConfig
<
S
>
{
/// Create a new builder for the KvManagerLayoutConfig
pub
fn
builder
()
->
KvManagerLayoutConfigBuilder
<
S
>
{
KvManagerLayoutConfigBuilder
::
default
()
}
}
// Implement the validation and build functions on the generated builder type
// Note: derive_builder generates KvManagerBlockConfigBuilder<S>
impl
<
S
:
Storage
+
NixlRegisterableStorage
>
KvManagerLayoutConfigBuilder
<
S
>
{
/// Custom setter for the `allocator` field
pub
fn
allocator
(
mut
self
,
allocator
:
impl
StorageAllocator
<
S
>
+
'static
)
->
Self
{
self
.allocator
=
Some
(
Some
(
Arc
::
new
(
allocator
)));
self
}
// Validation function
fn
validate
(
&
self
)
->
Result
<
(),
String
>
{
match
(
self
.storage
.is_some
(),
self
.allocator
.is_some
())
{
(
true
,
false
)
|
(
false
,
true
)
=>
Ok
(()),
// XOR condition met
(
true
,
true
)
=>
Err
(
"Cannot provide both `storage` and `allocator`."
.to_string
()),
(
false
,
false
)
=>
Err
(
"Must provide either `storage` or `allocator`."
.to_string
()),
}
}
}
/// Configuration for the KvBlockManager
#[derive(Builder,
Validate)]
#[builder(pattern
=
"owned"
)]
pub
struct
KvBlockManagerConfig
{
/// Runtime configuration
///
/// This provides core runtime configuration for the KvBlockManager.
pub
runtime
:
KvManagerRuntimeConfig
,
/// Model configuration
///
/// This provides model-specific configuration for the KvBlockManager, specifically,
/// the number of layers and the size of the inner dimension which is directly related
/// to the type of attention used by the model.
///
/// Included in this configuration is also the page_size, i.e. the number of tokens that will
/// be represented in each "paged" KV block.
pub
model
:
KvManagerModelConfig
,
/// Specific configuration for the device layout
///
/// This includes the number of blocks and the layout of the data into the device memory/storage.
#[builder(default,
setter(strip_option))]
pub
device_layout
:
Option
<
KvManagerLayoutConfig
<
DeviceStorage
>>
,
/// Specific configuration for the host layout
///
/// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default,
setter(strip_option))]
pub
host_layout
:
Option
<
KvManagerLayoutConfig
<
PinnedStorage
>>
,
}
impl
KvBlockManagerConfig
{
/// Create a new builder for the KvBlockManagerConfig
pub
fn
builder
()
->
KvBlockManagerConfigBuilder
{
KvBlockManagerConfigBuilder
::
default
()
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment