"examples/python_rs/vscode:/vscode.git/clone" did not exist on "0439d3b51d8b7e2577a623e491b857e7d4ff90c2"
Commit f04359cf authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: global kv block manager (#45)

parent 530a6be0
...@@ -193,7 +193,7 @@ dependencies = [ ...@@ -193,7 +193,7 @@ dependencies = [
"once_cell", "once_cell",
"pin-project", "pin-project",
"portable-atomic", "portable-atomic",
"rand", "rand 0.8.5",
"regex", "regex",
"ring", "ring",
"rustls-native-certs 0.7.3", "rustls-native-certs 0.7.3",
...@@ -232,7 +232,7 @@ dependencies = [ ...@@ -232,7 +232,7 @@ dependencies = [
"derive_builder", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"reqwest-eventsource", "reqwest-eventsource",
"secrecy", "secrecy",
...@@ -496,7 +496,7 @@ dependencies = [ ...@@ -496,7 +496,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"instant", "instant",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"tokio", "tokio",
] ]
...@@ -535,9 +535,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -535,9 +535,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.1" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bindgen" name = "bindgen"
...@@ -747,12 +747,12 @@ dependencies = [ ...@@ -747,12 +747,12 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
"candle-metal-kernels", "candle-metal-kernels",
"cudarc", "cudarc 0.13.9 (registry+https://github.com/rust-lang/crates.io-index)",
"float8", "float8",
"gemm", "gemm",
"half", "half",
...@@ -760,7 +760,7 @@ dependencies = [ ...@@ -760,7 +760,7 @@ dependencies = [
"metal", "metal",
"num-traits", "num-traits",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
"rayon", "rayon",
"safetensors", "safetensors",
...@@ -781,7 +781,7 @@ dependencies = [ ...@@ -781,7 +781,7 @@ dependencies = [
"indicatif", "indicatif",
"log", "log",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"rustls 0.23.23", "rustls 0.23.23",
"serde", "serde",
...@@ -794,7 +794,7 @@ dependencies = [ ...@@ -794,7 +794,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -802,7 +802,7 @@ dependencies = [ ...@@ -802,7 +802,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -813,7 +813,7 @@ dependencies = [ ...@@ -813,7 +813,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -1212,6 +1212,14 @@ dependencies = [ ...@@ -1212,6 +1212,14 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "cudarc"
version = "0.13.9"
source = "git+https://github.com/coreylowman/cudarc.git?rev=8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe#8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe"
dependencies = [
"libloading",
]
[[package]] [[package]]
name = "curve25519-dalek" name = "curve25519-dalek"
version = "4.1.3" version = "4.1.3"
...@@ -1494,12 +1502,6 @@ dependencies = [ ...@@ -1494,12 +1502,6 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]] [[package]]
name = "dunce" name = "dunce"
version = "1.0.5" version = "1.0.5"
...@@ -1535,9 +1537,12 @@ dependencies = [ ...@@ -1535,9 +1537,12 @@ dependencies = [
"bindgen 0.70.1", "bindgen 0.70.1",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck",
"bytes", "bytes",
"chrono", "chrono",
"cmake", "cmake",
"cudarc 0.13.9 (git+https://github.com/coreylowman/cudarc.git?rev=8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe)",
"derive-getters",
"derive_builder", "derive_builder",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
...@@ -1553,11 +1558,14 @@ dependencies = [ ...@@ -1553,11 +1558,14 @@ dependencies = [
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
"ndarray",
"prometheus", "prometheus",
"proptest", "proptest",
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rand 0.9.0",
"rayon",
"regex", "regex",
"reqwest 0.12.14", "reqwest 0.12.14",
"rstest 0.18.2", "rstest 0.18.2",
...@@ -1639,7 +1647,7 @@ dependencies = [ ...@@ -1639,7 +1647,7 @@ dependencies = [
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"regex", "regex",
"rstest 0.23.0", "rstest 0.23.0",
"serde", "serde",
...@@ -1963,10 +1971,10 @@ version = "0.1.3" ...@@ -1963,10 +1971,10 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a14d1c3d88fdab81b5886c34b2a424b3fba5123564ea415ff98b160cf44c432f" checksum = "a14d1c3d88fdab81b5886c34b2a424b3fba5123564ea415ff98b160cf44c432f"
dependencies = [ dependencies = [
"cudarc", "cudarc 0.13.9 (registry+https://github.com/rust-lang/crates.io-index)",
"half", "half",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
] ]
...@@ -2401,7 +2409,7 @@ dependencies = [ ...@@ -2401,7 +2409,7 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"crunchy", "crunchy",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
] ]
...@@ -2439,13 +2447,6 @@ version = "0.5.0" ...@@ -2439,13 +2447,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hello_world"
version = "0.1.0"
dependencies = [
"dynamo-runtime",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.9" version = "0.3.9"
...@@ -2466,7 +2467,7 @@ dependencies = [ ...@@ -2466,7 +2467,7 @@ dependencies = [
"log", "log",
"native-tls", "native-tls",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"serde", "serde",
"serde_json", "serde_json",
...@@ -2983,19 +2984,6 @@ dependencies = [ ...@@ -2983,19 +2984,6 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "inventory" name = "inventory"
version = "0.3.20" version = "0.3.20"
...@@ -3394,6 +3382,16 @@ version = "0.8.4" ...@@ -3394,6 +3382,16 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
...@@ -3453,7 +3451,7 @@ dependencies = [ ...@@ -3453,7 +3451,7 @@ dependencies = [
"opentelemetry", "opentelemetry",
"opentelemetry-prometheus", "opentelemetry-prometheus",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"reqwest 0.11.27", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
...@@ -3573,7 +3571,7 @@ dependencies = [ ...@@ -3573,7 +3571,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core", "candle-core",
...@@ -3584,7 +3582,7 @@ dependencies = [ ...@@ -3584,7 +3582,7 @@ dependencies = [
"image", "image",
"indexmap 2.8.0", "indexmap 2.8.0",
"mistralrs-core", "mistralrs-core",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"serde", "serde",
"serde_json", "serde_json",
...@@ -3594,7 +3592,7 @@ dependencies = [ ...@@ -3594,7 +3592,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3622,7 +3620,6 @@ dependencies = [ ...@@ -3622,7 +3620,6 @@ dependencies = [
"image", "image",
"indexmap 2.8.0", "indexmap 2.8.0",
"indicatif", "indicatif",
"interprocess",
"itertools 0.13.0", "itertools 0.13.0",
"llguidance", "llguidance",
"lrtable", "lrtable",
...@@ -3635,7 +3632,7 @@ dependencies = [ ...@@ -3635,7 +3632,7 @@ dependencies = [
"objc", "objc",
"once_cell", "once_cell",
"radix_trie", "radix_trie",
"rand", "rand 0.8.5",
"rand_isaac", "rand_isaac",
"rayon", "rayon",
"regex", "regex",
...@@ -3645,7 +3642,6 @@ dependencies = [ ...@@ -3645,7 +3642,6 @@ dependencies = [
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
"serde-big-array",
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
...@@ -3668,7 +3664,7 @@ dependencies = [ ...@@ -3668,7 +3664,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3683,7 +3679,7 @@ dependencies = [ ...@@ -3683,7 +3679,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3709,7 +3705,7 @@ dependencies = [ ...@@ -3709,7 +3705,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"image", "image",
...@@ -3759,6 +3755,21 @@ dependencies = [ ...@@ -3759,6 +3755,21 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]] [[package]]
name = "neli" name = "neli"
version = "0.6.5" version = "0.6.5"
...@@ -3874,7 +3885,7 @@ version = "3.0.0" ...@@ -3874,7 +3885,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3" checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -3913,7 +3924,7 @@ dependencies = [ ...@@ -3913,7 +3924,7 @@ dependencies = [
"ed25519-dalek", "ed25519-dalek",
"getrandom 0.2.15", "getrandom 0.2.15",
"log", "log",
"rand", "rand 0.8.5",
"signatory", "signatory",
] ]
...@@ -3952,7 +3963,7 @@ version = "0.5.0" ...@@ -3952,7 +3963,7 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83" checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -4087,9 +4098,9 @@ dependencies = [ ...@@ -4087,9 +4098,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -4211,7 +4222,7 @@ dependencies = [ ...@@ -4211,7 +4222,7 @@ dependencies = [
"opentelemetry_api", "opentelemetry_api",
"ordered-float", "ordered-float",
"percent-encoding", "percent-encoding",
"rand", "rand 0.8.5",
"regex", "regex",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -4479,9 +4490,9 @@ dependencies = [ ...@@ -4479,9 +4490,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -4566,8 +4577,8 @@ dependencies = [ ...@@ -4566,8 +4577,8 @@ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"lazy_static", "lazy_static",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_chacha", "rand_chacha 0.3.1",
"rand_xorshift", "rand_xorshift",
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
"rusty-fork", "rusty-fork",
...@@ -4816,7 +4827,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" ...@@ -4816,7 +4827,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [ dependencies = [
"bytes", "bytes",
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand 0.8.5",
"ring", "ring",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"rustls 0.23.23", "rustls 0.23.23",
...@@ -4868,8 +4879,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4868,8 +4879,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc", "libc",
"rand_chacha", "rand_chacha 0.3.1",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy 0.8.23",
] ]
[[package]] [[package]]
...@@ -4879,7 +4901,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4879,7 +4901,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
] ]
[[package]] [[package]]
...@@ -4891,6 +4923,15 @@ dependencies = [ ...@@ -4891,6 +4923,15 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
] ]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.1",
]
[[package]] [[package]]
name = "rand_distr" name = "rand_distr"
version = "0.4.3" version = "0.4.3"
...@@ -4898,7 +4939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4898,7 +4939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [ dependencies = [
"num-traits", "num-traits",
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -4907,7 +4948,7 @@ version = "0.3.0" ...@@ -4907,7 +4948,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28" checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28"
dependencies = [ dependencies = [
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -4916,7 +4957,7 @@ version = "0.3.0" ...@@ -4916,7 +4957,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f"
dependencies = [ dependencies = [
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -4928,6 +4969,12 @@ dependencies = [ ...@@ -4928,6 +4969,12 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.10.0" version = "1.10.0"
...@@ -4965,12 +5012,6 @@ version = "0.5.5" ...@@ -4965,12 +5012,6 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.10" version = "0.5.10"
...@@ -5268,7 +5309,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5268,7 +5309,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14"
dependencies = [ dependencies = [
"quote", "quote",
"rand", "rand 0.8.5",
"syn 2.0.100", "syn 2.0.100",
] ]
...@@ -5639,15 +5680,6 @@ dependencies = [ ...@@ -5639,15 +5680,6 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-pickle" name = "serde-pickle"
version = "1.2.0" version = "1.2.0"
...@@ -5769,17 +5801,6 @@ dependencies = [ ...@@ -5769,17 +5801,6 @@ dependencies = [
"unsafe-libyaml", "unsafe-libyaml",
] ]
[[package]]
name = "service_metrics"
version = "0.1.0"
dependencies = [
"dynamo-runtime",
"futures",
"serde",
"serde_json",
"tokio",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
...@@ -5849,7 +5870,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5849,7 +5870,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31" checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31"
dependencies = [ dependencies = [
"pkcs8", "pkcs8",
"rand_core", "rand_core 0.6.4",
"signature", "signature",
"zeroize", "zeroize",
] ]
...@@ -5861,7 +5882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5861,7 +5882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [ dependencies = [
"digest", "digest",
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -6346,7 +6367,7 @@ dependencies = [ ...@@ -6346,7 +6367,7 @@ dependencies = [
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand", "rand 0.8.5",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
...@@ -6465,7 +6486,7 @@ dependencies = [ ...@@ -6465,7 +6486,7 @@ dependencies = [
"futures-sink", "futures-sink",
"http 1.3.1", "http 1.3.1",
"httparse", "httparse",
"rand", "rand 0.8.5",
"ring", "ring",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pki-types", "rustls-pki-types",
...@@ -6617,7 +6638,7 @@ dependencies = [ ...@@ -6617,7 +6638,7 @@ dependencies = [
"indexmap 1.9.3", "indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
...@@ -7187,12 +7208,6 @@ dependencies = [ ...@@ -7187,12 +7208,6 @@ dependencies = [
"rustix 0.38.44", "rustix 0.38.44",
] ]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -136,7 +136,7 @@ dependencies = [ ...@@ -136,7 +136,7 @@ dependencies = [
"once_cell", "once_cell",
"pin-project", "pin-project",
"portable-atomic", "portable-atomic",
"rand", "rand 0.8.5",
"regex", "regex",
"ring", "ring",
"rustls-native-certs 0.7.3", "rustls-native-certs 0.7.3",
...@@ -175,7 +175,7 @@ dependencies = [ ...@@ -175,7 +175,7 @@ dependencies = [
"derive_builder", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand", "rand 0.8.5",
"reqwest", "reqwest",
"reqwest-eventsource", "reqwest-eventsource",
"secrecy", "secrecy",
...@@ -367,7 +367,7 @@ dependencies = [ ...@@ -367,7 +367,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"instant", "instant",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"tokio", "tokio",
] ]
...@@ -400,9 +400,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -400,9 +400,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.2" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8faa168b8c4ffca39c2699e772943af41ec2b75fb1683dda07b28a6d285c53dc" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bindgen" name = "bindgen"
...@@ -966,9 +966,11 @@ dependencies = [ ...@@ -966,9 +966,11 @@ dependencies = [
"bindgen", "bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck",
"bytes", "bytes",
"chrono", "chrono",
"cmake", "cmake",
"derive-getters",
"derive_builder", "derive_builder",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
...@@ -984,6 +986,8 @@ dependencies = [ ...@@ -984,6 +986,8 @@ dependencies = [
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rand 0.9.0",
"rayon",
"regex", "regex",
"semver", "semver",
"serde", "serde",
...@@ -1053,7 +1057,7 @@ dependencies = [ ...@@ -1053,7 +1057,7 @@ dependencies = [
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
...@@ -1497,7 +1501,7 @@ dependencies = [ ...@@ -1497,7 +1501,7 @@ dependencies = [
"indicatif", "indicatif",
"libc", "libc",
"log", "log",
"rand", "rand 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.12", "thiserror 2.0.12",
...@@ -2251,7 +2255,7 @@ version = "3.0.0" ...@@ -2251,7 +2255,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3" checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -2279,7 +2283,7 @@ dependencies = [ ...@@ -2279,7 +2283,7 @@ dependencies = [
"ed25519-dalek", "ed25519-dalek",
"getrandom 0.2.15", "getrandom 0.2.15",
"log", "log",
"rand", "rand 0.8.5",
"signatory", "signatory",
] ]
...@@ -2309,7 +2313,7 @@ version = "0.5.0" ...@@ -2309,7 +2313,7 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83" checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -2372,9 +2376,9 @@ dependencies = [ ...@@ -2372,9 +2376,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -2570,9 +2574,9 @@ dependencies = [ ...@@ -2570,9 +2574,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -2822,7 +2826,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" ...@@ -2822,7 +2826,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [ dependencies = [
"bytes", "bytes",
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand 0.8.5",
"ring", "ring",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"rustls", "rustls",
...@@ -2864,8 +2868,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2864,8 +2868,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc", "libc",
"rand_chacha", "rand_chacha 0.3.1",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy",
] ]
[[package]] [[package]]
...@@ -2875,7 +2890,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2875,7 +2890,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
] ]
[[package]] [[package]]
...@@ -2887,6 +2912,15 @@ dependencies = [ ...@@ -2887,6 +2912,15 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
] ]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.1",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.10.0" version = "1.10.0"
...@@ -3402,7 +3436,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3402,7 +3436,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31" checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31"
dependencies = [ dependencies = [
"pkcs8", "pkcs8",
"rand_core", "rand_core 0.6.4",
"signature", "signature",
"zeroize", "zeroize",
] ]
...@@ -3414,7 +3448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3414,7 +3448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [ dependencies = [
"digest", "digest",
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -3717,7 +3751,7 @@ dependencies = [ ...@@ -3717,7 +3751,7 @@ dependencies = [
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand", "rand 0.8.5",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
...@@ -3806,7 +3840,7 @@ dependencies = [ ...@@ -3806,7 +3840,7 @@ dependencies = [
"futures-sink", "futures-sink",
"http", "http",
"httparse", "httparse",
"rand", "rand 0.8.5",
"ring", "ring",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pki-types", "rustls-pki-types",
...@@ -3931,7 +3965,7 @@ dependencies = [ ...@@ -3931,7 +3965,7 @@ dependencies = [
"indexmap 1.9.3", "indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
......
...@@ -22,6 +22,7 @@ license.workspace = true ...@@ -22,6 +22,7 @@ license.workspace = true
homepage.workspace = true homepage.workspace = true
[features] [features]
default = []
mistralrs = ["dep:mistralrs"] mistralrs = ["dep:mistralrs"]
llamacpp = ["dep:llama-cpp-2"] llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"] sglang = ["dep:async_zmq"]
...@@ -29,6 +30,7 @@ sentencepiece = ["dep:sentencepiece"] ...@@ -29,6 +30,7 @@ sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"] vllm = ["dep:async_zmq"]
python = ["dep:pyo3-async-runtimes", "dep:pythonize"] python = ["dep:pyo3-async-runtimes", "dep:pythonize"]
trtllm = [] trtllm = []
cuda_kv = ["dep:cudarc", "dep:ndarray"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"] cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"] metal = ["mistralrs/metal", "llama-cpp-2/metal"]
...@@ -45,6 +47,7 @@ async-trait = { workspace = true } ...@@ -45,6 +47,7 @@ async-trait = { workspace = true }
bytes = { workspace = true } bytes = { workspace = true }
derive_builder = {workspace = true } derive_builder = {workspace = true }
futures = { workspace = true } futures = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
...@@ -58,7 +61,18 @@ strum = { workspace = true } ...@@ -58,7 +61,18 @@ strum = { workspace = true }
async-openai = "0.27.2" async-openai = "0.27.2"
blake3 = "1" blake3 = "1"
bytemuck = "1.22"
derive-getters = "0.5"
rand = "0.9"
regex = "1" regex = "1"
rayon = "1"
# kv_cuda
cudarc = { git = "https://github.com/coreylowman/cudarc.git", rev = "8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe", features = ["cuda-12040"], optional = true }
ndarray = { version = "0.16", optional = true }
# candle-core = { version = "0.8.3", features = ["cuda"], optional = true }
# half = "2.4.1"
pyo3 = { version = "0.23.3", default-features = false, features = [ pyo3 = { version = "0.23.3", default-features = false, features = [
"macros", "macros",
"experimental-async", "experimental-async",
...@@ -84,7 +98,7 @@ prometheus = { version = "0.13" } ...@@ -84,7 +98,7 @@ prometheus = { version = "0.13" }
# mistralrs # mistralrs
either = { version = "1.13" } either = { version = "1.13" }
indexmap = { version = "2.6" } indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "a691154bb", optional = true } mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true }
# sglang # sglang
async_zmq = { version = "0.4.0", optional = true } async_zmq = { version = "0.4.0", optional = true }
......
...@@ -13,9 +13,125 @@ ...@@ -13,9 +13,125 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(not(feature = "trtllm"))] #[cfg(not(feature = "cuda_kv"))]
fn main() {} fn main() {}
#[cfg(feature = "cuda_kv")]
fn main() {
use std::{path::PathBuf, process::Command};
println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let nvcc = Command::new("which").arg("nvcc").output().unwrap();
let cuda_lib = if nvcc.status.success() {
println!("cargo:info=nvcc found in path");
// Extract the path from nvcc location by removing "bin/nvcc"
let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
let path = PathBuf::from(nvcc_path);
if let Some(parent) = path.parent() {
// Remove "nvcc"
if let Some(cuda_root) = parent.parent() {
// Remove "bin"
cuda_root.to_string_lossy().to_string()
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
println!("cargo:warning=nvcc not found in path");
get_cuda_root_or_default()
};
println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// Link against multiple CUDA libraries
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
// Make sure the CUDA libraries are found before other system libraries
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
cuda_lib_path.display()
);
// Create kernels directory for output if it doesn't exist
std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
println!("Kernels directory already exists");
});
// Compile CUDA code
let output = Command::new("nvcc")
.arg("src/kernels/block_copy.cu")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-o")
.arg("src/kernels/libblock_copy.o")
.arg("-c")
.output()
.expect("Failed to compile CUDA code");
if !output.status.success() {
panic!(
"Failed to compile CUDA kernel: {}",
String::from_utf8_lossy(&output.stderr)
);
}
// Create static library
#[cfg(target_os = "windows")]
{
Command::new("lib")
.arg("/OUT:src/kernels/block_copy.lib")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
}
#[cfg(not(target_os = "windows"))]
{
Command::new("ar")
.arg("rcs")
.arg("src/kernels/libblock_copy.a")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}
#[cfg(feature = "cuda_kv")]
fn get_cuda_root_or_default() -> String {
match std::env::var("CUDA_ROOT") {
Ok(path) => path,
Err(_) => {
// Default locations based on OS
if cfg!(target_os = "windows") {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
} else {
"/usr/local/cuda".to_string()
}
}
}
}
#[cfg(feature = "trtllm")] #[cfg(feature = "trtllm")]
fn main() { fn main() {
extern crate bindgen; extern crate bindgen;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdint.h>
#include <stdio.h>
#include <cstring>
#include <memory>
#include <vector>
// Error checking macro
#define CUDA_CHECK(call) \
do { \
cudaError_t error = call; \
if (error != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d - %s\n", __FILE__, __LINE__, cudaGetErrorString(error)); \
return error; \
} \
} while (0)
// Number of elements to process per thread
#define ELEMENTS_PER_THREAD 4
// Use cache-line sized chunks when possible
#define CACHE_LINE_SIZE 128 // 128 bytes for most GPUs
// Optimized kernel that processes elements in a dimension-aware manner
__global__ void
copy_blocks_kernel(
const void* src_data, void* dst_data, const int* src_block_ids, const int* dst_block_ids, int num_block_pairs,
int prefix_dim, int suffix_dim, int elem_size, size_t src_prefix_stride, size_t src_block_stride,
size_t src_suffix_stride, size_t dst_prefix_stride, size_t dst_block_stride, size_t dst_suffix_stride)
{
// Calculate the total number of elements to process
const size_t total_elements = (size_t)prefix_dim * num_block_pairs * suffix_dim;
// Calculate the total number of bytes in the suffix part
const size_t bytes_per_suffix = (size_t)suffix_dim * elem_size;
// Calculate how many cache-line sized chunks per suffix part
const size_t chunks_per_suffix = (bytes_per_suffix + CACHE_LINE_SIZE - 1) / CACHE_LINE_SIZE;
const size_t elements_per_chunk = CACHE_LINE_SIZE / elem_size;
const bool is_perfect_chunk = (bytes_per_suffix % CACHE_LINE_SIZE) == 0;
// Get global thread index
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread processes ELEMENTS_PER_THREAD chunk indices
const size_t start_chunk = thread_idx * ELEMENTS_PER_THREAD;
const size_t total_chunks = prefix_dim * num_block_pairs * chunks_per_suffix;
// Early exit if completely out of range
if (start_chunk >= total_chunks) {
return;
}
// Process multiple chunks per thread
for (int chunk_offset = 0; chunk_offset < ELEMENTS_PER_THREAD; chunk_offset++) {
// Current chunk index
size_t chunk_idx = start_chunk + chunk_offset;
// Check if this chunk is within bounds
if (chunk_idx >= total_chunks) {
return; // No more chunks to process
}
// Decompose chunk index into prefix, block, and suffix chunks
size_t blocks_chunks = num_block_pairs * chunks_per_suffix;
size_t prefix_idx = chunk_idx / blocks_chunks;
size_t remainder = chunk_idx % blocks_chunks;
size_t block_pair_idx = remainder / chunks_per_suffix;
size_t chunk_in_suffix = remainder % chunks_per_suffix;
// Bounds check
if (prefix_idx >= prefix_dim || block_pair_idx >= num_block_pairs) {
continue; // Skip this chunk
}
// Get the actual source and destination block IDs
int src_block_id = src_block_ids[block_pair_idx];
int dst_block_id = dst_block_ids[block_pair_idx];
// Calculate element offset within the suffix dimension
size_t suffix_elem_offset = chunk_in_suffix * CACHE_LINE_SIZE / elem_size;
// Calculate the byte offset using explicit strides for each dimension
size_t src_byte_offset =
prefix_idx * src_prefix_stride + src_block_id * src_block_stride + suffix_elem_offset * src_suffix_stride;
size_t dst_byte_offset =
prefix_idx * dst_prefix_stride + dst_block_id * dst_block_stride + suffix_elem_offset * dst_suffix_stride;
// Calculate elements to copy in this chunk
size_t elements_to_copy = elements_per_chunk;
if (!is_perfect_chunk && chunk_in_suffix == chunks_per_suffix - 1) {
// Last chunk might be smaller
elements_to_copy = suffix_dim - suffix_elem_offset;
}
// Copy data based on element size for better performance
if (elem_size == 2 && (elements_to_copy % 2 == 0)) {
// Use 32-bit loads/stores for 16-bit data when possible (half precision)
const uint32_t* src_ptr = (const uint32_t*)((const char*)src_data + src_byte_offset);
uint32_t* dst_ptr = (uint32_t*)((char*)dst_data + dst_byte_offset);
for (size_t i = 0; i < elements_to_copy / 2; i++) {
dst_ptr[i] = src_ptr[i];
}
// } else if (elem_size == 1 && (elements_to_copy % 4 == 0)) {
// // Use 32-bit loads/stores for 8-bit data when possible (half precision)
// const uint32_t* src_ptr = (const uint32_t*)((const char*)src_data + src_byte_offset);
// uint32_t* dst_ptr = (uint32_t*)((char*)dst_data + dst_byte_offset);
// for (size_t i = 0; i < elements_to_copy / 4; i++) {
// dst_ptr[i] = src_ptr[i];
// }
} else if (elem_size == 2) {
// Handle 16-bit elements one by one if necessary
const uint16_t* src_ptr = (const uint16_t*)((const char*)src_data + src_byte_offset);
uint16_t* dst_ptr = (uint16_t*)((char*)dst_data + dst_byte_offset);
for (size_t i = 0; i < elements_to_copy; i++) {
dst_ptr[i] = src_ptr[i];
}
} else if (elem_size == 4) {
// Copy 32-bit elements (float, int32)
const uint32_t* src_ptr = (const uint32_t*)((const char*)src_data + src_byte_offset);
uint32_t* dst_ptr = (uint32_t*)((char*)dst_data + dst_byte_offset);
for (size_t i = 0; i < elements_to_copy; i++) {
dst_ptr[i] = src_ptr[i];
}
} else if (elem_size == 8) {
// Copy 64-bit elements (double, int64)
const uint64_t* src_ptr = (const uint64_t*)((const char*)src_data + src_byte_offset);
uint64_t* dst_ptr = (uint64_t*)((char*)dst_data + dst_byte_offset);
for (size_t i = 0; i < elements_to_copy; i++) {
dst_ptr[i] = src_ptr[i];
}
} else {
// For other element sizes, copy byte by byte
const char* src_ptr = (const char*)src_data + src_byte_offset;
char* dst_ptr = (char*)dst_data + dst_byte_offset;
for (size_t i = 0; i < elements_to_copy * elem_size; i++) {
dst_ptr[i] = src_ptr[i];
}
}
}
}
// Simplified launcher that uses the 3D tensor view
extern "C" cudaError_t
copy_blocks_launcher_3d(
const void* src_data, void* dst_data, const int* d_src_block_ids, const int* d_dst_block_ids, int num_block_pairs,
int prefix_dim, int suffix_dim, int elem_size, int src_block_dim, int dst_block_dim, cudaStream_t stream)
{
// Validate inputs
if (src_data == NULL || dst_data == NULL) {
fprintf(stderr, "NULL data pointers\n");
return cudaErrorInvalidValue;
}
if (d_src_block_ids == NULL || d_dst_block_ids == NULL) {
fprintf(stderr, "NULL device block ID pointers\n");
return cudaErrorInvalidValue;
}
if (num_block_pairs <= 0) {
fprintf(stderr, "Invalid number of block pairs: %d\n", num_block_pairs);
return cudaErrorInvalidValue;
}
if (prefix_dim <= 0 || suffix_dim <= 0 || elem_size <= 0) {
fprintf(stderr, "Invalid dimensions: prefix=%d, suffix=%d, elem=%d\n", prefix_dim, suffix_dim, elem_size);
return cudaErrorInvalidValue;
}
// Calculate row-major strides internally
size_t src_suffix_stride = elem_size;
size_t dst_suffix_stride = elem_size;
size_t src_block_stride = suffix_dim * src_suffix_stride;
size_t dst_block_stride = suffix_dim * dst_suffix_stride;
size_t src_prefix_stride = src_block_dim * src_block_stride;
size_t dst_prefix_stride = dst_block_dim * dst_block_stride;
// // Optional debug output
// printf(
// "Tensor dims: prefix=%d, src_blocks=%d, dst_blocks=%d, suffix=%d, elem_size=%d\n", prefix_dim, src_blocks_dim,
// dst_blocks_dim, suffix_dim, elem_size);
// printf(
// "Calculated strides: src_prefix=%zu, src_block=%zu, src_suffix=%zu\n", src_prefix_stride, src_block_stride,
// src_suffix_stride);
// Calculate total number of bytes to copy
size_t total_bytes = (size_t)prefix_dim * num_block_pairs * suffix_dim * elem_size;
// Calculate number of cache-line sized chunks
size_t bytes_per_suffix = (size_t)suffix_dim * elem_size;
size_t chunks_per_suffix = (bytes_per_suffix + CACHE_LINE_SIZE - 1) / CACHE_LINE_SIZE;
size_t total_chunks = prefix_dim * num_block_pairs * chunks_per_suffix;
// Adjust grid size to account for multiple elements per thread
int total_threads = (total_chunks + ELEMENTS_PER_THREAD - 1) / ELEMENTS_PER_THREAD;
int cuda_block_size = 256;
int grid_size = (total_threads + cuda_block_size - 1) / cuda_block_size;
// Validate grid size
if (grid_size <= 0) {
fprintf(stderr, "Invalid grid size: %d\n", grid_size);
return cudaErrorInvalidValue;
}
// Launch kernel on specified stream
copy_blocks_kernel<<<grid_size, cuda_block_size, 0, stream>>>(
src_data, dst_data, d_src_block_ids, d_dst_block_ids, num_block_pairs, prefix_dim, suffix_dim, elem_size,
src_prefix_stride, src_block_stride, src_suffix_stride, dst_prefix_stride, dst_block_stride, dst_suffix_stride);
// Check for kernel launch errors immediately
cudaError_t kernel_error = cudaGetLastError();
if (kernel_error != cudaSuccess) {
fprintf(stderr, "Kernel execution error: %s\n", cudaGetErrorString(kernel_error));
return kernel_error;
}
return cudaSuccess;
}
extern "C" cudaError_t
copy_blocks_memcpy_3d(
const void* src_data, void* dst_data, const int* h_src_block_ids, const int* h_dst_block_ids, int num_block_pairs,
int prefix_dim, int suffix_dim, int elem_size, int src_block_dim, int dst_block_dim, cudaStream_t stream)
{
// Validate inputs
if (src_data == NULL || dst_data == NULL) {
fprintf(stderr, "NULL data pointers\n");
return cudaErrorInvalidValue;
}
if (h_src_block_ids == NULL || h_dst_block_ids == NULL) {
fprintf(stderr, "NULL host block ID pointers\n");
return cudaErrorInvalidValue;
}
if (num_block_pairs <= 0) {
fprintf(stderr, "Invalid number of block pairs: %d\n", num_block_pairs);
return cudaErrorInvalidValue;
}
if (prefix_dim <= 0 || suffix_dim <= 0 || elem_size <= 0) {
fprintf(stderr, "Invalid dimensions: prefix=%d, suffix=%d, elem=%d\n", prefix_dim, suffix_dim, elem_size);
return cudaErrorInvalidValue;
}
// Calculate row-major strides for source and destination
size_t suffix_size_bytes = suffix_dim * elem_size;
size_t src_block_stride = suffix_size_bytes;
size_t dst_block_stride = suffix_size_bytes;
size_t src_prefix_stride = src_block_dim * src_block_stride;
size_t dst_prefix_stride = dst_block_dim * dst_block_stride;
size_t count = 0;
// Loop through all prefix dimensions and block pairs
for (int prefix_idx = 0; prefix_idx < prefix_dim; prefix_idx++) {
for (int pair_idx = 0; pair_idx < num_block_pairs; pair_idx++) {
int src_block_id = h_src_block_ids[pair_idx];
int dst_block_id = h_dst_block_ids[pair_idx];
// Calculate byte offsets
size_t src_offset = prefix_idx * src_prefix_stride + src_block_id * src_block_stride;
size_t dst_offset = prefix_idx * dst_prefix_stride + dst_block_id * dst_block_stride;
// Copy the suffix data in one call (it's contiguous)
const void* src_ptr = static_cast<const char*>(src_data) + src_offset;
void* dst_ptr = static_cast<char*>(dst_data) + dst_offset;
cudaError_t error = cudaMemcpyAsync(dst_ptr, src_ptr, suffix_size_bytes, cudaMemcpyDefault, stream);
if (error != cudaSuccess) {
return error;
}
count += suffix_size_bytes;
}
}
return cudaSuccess;
}
// New function for 3D tensor copy blocks operation
extern "C" cudaError_t
copy_blocks_3d(
const void* src_data, void* dst_data, const int* h_src_block_ids, const int* h_dst_block_ids, int num_block_pairs,
int prefix_dim, int src_blocks_dim, int dst_blocks_dim, int suffix_dim, int elem_size)
{
#ifdef USE_KERNEL
// Allocate device memory for block IDs
int* d_src_block_ids = NULL;
int* d_dst_block_ids = NULL;
CUDA_CHECK(cudaMalloc(&d_src_block_ids, num_block_pairs * sizeof(int)));
CUDA_CHECK(cudaMalloc(&d_dst_block_ids, num_block_pairs * sizeof(int)));
CUDA_CHECK(
cudaMemcpyAsync(d_src_block_ids, h_src_block_ids, num_block_pairs * sizeof(int), cudaMemcpyHostToDevice, 0));
CUDA_CHECK(
cudaMemcpyAsync(d_dst_block_ids, h_dst_block_ids, num_block_pairs * sizeof(int), cudaMemcpyHostToDevice, 0));
// Launch kernel with explicit strides
cudaError_t result = copy_blocks_launcher_3d(
src_data, dst_data, d_src_block_ids, d_dst_block_ids, num_block_pairs, prefix_dim, suffix_dim, elem_size,
src_blocks_dim, dst_blocks_dim, 0);
// Handle errors from kernel launch
if (result != cudaSuccess) {
cudaFree(d_src_block_ids);
cudaFree(d_dst_block_ids);
return result;
}
#else
cudaError_t result = copy_blocks_memcpy_3d(
src_data, dst_data, h_src_block_ids, h_dst_block_ids, num_block_pairs, prefix_dim, suffix_dim, elem_size,
src_blocks_dim, dst_blocks_dim, 0);
#endif
// Wait for completion
CUDA_CHECK(cudaStreamSynchronize(0));
#ifdef USE_KERNEL
// Clean up
cudaFree(d_src_block_ids);
cudaFree(d_dst_block_ids);
#endif
return cudaSuccess;
}
// TODO: Refactor the driver code to take pointers for the device block_id arrays
// TODO: Maintain a blocking driver, but then also provide a non-blocking driver
//
// We will have N copies of the CopyStream struct which we will put in a reusable
// pool. Acquiring a CopyStream will let you perform a copy for a kv attention layer.
//
// From rust or python we'll execute this on a thread allowed to block. We'll await the
// cuda event for completion and report the return code on the driver.
//
// TODO: decide whether or not we need a pool of streams or use a single stream.
//
// We should be able to decouple this from the forward pass. The only condition is that
// a new forward pass can not start until the last copy has completed.
//
// To that end, we might want to tie this copy kernel to the stream used for the forward pass.
struct CopyStream {
// Device block arrays
int* d_src_blocks;
int* d_dst_blocks;
// Host copies of block arrays
int* h_src_blocks;
int* h_dst_blocks;
int num_blocks;
cudaStream_t stream;
cudaEvent_t start_event;
cudaEvent_t stop_event;
CopyStream(int num_layers, int num_blocks);
~CopyStream();
void reset();
};
CopyStream::CopyStream(int num_layers, int num_blocks)
{
cudaError_t status;
// Allocate device memory
status = cudaMalloc(&d_src_blocks, num_blocks * sizeof(int));
if (status != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(status));
return;
}
status = cudaMalloc(&d_dst_blocks, num_blocks * sizeof(int));
if (status != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(status));
cudaFree(d_src_blocks);
return;
}
// Allocate host memory
h_src_blocks = (int*)malloc(num_blocks * sizeof(int));
h_dst_blocks = (int*)malloc(num_blocks * sizeof(int));
if (!h_src_blocks || !h_dst_blocks) {
fprintf(stderr, "Host memory allocation failed\n");
if (h_src_blocks)
free(h_src_blocks);
cudaFree(d_src_blocks);
cudaFree(d_dst_blocks);
return;
}
status = cudaStreamCreate(&stream);
if (status != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(status));
free(h_src_blocks);
free(h_dst_blocks);
cudaFree(d_src_blocks);
cudaFree(d_dst_blocks);
return;
}
// Create events
status = cudaEventCreateWithFlags(&start_event, cudaEventDisableTiming);
if (status != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(status));
free(h_src_blocks);
free(h_dst_blocks);
cudaFree(d_src_blocks);
cudaFree(d_dst_blocks);
}
status = cudaEventCreateWithFlags(&stop_event, cudaEventDisableTiming);
if (status != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(status));
free(h_src_blocks);
free(h_dst_blocks);
cudaFree(d_src_blocks);
cudaFree(d_dst_blocks);
}
}
CopyStream::~CopyStream()
{
free(h_src_blocks);
free(h_dst_blocks);
cudaFree(d_src_blocks);
cudaFree(d_dst_blocks);
cudaEventDestroy(start_event);
cudaEventDestroy(stop_event);
}
extern "C" {
int cuda_malloc_host(void** ptr, size_t size);
int cuda_free_host(void* ptr);
int cuda_memcpy_async(void* dst, const void* src, size_t count, cudaStream_t stream);
int
copy_stream_create(CopyStream** stream, int num_layers, int num_blocks)
{
*stream = new CopyStream(num_layers, num_blocks);
return 0;
}
int
copy_stream_destroy(CopyStream* stream)
{
delete stream;
return 0;
}
int
copy_stream_prepare_block_ids(CopyStream* cs, int* src_block_ids, int* dst_block_ids, int num_blocks)
{
// Make host copies
memcpy(cs->h_src_blocks, src_block_ids, num_blocks * sizeof(int));
memcpy(cs->h_dst_blocks, dst_block_ids, num_blocks * sizeof(int));
// Copy to device (for kernel-based implementation)
CUDA_CHECK(
cudaMemcpyAsync(cs->d_src_blocks, src_block_ids, num_blocks * sizeof(int), cudaMemcpyHostToDevice, cs->stream));
CUDA_CHECK(
cudaMemcpyAsync(cs->d_dst_blocks, dst_block_ids, num_blocks * sizeof(int), cudaMemcpyHostToDevice, cs->stream));
cs->num_blocks = num_blocks;
return 0;
}
int
copy_stream_launch(
CopyStream* cs, const void* src_data, void* dst_data, int prefix_dim, int suffix_dim, int elem_size,
int src_block_dim, int dst_block_dim)
{
return copy_blocks_launcher_3d(
src_data, dst_data, cs->d_src_blocks, cs->d_dst_blocks, cs->num_blocks, prefix_dim, suffix_dim, elem_size,
src_block_dim, dst_block_dim, cs->stream);
}
int
copy_stream_memcpy(
CopyStream* cs, const void* src_data, void* dst_data, int prefix_dim, int suffix_dim, int elem_size,
int src_block_dim, int dst_block_dim)
{
return copy_blocks_memcpy_3d(
src_data, dst_data, cs->h_src_blocks, cs->h_dst_blocks, cs->num_blocks, prefix_dim, suffix_dim, elem_size,
src_block_dim, dst_block_dim, cs->stream);
}
int
copy_stream_sync(CopyStream* cs)
{
// sync on the event
CUDA_CHECK(cudaStreamSynchronize(cs->stream));
return cudaSuccess;
}
int
cuda_malloc_host(void** ptr, size_t size)
{
CUDA_CHECK(cudaHostAlloc(ptr, size, cudaHostAllocDefault));
return cudaSuccess;
}
int
cuda_free_host(void* ptr)
{
CUDA_CHECK(cudaFreeHost(ptr));
return cudaSuccess;
}
int
cuda_memcpy_async(void* dst, const void* src, size_t count, cudaStream_t stream)
{
CUDA_CHECK(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream));
return cudaSuccess;
}
int
cuda_memcpy_sync(void* dst, const void* src, size_t count)
{
CUDA_CHECK(cudaMemcpy(dst, src, count, cudaMemcpyDefault));
return cudaSuccess;
}
}
/// This accepts a 6D tensor with dimensions that represent a tensor to be distributed
/// across tensor parallel ranks.
///
/// The dimensions of the source tensor are expected to be:
/// dims[0]: kv or block (depending on KvLayout)
/// dims[1]: block or kv (depending on KvLayout)
/// dims[2]: block_size (sequence length) # aka bs
/// dims[3]: scatter_factor (dst_tp_size / src_tp_size)
/// dims[4]: num_heads / (src_tp_size * scatter_factor) # aka dst_num_heads or dnh
/// dims[5]: head_size # aka hs
///
/// The permutation applied is (3, 0, 1, 2, 4, 5) which transforms
/// the tensor:
/// - from: [kv/block, block/kv, bs, scatter_factor, dnh, hs] to
/// - to: [scatter_factor, kv/block, block/kv, bs, dnh, hs].
///
/// This transformation effectively distributes the heads dimension across
/// tensor parallel ranks, where we transform from src_tp_size to dst_tp_size,
/// with dst_tp_size > src_tp_size.
int
permute_scatter_memcpy(
const void* src, // source data
void* dst, // destination data
const uint32_t* dims, // 6d dimensions of source tensor
uint32_t num_dims, // semi-redundant, size of the dims array, must be 6
uint32_t elem_size, // element size in bytes
uint32_t block_dim_index, // which dimension represents blocks
uint32_t src_block_dim, // the dimension of the source blocks
uint32_t dst_block_dim, // the dimension of the destination blocks
int* src_block_ids, // from state: the block IDs to copy
int* dst_block_ids, // from state: the block IDs to copy
uint32_t num_blocks, // from state: the number of blocks to copy
cudaStream_t stream // from state: the stream to use
)
{
if (num_dims != 6) {
printf("ERROR: num_dims must be 6\n");
return -1;
}
if (block_dim_index != 0 && block_dim_index != 1) {
printf("ERROR: block_dim_index must be 0 or 1\n");
return -2;
}
uint32_t kv_dim_index = block_dim_index == 0 ? 1 : 0;
// expect dims[block_dim_index] == src_block_dim
// expect dims[kv_dim_index] == 2
if (dims[block_dim_index] != src_block_dim) {
printf("ERROR: dims[block_dim_index] must be equal to src_block_dim\n");
return -3;
}
if (dims[kv_dim_index] != 2) {
printf("ERROR: dims[kv_dim_index] must be 2\n");
return -4;
}
size_t src_shape[5];
size_t dst_shape[5];
src_shape[block_dim_index] = src_block_dim;
src_shape[kv_dim_index] = dims[kv_dim_index];
src_shape[2] = dims[2];
src_shape[3] = dims[3];
src_shape[4] = dims[4] * dims[5];
dst_shape[0] = dims[3]; // scatter factor
dst_shape[block_dim_index + 1] = dst_block_dim;
dst_shape[kv_dim_index + 1] = dims[kv_dim_index];
dst_shape[3] = dims[2]; // block size
dst_shape[4] = dims[4] * dims[5];
size_t src_strides[5];
size_t dst_strides[5];
src_strides[4] = elem_size;
dst_strides[4] = elem_size;
// Compute source strides recursively (row-major order)
for (int i = 3; i >= 0; i--) {
src_strides[i] = src_strides[i + 1] * src_shape[i + 1];
}
// Compute destination strides based on permuted dimensions
for (int i = 3; i >= 0; i--) {
dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1];
}
#ifdef DEBUG
printf("src_shape: ");
for (int i = 0; i < 5; i++) {
printf("%zu ", src_shape[i]);
}
printf("\n");
printf("src_strides: ");
for (int i = 0; i < 5; i++) {
printf("%zu ", src_strides[i]);
}
printf("\n");
printf("dst_shape: ");
for (int i = 0; i < 5; i++) {
printf("%zu ", dst_shape[i]);
}
printf("\n");
printf("dst_strides: ");
for (int i = 0; i < 5; i++) {
printf("%zu ", dst_strides[i]);
}
printf("\n");
#endif
size_t copy_size_bytes = dims[4] * dims[5] * elem_size;
// we will start by computing the full offsets for each inner copy blocks
size_t src_idx[5];
size_t dst_idx[5];
// notes:
// - in the outer two loops, the index for the dst is shifted by one since we moved the
// scatter dimension to the front [0]
const char* src_ptr = (const char*)src;
char* dst_ptr = (char*)dst;
// loop over blocks
for (int block = 0; block < num_blocks; block++) {
src_idx[block_dim_index] = block;
dst_idx[block_dim_index + 1] = block;
// loop over the kv dimension
for (int kv = 0; kv < src_shape[kv_dim_index]; kv++) {
src_idx[kv_dim_index] = kv;
dst_idx[kv_dim_index + 1] = kv;
// loop over block size
for (int block_size = 0; block_size < src_shape[2]; block_size++) {
src_idx[2] = block_size;
dst_idx[3] = block_size;
// loop over scatter factor
for (int scatter = 0; scatter < src_shape[3]; scatter++) {
src_idx[3] = scatter;
dst_idx[0] = scatter;
src_idx[4] = 0;
dst_idx[4] = 0;
size_t src_offset = 0;
size_t dst_offset = 0;
for (int i = 0; i < 5; i++) {
src_offset += src_idx[i] * src_strides[i];
dst_offset += dst_idx[i] * dst_strides[i];
}
auto rc =
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, copy_size_bytes, cudaMemcpyDefault, stream);
if (rc != cudaSuccess) {
printf("ERROR: cudaMemcpyAsync failed with error code %d\n", rc);
return -5;
}
}
}
}
}
return 0;
}
// Updated C API wrapper for the permutation function
extern "C" int
copy_stream_scatter(
CopyStream* cs, // the copy stream
const void* src_data, // the source data (single layer)
void* dst_data, // the destination data (single layer)
const uint32_t* dims, // 6d dimensions of source tensor
uint32_t num_dims, // semi-redundant, size of the dims array, must be 6
uint32_t elem_size, // element size in bytes
uint32_t block_dim_index, // which dimension represents blocks; either 0 or 1
uint32_t src_block_dim, // number of blocks in the src tensor (should match dims[block_dim_index])
uint32_t dst_block_dim // number of blocks in the dst tensor
)
{
return permute_scatter_memcpy(
src_data, //
dst_data, //
dims, //
num_dims, //
elem_size, //
block_dim_index, //
src_block_dim, //
dst_block_dim, //
cs->h_src_blocks, //
cs->h_dst_blocks, //
cs->num_blocks, //
cs->stream //
);
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod layer;
pub mod manager;
pub mod reserved;
pub mod reuse;
pub mod sequence;
pub mod storage;
// #[cfg(feature = "cuda_kv")]
// pub mod storage;
use reserved::*;
use std::{
collections::{BTreeMap, HashMap, VecDeque},
sync::{atomic::AtomicU64, Arc, RwLock},
};
use async_trait::async_trait;
use derive_getters::Dissolve;
use dynamo_runtime::{
raise,
utils::pool::{PoolExt, PoolItem, PoolValue, Returnable, SharedPoolItem},
Result,
};
use crate::tokens::{PartialTokenBlock, SequenceHash, TokenBlock, Tokens};
use tracing as log;
pub type UniqueBlock = PoolItem<KvBlock>;
pub type SharedBlock = SharedPoolItem<KvBlock>;
#[derive(Default)]
pub struct KvBlock {
token_block: TokenBlock,
priority: u32,
return_tick: u64,
}
// pub struct KvStorage {
// data: u64,
// size: usize,
// layer_idx: usize,
// block_idx: usize,
// /// The layout of the tensor
// layout: layer::KvLayer,
// }
impl KvBlock {
/// Creates a new KvBlock with the given token block
pub fn new(token_block: TokenBlock) -> Self {
Self {
token_block,
priority: 0,
return_tick: 0,
// storage: None,
}
}
/// Updates the token block
pub fn update_token_block(&mut self, token_block: TokenBlock) {
self.token_block = token_block;
}
/// Resets the block to its initial state
pub(crate) fn reset(&mut self) {
self.token_block = TokenBlock::default();
self.priority = 0;
self.return_tick = 0;
// self.storage = None;
// self.storage_state = StorageState::Absent;
}
}
impl Returnable for KvBlock {
fn on_return(&mut self) {}
}
pub struct KvBlockConfig {}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! While KV blocks can be formed from any memory allocation, we highly encourage you to
//! use large slabs of pinned or device memory.
//!
//! The primary reason for this to efficiently map to the RDMA transport layers which perform
//! better and have less overheads when using fewer large regions of registered memory vs many
//! smaller regions.
//!
//! To this end, we encourage the developer if using the BYO-memory option to allocate either
//! a single large tensor or a set of tensors, one per layer, to effectively map to the NIXL
//! dataplane.
use derive_builder::Builder;
use dynamo_runtime::{error, raise, utils::pool::Returnable, ErrorContext, Result};
use std::{ptr::NonNull, sync::Arc};
use validator::{Validate, ValidationError};
use super::storage::{DType, OwnedStorage, Storage, StorageType, TensorView};
extern "C" {
fn copy_blocks_3d(
src_data: *const std::ffi::c_void,
dst_data: *mut std::ffi::c_void,
h_src_block_ids: *const std::os::raw::c_int,
h_dst_block_ids: *const std::os::raw::c_int,
num_block_pairs: std::os::raw::c_int,
prefix_dim: std::os::raw::c_int,
src_blocks: std::os::raw::c_int,
dst_blocks: std::os::raw::c_int,
suffix_dim: std::os::raw::c_int,
elem_size: std::os::raw::c_int,
) -> std::os::raw::c_int;
fn copy_stream_create(
stream: *mut *mut std::ffi::c_void,
num_layers: std::os::raw::c_int,
num_blocks: std::os::raw::c_int,
) -> std::os::raw::c_int;
fn copy_stream_prepare_block_ids(
cs: *mut std::ffi::c_void,
src_block_ids: *const std::os::raw::c_int,
dst_block_ids: *const std::os::raw::c_int,
num_block_pairs: std::os::raw::c_int,
) -> std::os::raw::c_int;
#[allow(unused)]
fn copy_stream_launch(
cs: *mut std::ffi::c_void,
src_data: *const std::ffi::c_void,
dst_data: *mut std::ffi::c_void,
prefix_dim: std::os::raw::c_int,
suffix_dim: std::os::raw::c_int,
elem_size: std::os::raw::c_int,
src_block_dim: std::os::raw::c_int,
dst_block_dim: std::os::raw::c_int,
) -> std::os::raw::c_int;
fn copy_stream_memcpy(
cs: *mut std::ffi::c_void,
src_data: *const std::ffi::c_void,
dst_data: *mut std::ffi::c_void,
prefix_dim: std::os::raw::c_int,
suffix_dim: std::os::raw::c_int,
elem_size: std::os::raw::c_int,
src_block_dim: std::os::raw::c_int,
dst_block_dim: std::os::raw::c_int,
) -> std::os::raw::c_int;
fn copy_stream_sync(cs: *mut std::ffi::c_void) -> std::os::raw::c_int;
fn copy_stream_scatter(
cs: *mut std::ffi::c_void,
src_data: *const std::ffi::c_void,
dst_data: *mut std::ffi::c_void,
dims: *const u32,
num_dims: u32,
elem_size: u32,
block_dim_index: u32,
src_block_dim: u32,
dst_block_dim: u32,
) -> std::os::raw::c_int;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KvLayout {
/// Tensor is laid out as [kv, block, head, head_dim]
KvFirst,
/// Tensor is laid out as [block, kv, head, head_dim]
BlockFirst,
}
#[derive(Debug, Clone, Builder, PartialEq, Eq)]
pub struct KvModelDetails {
/// The number of layers in the model
number_of_layers: usize,
/// The number of heads in the tensor
number_of_heads: usize,
/// The size of each head in the tensor
head_size: usize,
/// Data type of the tensor
dtype: DType,
}
impl KvModelDetails {
pub fn number_of_elements_per_token_per_layer(&self) -> usize {
2 * self.number_of_heads * self.head_size
}
pub fn bytes_per_token_per_layer(&self) -> usize {
self.number_of_elements_per_token_per_layer() * self.dtype.size_in_bytes()
}
// pub fn number_of_elements_per_token(&self) -> usize {
// self.number_of_elements_per_token_per_layer() * self.number_of_layers
// }
// pub fn bytes_per_token(&self) -> usize {
// self.number_of_elements_per_token() * self.dtype.size_in_bytes()
// }
}
#[derive(Debug, Clone, Builder, Validate)]
#[validate(schema(function = "validate_block_details", skip_on_field_errors = true))]
pub struct KvBlockDetails {
/// The layout of the tensor
layout: KvLayout,
/// The size of each block in the tensor
block_size: usize,
/// The rank of the current process in the tensor parallel group
#[builder(default = "0")]
tp_rank: usize,
/// The size of the tensor parallel group
#[builder(default = "1")]
tp_size: usize,
/// The details of the model
model_details: KvModelDetails,
}
impl KvBlockDetails {
pub fn bytes_per_token_block_per_layer(&self) -> usize {
(self.model_details.bytes_per_token_per_layer() * self.block_size) / self.tp_size
}
pub fn is_compatible(&self, other: &KvBlockDetails) -> bool {
self.layout == other.layout
&& self.block_size == other.block_size
&& self.tp_size == other.tp_size
&& self.model_details == other.model_details
}
pub fn prefix_dim(&self) -> usize {
match self.layout {
KvLayout::KvFirst => 2,
KvLayout::BlockFirst => 1,
}
}
pub fn suffix_dim(&self) -> usize {
let suffix_dim = self.block_size
* (self.model_details.number_of_heads / self.tp_size)
* self.model_details.head_size;
match self.layout {
KvLayout::KvFirst => suffix_dim,
KvLayout::BlockFirst => 2 * suffix_dim,
}
}
pub fn elem_size(&self) -> usize {
self.model_details.dtype.size_in_bytes()
}
}
fn validate_block_details(block_details: &KvBlockDetails) -> Result<(), ValidationError> {
// tp size must evenly divide the number of heads
if block_details.model_details.number_of_heads % block_details.tp_size != 0 {
return Err(ValidationError::new("tp_size must evenly divide num_heads"));
}
if block_details.tp_rank >= block_details.tp_size {
return Err(ValidationError::new("tp_rank must be less than tp_size"));
}
if block_details.tp_size > block_details.model_details.number_of_heads {
return Err(ValidationError::new("tp_size must be less than num_heads"));
}
Ok(())
}
#[derive(Debug, Builder, Validate)]
#[validate(schema(function = "validate_kv_layer", skip_on_field_errors = true))]
pub struct KvLayer {
/// The layout of the tensor
layout: KvLayout,
/// The storage of the tensor
storage: OwnedStorage,
/// The number of blocks in the tensor
#[validate(range(min = 1))]
number_of_blocks: usize,
/// The size of each block in the tensor
#[validate(range(min = 1))]
block_size: usize,
/// The number of heads in the tensor of the canonical model
/// The actual number for this layer is this number divided by tp_size
#[validate(range(min = 1))]
number_of_heads: usize,
/// The size of each head in the tensor
#[validate(range(min = 1))]
head_size: usize,
/// DataType
dtype: DType,
/// The tensor parallel size (default is 1)
#[builder(default = 1)]
tp_size: usize,
/// The tensor parallel rank (default is 0)
#[builder(default = 0)]
tp_rank: usize,
}
fn validate_kv_layer(layer: &KvLayer) -> Result<(), ValidationError> {
if layer.number_of_heads % layer.tp_size != 0 {
return Err(ValidationError::new(
"number_of_heads must be divisible by tp_size",
));
}
if layer.tp_rank >= layer.tp_size {
return Err(ValidationError::new("tp_rank must be less than tp_size"));
}
if layer.tp_size > layer.number_of_heads {
return Err(ValidationError::new(
"tp_size must be less than number_of_heads",
));
}
let dims = layer.layer_shape();
let elements = dims.iter().product::<usize>();
let bytes = elements * layer.dtype.size_in_bytes();
if layer.storage.storage_size() < bytes {
return Err(ValidationError::new(
"storage must be at least as large as the layer",
));
}
Ok(())
}
impl KvLayer {}
impl Storage for KvLayer {
fn storage_type(&self) -> StorageType {
self.storage.storage_type()
}
fn get_pointer(&self) -> u64 {
self.storage.get_pointer()
}
fn storage_size(&self) -> usize {
self.storage.storage_size()
}
}
impl KvLayer {
fn from_storage(
block_details: &KvBlockDetails,
number_of_blocks: usize,
storage: OwnedStorage,
) -> Result<Self> {
let layer = Self {
storage,
number_of_blocks,
layout: block_details.layout.clone(),
block_size: block_details.block_size,
number_of_heads: block_details.model_details.number_of_heads,
head_size: block_details.model_details.head_size,
dtype: block_details.model_details.dtype,
tp_size: block_details.tp_size,
tp_rank: block_details.tp_rank,
};
layer.validate()?;
Ok(layer)
}
/// Get the shape of the layer
pub fn layer_shape(&self) -> [usize; 5] {
match self.layout {
KvLayout::KvFirst => [
2, // K and V as first dimension
self.number_of_blocks,
self.block_size,
self.number_of_heads / self.tp_size,
self.head_size,
],
KvLayout::BlockFirst => [
self.number_of_blocks,
2,
self.block_size,
self.number_of_heads / self.tp_size,
self.head_size,
],
}
}
/// Get a view of the layer
pub fn view(&self) -> Result<TensorView<'_, Self, 5>> {
// Calculate dimensions based on layout
let dims = self.layer_shape();
// Verify dimensions make sense
if self.number_of_heads % self.tp_size != 0 {
raise!(
"Number of heads ({}) is not divisible by tp_size ({})",
self.number_of_heads,
self.tp_size
);
}
// Log dimensions for debugging
tracing::debug!(
"Creating TensorView with dims: {:?}, dtype: {:?}, size: {}",
dims,
self.dtype,
self.dtype.size_in_bytes()
);
// Create and return the view
let view = TensorView::new(self, dims, self.dtype.size_in_bytes())
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(view)
}
/// Perform a copy of blocks from one layer to another
/// This launch a cuda kernel to perform the copy
pub fn copy_blocks_to(
&self,
src_block_ids: &[usize],
dst: &mut KvLayer,
dst_block_ids: &[usize],
) -> Result<()> {
if src_block_ids.len() != dst_block_ids.len() {
raise!("src_block_ids and dst_block_ids must have the same length");
}
if self.layout != dst.layout {
raise!("src and dst must have the same layout");
}
match (self.storage.storage_type(), dst.storage.storage_type()) {
(StorageType::Pinned, StorageType::Pinned) => {
raise!("Pinned to Pinned copy not implemented");
}
(StorageType::Pinned, StorageType::Device(_)) => {}
(StorageType::Device(_), StorageType::Pinned) => {}
(StorageType::Device(_), StorageType::Device(_)) => {
raise!("Device to Device copy not implemented");
}
(StorageType::System, _) => {
raise!("System to Device copy not implemented");
}
(_, StorageType::System) => {
raise!("Device to System copy not implemented");
}
};
let h_src_block_ids = src_block_ids
.iter()
.map(|id| *id as i32)
.collect::<Vec<_>>();
let h_dst_block_ids = dst_block_ids
.iter()
.map(|id| *id as i32)
.collect::<Vec<_>>();
let num_block_pairs = src_block_ids.len() as i32;
let prefix_dim = match self.layout {
KvLayout::KvFirst => 2,
KvLayout::BlockFirst => 1,
};
let suffix_dim = self.head_size * (self.number_of_heads / self.tp_size) * self.block_size;
let suffix_dim = match self.layout {
KvLayout::KvFirst => suffix_dim,
KvLayout::BlockFirst => 2 * suffix_dim,
};
let elem_size = self.dtype.size_in_bytes();
let src_blocks = self.number_of_blocks as i32;
let dst_blocks = dst.number_of_blocks as i32;
unsafe {
let rc = copy_blocks_3d(
self.storage.get_pointer() as *const std::ffi::c_void,
dst.storage.get_pointer() as *mut std::ffi::c_void,
h_src_block_ids.as_ptr() as *const std::os::raw::c_int,
h_dst_block_ids.as_ptr() as *const std::os::raw::c_int,
num_block_pairs,
prefix_dim,
src_blocks,
dst_blocks,
suffix_dim as i32,
elem_size as i32,
);
if rc != 0 {
raise!("Failed to copy blocks");
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct KvBlockStorage {
/// The details of the model
block_details: KvBlockDetails,
/// The type of storage
storage_type: StorageType,
/// Number of blocks
number_of_blocks: usize,
/// Layers
layers: Vec<KvLayer>,
}
/// This object holds a set of layers that are used to store the KV cache
impl KvBlockStorage {
/// Create a new KvBlockStorage object
/// This allows you to bring in a set of layers that are already allocated
pub fn from_layers(layers: Vec<KvLayer>) -> Result<Self> {
if layers.is_empty() {
raise!("Layers must not be empty");
}
// validate all layers have the same type
let storage_type = layers[0].storage.storage_type();
for layer in &layers {
if layer.storage.storage_type() != storage_type {
raise!("All layers must have the same storage type");
}
}
// validate all layers have the same number of blocks
let number_of_blocks = layers[0].number_of_blocks;
for layer in &layers {
if layer.number_of_blocks != number_of_blocks {
raise!("All layers must have the same number of blocks");
}
}
// extract the details from the first layer, construct ModelDetails, BlockDetails
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(layers.len())
.number_of_heads(layers[0].number_of_heads)
.head_size(layers[0].head_size)
.dtype(layers[0].dtype)
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(layers[0].layout.clone())
.block_size(layers[0].block_size)
.model_details(model_details.clone())
.tp_size(layers[0].tp_size)
.tp_rank(layers[0].tp_rank)
.build()?;
block_details.validate()?;
let bytes_per_token_block = block_details.bytes_per_token_block_per_layer();
let storage_type = layers[0].storage.storage_type();
// validate all layers have enough capacity to store hold the block data
for layer in &layers {
if layer.storage.storage_size() < bytes_per_token_block {
raise!("All layers must have enough capacity to store hold the block data");
}
}
Ok(Self {
block_details,
storage_type,
number_of_blocks,
layers,
})
}
/// Given a number of blocks and the block details, allocate the storage for the layers
pub fn allocate(
number_of_blocks: usize,
block_details: KvBlockDetails,
storage_type: StorageType,
) -> Result<Self> {
block_details.validate()?;
// determine the number of blocks
let bytes = block_details.bytes_per_token_block_per_layer() * number_of_blocks;
let mut layers = Vec::new();
// for each layer, create a device storage object, then for a kv layer
for layer in 0..block_details.model_details.number_of_layers {
let storage = OwnedStorage::create(bytes, storage_type.clone()).with_context(|| {
error!("Failed to allocate memory for KV BlockStorage for layer {layer}")
})?;
let layer = KvLayer::from_storage(&block_details, number_of_blocks, storage)?;
layers.push(layer);
}
Self::from_layers(layers).context("Validating KvBlockStorage")
}
/// Get an immutable reference to a layer
pub fn layer(&self, layer: usize) -> Result<&KvLayer> {
if layer >= self.layers.len() {
raise!(
"Layer index {} out of bounds (max {})",
layer,
self.layers.len() - 1
);
}
Ok(&self.layers[layer])
}
/// Get a mutable reference to a layer
pub fn layer_mut(&mut self, layer: usize) -> Result<&mut KvLayer> {
if layer >= self.layers.len() {
raise!(
"Layer index {} out of bounds (max {})",
layer,
self.layers.len() - 1
);
}
Ok(&mut self.layers[layer])
}
pub fn number_of_blocks(&self) -> usize {
self.number_of_blocks
}
pub fn storage_type(&self) -> StorageType {
self.storage_type.clone()
}
// pub fn suffix_dim(&self) -> usize {
// let value = match self.block_details.layout {
// KvLayout::KvFirst => self.block_details.model_details.number_of_heads * self.block_details.block_size,
// // s![block_id, 0..block_size, 0..
// KvLayout::BlockFirst => 2 *
// };
// }
}
/// This struct holds the details of the layers to be copied
/// We should not have to recompute this for each copy stream, simply once for each
/// block set and each direction.
///
/// If we have two block sets -- one host, one device, the we need two copies of
/// this object, one H2D copies and another for D2H copies.
///
/// If we direct address into flash storage, then we will need another pair for H2F
/// and F2H copies.
///
/// Note: When we register two block sets, we need to validate that the suffix dimensions
/// is equivalent in both sets.
///
/// We may in the future need to add src_suffix_dim, dst_suffix_dim, src_suffix_stride and
/// dst_suffix_stride to the details to support non-unit strides. Today, the copy kernel
/// does not support that.
#[derive(Debug, Clone, Builder, Default, Validate)]
#[validate(schema(
function = "validate_copy_stream_layer_details",
skip_on_field_errors = true
))]
pub struct CopyStreamBlockMap {
/// The source layer pointer
src_layer_ptrs: Vec<u64>,
/// The destination layer pointer
dst_layer_ptrs: Vec<u64>,
/// The non-contiguous dimension above the block_dimension
#[validate(range(min = 1))]
prefix_dim: i32,
/// The size of the source blocks dimension in the layer shape
#[validate(range(min = 1))]
src_block_dim: i32,
/// The size of the destination blocks dimension in the layer shape
#[validate(range(min = 1))]
dst_block_dim: i32,
/// The contiguous dimension below the block_dimension
#[validate(range(min = 1))]
suffix_dim: i32,
/// The element size in bytes
#[validate(range(min = 1, max = 8))]
elem_size: i32,
}
impl CopyStreamBlockMap {
pub fn new(src: &KvBlockStorage, dst: &KvBlockStorage) -> Result<Arc<Self>> {
if !src.block_details.is_compatible(&dst.block_details) {
return Err(error!("src and dst must have compatible block details"));
}
let src_layer_ptrs: Vec<u64> = src
.layers
.iter()
.map(|l| l.storage.get_pointer())
.collect::<Vec<_>>();
let dst_layer_ptrs: Vec<u64> = dst
.layers
.iter()
.map(|l| l.storage.get_pointer())
.collect::<Vec<_>>();
if src_layer_ptrs.len() != dst_layer_ptrs.len() {
return Err(error!("src and dst must have the same number of layers"));
}
let prefix_dim = src.block_details.prefix_dim() as i32;
let suffix_dim = src.block_details.suffix_dim() as i32;
let elem_size = src.block_details.elem_size() as i32;
let src_block_dim = src.number_of_blocks as i32;
let dst_block_dim = dst.number_of_blocks as i32;
let details = Self {
src_layer_ptrs,
dst_layer_ptrs,
prefix_dim,
src_block_dim,
dst_block_dim,
suffix_dim,
elem_size,
};
details.validate()?;
Ok(Arc::new(details))
}
}
fn validate_copy_stream_layer_details(
layer_details: &CopyStreamBlockMap,
) -> Result<(), ValidationError> {
if layer_details.src_layer_ptrs.is_empty() {
return Err(ValidationError::new("src_layer_ptrs must not be empty"));
}
if layer_details.dst_layer_ptrs.is_empty() {
return Err(ValidationError::new("dst_layer_ptrs must not be empty"));
}
if layer_details.src_layer_ptrs.len() != layer_details.dst_layer_ptrs.len() {
return Err(ValidationError::new(
"src_layer_ptrs and dst_layer_ptrs must have the same length",
));
}
Ok(())
}
#[derive(Debug)]
pub struct CopyStreamContext {
/// Pointer to the C++ copy stream object
c_handle: NonNull<std::ffi::c_void>,
/// Maximum number of layers used to initialize the C++ object
max_num_layers: usize,
/// Maximum number of blocks used to initialize the C++ object
max_num_blocks: usize,
/// Whether the layers have been staged
staged_layers: bool,
/// Whether the block ids have been staged
staged_block_ids: bool,
/// Doorbells for each layer
layer_doorbells: Vec<bool>,
// block ids
src_block_ids: Vec<i32>,
dst_block_ids: Vec<i32>,
// layer details
layer_details: Arc<CopyStreamBlockMap>,
}
impl CopyStreamContext {
pub fn new(max_num_layers: usize, max_num_blocks: usize) -> Result<Self> {
let mut c_handle = std::ptr::null_mut();
let rc = unsafe {
copy_stream_create(&mut c_handle, max_num_layers as i32, max_num_blocks as i32)
};
if rc != 0 {
return Err(error!("Failed to create copy stream"));
}
let layer_doorbells = vec![false; max_num_layers];
Ok(Self {
c_handle: NonNull::new(c_handle).ok_or(error!("Failed to create copy stream"))?,
max_num_layers,
max_num_blocks,
staged_layers: false,
staged_block_ids: false,
layer_doorbells,
// block ids
src_block_ids: Vec::new(),
dst_block_ids: Vec::new(),
// layer details
layer_details: Arc::new(CopyStreamBlockMap::default()),
})
}
pub fn src_block_dim(&self) -> usize {
self.layer_details.src_block_dim as usize
}
pub fn dst_block_dim(&self) -> usize {
self.layer_details.dst_block_dim as usize
}
}
unsafe impl Send for CopyStream {}
unsafe impl Sync for CopyStream {}
/// This object holds a stateful copy stream for the copy_blocks_3d kernel
/// Each instance will hold:
/// - device memory for the block ids
/// - a cuda stream
/// - a cuda event
pub struct CopyStream {
state: CopyStreamContext,
}
impl CopyStream {
pub fn new(num_layers: usize, num_blocks: usize) -> Result<Self> {
let state = CopyStreamContext::new(num_layers, num_blocks)?;
Ok(Self { state })
}
/// Prepare the layer pointers for the copy kernel
/// See [CopyStreamBlockMap] for more details
/// - src_layer_ptrs: the source layer pointers
/// - dst_layer_ptrs: the destination layer pointers
/// - prefix_dim: the non-contiguous dimension above the block_dimension
/// - src_blocks_dim: the size of the source blocks dimension in the layer shape
/// - dst_blocks_dim: the size of the destination blocks dimension in the layer shape
/// - suffix_dim: the contiguous dimension below the block_dimension
/// - elem_size: the element size in bytes
pub fn prepare_block_map(&mut self, details: Arc<CopyStreamBlockMap>) -> Result<()> {
let state = &mut self.state;
let layer_count = details.src_layer_ptrs.len();
if state.max_num_layers < layer_count {
return Err(error!(
"Number of layers {} exceeds max number of layers {}",
layer_count, state.max_num_layers
));
}
if state.staged_layers {
return Err(error!("Layers already loaded"));
}
state.staged_layers = true;
state.layer_details = details;
assert!(state.layer_doorbells.len() >= layer_count);
state
.layer_doorbells
.iter_mut()
.for_each(|doorbell| *doorbell = false);
Ok(())
}
/// Prepare the block ids for the copy kernel
/// See [CopyStreamBlockMap] for more details
/// - src_block_ids: the source block ids
/// - dst_block_ids: the destination block ids
pub fn prepare_block_ids(
&mut self,
src_block_ids: Vec<i32>,
dst_block_ids: Vec<i32>,
) -> Result<()> {
if src_block_ids.len() != dst_block_ids.len() {
return Err(error!(
"src_block_ids and dst_block_ids must have the same length"
));
}
// we could disable the unique block id test in production as it adds some overhead
#[cfg(debug_assertions)]
{
// validate that the dst block ids are unique
let dst_block_ids_set: std::collections::HashSet<_> = dst_block_ids.iter().collect();
if dst_block_ids_set.len() != dst_block_ids.len() {
return Err(error!("dst_block_ids must be unique"));
}
// validate that the src block ids are unique
let src_block_ids_set: std::collections::HashSet<_> = src_block_ids.iter().collect();
if src_block_ids_set.len() != src_block_ids.len() {
return Err(error!("src_block_ids must be unique"));
}
}
let state = &mut self.state;
if state.max_num_blocks < src_block_ids.len() {
return Err(error!(
"Number of blocks {} exceeds max number of blocks {}",
src_block_ids.len(),
state.max_num_blocks
));
}
if !state.staged_layers {
return Err(error!("Layers must be loaded before preparing block ids"));
}
if state.staged_block_ids {
return Err(error!("Block ids already loaded"));
}
// we need to copy the block ids to the state so we don't have to block on the async xfer
// of the lists from host to device
state.src_block_ids = src_block_ids;
state.dst_block_ids = dst_block_ids;
// transfer the block ids to the device
// this can be safely done without blocking as the copy stream state is
let rc = unsafe {
copy_stream_prepare_block_ids(
state.c_handle.as_ptr(),
state.src_block_ids.as_ptr() as *const std::os::raw::c_int,
state.dst_block_ids.as_ptr() as *const std::os::raw::c_int,
state.src_block_ids.len() as i32,
)
};
if rc != 0 {
return Err(error!("Failed to prepare block ids"));
}
state.staged_block_ids = true;
Ok(())
}
pub fn trigger_layer(&mut self, layer: usize) -> Result<()> {
let state = &mut self.state;
if !state.staged_layers {
return Err(error!("Layers must be loaded before triggering a layer"));
}
if !state.staged_block_ids {
return Err(error!("Block ids must be loaded before triggering a layer"));
}
if layer >= state.layer_details.src_layer_ptrs.len() {
return Err(error!(
"layer index {} out of bounds (max {})",
layer,
state.layer_details.src_layer_ptrs.len() - 1
));
}
if state.layer_doorbells[layer] {
tracing::trace!("layer {} already triggered; this is a no-op", layer);
return Ok(());
}
let cs = state.c_handle.as_ptr();
let src_data = state.layer_details.src_layer_ptrs[layer] as *const std::ffi::c_void;
let dst_data = state.layer_details.dst_layer_ptrs[layer] as *mut std::ffi::c_void;
let rc = unsafe {
copy_stream_memcpy(
cs,
src_data,
dst_data,
state.layer_details.prefix_dim,
state.layer_details.suffix_dim,
state.layer_details.elem_size,
state.layer_details.src_block_dim,
state.layer_details.dst_block_dim,
)
};
if rc != 0 {
return Err(error!("Failed to execute layer {} copy", layer));
}
state.layer_doorbells[layer] = true;
Ok(())
}
pub fn layer_count(&self) -> usize {
self.state.layer_details.src_layer_ptrs.len()
}
pub fn trigger_all_layers(&mut self) -> Result<()> {
let layer_count = self.layer_count();
for layer in 0..layer_count {
self.trigger_layer(layer)?;
}
Ok(())
}
pub fn sync_stream(&mut self) -> Result<()> {
let state = &mut self.state;
let cs = state.c_handle.as_ptr();
let rc = unsafe { copy_stream_sync(cs) };
if rc != 0 {
return Err(error!("Failed to synchronize copy stream"));
}
Ok(())
}
/// Performs a tensor permutation with arbitrary dimension reordering
/// This accepts a 5D tensor with a known src_tp_size and a dst_tp_size.
///
/// The dimensions of the src_data are expected to be consistent with the
/// KvLayout, where the first two dimensions are the kv dimension and the block
/// dimension depending on the KvLayout.
///
/// dim0: kv or block
/// dim1: block or kv
/// dim2: block_size
/// dim3: num_heads / src_tp_size
/// dim4: head_size
///
/// Note: the incoming tensor dimensions for dim3 is already the number of heads per TP rank
/// for the source TP size (src_tp_size).
///
/// A scatter will always transform from tpX -> tpY where X < Y.
///
/// The scale of the transformation `scatter_factor` is given by `dst_tp_size / src_tp_size`.
///
/// This results in a reshaping of the tensor view to 6 dimensions, but the memory layout
/// is still contiguous in memory.
///
/// src6d_dim0: kv or block
/// src6d_dim1: block or kv
/// src6d_dim2: block_size
/// src6d_dim3: scatter_factor
/// src6d_dim4: (model_num_heads / src_tp_size) / scatter_factor
/// src6d_dim5: head_size
///
/// The 6D tensor has exactly the same storage requirement and data layout as the 5D tensor.
///
/// The `scatter_copy` method will then do a fused permute copy from src_data to dst_data with
/// the following permutation: (3, 0, 1, 4, 2, 5) of the src6d dimensions.
///
/// The resulting dst6d dimensions are:
///
/// dst6d_dim0: scatter_factor
/// dst6d_dim1: kv or block
/// dst6d_dim2: block or kv
/// dst6d_dim3: block_size
/// dst6d_dim4: (model_num_heads / src_tp_size) / scatter_factor
/// dst6d_dim5: head_size
///
/// These are fully contiguous in memory.
///
pub fn scatter_copy_layer(
&mut self,
layer: usize,
dims: &[usize], // 5d dimensions of source tensor per the description above
elem_size: usize,
block_dim_index: usize, // Added parameter for block dimension index
src_tp_size: usize,
dst_tp_size: usize,
) -> Result<()> {
// validate the dimensions
if dims.len() != 5 {
return Err(error!("Expected 5 dimensions for src_data"));
}
// validate the block_dim_index is 0 or 1
if block_dim_index > 1 {
return Err(error!("block_dim_index must be 0 or 1"));
}
// validate the elem_size is supported (should be > 0 and <= 8)
if elem_size == 0 || elem_size > 8 {
return Err(error!(
"elem_size must be greater than 0 and less or equal to 8 bytes"
));
}
// validate src_tp_size < dst_tp_size and both are powers of 2
if src_tp_size >= dst_tp_size {
return Err(error!("src_tp_size must be less than dst_tp_size"));
}
if src_tp_size & (src_tp_size - 1) != 0 {
return Err(error!("src_tp_size must be a power of 2"));
}
if dst_tp_size & (dst_tp_size - 1) != 0 {
return Err(error!("dst_tp_size must be a power of 2"));
}
let scatter_factor = dst_tp_size / src_tp_size;
let state = &mut self.state;
if !state.staged_layers {
return Err(error!(
"Layers must be loaded before performing permutation"
));
}
if !state.staged_block_ids {
return Err(error!(
"Block IDs must be loaded before performing permutation"
));
}
// check layer index is valid
if layer >= state.layer_details.src_layer_ptrs.len() {
return Err(error!(
"layer index {} out of bounds (max {})",
layer,
state.layer_details.src_layer_ptrs.len() - 1
));
}
let src_data = state.layer_details.src_layer_ptrs[layer] as *const std::ffi::c_void;
let dst_data = state.layer_details.dst_layer_ptrs[layer] as *mut std::ffi::c_void;
// prepare 6d dimensions
let mut src_6d_dims = vec![0_u32; 6];
// populate src_6d_dims
src_6d_dims[0] = dims[0] as u32;
src_6d_dims[1] = dims[1] as u32;
src_6d_dims[2] = dims[2] as u32;
src_6d_dims[3] = scatter_factor as u32;
src_6d_dims[4] = (dims[3] / scatter_factor) as u32;
src_6d_dims[5] = dims[4] as u32;
tracing::debug!("scatter_factor: {}", scatter_factor);
tracing::debug!("src_6d_dims: {:?}", src_6d_dims);
// the state has the layers src/dst pointers and src/dst block ids
let cs = state.c_handle.as_ptr();
let rc = unsafe {
copy_stream_scatter(
cs,
src_data,
dst_data,
src_6d_dims.as_ptr(),
6_u32,
elem_size as u32,
block_dim_index as u32,
self.state.src_block_dim() as u32,
self.state.dst_block_dim() as u32,
)
};
if rc != 0 {
return Err(error!("Failed to execute tensor permutation"));
}
Ok(())
}
pub fn reset(&mut self) {
self.state.staged_block_ids = false;
self.state.staged_layers = false;
}
}
impl Returnable for CopyStream {
fn on_return(&mut self) {
// reset the staged flags
self.reset()
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use super::*;
use cudarc::driver::CudaContext;
use ndarray::prelude::*;
use rstest::rstest;
impl CopyStream {
fn reuse(&mut self) -> Result<()> {
self.state
.layer_doorbells
.iter_mut()
.for_each(|doorbell| *doorbell = false);
Ok(())
}
}
impl KvBlockStorage {
// only works with host/pinned fp32 tensors
fn fill_layer_with_block_id(&mut self) -> Result<()> {
if let StorageType::Device(_) = self.storage_type() {
raise!("fill_layer_with_block_id not implemented for device storage");
}
// get number of blocks
let num_blocks = self.number_of_blocks();
let num_layers = self.layers.len();
let layout = self.block_details.layout.clone();
for ilayer in 0..num_layers {
let layer = self.layer_mut(ilayer)?;
let mut view = layer.view()?;
let mut nd_view = view.as_ndarray_view_mut::<f32>()?;
for iblock in 0..num_blocks {
match &layout {
KvLayout::KvFirst => nd_view
.slice_mut(s![.., iblock, .., .., ..])
.fill(iblock as f32),
KvLayout::BlockFirst => {
nd_view
.slice_mut(s![iblock, .., .., .., ..])
.fill(iblock as f32);
}
}
}
}
Ok(())
}
}
#[rstest]
#[test]
fn test_kv_block_storage_kv_first() -> Result<()> {
let device = CudaContext::new(0)?;
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(2)
.number_of_heads(4)
.head_size(8)
.dtype(DType::F32)
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(KvLayout::KvFirst)
.block_size(8)
.tp_size(1)
.tp_rank(0)
.model_details(model_details)
.build()?;
// Create the storage blocks
let mut h_blocks =
KvBlockStorage::allocate(32, block_details.clone(), StorageType::Pinned)?;
let mut d_blocks = KvBlockStorage::allocate(
32,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
println!("Allocated pinned and device blocks");
println!("Letting layer 0 on host to be 1s");
// Use separate scopes to manage borrows
{
// Get a mutable reference to a layer
let layer = h_blocks.layer_mut(0)?;
// Create a mutable view and work with it
let mut view = layer.view()?;
// Get shape information before creating the ndarray view
let shape = *view.shape();
println!("TensorView shape: {:?}", shape);
// Create and use the mutable ndarray view in its own scope
{
let mut nd_view = view.as_ndarray_view_mut::<f32>()?;
let ones = ndarray::Array::from_shape_fn(nd_view.dim(), |_| 1.0);
nd_view.assign(&ones);
// Verify some values while we have the view
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 1.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 1.0);
}
// nd_view is dropped here, releasing the mutable borrow
}
// Copy data to device
let stream = device.new_stream()?;
println!("Copying data to device");
{
let h_view = h_blocks.layer(0)?.view().unwrap();
let mut d_view = d_blocks.layer_mut(0)?.view().unwrap();
h_view.copy_to_view_blocking(&mut d_view).unwrap();
stream.synchronize().unwrap();
}
println!("Setting all values on host back to 0");
// Set all values on host back to 0
{
let mut h_layer = h_blocks.layer_mut(0)?.view()?;
let mut nd_view = h_layer.as_ndarray_view_mut::<f32>()?;
let zeros = ndarray::Array::from_shape_fn(nd_view.dim(), |_| 0.0);
nd_view.assign(&zeros);
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 0.0);
}
println!("Copying data back to host");
// Copy data back to host
{
let d_view = d_blocks.layer(0)?.view()?;
let mut h_view = h_blocks.layer_mut(0)?.view()?;
d_view.copy_to_view_blocking(&mut h_view)?;
stream.synchronize()?;
}
println!("Verifying host data is 1");
// Verify the host data is not back to 1
{
let h_layer = h_blocks.layer(0)?.view()?;
let nd_view = h_layer.as_ndarray_view::<f32>()?;
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 1.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 1.0);
}
Ok(())
}
#[rstest]
#[test]
fn test_kv_block_storage_kv_first_direct() -> Result<()> {
let device = CudaContext::new(0)?;
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(2)
.number_of_heads(4)
.head_size(8)
.dtype(DType::F32)
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(KvLayout::KvFirst)
.block_size(8)
.tp_size(1)
.tp_rank(0)
.model_details(model_details)
.build()?;
// Create the storage blocks
let mut h_blocks =
KvBlockStorage::allocate(32, block_details.clone(), StorageType::Pinned)?;
let mut d_blocks = KvBlockStorage::allocate(
32,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
println!("Allocated pinned and device blocks");
println!("Letting layer 0 on host to be 1s");
// Use separate scopes to manage borrows
{
// Get a mutable reference to a layer
let layer = h_blocks.layer_mut(0)?;
// Create a mutable view and work with it
let mut view = layer.view()?;
// Get shape information before creating the ndarray view
let shape = *view.shape();
println!("TensorView shape: {:?}", shape);
// Create and use the mutable ndarray view in its own scope
{
let mut nd_view = view.as_ndarray_view_mut::<f32>()?;
let ones = ndarray::Array::from_shape_fn(nd_view.dim(), |_| 1.0);
nd_view.assign(&ones);
// Verify some values while we have the view
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 1.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 1.0);
}
// nd_view is dropped here, releasing the mutable borrow
}
println!("Copying data to device");
{
let blocks = (0..32).collect::<Vec<_>>();
let h_layer = h_blocks.layer(0).unwrap();
let d_layer = d_blocks.layer_mut(0).unwrap();
h_layer.copy_blocks_to(&blocks, d_layer, &blocks).unwrap();
}
println!("Setting all values on host back to 0");
// Set all values on host back to 0
{
let mut h_layer = h_blocks.layer_mut(0)?.view()?;
let mut nd_view = h_layer.as_ndarray_view_mut::<f32>()?;
let zeros = ndarray::Array::from_shape_fn(nd_view.dim(), |_| 0.0);
nd_view.assign(&zeros);
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 0.0);
}
println!("Copying data back to host");
// Copy data back to host
{
let blocks = (0..32).collect::<Vec<_>>();
let h_layer = h_blocks.layer_mut(0).unwrap();
let d_layer = d_blocks.layer(0).unwrap();
d_layer.copy_blocks_to(&blocks, h_layer, &blocks).unwrap();
}
println!("Verifying host data is 1");
// Verify the host data is not back to 1
{
let h_layer = h_blocks.layer(0)?.view()?;
let nd_view = h_layer.as_ndarray_view::<f32>()?;
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 1.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 1.0);
}
Ok(())
}
#[rstest]
#[case(KvLayout::KvFirst)]
#[case(KvLayout::BlockFirst)]
#[test]
fn test_kv_block_storage_layouts(#[case] layout: KvLayout) -> Result<()> {
let device = CudaContext::new(0)?;
let layout_name = match layout {
KvLayout::KvFirst => "KvFirst",
KvLayout::BlockFirst => "BlockFirst",
};
println!("Testing layout: {}", layout_name);
let number_of_blocks = 8;
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(2)
.number_of_heads(2)
.head_size(2)
.dtype(DType::F32)
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(layout)
.block_size(4)
.tp_size(1)
.tp_rank(0)
.model_details(model_details)
.build()?;
// Create the storage blocks
let mut h_blocks =
KvBlockStorage::allocate(number_of_blocks, block_details.clone(), StorageType::Pinned)?;
let mut d_blocks = KvBlockStorage::allocate(
number_of_blocks,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
println!("Allocated pinned and device blocks");
println!("Letting layer 0 on host to be 1s");
let layout = h_blocks.layer(0).unwrap().layout.clone();
let shape = *h_blocks.layer(0).unwrap().view().unwrap().shape();
println!("shape: {:?}", shape);
h_blocks.fill_layer_with_block_id().unwrap();
// test that our test function fill_layer_with_block_id works
{
// Get a mutable reference to a layer
let layer = h_blocks.layer_mut(0)?;
// Create a mutable view and work with it
let mut view = layer.view()?;
// Create and use the mutable ndarray view in its own scope
{
let nd_view = view.as_ndarray_view_mut::<f32>()?;
// iter over nd_view and set the values equal the the block index
// all kv and v of block 42 have values 42
match layout {
KvLayout::KvFirst => {
// let block_count = shape[block_dim_idx];
// Fill each block with its index as the value
// for i in 0..block_count {
// nd_view.slice_mut(s![.., i, .., .., ..]).fill(i as f32);
// }
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[1, 0, 1, 1, 1]], 0.0);
assert_eq!(nd_view[[0, 2, 0, 0, 0]], 2.0);
assert_eq!(nd_view[[1, 2, 0, 0, 0]], 2.0);
assert_eq!(nd_view[[1, 2, 1, 1, 1]], 2.0);
}
KvLayout::BlockFirst => {
//let block_count = shape[block_dim_idx];
// Fill each block with its index as the value
// for i in 0..block_count {
// nd_view.slice_mut(s![i, .., .., .., ..]).fill(i as f32);
// }
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[0, 1, 1, 1, 1]], 0.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 1.0);
assert_eq!(nd_view[[1, 1, 1, 1, 1]], 1.0);
}
}
}
// nd_view is dropped here, releasing the mutable borrow
}
// Copy data to device
let context = CudaContext::new(0)?;
let stream = context.new_stream()?;
println!("Copying data to device");
{
let blocks = (0..number_of_blocks).collect::<Vec<_>>();
let h_layer = h_blocks.layer(0).unwrap();
let d_layer = d_blocks.layer_mut(0).unwrap();
h_layer.copy_blocks_to(&blocks, d_layer, &blocks).unwrap();
stream.synchronize().unwrap();
}
println!("Setting all values on host back to 0");
// Set all values on host back to 0
{
let mut h_layer = h_blocks.layer_mut(0)?.view()?;
let mut nd_view = h_layer.as_ndarray_view_mut::<f32>()?;
nd_view.fill(0.0);
assert_eq!(nd_view[[0, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[1, 0, 0, 0, 0]], 0.0);
assert_eq!(nd_view[[0, 1, 1, 1, 1]], 0.0);
assert_eq!(nd_view[[1, 1, 1, 1, 1]], 0.0);
}
println!("Copying data back to host");
let src_blocks = &[1, 2, 2, 3, 5];
let dst_blocks = &[0, 3, 2, 1, 4];
// Copy data back to host
{
let h_layer = h_blocks.layer_mut(0).unwrap();
let d_layer = d_blocks.layer(0).unwrap();
d_layer
.copy_blocks_to(src_blocks, h_layer, dst_blocks)
.unwrap();
stream.synchronize().unwrap();
}
println!("Verifying host data is 1");
// Verify the host data is not back to 1
{
let h_layer = h_blocks.layer(0)?.view()?;
let nd_view = h_layer.as_ndarray_view::<f32>()?;
println!("nd_view: {:?}", nd_view);
// validate
for i in 0..src_blocks.len() {
println!(
"Validating src block {} -> dst block {}",
src_blocks[i], dst_blocks[i]
);
let expected_value = src_blocks[i] as f32;
match layout {
KvLayout::KvFirst => {
assert_eq!(nd_view[[0, dst_blocks[i], 0, 0, 0]], expected_value);
assert_eq!(nd_view[[1, dst_blocks[i], 0, 0, 0]], expected_value);
assert_eq!(nd_view[[0, dst_blocks[i], 1, 1, 1]], expected_value);
assert_eq!(nd_view[[1, dst_blocks[i], 1, 1, 1]], expected_value);
}
KvLayout::BlockFirst => {
assert_eq!(nd_view[[dst_blocks[i], 0, 0, 0, 0]], expected_value);
assert_eq!(nd_view[[dst_blocks[i], 0, 1, 1, 1]], expected_value);
assert_eq!(nd_view[[dst_blocks[i], 1, 0, 0, 0]], expected_value);
assert_eq!(nd_view[[dst_blocks[i], 1, 1, 1, 1]], expected_value);
}
}
}
}
Ok(())
}
#[rstest]
#[case(KvLayout::KvFirst)]
#[case(KvLayout::BlockFirst)]
#[test]
fn test_kv_block_copy_stream_validated(#[case] layout: KvLayout) -> Result<()> {
println!("Testing block copy stream with validation");
let device = CudaContext::new(0)?;
// Set up a small tensor for testing with validation
let number_of_layers = 2;
let number_of_heads = 2;
let head_size = 2;
let number_of_cpu_blocks = 8;
let number_of_gpu_blocks = 4;
let block_size = 2;
println!("Test configuration:");
println!(" Number of layers: {}", number_of_layers);
println!(" Number of heads: {}", number_of_heads);
println!(" Head size: {}", head_size);
println!(" Block size: {}", block_size);
println!(" CPU blocks: {}", number_of_cpu_blocks);
println!(" GPU blocks: {}", number_of_gpu_blocks);
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(number_of_layers)
.number_of_heads(number_of_heads)
.head_size(head_size)
.dtype(DType::F32) // Use F32 for easier validation
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(layout.clone())
.block_size(block_size)
.tp_size(1)
.tp_rank(0)
.model_details(model_details)
.build()?;
// Create the storage blocks
let mut h_blocks = KvBlockStorage::allocate(
number_of_cpu_blocks,
block_details.clone(),
StorageType::Pinned,
)?;
h_blocks.fill_layer_with_block_id().unwrap();
let d_blocks = KvBlockStorage::allocate(
number_of_gpu_blocks,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
// Set up block mapping for copying
let h2d_block_map = CopyStreamBlockMap::new(&h_blocks, &d_blocks)?;
let d2h_block_map = CopyStreamBlockMap::new(&d_blocks, &h_blocks)?;
let mut copy_stream = CopyStream::new(number_of_layers, number_of_gpu_blocks)?;
// Convert to i32 for the API
let src_block_ids: Vec<i32> = (0..number_of_gpu_blocks).map(|id| id as i32).collect();
let dst_block_ids: Vec<i32> = (0..number_of_gpu_blocks).map(|id| id as i32).collect();
// Test H2D copy
copy_stream.prepare_block_map(h2d_block_map.clone())?;
copy_stream.prepare_block_ids(src_block_ids.clone(), dst_block_ids.clone())?;
println!("Copying data from host to device");
copy_stream.trigger_all_layers()?;
copy_stream.sync_stream()?;
copy_stream.reset();
// Clear host blocks to verify D2H copy later
println!("Clearing host blocks to verify D2H copy");
for layer_idx in 0..number_of_layers {
let layer = h_blocks.layer_mut(layer_idx)?;
let mut view = layer.view()?;
let mut nd_view = view.as_ndarray_view_mut::<f32>()?;
nd_view.fill(0.0);
}
// Now copy back from device to host
// let's reverse the src block ids this will take gpu
//
// this should map:
// - gpu:3 -> cpu:0
// - gpu:2 -> cpu:1
// - gpu:1 -> cpu:2
// - gpu:0 -> cpu:3
let src_block_ids: Vec<i32> = (0..number_of_gpu_blocks)
.rev()
.map(|id| id as i32)
.collect();
copy_stream.prepare_block_map(d2h_block_map)?;
copy_stream.prepare_block_ids(src_block_ids.clone(), dst_block_ids.clone())?;
println!("Copying data back from device to host");
copy_stream.trigger_all_layers()?;
copy_stream.sync_stream()?;
// Verify the data transfer
for layer_idx in 0..number_of_layers {
let host_layer = h_blocks.layer(layer_idx)?;
let host_view = host_layer.view()?;
let host_nd_view = host_view.as_ndarray_view::<f32>()?;
// Validate that each destination block contains the expected values from the source block
for (src_block_id, dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_block_value = *src_block_id as f32;
match layout {
KvLayout::KvFirst => {
// In KvFirst layout, blocks are in dimension 1
let block_slice =
host_nd_view.slice(s![.., *dst_block_id as usize, .., .., ..]);
for &value in block_slice.iter() {
assert_eq!(
value, src_block_value,
"Block validation failed: src_block {} -> dst_block {}",
src_block_id, dst_block_id
);
}
}
KvLayout::BlockFirst => {
// In BlockFirst layout, blocks are in dimension 0
let block_slice =
host_nd_view.slice(s![*dst_block_id as usize, .., .., .., ..]);
for &value in block_slice.iter() {
assert_eq!(
value, src_block_value,
"Block validation failed: src_block {} -> dst_block {}",
src_block_id, dst_block_id
);
}
}
}
println!(
"Validated: src_block {} -> dst_block {}",
src_block_id, dst_block_id
);
}
}
println!("Transfer validation successful");
Ok(())
}
#[rstest]
#[case(KvLayout::KvFirst, true)]
#[case(KvLayout::KvFirst, false)]
#[case(KvLayout::BlockFirst, true)]
#[case(KvLayout::BlockFirst, false)]
#[test]
fn bench_kv_block_copy_stream(#[case] layout: KvLayout, #[case] is_h2d: bool) -> Result<()> {
let layout_name = match layout {
KvLayout::KvFirst => "KvFirst",
KvLayout::BlockFirst => "BlockFirst",
};
let direction = if is_h2d { "H2D" } else { "D2H" };
println!("Testing layout: {} direction: {}", layout_name, direction);
let device = CudaContext::new(0)?;
// Reduce sizes for testing performance
let number_of_layers = 32;
let number_of_heads = 8;
let head_size = 128;
let number_of_cpu_blocks = 64;
let number_of_gpu_blocks = 64;
let block_size = 64;
println!("Test configuration:");
println!(" Number of layers: {}", number_of_layers);
println!(" Number of heads: {}", number_of_heads);
println!(" Head size: {}", head_size);
println!(" Block size: {}", block_size);
println!(" CPU blocks: {}", number_of_cpu_blocks);
println!(" GPU blocks: {}", number_of_gpu_blocks);
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(number_of_layers)
.number_of_heads(number_of_heads)
.head_size(head_size)
.dtype(DType::F32) // Use F32 for easier validation
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(layout.clone())
.block_size(block_size)
.tp_size(1)
.tp_rank(0)
.model_details(model_details)
.build()?;
// Create the storage blocks
let h_blocks = KvBlockStorage::allocate(
number_of_cpu_blocks,
block_details.clone(),
StorageType::Pinned,
)?;
let d_blocks = KvBlockStorage::allocate(
number_of_gpu_blocks,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
let h2d_block_map = CopyStreamBlockMap::new(&h_blocks, &d_blocks).unwrap();
let d2h_block_map = CopyStreamBlockMap::new(&d_blocks, &h_blocks).unwrap();
let mut copy_stream = CopyStream::new(number_of_layers, number_of_gpu_blocks).unwrap();
// block list 0..64 as i32
let mut block_list: Vec<i32> = (0..number_of_gpu_blocks).map(|x| x as i32).collect();
// randomize the block list
use rand::seq::SliceRandom;
let mut rng = rand::rng();
block_list.shuffle(&mut rng);
let src_block_ids = block_list.clone();
block_list.shuffle(&mut rng);
let dst_block_ids = block_list.clone();
// Select the appropriate block map based on direction
if is_h2d {
copy_stream.prepare_block_map(h2d_block_map).unwrap();
} else {
copy_stream.prepare_block_map(d2h_block_map).unwrap();
}
copy_stream
.prepare_block_ids(src_block_ids, dst_block_ids)
.unwrap();
let timer = Instant::now();
copy_stream.trigger_all_layers().unwrap();
copy_stream.sync_stream().unwrap();
let duration = timer.elapsed();
println!("Time taken: {:?}", duration);
let iterations = 10;
let timer = Instant::now();
for _ in 0..iterations {
copy_stream.trigger_all_layers().unwrap();
copy_stream.reuse().unwrap();
}
copy_stream.sync_stream().unwrap();
let duration = timer.elapsed();
println!("Time taken: {:?}", duration);
let single_layer_gpu_storage_size = d_blocks.layers[0].storage.storage_size();
let total_gpu_storage_size = single_layer_gpu_storage_size * number_of_layers;
println!(
"Total GPU storage size: {:.2} MB ({} bytes)",
total_gpu_storage_size as f64 / (1024.0 * 1024.0),
total_gpu_storage_size
);
println!(
"Transfer rate: {:.2} GB/s",
(iterations * total_gpu_storage_size) as f64
/ (1024.0 * 1024.0 * 1024.0 * duration.as_secs_f64())
);
Ok(())
}
#[rstest]
#[case(KvLayout::KvFirst)]
#[case(KvLayout::BlockFirst)]
#[test]
fn test_kv_tensor_permute_basic(#[case] layout: KvLayout) -> Result<()> {
println!("Testing simple tensor permutation");
let device = CudaContext::new(0)?;
// Set up a small tensor for testing permutation
let number_of_blocks = 8;
let block_size = 2;
let number_of_heads = 16;
let head_size = 6;
let src_tp_size = 1;
let dst_tp_size = 4;
let scatter_factor = dst_tp_size / src_tp_size;
println!("Test configuration:");
println!(" Layout: {:?}", layout);
println!(" Number of blocks: {}", number_of_blocks);
println!(" Block size: {}", block_size);
println!(" Number of heads: {}", number_of_heads);
println!(" Head size: {}", head_size);
println!(" Source TP size: {}", src_tp_size);
println!(" Destination TP size: {}", dst_tp_size);
let (_kv_dim_idx, block_dim_idx) = match layout {
KvLayout::KvFirst => (0, 1),
KvLayout::BlockFirst => (1, 0),
};
let model_details = KvModelDetailsBuilder::default()
.number_of_layers(1)
.number_of_heads(number_of_heads)
.head_size(head_size)
.dtype(DType::F32)
.build()?;
let block_details = KvBlockDetailsBuilder::default()
.layout(layout.clone())
.block_size(block_size)
.tp_size(src_tp_size)
.tp_rank(0)
.model_details(model_details.clone())
.build()?;
// 1. Allocate storage for source blocks
let mut src_blocks =
KvBlockStorage::allocate(number_of_blocks, block_details.clone(), StorageType::Pinned)?;
// 2. Create a view for the source layer and initialize it
{
let src_layer = src_blocks.layer_mut(0)?;
let mut src_view = src_layer.view()?;
let src_shape = *src_view.shape();
println!("Source tensor shape: {:?}", src_shape);
println!("Scatter factor: {}", scatter_factor);
// println!("Source tensor strides: {:?}", src_view.strides());
// println!("Source tensor byte_strides: {:?}", src_view.byte_strides());
// 3. Initialize the source tensor with known values
// For KV layout [2, number_of_blocks, block_size, number_of_heads/tp_size, head_size]
// values will be: kv*1000 + block*100 + bs*10 + head*1 + dim*0.1
let mut nd_view = src_view.as_ndarray_view_mut::<f32>()?;
for idx_0 in 0..src_shape[0] {
for idx_1 in 0..src_shape[1] {
for bs_idx in 0..src_shape[2] {
for head_idx in 0..src_shape[3] {
for head_dim_idx in 0..src_shape[4] {
let value = 1000.0 * idx_0 as f32
+ 100.0 * idx_1 as f32
+ 10.0 * bs_idx as f32
+ 1.0 * head_idx as f32
+ 0.1 * head_dim_idx as f32;
nd_view[[idx_0, idx_1, bs_idx, head_idx, head_dim_idx]] = value;
// println!(
// "Init: [idx_0={}, idx_1={}, bs={}, h={}, hd={}] = {}",
// idx_0, idx_1, bs_idx, head_idx, head_dim_idx, value
// );
}
}
}
}
}
}
// Get source layer shape for later use
let src_layer = src_blocks.layer(0)?;
let src_view = src_layer.view()?;
let src_shape = *src_view.shape();
// 4. Allocate device storage for the test
let dst_blocks = KvBlockStorage::allocate(
number_of_blocks,
block_details.clone(),
StorageType::Device(device.clone()),
)?;
// 5. Create a copy stream for the tensor transpose
let mut copy_stream = CopyStream::new(1, number_of_blocks)?;
// Set up block mapping - for permutation, we'll use same blocks
let src_block_ids: Vec<i32> = (0..number_of_blocks).map(|id| id as i32).collect();
let dst_block_ids: Vec<i32> = (0..number_of_blocks).map(|id| id as i32).collect();
let h2d_block_map = CopyStreamBlockMap::new(&src_blocks, &dst_blocks)?;
let d2h_block_map = CopyStreamBlockMap::new(&dst_blocks, &src_blocks)?;
copy_stream.prepare_block_map(h2d_block_map)?;
copy_stream.prepare_block_ids(src_block_ids.clone(), dst_block_ids.clone())?;
// 6. Define a simple permutation - swap first two dimensions
// For this test, we'll use a simple permutation [1, 0, 2, 3, 4] - swap first two dims
let dims = src_shape.to_vec();
let elem_size = model_details.dtype.size_in_bytes();
// println!("Dimensions: {:?}", dims);
// println!("Source strides (bytes): {:?}", src_strides);
// println!("Element size: {} bytes", elem_size);
// 7. Execute the permutation
copy_stream.scatter_copy_layer(
0,
&dims,
elem_size,
block_dim_idx,
src_tp_size,
dst_tp_size,
)?;
copy_stream.sync_stream()?;
// 8. Preserve a copy of the source tensor for validation
let src_layer = src_blocks.layer(0)?;
let src_view = src_layer.view()?;
let src_nd = src_view.as_ndarray_view::<f32>()?;
let expected = src_nd.to_owned();
// reshape expected to match the src layout in 6d
let expected_6d = expected.into_shape_with_order([
src_shape[0],
src_shape[1],
src_shape[2],
scatter_factor,
src_shape[3] / scatter_factor,
src_shape[4],
])?;
let expected = expected_6d.permuted_axes([3, 0, 1, 2, 4, 5]).into_dyn();
// 9. Copy results back to host for verification
copy_stream.on_return(); // reset
copy_stream.prepare_block_map(d2h_block_map)?;
copy_stream.prepare_block_ids(src_block_ids.clone(), dst_block_ids.clone())?;
copy_stream.trigger_all_layers()?;
copy_stream.sync_stream()?;
// the host data should be updated with the values from the device
let src_layer = src_blocks.layer(0)?;
let src_view = src_layer.view()?;
let src_nd = src_view.as_ndarray_view::<f32>()?;
let actual = src_nd.to_owned();
// reshape actual to match the dst layout in 6d
let actual = actual.into_shape_with_order([
scatter_factor,
src_shape[0],
src_shape[1],
src_shape[2],
src_shape[3] / scatter_factor,
src_shape[4],
])?;
// 10. Validate
// the shapes of expected and actual should be the same
assert_eq!(expected.shape(), actual.shape());
println!("Output Shape: {:?}", actual.shape());
// check that the values are the same
// Compare all elements with a small epsilon to account for potential floating point differences
let epsilon = 1e-6;
let mut all_match = true;
for (idx, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() {
if (expected_val - actual_val).abs() > epsilon {
println!(
"Mismatch at index {}: expected {}, got {}",
idx, expected_val, actual_val
);
all_match = false;
break;
}
}
assert!(all_match, "Tensor values don't match after permutation");
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use reuse::AvailableBlocks;
/// Manages the reservation and priority reuse of kv blocks for a single storage type,
/// e.g. a GPU, host memory.
pub struct KvStorageManager {
available_blocks: AvailableBlocks,
inflight_blocks: ReservedBlocks,
block_size: usize,
}
impl KvStorageManager {
pub async fn new(block_size: usize) -> Self {
Self {
available_blocks: AvailableBlocks::new().await,
inflight_blocks: ReservedBlocks::new(block_size),
block_size,
}
}
pub async fn prepare_prefill_sequence(&mut self, tokens: Tokens) -> Result<PrefillMatched> {
log::debug!("adding request with {} tokens", tokens.len());
let seq = tokens.into_sequence(self.block_size);
let (blocks, tail_block) = seq.into_parts();
log::debug!(
"request translates to {} blocks; remaining tokens: {}",
blocks.len(),
tail_block.tokens().len()
);
// first match blocks to inflight blocks
let mut inflight_blocks = self.inflight_blocks.match_token_blocks(&blocks)?;
log::debug!("matched {} inflight blocks", inflight_blocks.len());
// shift the blocks to the left by the number of inflight blocks
let unmatched_blocks = &blocks[inflight_blocks.len()..];
let unmatched_hashes = unmatched_blocks
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
// match the remaining blocks to freed gpu blocks (available_blocks)
let unregistered_blocks = self.available_blocks.match_blocks(unmatched_hashes).await?;
log::debug!("matched {} freed blocks", unregistered_blocks.len());
// the blocks from the freed blocks pool must be registered as inflight blocks
// todo - we might have to register the list of unregistered blocks as a single transaction
for block in unregistered_blocks {
inflight_blocks.push(self.inflight_blocks.register(block)?);
}
// the remaining blocks are the unmatched blocks
let remaining_blocks = blocks.into_iter().skip(inflight_blocks.len()).collect();
Ok(PrefillMatched {
inflight_blocks,
remaining_blocks,
tail_block,
})
}
pub async fn prepare_prefill_offload(
&mut self,
matched: PrefillMatched,
) -> Result<PrefillOffload> {
let (inflight_blocks, remaining_blocks, tail_block) = matched.dissolve();
let mut blocks_to_reuse = self
.available_blocks
.take_blocks(remaining_blocks.len() as u32 + 1)
.await?;
if blocks_to_reuse.len() != remaining_blocks.len() + 1 {
raise!(
"expected {} blocks, got {}",
remaining_blocks.len() + 1,
blocks_to_reuse.len()
);
}
// update the blocks_to_reuse with the token block from remaining_blocks
let complete_prefill_blocks: Vec<UniqueBlock> = remaining_blocks
.into_iter()
.map(|b| {
let mut block = blocks_to_reuse.pop().unwrap();
block.update_token_block(b);
block
})
.collect();
assert_eq!(blocks_to_reuse.len(), 1);
let tail_kv_block = blocks_to_reuse.pop().unwrap();
let tail_prefill_block = PartialKvBlock {
token_block: tail_block,
kv_block: tail_kv_block,
};
Ok(PrefillOffload {
inflight_blocks,
complete_prefill_blocks,
tail_prefill_block,
})
}
}
#[derive(Dissolve)]
pub struct PartialKvBlock {
token_block: PartialTokenBlock,
kv_block: UniqueBlock,
}
#[derive(Dissolve)]
pub struct PrefillMatched {
inflight_blocks: Vec<ReservedBlock>,
remaining_blocks: Vec<TokenBlock>,
tail_block: PartialTokenBlock,
}
#[derive(Dissolve)]
pub struct PrefillOffload {
inflight_blocks: Vec<ReservedBlock>,
complete_prefill_blocks: Vec<UniqueBlock>,
tail_prefill_block: PartialKvBlock,
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use dynamo_runtime::logging::init;
// #[tokio::test]
// async fn test() {
// init();
// let mut manager = KvStorageManager::new(2);
// for _ in 0..100 {
// manager.available_blocks.insert(KvBlock::default());
// }
// let tokens = Tokens::from([0_i32, 1, 2, 3, 4, 5, 6, 7, 8].as_ref());
// // this is good for the scheduler to make a local decision as it now knows how many
// // net-new blocks need to be prefilled
// let sequence = manager.prepare_prefill_sequence(tokens).unwrap();
// assert_eq!(sequence.inflight_blocks.len(), 0);
// assert_eq!(sequence.remaining_blocks.len(), 4);
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Weak;
use super::*;
type ReservedBlockMap = Arc<RwLock<HashMap<SequenceHash, Weak<ReservedBlockInner>>>>;
#[derive(Clone)]
pub struct ReservedBlock {
inner: Arc<ReservedBlockInner>,
}
impl ReservedBlock {
fn new(inner: Arc<ReservedBlockInner>) -> Self {
Self { inner }
}
pub fn inflight_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
}
impl std::ops::Deref for ReservedBlock {
type Target = SharedBlock;
fn deref(&self) -> &Self::Target {
&self.inner.block
}
}
struct ReservedBlockInner {
block: SharedBlock,
map: ReservedBlockMap,
}
impl Drop for ReservedBlockInner {
fn drop(&mut self) {
let sequence_hash = self.block.token_block.sequence_hash();
let mut map = self.map.write().unwrap();
let val = map.remove(&sequence_hash);
if let Some(inner) = val {
if inner.strong_count() > 0 {
// this was not the weak pointer we were looking for
map.insert(sequence_hash, inner);
}
}
}
}
/// [ReservedBlocks] is a collection of inflight blocks that are actively being used
pub struct ReservedBlocks {
block_size: usize,
blocks: ReservedBlockMap,
}
impl ReservedBlocks {
pub fn new(block_size: usize) -> Self {
Self {
block_size,
blocks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<Vec<ReservedBlock>> {
let mut inflight_blocks = Vec::new();
let map = self.blocks.read().unwrap();
for sequence_hash in sequence_hashes {
if let Some(inner) = map.get(sequence_hash) {
if let Some(inner) = inner.upgrade() {
inflight_blocks.push(ReservedBlock::new(inner.clone()));
} else {
break;
}
} else {
break;
}
}
Ok(inflight_blocks)
}
/// Match the list of blocks to inflight blocks
///
/// This will return a [Vec<ReservedBlock>] that match the sequence hashes
/// in the order of the token blocks.
///
/// The matching is done in order, with the first block in the list being the first
/// block in the token blocks list.
///
/// If a block is not found, the function will return the list of matched blocks
/// and the remaining blocks will not be included.
pub fn match_token_blocks(&self, token_blocks: &[TokenBlock]) -> Result<Vec<ReservedBlock>> {
let mut inflight_blocks = Vec::new();
let map = self.blocks.read().unwrap();
for token_block in token_blocks {
let sequence_hash = token_block.sequence_hash();
if let Some(inner) = map.get(&sequence_hash) {
if let Some(inner) = inner.upgrade() {
inflight_blocks.push(ReservedBlock::new(inner.clone()));
} else {
break;
}
} else {
break;
}
}
Ok(inflight_blocks)
}
pub fn register(&mut self, block: UniqueBlock) -> Result<ReservedBlock> {
let sequence_hash = block.token_block.sequence_hash();
let shared = block.into_shared();
if shared.token_block.tokens().len() != self.block_size {
raise!("Block size mismatch");
}
// if the block already exists, we drop the block the user passed in and return the existing block
// this should return the passed in block to the free pool
let mut map = self.blocks.write().unwrap();
if let Some(existing_block) = map.get(&sequence_hash) {
// return an ReservedBlock with the existing block
// the passed in block will be dropped and returned to the pool
// this could happen if two sequences are building the same block at the same time,
// the first sequence to finish and register the block will insert it into the map
if let Some(inner) = existing_block.upgrade() {
return Ok(ReservedBlock::new(inner.clone()));
}
}
// Insert the new block and create an ReservedBlock from it
let inner = Arc::new(ReservedBlockInner {
block: shared,
map: self.blocks.clone(),
});
map.insert(sequence_hash, Arc::downgrade(&inner));
Ok(ReservedBlock::new(inner))
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::reuse::tests::{create_blocks, create_token_sequence};
use super::reuse::AvailableBlocks;
#[tokio::test]
async fn test_reserved_blocks() {
let available_blocks = AvailableBlocks::new().await;
let mut reserved_blocks = ReservedBlocks::new(2);
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
// This is creating new KvBlock; this is will be done when the block manager is initialized
// but since we are not using the block manager in this test, we need to create them manually
let blocks1 = create_blocks(seq1, 2);
let blocks2 = create_blocks(seq2, 2);
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
available_blocks.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
available_blocks.insert(block).await.unwrap();
}
available_blocks.fence().await.unwrap();
assert_eq!(available_blocks.total_blocks(), 4);
assert_eq!(available_blocks.available_blocks(), 4);
// Initialize of the KvBlocks is complete - there are 4 blocks with state in the available pool
// Mimic a request for 2 tokens and test the block matching sequence
// This pattern will be used in the KvBlockManager
let req1 = create_token_sequence(&[1, 2]);
let seq1 = req1.into_sequence(2);
let (blocks, tail_block) = seq1.into_parts();
assert_eq!(blocks.len(), 1);
assert_eq!(tail_block.tokens().len(), 0);
let matched = reserved_blocks.match_token_blocks(&blocks).unwrap();
assert_eq!(matched.len(), 0);
let matched = available_blocks.match_token_blocks(&blocks).await.unwrap();
assert_eq!(matched.len(), 1);
// possible update the api to take a vec of unique blocks and return a vec of reserved blocks
let reserved: Vec<ReservedBlock> = matched
.into_iter()
.map(|unique_block| reserved_blocks.register(unique_block).unwrap())
.collect();
assert_eq!(reserved.len(), 1);
assert_eq!(reserved[0].inflight_count(), 1);
assert_eq!(available_blocks.available_blocks(), 3);
// request 2
// reuse blocks
// match blocks to the reserved blocks get a new reserved block which should have a ref count of 2
let reserved2 = reserved_blocks.match_token_blocks(&blocks).unwrap();
assert_eq!(reserved2.len(), 1);
assert_eq!(reserved2[0].inflight_count(), 2);
assert_eq!(available_blocks.available_blocks(), 3);
drop(reserved2);
available_blocks.fence().await.unwrap();
assert_eq!(reserved[0].inflight_count(), 1);
assert_eq!(available_blocks.available_blocks(), 3);
drop(reserved);
available_blocks.fence().await.unwrap();
assert_eq!(available_blocks.available_blocks(), 4);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Block Available Pool
//!
//! The Available Pool manages KV blocks that are not actively in use but retain their previous state.
//!
//! ## Key Features:
//!
//! - **State Preservation**: Blocks in the pool maintain their previous state and can be reused.
//!
//! - **Priority-Based FIFO**: Blocks are returned in first-in, first-out order within their priority levels.
//! Lower priority values are processed first, allowing important blocks to be retained longer.
//!
//! - **State Matching**: Blocks can be matched against their previous state instead of being taken randomly,
//! enabling efficient reuse of blocks with specific sequence hashes.
//!
//! - **Priority Management**: Priorities can be applied to blocks based on their sequence hash,
//! requiring some external knowledge of the block's characteristics.
//!
//! - **State Management**: Blocks can have their states wiped clean/reset individually or in groups.
//! The entire pool can also be reset as needed.
//!
//! - **Synchronization**: Fence operations ensure all higher priority operations have completed
//! before proceeding. Note that this is not a true fence - higher priority operations issued
//! after the fence will still be processed before the fence completes.
use std::sync::atomic::Ordering;
use dynamo_runtime::utils::pool::ReturnHandle;
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use super::*;
pub struct AvailableBlocks {
match_tx: mpsc::UnboundedSender<MatchRequest>,
control_tx: mpsc::UnboundedSender<ControlRequest>,
fence_tx: mpsc::UnboundedSender<oneshot::Sender<()>>,
return_handle: Arc<ReturnHandleImpl>,
total_blocks: Arc<AtomicU64>,
available_blocks: Arc<AtomicU64>,
join_handle: JoinHandle<()>,
}
impl AvailableBlocks {
pub fn total_blocks(&self) -> u64 {
self.total_blocks.load(Ordering::SeqCst)
}
pub fn available_blocks(&self) -> u64 {
self.available_blocks.load(Ordering::SeqCst)
}
pub fn is_active(&self) -> bool {
!self.join_handle.is_finished()
}
pub async fn match_blocks(&self, hashes: Vec<SequenceHash>) -> Result<Vec<PoolItem<KvBlock>>> {
let (tx, rx) = oneshot::channel();
if self
.match_tx
.send(MatchRequest::MatchMultiple(MatchMultiple {
hashes,
return_handle: self.return_handle.clone(),
tx,
}))
.is_err()
{
raise!("failed to send match request; channel closed");
}
let matched_blocks = rx.await?;
Ok(matched_blocks)
}
pub async fn match_token_blocks(
&self,
token_blocks: &[TokenBlock],
) -> Result<Vec<PoolItem<KvBlock>>> {
let hashes: Vec<u64> = token_blocks.iter().map(|b| b.sequence_hash()).collect();
self.match_blocks(hashes).await
}
pub async fn take_blocks(&self, count: u32) -> Result<Vec<PoolItem<KvBlock>>> {
let (tx, rx) = oneshot::channel();
if self
.match_tx
.send(MatchRequest::Take(Take {
count,
return_handle: self.return_handle.clone(),
tx,
}))
.is_err()
{
raise!("failed to send take request; channel closed");
}
let matched_blocks = rx.await?;
Ok(matched_blocks)
}
pub async fn insert(&self, block: KvBlock) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self
.control_tx
.send(ControlRequest::Insert(InsertControl { block, tx }))
.is_err()
{
raise!("failed to send insert request; channel closed");
}
rx.await?;
Ok(())
}
pub async fn update_single(&self, update: UpdateBlock) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self
.control_tx
.send(ControlRequest::UpdateSingle(UpdateSingleControl {
update,
tx,
}))
.is_err()
{
raise!("failed to send update single request; channel closed");
}
rx.await?;
Ok(())
}
pub async fn update_multiple(&self, updates: Vec<UpdateBlock>) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self
.control_tx
.send(ControlRequest::UpdateMultiple(UpdateMultipleControl {
updates,
tx,
}))
.is_err()
{
raise!("failed to send update multiple request; channel closed");
}
rx.await?;
Ok(())
}
pub async fn reset(&self, sequence_hashes: Vec<SequenceHash>) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self
.control_tx
.send(ControlRequest::Reset(ResetControl {
sequence_hashes,
tx,
}))
.is_err()
{
raise!("failed to send reset request; channel closed");
}
rx.await?;
Ok(())
}
pub async fn reset_all(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self
.control_tx
.send(ControlRequest::ResetAll(ResetAllControl { tx }))
.is_err()
{
raise!("failed to send reset all request; channel closed");
}
rx.await?;
Ok(())
}
pub async fn fence(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();
if self.fence_tx.send(tx).is_err() {
raise!("failed to send fence request; channel closed");
}
rx.await?;
Ok(())
}
}
struct ReturnHandleImpl {
return_tx: mpsc::UnboundedSender<PoolValue<KvBlock>>,
}
impl ReturnHandle<KvBlock> for ReturnHandleImpl {
fn return_to_pool(&self, value: PoolValue<KvBlock>) {
if self.return_tx.send(value).is_err() {
log::trace!("Failed to return block to pool");
}
}
}
impl AvailableBlocks {
pub async fn new() -> Self {
let (match_tx, match_rx) = mpsc::unbounded_channel();
let (return_tx, return_rx) = mpsc::unbounded_channel();
let (control_tx, control_rx) = mpsc::unbounded_channel();
let (fence_tx, fence_rx) = mpsc::unbounded_channel();
let total_blocks = Arc::new(AtomicU64::new(0));
let available_blocks = Arc::new(AtomicU64::new(0));
let return_tx_clone = return_tx.clone();
let return_handle = Arc::new(ReturnHandleImpl {
return_tx: return_tx_clone,
});
let join_handle = tokio::spawn(progress_engine(
match_rx,
return_rx,
control_rx,
fence_rx,
total_blocks.clone(),
available_blocks.clone(),
));
Self {
match_tx,
control_tx,
fence_tx,
return_handle,
total_blocks,
available_blocks,
join_handle,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct PriorityKey {
priority: u32,
return_tick: u64,
sequence_hash: SequenceHash,
}
// customize ord and partial ord for to store first by priority (lowest to highest), then by return_tick (lowest to highest)
impl PartialOrd for PriorityKey {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityKey {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority
.cmp(&other.priority)
.then(self.return_tick.cmp(&other.return_tick))
}
}
impl From<&KvBlock> for PriorityKey {
fn from(block: &KvBlock) -> Self {
Self {
priority: block.priority,
return_tick: block.return_tick,
sequence_hash: block.token_block.sequence_hash(),
}
}
}
#[derive(Default)]
struct AvailableBlocksState {
// Direct lookup by sequence_hash
lookup_map: HashMap<SequenceHash, PoolValue<KvBlock>>,
// // Ordered by timestamp (oldest first)
priority_set: BTreeMap<PriorityKey, SequenceHash>,
// Fully Uninitialized
uninitialized_set: VecDeque<PoolValue<KvBlock>>,
// Return Tick
return_tick: u64,
// Total blocks
total_blocks: Arc<AtomicU64>,
// Available blocks
available_blocks: Arc<AtomicU64>,
}
impl AvailableBlocksState {
fn new(total_blocks: Arc<AtomicU64>, available_blocks: Arc<AtomicU64>) -> Self {
Self {
lookup_map: HashMap::new(),
priority_set: BTreeMap::new(),
uninitialized_set: VecDeque::new(),
return_tick: 0,
total_blocks,
available_blocks,
}
}
// Insert an item with a given key and sequence_hash
fn insert(&mut self, block: PoolValue<KvBlock>) {
let sequence_hash = block.token_block.sequence_hash();
log::debug!(sequence_hash, "inserting block into available blocks");
// If we already have an entry for this sequence hash, we need to move it to the uninitialized set
// the lookup map has only one entry per sequence hash
if self.lookup_map.contains_key(&sequence_hash) || sequence_hash == 0u64 {
log::debug!(sequence_hash, "inserted block to uninitialized set");
self.uninitialized_set.push_back(block);
return;
}
// Insert into timestamp set
let key = PriorityKey::from(&*block);
let check_multiple_entries = self.priority_set.insert(key, sequence_hash);
assert!(
check_multiple_entries.is_none(),
"fatal error: multiple entries for the same sequence hash in timestamp set"
);
// Add to the lookup map
let check_multiple_entries = self.lookup_map.insert(sequence_hash, block);
assert!(
check_multiple_entries.is_none(),
"fatal error: multiple entries for the same sequence hash in lookup map"
);
}
fn take_with_sequence_hash(
&mut self,
sequence_hash: SequenceHash,
) -> Option<PoolValue<KvBlock>> {
match self.lookup_map.remove(&sequence_hash) {
Some(block) => {
// Remove from timestamp set
self.priority_set.remove(&PriorityKey::from(&*block));
Some(block)
}
None => None,
}
}
fn match_hashes(
&mut self,
hashes: Vec<SequenceHash>,
return_handle: Arc<ReturnHandleImpl>,
) -> Vec<PoolItem<KvBlock>> {
let mut matched_blocks = Vec::with_capacity(hashes.len());
for hash in hashes {
if let Some(block) = self.take_with_sequence_hash(hash) {
matched_blocks.push(self.create_pool_item(block, return_handle.clone()));
} else {
break;
}
}
self.available_blocks
.fetch_sub(matched_blocks.len() as u64, Ordering::SeqCst);
matched_blocks
}
fn handle_match_single(&mut self, match_single: MatchSingle) {
let (hash, return_handle, rx) = match_single.dissolve();
let matched_blocks = self.match_hashes(vec![hash], return_handle);
let optional_single = matched_blocks.into_iter().next();
// Send the result back through the channel
if rx.send(optional_single).is_err() {
log::trace!("Failed to send matched block to requester");
}
}
fn handle_match_multiple(&mut self, match_multiple: MatchMultiple) {
let (hashes, return_handle, rx) = match_multiple.dissolve();
let matched_blocks = self.match_hashes(hashes, return_handle);
// Send the matched blocks back through the channel
if rx.send(matched_blocks).is_err() {
log::trace!("Failed to send matched blocks to requester");
}
}
fn take(&mut self) -> Option<PoolValue<KvBlock>> {
// First try uninitialized blocks - these are often part of sequences
// that have been arranged in the correct order
if let Some(block) = self.uninitialized_set.pop_front() {
return Some(block);
}
// if we have blocks in the priority set, pop the first (it's sorted by priority)
// a fatal error will occur if the block is not found in the lookup map
if let Some((_key, sequence_hash)) = self.priority_set.pop_first() {
let block = match self.lookup_map.remove(&sequence_hash) {
Some(block) => block,
None => {
panic!("block from priority set not found in lookup map");
}
};
return Some(block);
}
None
}
fn handle_take(&mut self, take: Take) {
let (count, return_handle, tx) = take.dissolve();
let mut taken_blocks = Vec::with_capacity(count as usize);
for _ in 0..count {
if let Some(block) = self.take() {
taken_blocks.push(self.create_pool_item(block, return_handle.clone()));
} else {
break;
}
}
self.available_blocks.fetch_sub(
taken_blocks.len() as u64,
std::sync::atomic::Ordering::SeqCst,
);
// Send the result back through the channel
if tx.send(taken_blocks).is_err() {
log::trace!("Failed to send matched blocks to requester");
}
}
fn handle_match_request(&mut self, match_request: MatchRequest) {
match match_request {
MatchRequest::MatchSingle(match_single) => self.handle_match_single(match_single),
MatchRequest::MatchMultiple(match_multiple) => {
self.handle_match_multiple(match_multiple)
}
MatchRequest::Take(take) => self.handle_take(take),
}
}
fn handle_control_request(&mut self, control_request: ControlRequest) {
match control_request {
ControlRequest::Insert(insert) => {
let (block, tx) = insert.dissolve();
self.handle_insert(block);
if tx.send(()).is_err() {
log::trace!("Failed to send insert ack; receiver dropped");
}
}
ControlRequest::UpdateSingle(update_single) => {
let (update, tx) = update_single.dissolve();
self.handle_update_single(update);
if tx.send(()).is_err() {
log::trace!("Failed to send update single ack; receiver dropped");
}
}
ControlRequest::UpdateMultiple(update_multiple) => {
let (updates, tx) = update_multiple.dissolve();
self.handle_update_multiple(updates);
if tx.send(()).is_err() {
log::trace!("Failed to send update multiple ack; receiver dropped");
}
}
ControlRequest::Reset(reset) => {
let (sequence_hashes, tx) = reset.dissolve();
self.handle_reset(sequence_hashes);
if tx.send(()).is_err() {
log::trace!("Failed to send reset ack; receiver dropped");
}
}
ControlRequest::ResetAll(reset_all) => {
let tx = reset_all.dissolve();
self.handle_reset_all();
if tx.send(()).is_err() {
log::trace!("Failed to send reset all ack; receiver dropped");
}
}
}
}
fn handle_insert(&mut self, block: KvBlock) {
self.available_blocks
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.total_blocks
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.return_tick += 1;
// update the return tick
let mut block = block;
block.return_tick = self.return_tick;
self.insert(PoolValue::Direct(block));
}
fn handle_return(&mut self, block: PoolValue<KvBlock>) {
self.available_blocks
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.return_tick += 1;
// update the return tick
let mut block = block;
block.return_tick = self.return_tick;
self.insert(block);
}
fn handle_update_single(&mut self, update: UpdateBlock) {
self.update_block(vec![update]);
}
fn handle_update_multiple(&mut self, updates: Vec<UpdateBlock>) {
self.update_block(updates);
}
fn update_block(&mut self, updates: Vec<UpdateBlock>) {
for update in updates {
if let Some(mut block) = self.take_with_sequence_hash(update.hash) {
if let Some(priority) = update.priority {
block.priority = priority;
}
// if let Some(deadline) = update.deadline {
// block.set_deadline(deadline);
// }
self.insert(block);
}
}
}
fn handle_reset(&mut self, sequence_hashes: Vec<SequenceHash>) {
for hash in sequence_hashes {
if let Some(mut block) = self.take_with_sequence_hash(hash) {
block.reset();
self.insert(block);
}
}
}
fn handle_reset_all(&mut self) {
// for all blocks in the priority set, reset them
while let Some((_key, sequence_hash)) = self.priority_set.pop_first() {
if let Some(mut block) = self.lookup_map.remove(&sequence_hash) {
block.reset();
self.insert(block);
} else {
panic!("block from priority set not found in lookup map");
}
}
}
}
#[async_trait]
impl PoolExt<KvBlock> for AvailableBlocksState {}
#[derive(Dissolve)]
pub struct MatchSingle {
hash: SequenceHash,
return_handle: Arc<ReturnHandleImpl>,
tx: oneshot::Sender<Option<UniqueBlock>>,
}
#[derive(Dissolve)]
pub struct MatchMultiple {
hashes: Vec<SequenceHash>,
return_handle: Arc<ReturnHandleImpl>,
tx: oneshot::Sender<Vec<UniqueBlock>>,
}
#[derive(Dissolve)]
pub struct Take {
count: u32,
return_handle: Arc<ReturnHandleImpl>,
tx: oneshot::Sender<Vec<UniqueBlock>>,
}
pub enum MatchRequest {
MatchSingle(MatchSingle),
MatchMultiple(MatchMultiple),
Take(Take),
}
pub struct UpdateBlock {
hash: SequenceHash,
priority: Option<u32>,
}
#[derive(Dissolve)]
pub struct InsertControl {
block: KvBlock,
tx: oneshot::Sender<()>,
}
#[derive(Dissolve)]
pub struct UpdateSingleControl {
update: UpdateBlock,
tx: oneshot::Sender<()>,
}
#[derive(Dissolve)]
pub struct UpdateMultipleControl {
updates: Vec<UpdateBlock>,
tx: oneshot::Sender<()>,
}
#[derive(Dissolve)]
pub struct ResetControl {
sequence_hashes: Vec<SequenceHash>,
tx: oneshot::Sender<()>,
}
#[derive(Dissolve)]
pub struct ResetAllControl {
tx: oneshot::Sender<()>,
}
pub enum ControlRequest {
Insert(InsertControl),
UpdateSingle(UpdateSingleControl),
UpdateMultiple(UpdateMultipleControl),
Reset(ResetControl),
ResetAll(ResetAllControl),
}
pub async fn progress_engine(
match_rx: mpsc::UnboundedReceiver<MatchRequest>,
return_rx: mpsc::UnboundedReceiver<PoolValue<KvBlock>>,
ctrl_rx: mpsc::UnboundedReceiver<ControlRequest>,
fence_rx: mpsc::UnboundedReceiver<oneshot::Sender<()>>,
total_blocks: Arc<AtomicU64>,
available_blocks: Arc<AtomicU64>,
) {
let mut match_rx = match_rx;
let mut return_rx = return_rx;
let mut ctrl_rx = ctrl_rx;
let mut fence_rx = fence_rx;
let mut state = AvailableBlocksState::new(total_blocks, available_blocks);
loop {
tokio::select! {
biased;
Some(match_req) = match_rx.recv(), if !match_rx.is_closed() => {
state.handle_match_request(match_req);
}
Some(block) = return_rx.recv(), if !return_rx.is_closed() => {
state.handle_return(block);
}
Some(req) = ctrl_rx.recv(), if !ctrl_rx.is_closed() => {
state.handle_control_request(req);
}
Some(tx) = fence_rx.recv() => {
if tx.send(()).is_err() {
log::trace!("Failed to send fence ack; receiver dropped");
}
}
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::tokens::Token;
use super::*;
#[test]
fn test_priority_key_ord() {
let mut map = BTreeMap::new();
let hash1 = SequenceHash::from(1u64);
let hash2 = SequenceHash::from(2u64);
let hash3 = SequenceHash::from(3u64);
map.insert(
PriorityKey {
priority: 0,
return_tick: 1,
sequence_hash: hash1,
},
"value1",
);
map.insert(
PriorityKey {
priority: 1,
return_tick: 0,
sequence_hash: hash2,
},
"value2",
);
map.insert(
PriorityKey {
priority: 0,
return_tick: 2,
sequence_hash: hash3,
},
"value3",
);
let keys: Vec<_> = map.keys().collect();
// Priority is the primary sort key (0 before 1)
assert_eq!(keys[0].priority, 0);
assert_eq!(keys[1].priority, 0);
assert_eq!(keys[2].priority, 1);
// For same priority, return_tick is the secondary sort key
assert_eq!(keys[0].return_tick, 1);
assert_eq!(keys[1].return_tick, 2);
// Test popping from the map to verify ordering
let (first_key, first_value) = map.pop_first().unwrap();
assert_eq!(first_key.priority, 0);
assert_eq!(first_key.return_tick, 1);
assert_eq!(first_key.sequence_hash, hash1);
assert_eq!(first_value, "value1");
let (second_key, second_value) = map.pop_first().unwrap();
assert_eq!(second_key.priority, 0);
assert_eq!(second_key.return_tick, 2);
assert_eq!(second_key.sequence_hash, hash3);
assert_eq!(second_value, "value3");
let (third_key, third_value) = map.pop_first().unwrap();
assert_eq!(third_key.priority, 1);
assert_eq!(third_key.return_tick, 0);
assert_eq!(third_key.sequence_hash, hash2);
assert_eq!(third_value, "value2");
// Map should now be empty
assert!(map.is_empty());
}
// Helper function to create a sequence of tokens
pub fn create_token_sequence(values: &[u32]) -> Tokens {
let tokens: Vec<Token> = values.iter().map(|&v| Token::from(v)).collect();
Tokens::from(tokens)
}
// Helper to create blocks from a sequence with given size
pub fn create_blocks(sequence: Tokens, block_size: usize) -> Vec<KvBlock> {
let (blocks, _) = sequence.into_sequence(block_size).into_parts();
blocks
.into_iter()
.map(|token_block| KvBlock {
token_block,
..Default::default()
})
.collect()
}
#[tokio::test]
async fn test_basic_sequence_matching() {
let pool = AvailableBlocks::new().await;
// Create a sequence of 4 tokens split into blocks of 2
let sequence = create_token_sequence(&[1, 2, 3, 4]);
let blocks = create_blocks(sequence, 2);
assert_eq!(blocks.len(), 2);
// Match the blocks in sequence
let hashes: Vec<_> = blocks
.iter()
.map(|b| b.token_block.sequence_hash())
.collect();
// Insert blocks into pool
for block in blocks {
pool.insert(block).await.unwrap();
}
pool.fence().await.unwrap();
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
// Match the blocks in sequence
let matched = pool.match_blocks(hashes.clone()).await.unwrap();
assert_eq!(matched.len(), 2);
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 0);
// Validate the blocks are in the correct order and match the sequence hashes
assert_eq!(matched[0].token_block.sequence_hash(), hashes[0]);
assert_eq!(matched[1].token_block.sequence_hash(), hashes[1]);
// Return blocks in reverse order (tail to root)
for block in matched.into_iter().rev() {
drop(block); // This will trigger return_to_pool
}
pool.fence().await.unwrap();
assert_eq!(pool.total_blocks(), 2);
assert_eq!(pool.available_blocks(), 2);
}
#[tokio::test]
async fn test_equal_priority_taking() {
let pool = AvailableBlocks::new().await;
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
let mut blocks1 = create_blocks(seq1, 2);
let mut blocks2 = create_blocks(seq2, 2);
for block in blocks1.iter_mut() {
block.priority = 1;
}
for block in blocks2.iter_mut() {
block.priority = 1;
}
// If priorities were equal, first in, first out would apply
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
pool.insert(block).await.unwrap();
}
pool.fence().await.unwrap();
let blocks = pool.take_blocks(4).await.unwrap();
assert_eq!(blocks.len(), 4);
// Validate the blocks are in the correct order
assert_eq!(blocks[0].token_block.tokens()[0], 7);
assert_eq!(blocks[1].token_block.tokens()[0], 5);
assert_eq!(blocks[2].token_block.tokens()[0], 3);
assert_eq!(blocks[3].token_block.tokens()[0], 1);
}
#[tokio::test]
async fn test_priority_taking() {
let pool = AvailableBlocks::new().await;
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
let mut blocks1 = create_blocks(seq1, 2);
let mut blocks2 = create_blocks(seq2, 2);
for block in blocks1.iter_mut() {
block.priority = 1;
}
for block in blocks2.iter_mut() {
block.priority = 2;
}
// If priorities were equal, first in, first out would apply
// but here we have a higher priority block first (which are taken last)
// returned first, but lower priority blocks inserted after
// we expect the lower priority blocks to be taken first
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
pool.insert(block).await.unwrap();
}
pool.fence().await.unwrap();
let blocks = pool.take_blocks(4).await.unwrap();
assert_eq!(blocks.len(), 4);
// Validate the blocks are in the correct order
assert_eq!(blocks[0].token_block.tokens()[0], 3);
assert_eq!(blocks[1].token_block.tokens()[0], 1);
assert_eq!(blocks[2].token_block.tokens()[0], 7);
assert_eq!(blocks[3].token_block.tokens()[0], 5);
}
#[tokio::test]
async fn test_priority_taking_after_update() {
let pool = AvailableBlocks::new().await;
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
let mut blocks1 = create_blocks(seq1, 2);
let mut blocks2 = create_blocks(seq2, 2);
for block in blocks1.iter_mut() {
block.priority = 1;
}
for block in blocks2.iter_mut() {
block.priority = 1;
}
// record hash of blocks 2
// insert blocks 2, then blocks 1
// update priority of blocks 2 to 2 using the update api
// pull 4 blocks and test order
let block_hashes = blocks2
.iter()
.map(|b| b.token_block.sequence_hash())
.collect::<Vec<_>>();
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
pool.insert(block).await.unwrap();
}
pool.fence().await.unwrap();
// Update priority of blocks 2 to 2
pool.update_multiple(
block_hashes
.into_iter()
.map(|h| UpdateBlock {
hash: h,
priority: Some(2),
})
.collect(),
)
.await
.unwrap();
pool.fence().await.unwrap();
let blocks = pool.take_blocks(4).await.unwrap();
assert_eq!(blocks.len(), 4);
// Validate the blocks are in the correct order
assert_eq!(blocks[0].token_block.tokens()[0], 3);
assert_eq!(blocks[1].token_block.tokens()[0], 1);
assert_eq!(blocks[2].token_block.tokens()[0], 7);
assert_eq!(blocks[3].token_block.tokens()[0], 5);
}
#[tokio::test]
async fn test_reset_all() {
let pool = AvailableBlocks::new().await;
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
let mut blocks1 = create_blocks(seq1, 2);
let mut blocks2 = create_blocks(seq2, 2);
for block in blocks1.iter_mut() {
block.priority = 1;
}
for block in blocks2.iter_mut() {
block.priority = 1;
}
// record hash of blocks 2
let block_hashes = blocks2
.iter()
.map(|b| b.token_block.sequence_hash())
.collect::<Vec<_>>();
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Reset All
pool.reset_all().await.unwrap();
pool.fence().await.unwrap();
// Try to match from block 2 hashes, expect no matches
let matched = pool.match_blocks(block_hashes).await.unwrap();
assert_eq!(matched.len(), 0);
}
#[tokio::test]
async fn test_reset_block2() {
let pool = AvailableBlocks::new().await;
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
let mut blocks1 = create_blocks(seq1, 2);
let mut blocks2 = create_blocks(seq2, 2);
for block in blocks1.iter_mut() {
block.priority = 1;
}
for block in blocks2.iter_mut() {
block.priority = 1;
}
// record hash of blocks 2
let block2_hashes = blocks2
.iter()
.map(|b| b.token_block.sequence_hash())
.collect::<Vec<_>>();
let block1_hashes = blocks1
.iter()
.map(|b| b.token_block.sequence_hash())
.collect::<Vec<_>>();
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
pool.insert(block).await.unwrap();
}
// Reset Block 2
pool.reset(block2_hashes.clone()).await.unwrap();
pool.fence().await.unwrap();
// Try to match from block 2 hashes, expect no matches
let matched = pool.match_blocks(block2_hashes).await.unwrap();
assert_eq!(matched.len(), 0);
let matched = pool.match_blocks(block1_hashes).await.unwrap();
assert_eq!(matched.len(), 2);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Storage object representing large single slabs of bytes.
//!
//! There are three types denoted by [StorageType]
//!
//! - [StorageType::Device]: A pointer to a device memory allocation
//! - [StorageType::Pinned]: A pointer to a pinned memory allocation from cudaMallocHost
//! - [StorageType::System]: A pointer to a system memory allocation from malloc/calloc or
//! other forms of heap allocation.
//!
//! Use [StorageType::System] Grace and other embedded platforms.
//!
//! Use [StorageType::Pinned] and [StorageType::Device] on traditional x86 platforms.
//!
//! WARNING: [Storage] and [OwnedStorage] are not Rust safe objects. For KV blocks, we use
//! [Storage]-like stabs to form [KvLayers][super::layer::KvLayer], both of which do not
//! conform to Rust's ownership or safety guarantees.
//!
//! As the underlying cuda kernels have ownership policies, they are not guarantees, nor are
//! they enforceable at this level by the Rust compiler.
//!
//! The first unit of ownership that will be Rust safe is the [KvBlock][super::KvBlock].
use bs62::num_traits;
use cudarc::driver::{CudaContext, CudaSlice, CudaStream, DevicePtr};
use dynamo_runtime::{error, raise, Result};
use ndarray::{ArrayViewMut, IxDyn};
use std::any::Any;
use std::ffi::c_void;
use std::ptr::NonNull;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StorageType {
Device(Arc<CudaContext>),
Pinned,
System, // todo: for grace
}
/// Represents the data type of tensor elements
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
F32,
F16,
BF16,
FP8,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
}
impl DType {
/// Get the size of the data type in bytes
pub fn size_in_bytes(&self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::BF16 => 2,
DType::FP8 => 1,
DType::U8 => 1,
DType::U16 => 2,
DType::U32 => 4,
DType::U64 => 8,
DType::I8 => 1,
DType::I16 => 2,
DType::I32 => 4,
DType::I64 => 8,
}
}
}
extern "C" {
fn cuda_malloc_host(ptr: *mut *mut c_void, size: usize) -> i32;
fn cuda_free_host(ptr: *mut c_void) -> i32;
fn cuda_memcpy_async(
dst: *mut c_void,
src: *const c_void,
count: usize,
stream: *mut c_void,
) -> i32;
fn cuda_memcpy_sync(dst: *mut c_void, src: *const c_void, count: usize) -> i32;
}
pub trait Storage: std::fmt::Debug {
/// Get memory pointer as a u64 for direct indexing
fn get_pointer(&self) -> u64;
/// Get the total storage size in bytes
fn storage_size(&self) -> usize;
/// Get the storage type of the tensor
fn storage_type(&self) -> StorageType;
/// Create a view of the tensor
fn view<const D: usize>(
&self,
shape: [usize; D],
dtype: DType,
) -> Result<TensorView<'_, Self, D>>
where
Self: Sized,
{
TensorView::new(self, shape, dtype.size_in_bytes())
}
}
#[derive(Clone)]
pub struct OwnedStorage {
storage: Arc<dyn Storage>,
}
impl OwnedStorage {
pub fn new(storage: Arc<dyn Storage>) -> Self {
Self { storage }
}
pub fn create(bytes: usize, storage_type: StorageType) -> Result<Self> {
match storage_type {
StorageType::Device(device) => Self::create_device_array(bytes, device),
StorageType::Pinned => Self::create_pinned_array(bytes),
StorageType::System => {
raise!("System memory not yet supported");
}
}
}
pub fn create_device_array(bytes: usize, device: Arc<CudaContext>) -> Result<Self> {
let device_storage = DeviceStorageOwned::new(bytes, device)?;
Ok(Self::new(Arc::new(device_storage)))
}
pub fn create_pinned_array(bytes: usize) -> Result<Self> {
let pinned_memory = CudaPinnedMemory::new(bytes)?;
Ok(Self::new(Arc::new(pinned_memory)))
}
pub fn byo_device_array(
device_ptr: u64,
bytes: usize,
device: Arc<CudaContext>,
owner: Arc<dyn Any + Send + Sync>,
) -> Result<Self> {
let device_storage = DeviceStorageFromAny::new(owner, device_ptr, bytes, device);
Ok(Self::new(Arc::new(device_storage)))
}
}
impl Storage for OwnedStorage {
fn get_pointer(&self) -> u64 {
self.storage.get_pointer()
}
fn storage_size(&self) -> usize {
self.storage.storage_size()
}
fn storage_type(&self) -> StorageType {
self.storage.storage_type()
}
}
impl std::fmt::Debug for OwnedStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedStorage")
.field("storage_type", &self.storage.storage_type())
.finish()
}
}
pub struct DeviceStorageOwned {
bytes: usize,
cuda_device: Arc<CudaContext>,
cuda_slice: Arc<CudaSlice<u8>>,
}
impl DeviceStorageOwned {
pub fn new(bytes: usize, device: Arc<CudaContext>) -> Result<Self> {
let cuda_slice = device.default_stream().alloc_zeros::<u8>(bytes)?;
device.default_stream().synchronize()?;
Ok(Self {
bytes,
cuda_device: device,
cuda_slice: Arc::new(cuda_slice),
})
}
pub fn device_ptr(&self) -> *const c_void {
let ptr = self.cuda_slice.device_ptr();
(*ptr) as *const c_void
}
pub fn context(&self) -> Arc<CudaContext> {
self.cuda_device.clone()
}
}
impl Storage for DeviceStorageOwned {
fn get_pointer(&self) -> u64 {
self.device_ptr() as u64
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Device(self.cuda_device.clone())
}
}
impl std::fmt::Debug for DeviceStorageOwned {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Storage")
.field("storage_type", &self.storage_type())
.field("storage_size", &self.storage_size())
.finish()
}
}
/// Direct wrapper around CUDA pinned memory
pub struct CudaPinnedMemory {
/// Raw pointer to the pinned memory
ptr: NonNull<c_void>,
/// Size in bytes
bytes: usize,
}
unsafe impl Send for CudaPinnedMemory {}
unsafe impl Sync for CudaPinnedMemory {}
impl CudaPinnedMemory {
/// Allocate new pinned memory using CUDA
pub fn new(bytes: usize) -> Result<Self> {
if bytes == 0 {
raise!("Bytes must be greater than 0");
}
let mut ptr: *mut c_void = std::ptr::null_mut();
let result = unsafe { cuda_malloc_host(&mut ptr, bytes) };
if result != 0 {
raise!("Failed to allocate pinned memory");
}
// Safety: We just checked that the allocation succeeded
let ptr =
NonNull::new(ptr).ok_or_else(|| anyhow::anyhow!("Null pointer after allocation"))?;
// Zero out the memory
unsafe {
std::ptr::write_bytes(ptr.as_ptr() as *mut u8, 0, bytes);
}
Ok(Self { ptr, bytes })
}
/// Get raw pointer
pub fn as_ptr(&self) -> *const c_void {
self.ptr.as_ptr()
}
/// Get mutable raw pointer
pub fn as_mut_ptr(&mut self) -> *mut c_void {
self.ptr.as_ptr()
}
/// Get size in bytes
pub fn size(&self) -> usize {
self.bytes
}
}
impl Drop for CudaPinnedMemory {
fn drop(&mut self) {
let result = unsafe { cuda_free_host(self.ptr.as_ptr()) };
if result != 0 {
eprintln!("Failed to free pinned memory");
}
}
}
// Implement Storage trait for the new CudaPinnedMemory
impl Storage for CudaPinnedMemory {
fn get_pointer(&self) -> u64 {
self.ptr.as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Pinned
}
}
impl std::fmt::Debug for CudaPinnedMemory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaPinnedMemory")
.field("ptr", &(self.ptr.as_ptr() as usize))
.field("bytes", &self.bytes)
.field("storage_type", &self.storage_type())
.finish()
}
}
/// A view into tensor data with statically-known dimension count
#[derive(Clone)]
pub struct TensorView<'a, T: Storage, const D: usize> {
/// The underlying tensor storage
storage: &'a T,
/// Shape of the view (dimensions)
shape: [usize; D],
/// Strides for each dimension (in elements, not bytes)
strides: [usize; D],
/// Strides for each dimension (in bytes)
byte_strides: [usize; D],
/// Offset from the start of the storage, in bytes
offset: usize,
/// Element size in bytes
element_size: usize,
/// Total elements in this view
total_elements: usize,
}
impl<'a, T: Storage, const D: usize> TensorView<'a, T, D> {
/// Create a new tensor view from storage and shape
pub fn new(storage: &'a T, shape: [usize; D], element_size: usize) -> Result<Self> {
// Calculate row-major strides (in elements)
let mut strides = [0; D];
let mut byte_strides = [0; D];
if D > 0 {
strides[D - 1] = 1; // Rightmost dimension is contiguous (elements)
byte_strides[D - 1] = element_size; // Rightmost dimension in bytes
// Calculate remaining strides
for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
byte_strides[i] = strides[i] * element_size;
}
}
// Calculate total elements
let total_elements = shape.iter().product();
// Validate that the view fits within the storage
if total_elements * element_size > storage.storage_size() {
return Err(error!(
"Shape {:?} requires {} bytes, but storage only has {} bytes",
shape,
total_elements * element_size,
storage.storage_size()
));
}
Ok(Self {
storage,
shape,
strides,
byte_strides,
offset: 0,
element_size,
total_elements,
})
}
/// Create a new tensor view with custom strides
pub fn with_strides(
storage: &'a T,
shape: [usize; D],
strides: [usize; D],
offset: usize,
element_size: usize,
) -> Result<Self, String> {
// Calculate byte strides using iterator
let byte_strides = strides.map(|stride| stride * element_size);
// Calculate total elements
let total_elements = shape.iter().product();
// Validate that the view fits within the storage
// Calculate the maximum offset this view will access
let max_offset = if D > 0 {
offset + Self::calculate_max_offset(&shape, &byte_strides)
} else {
offset
};
if max_offset > storage.storage_size() {
return Err(format!(
"View would access up to byte offset {}, but storage size is only {} bytes",
max_offset,
storage.storage_size()
));
}
Ok(Self {
storage,
shape,
strides,
byte_strides,
offset,
element_size,
total_elements,
})
}
/// Calculate the maximum byte offset that will be accessed by this view
fn calculate_max_offset(shape: &[usize; D], byte_strides: &[usize; D]) -> usize {
// Calculate the maximum offset by positioning at the furthest element
shape
.iter()
.zip(byte_strides.iter())
.map(|(&dim_size, &stride)| {
if dim_size > 0 {
(dim_size - 1) * stride
} else {
0
}
})
.sum()
}
/// Get the shape of the tensor view
pub fn shape(&self) -> &[usize; D] {
&self.shape
}
/// Get the strides of the tensor view (in elements)
pub fn strides(&self) -> &[usize; D] {
&self.strides
}
/// Get the byte strides of the tensor view
pub fn byte_strides(&self) -> &[usize; D] {
&self.byte_strides
}
/// Get the element size in bytes
pub fn element_size(&self) -> usize {
self.element_size
}
/// Validate indices against tensor shape
fn validate_indices(&self, indices: &[usize; D]) -> Result<(), String> {
for (dim, (&idx, &dim_size)) in indices.iter().zip(self.shape.iter()).enumerate() {
if idx >= dim_size {
return Err(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, dim, dim_size
));
}
}
Ok(())
}
/// Calculate flat index from multi-dimensional indices (in elements)
pub fn flat_index(&self, indices: &[usize; D]) -> Result<usize, String> {
self.validate_indices(indices)?;
// Calculate flat index using zip for better performance
let flat_idx = indices
.iter()
.zip(self.strides.iter())
.fold(0, |acc, (&idx, &stride)| acc + idx * stride);
Ok(flat_idx)
}
/// Calculate byte offset for indices
pub fn byte_offset(&self, indices: &[usize; D]) -> Result<usize> {
self.validate_indices(indices)
.map_err(|e| error!("{}", e))?;
// Calculate byte offset directly using byte_strides
let offset = indices
.iter()
.zip(self.byte_strides.iter())
.fold(self.offset, |acc, (&idx, &stride)| acc + idx * stride);
Ok(offset)
}
/// Get the absolute memory address for indices
pub fn address(&self, indices: &[usize; D]) -> Result<u64> {
let byte_offset = self.byte_offset(indices)?;
Ok(self.storage.get_pointer() + byte_offset as u64)
}
/// Check if indices are in bounds without calculating offset
pub fn in_bounds(&self, indices: &[usize; D]) -> bool {
indices
.iter()
.zip(self.shape.iter())
.all(|(&idx, &dim_size)| idx < dim_size)
}
/// Get the element value at the specified indices (for host-accessible tensors)
pub fn get_element<E: bytemuck::Pod + Copy>(&self, indices: &[usize; D]) -> Result<E> {
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot directly access elements from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
let offset = self.byte_offset(indices)?;
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(offset) as *const E;
// Safety: We've validated the type size and the indices are in bounds
let value = unsafe { *ptr };
Ok(value)
}
/// Set the element value at the specified indices (for host-accessible tensors)
pub fn set_element<E: bytemuck::Pod + Copy>(
&mut self,
indices: &[usize; D],
value: E,
) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot directly modify device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
let offset = self.byte_offset(indices)?;
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(offset) as *mut E;
// Safety: We've validated the type size and the indices are in bounds
unsafe { *ptr = value };
Ok(())
}
/// Fill the tensor with a single value (for host-accessible tensors)
pub fn fill<E: bytemuck::Pod + Copy>(&mut self, value: E) -> Result<()> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot directly modify device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot fill non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(self.offset) as *mut E;
let len = self.total_elements;
// Safety: We've validated the type size and ensured contiguity
unsafe {
let slice = std::slice::from_raw_parts_mut(ptr, len);
slice.fill(value);
}
Ok(())
}
/// Check if the tensor has a standard row-major contiguous layout
pub fn is_contiguous(&self) -> bool {
if D == 0 {
return true;
}
let mut expected_stride = 1;
let mut expected_byte_stride = self.element_size;
for i in (0..D).rev() {
if self.strides[i] != expected_stride || self.byte_strides[i] != expected_byte_stride {
return false;
}
expected_stride *= self.shape[i];
expected_byte_stride *= self.shape[i];
}
true
}
/// Get the total number of elements in the view
pub fn num_elements(&self) -> usize {
self.total_elements
}
/// Get the pointer to the data
pub fn data(&self) -> u64 {
self.storage.get_pointer()
}
/// Get the total size in bytes
pub fn size_in_bytes(&self) -> usize {
self.total_elements * self.element_size
}
pub fn copy_to_view_blocking<S: Storage>(
&self,
dst_view: &mut TensorView<'_, S, D>,
) -> Result<()> {
// validate same shape and strides
if self.shape != dst_view.shape || self.strides != dst_view.strides {
raise!(
"Shape or strides mismatch: {:?} vs {:?}",
self.shape,
dst_view.shape
);
}
if !self.is_contiguous() {
raise!("Source is not contiguous");
}
if !dst_view.is_contiguous() {
raise!("Destination is not contiguous");
}
assert_eq!(self.size_in_bytes(), dst_view.size_in_bytes());
tracing::debug!("Copying from {:?} to {:?}", self, dst_view);
let rc = unsafe {
cuda_memcpy_sync(
dst_view.data() as *mut c_void,
self.data() as *const c_void,
self.size_in_bytes(),
)
};
if rc != 0 {
raise!("cudaMemcpyAsync failed");
}
Ok(())
}
/// Create a sliced view of this tensor along a dimension
pub fn slice(&self, dim: usize, start: usize, end: Option<usize>) -> Result<Self, String> {
if dim >= D {
return Err(format!(
"Dimension {} out of bounds for tensor with {} dimensions",
dim, D
));
}
let end_idx = end.unwrap_or(self.shape[dim]);
if end_idx > self.shape[dim] {
return Err(format!(
"End index {} out of bounds for dimension {} with size {}",
end_idx, dim, self.shape[dim]
));
}
if start >= end_idx {
return Err(format!(
"Invalid slice range: start={}, end={}",
start, end_idx
));
}
// Create a new shape array with the sliced dimension
let mut new_shape = self.shape;
new_shape[dim] = end_idx - start;
// Calculate the offset for the start of the slice (in bytes)
let new_offset = self.offset + start * self.byte_strides[dim];
// Create a new view with the same strides but updated shape and offset
Ok(Self {
storage: self.storage,
shape: new_shape,
strides: self.strides,
byte_strides: self.byte_strides,
offset: new_offset,
element_size: self.element_size,
total_elements: new_shape.iter().product(),
})
}
pub fn as_ndarray_view<DT>(&self) -> Result<ndarray::ArrayView<'_, DT, IxDyn>>
// where
// DT: bytemuck::Pod,
{
match self.storage.storage_type() {
StorageType::Device(_) => raise!("Cannot convert device tensor to ndarray"),
StorageType::System | StorageType::Pinned => {}
};
self.as_unsafe_ndarray_view::<DT>()
}
pub(crate) fn as_unsafe_ndarray_view<DT>(&self) -> Result<ndarray::ArrayView<'_, DT, IxDyn>>
// where
// DT: bytemuck::Pod,
{
// validate DT matches bytes per element
if std::mem::size_of::<DT>() != self.element_size {
return Err(anyhow::anyhow!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<DT>(),
self.element_size
));
}
if !self.is_contiguous() {
raise!("Cannot convert non-contiguous tensor to ndarray");
}
// create a slice from the raw pointer
let ptr = self.storage.get_pointer() as *const DT;
let size = self.shape.iter().product::<usize>();
// Create a slice from the raw pointer
let slice = unsafe { std::slice::from_raw_parts::<DT>(ptr, size) };
// Create an ndarray view from the slice
// Convert our shape array to ndarray's Dim type
let dim = ndarray::IxDyn(&self.shape);
let array = ndarray::ArrayView::from_shape(dim, slice)?;
Ok(array)
}
/// Convert to a mutable ndarray view
pub fn as_ndarray_view_mut<DT>(&mut self) -> Result<ArrayViewMut<'_, DT, IxDyn>>
where
DT: bytemuck::Pod,
{
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(anyhow::anyhow!("Cannot convert device tensor to ndarray"))
}
StorageType::System | StorageType::Pinned => {}
};
// validate DT matches bytes per element
if std::mem::size_of::<DT>() != self.element_size {
return Err(anyhow::anyhow!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<DT>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(anyhow::anyhow!(
"Cannot convert non-contiguous tensor to ndarray"
));
}
// Get the pointer to the data plus offset
let ptr =
(self.storage.get_pointer() as *mut DT).wrapping_add(self.offset / self.element_size);
let size = self.shape.iter().product::<usize>();
// Create a mutable slice from the raw pointer
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, size) };
// Create an ndarray view from the slice - use the same pattern as the immutable version
let dim = ndarray::IxDyn(&self.shape);
let array = ndarray::ArrayViewMut::from_shape(dim, slice)?;
Ok(array)
}
/// Returns the storage type of the underlying tensor
pub fn storage_type(&self) -> StorageType {
self.storage.storage_type()
}
/// Returns an iterator over all valid indices for this tensor
/// This is useful for iterating through all elements in the tensor
pub fn indices_iter(&self) -> impl Iterator<Item = [usize; D]> + '_ {
let shape = self.shape;
let total = self.total_elements;
(0..total).map(move |idx| tensor_indexing::unflatten_index(idx, &shape))
}
/// Maps a function over all elements in the tensor (for host-accessible tensors)
/// Returns a new Vec containing the results
pub fn map_elements<E, R, F>(&self, f: F) -> Result<Vec<R>>
where
E: bytemuck::Pod + Copy,
F: Fn(E) -> R,
{
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot directly access elements from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot map over elements of non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const E;
let len = self.total_elements;
// Safety: We've validated the type size and ensured contiguity
let result = unsafe {
let slice = std::slice::from_raw_parts(ptr, len);
slice.iter().map(|&e| f(e)).collect()
};
Ok(result)
}
/// Gets a slice of the underlying data if it's contiguous and on the host
pub fn as_slice<E: bytemuck::Pod>(&self) -> Result<&[E]> {
match self.storage.storage_type() {
StorageType::Device(_) => return Err(error!("Cannot get slice from device tensor")),
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!("Cannot get slice from non-contiguous tensor"));
}
let ptr = (self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const E;
let len = self.total_elements;
// Safety: We've validated the type size, alignment, and ensured contiguity
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
Ok(slice)
}
/// Gets a mutable slice of the underlying data if it's contiguous and on the host
pub fn as_slice_mut<E: bytemuck::Pod>(&mut self) -> Result<&mut [E]> {
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Cannot get mutable slice from device tensor"))
}
StorageType::System | StorageType::Pinned => {}
};
if std::mem::size_of::<E>() != self.element_size {
return Err(error!(
"Type size mismatch: {} vs {}",
std::mem::size_of::<E>(),
self.element_size
));
}
if !self.is_contiguous() {
return Err(error!(
"Cannot get mutable slice from non-contiguous tensor"
));
}
let ptr = (self.storage.get_pointer() as *mut u8).wrapping_add(self.offset) as *mut E;
let len = self.total_elements;
// Safety: We've validated the type size, alignment, and ensured contiguity
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) };
Ok(slice)
}
/// Copy data from host tensor (self) to device tensor (device_view)
///
/// This is a convenience method for copying data from a host tensor to a device tensor.
/// Both tensors must have the same shape, element size, and total number of elements.
pub fn h2d<S: Storage>(
&self,
device_view: &mut TensorView<'_, S, D>,
stream: &CudaStream,
) -> Result<()> {
// Ensure self is a host tensor
match self.storage.storage_type() {
StorageType::Device(_) => {
return Err(error!("Source must be a host tensor (System or Pinned)"))
}
StorageType::System | StorageType::Pinned => {}
};
// Ensure device_view is a device tensor
match device_view.storage_type() {
StorageType::Device(_) => {}
_ => return Err(error!("Destination must be a device tensor")),
};
// Validate shape and element size
if self.shape != device_view.shape {
return Err(error!(
"Shape mismatch: {:?} vs {:?}",
self.shape, device_view.shape
));
}
if self.element_size != device_view.element_size {
return Err(error!(
"Element size mismatch: {} vs {}",
self.element_size, device_view.element_size
));
}
// Ensure contiguity for both tensors
if !self.is_contiguous() {
return Err(error!("Source tensor must be contiguous"));
}
if !device_view.is_contiguous() {
return Err(error!("Destination tensor must be contiguous"));
}
// Get pointers with proper offsets
let src_ptr =
(self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const c_void;
let dst_ptr = (device_view.storage.get_pointer() as *mut u8)
.wrapping_add(device_view.offset) as *mut c_void;
let size_in_bytes = self.size_in_bytes();
let stream_id = stream.cu_stream();
// Perform the upload operation
let rc =
unsafe { cuda_memcpy_async(dst_ptr, src_ptr, size_in_bytes, stream_id as *mut c_void) };
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during host-to-device transfer"
));
}
Ok(())
}
/// Copy data from device tensor (self) to host tensor (host_view)
///
/// This is a convenience method for copying data from a device tensor to a host tensor.
/// Both tensors must have the same shape, element size, and total number of elements.
pub fn d2h<S: Storage>(
&self,
host_view: &mut TensorView<'_, S, D>,
stream: &CudaStream,
) -> Result<()> {
// Ensure self is a device tensor
match self.storage.storage_type() {
StorageType::Device(_) => {}
_ => return Err(error!("Source must be a device tensor")),
};
// Ensure host_view is a host tensor
match host_view.storage_type() {
StorageType::Device(_) => {
return Err(error!(
"Destination must be a host tensor (System or Pinned)"
))
}
StorageType::System | StorageType::Pinned => {}
};
// Validate shape and element size
if self.shape != host_view.shape {
return Err(error!(
"Shape mismatch: {:?} vs {:?}",
self.shape, host_view.shape
));
}
if self.element_size != host_view.element_size {
return Err(error!(
"Element size mismatch: {} vs {}",
self.element_size, host_view.element_size
));
}
// Ensure contiguity for both tensors
if !self.is_contiguous() {
return Err(error!("Source tensor must be contiguous"));
}
if !host_view.is_contiguous() {
return Err(error!("Destination tensor must be contiguous"));
}
// Get pointers with proper offsets
let src_ptr =
(self.storage.get_pointer() as *const u8).wrapping_add(self.offset) as *const c_void;
let dst_ptr = (host_view.storage.get_pointer() as *mut u8).wrapping_add(host_view.offset)
as *mut c_void;
let size_in_bytes = self.size_in_bytes();
let stream_id = stream.cu_stream();
// Perform the download operation
let rc =
unsafe { cuda_memcpy_async(dst_ptr, src_ptr, size_in_bytes, stream_id as *mut c_void) };
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during device-to-host transfer"
));
}
Ok(())
}
/// Convert the tensor view to a new owned ndarray tensor in host memory
/// This is not a performant operation, and should only be used for testing
pub fn to_owned<DT: std::fmt::Debug + Clone + num_traits::Zero>(
&self,
) -> Result<ndarray::Array<DT, IxDyn>> {
match self.storage.storage_type() {
StorageType::System | StorageType::Pinned => {
let nd = self.as_ndarray_view::<DT>()?;
Ok(nd.to_owned())
}
StorageType::Device(_device) => {
// create an ndarray with the same shape and element size
let shape = self.shape.to_vec();
// Create an ndarray with the correct shape
let dim = ndarray::IxDyn(&shape);
// Create an uninitialized array with the correct shape
let mut nd = ndarray::Array::<DT, _>::zeros(dim);
println!("Copying from device to host");
println!("Before copy Values: {:?}", nd);
let rc = unsafe {
cuda_memcpy_sync(
nd.as_mut_ptr() as *mut c_void,
self.storage.get_pointer() as *const c_void,
self.size_in_bytes(),
)
};
if rc != 0 {
return Err(error!(
"cudaMemcpyAsync failed during device-to-host transfer"
));
}
println!("After copy Values: {:?}", nd);
Ok(nd)
}
}
}
}
impl<T: Storage, const D: usize> std::fmt::Debug for TensorView<'_, T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorView")
.field("shape", &self.shape)
.field("strides", &self.strides)
.field("byte_strides", &self.byte_strides)
.field("offset", &self.offset)
.field("element_size", &self.element_size)
.field("total_elements", &self.total_elements)
.field("storage_type", &self.storage.storage_type())
.finish()
}
}
// Indexing helpers with updated byte stride handling
pub mod tensor_indexing {
/// Converts a flat index to multidimensional indices for a given shape
pub fn unflatten_index<const D: usize>(flat_idx: usize, shape: &[usize; D]) -> [usize; D] {
let mut indices = [0; D];
let mut remaining = flat_idx;
// Calculate strides for the shape
let mut strides = [0; D];
if D > 0 {
strides[D - 1] = 1;
for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
// Calculate indices using strides
for (i, &stride) in strides.iter().enumerate() {
indices[i] = remaining / stride;
remaining %= stride;
}
indices
}
/// Calculates row-major strides for a given shape (element strides, not byte strides)
pub fn calculate_strides<const D: usize>(shape: &[usize; D]) -> [usize; D] {
let mut strides = [0; D];
if D > 0 {
strides[D - 1] = 1; // Rightmost dimension is contiguous
for i in (0..D - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
strides
}
/// Calculates row-major byte strides for a given shape and element size
pub fn calculate_byte_strides<const D: usize>(
shape: &[usize; D],
element_size: usize,
) -> [usize; D] {
let mut byte_strides = [0; D];
if D > 0 {
byte_strides[D - 1] = element_size; // Rightmost dimension is contiguous
for i in (0..D - 1).rev() {
byte_strides[i] = byte_strides[i + 1] * shape[i + 1];
}
}
byte_strides
}
}
/// Storage that wraps external device memory with metadata provided externally
/// This is unsafe as it trusts that the provided device pointer and sizes are valid
#[derive(Debug)]
pub struct DeviceStorageFromAny {
/// The original object that owns the memory (e.g., a PyObject)
source: Arc<dyn Any + Send + Sync>,
/// Device pointer to the data
device_ptr: u64,
/// Size of each element in bytes
bytes: usize,
/// CUDA device ordinal
device: Arc<CudaContext>,
}
impl DeviceStorageFromAny {
/// Create a new DeviceStorageFromAny wrapper
///
/// # Safety
///
/// This is unsafe because it trusts that:
/// 1. The device_ptr is a valid CUDA device pointer
/// 2. The device_ptr points to at least elements * bytes_per_element bytes of valid memory
/// 3. The memory remains valid for the lifetime of this object
/// 4. The device_id corresponds to the device where the memory is allocated
pub fn new(
source: Arc<dyn Any + Send + Sync>,
device_ptr: u64,
bytes: usize,
device: Arc<CudaContext>,
) -> Self {
Self {
source,
device_ptr,
bytes,
device,
}
}
/// Get the original source object as Any
pub fn source(&self) -> &Arc<dyn Any + Send + Sync> {
&self.source
}
/// Try to downcast the source to a specific type
pub fn downcast_source<T: 'static + Send + Sync>(&self) -> Option<&T> {
self.source.downcast_ref::<T>()
}
}
impl Storage for DeviceStorageFromAny {
fn get_pointer(&self) -> u64 {
self.device_ptr
}
fn storage_size(&self) -> usize {
self.bytes
}
fn storage_type(&self) -> StorageType {
StorageType::Device(self.device.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
// Mock implementation of Storage for testing
#[derive(Debug)]
struct MockTensor {
data_ptr: u64,
storage_size_bytes: usize,
}
impl Storage for MockTensor {
fn get_pointer(&self) -> u64 {
self.data_ptr
}
fn storage_size(&self) -> usize {
self.storage_size_bytes
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
#[test]
fn test_tensor_view_creation() {
// Create a mock tensor with sufficient storage
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // 24 elements * 4 bytes
};
// Create a 3D tensor view with F32 elements
let shape = [2, 3, 4];
let element_size = 4; // F32 size
let view = TensorView::<_, 3>::new(&mock_tensor, shape, element_size).unwrap();
// Verify shape and strides
assert_eq!(view.shape(), &[2, 3, 4]);
assert_eq!(view.strides(), &[12, 4, 1]);
assert_eq!(view.byte_strides(), &[48, 16, 4]);
assert_eq!(view.num_elements(), 24);
assert_eq!(view.size_in_bytes(), 96);
assert!(view.is_contiguous());
}
#[test]
fn test_tensor_view_indexing() {
// Error shows: "Shape [2, 3, 4] requires 96 bytes, but storage only has 24 bytes"
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // Increase from 24 to 96 bytes
};
// Create a 3D tensor view
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
// Rest of test unchanged
// Test flat index calculations
assert_eq!(view.flat_index(&[0, 0, 0]).unwrap(), 0);
assert_eq!(view.flat_index(&[0, 0, 1]).unwrap(), 1);
assert_eq!(view.flat_index(&[0, 1, 0]).unwrap(), 4);
assert_eq!(view.flat_index(&[1, 0, 0]).unwrap(), 12);
assert_eq!(view.flat_index(&[1, 2, 3]).unwrap(), 23);
// Test byte offset calculations
assert_eq!(view.byte_offset(&[0, 0, 0]).unwrap(), 0);
assert_eq!(view.byte_offset(&[0, 0, 1]).unwrap(), 4);
assert_eq!(view.byte_offset(&[0, 1, 0]).unwrap(), 16);
assert_eq!(view.byte_offset(&[1, 0, 0]).unwrap(), 48);
// Test absolute address calculations
assert_eq!(view.address(&[0, 0, 0]).unwrap(), 0x1000);
assert_eq!(view.address(&[0, 0, 1]).unwrap(), 0x1004);
assert_eq!(view.address(&[1, 2, 3]).unwrap(), 0x1000 + 23 * 4);
}
#[test]
fn test_tensor_view_slicing() {
// Error shows: "Shape [2, 3, 4] requires 96 bytes, but storage only has 24 bytes"
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // Increase from 24 to 96 bytes
};
// Create a 3D tensor view
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
// Rest of test unchanged
// Create a slice along dimension 1 (the middle dimension)
let sliced = view.slice(1, 1, Some(3)).unwrap();
// Verify the slice properties
assert_eq!(sliced.shape(), &[2, 2, 4]); // Dimension 1 reduced from 3 to 2
assert_eq!(sliced.strides(), &[12, 4, 1]); // Strides remain the same
assert_eq!(sliced.byte_strides(), &[48, 16, 4]); // Byte strides remain the same
assert_eq!(sliced.offset, 16); // Offset is now 4 elements (16 bytes)
// Test addressing in the slice
assert_eq!(sliced.address(&[0, 0, 0]).unwrap(), 0x1000 + 16);
assert_eq!(
sliced.address(&[1, 1, 3]).unwrap(),
0x1000 + 16 + 48 + 16 + 12
);
}
#[test]
fn test_tensor_views_with_custom_strides() {
// Error shows: "Shape [2, 3, 4] requires 96 bytes, but storage only has 24 bytes"
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // Increase from 24 to 96 bytes
};
// Total storage: 24 elements * 4 bytes = 96 bytes
let shape = [2, 3, 4];
// CASE 1: Standard contiguous view
let contiguous_view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
assert!(contiguous_view.is_contiguous());
assert_eq!(contiguous_view.strides(), &[12, 4, 1]);
assert_eq!(contiguous_view.byte_strides(), &[48, 16, 4]);
// CASE 2: Non-contiguous but within bounds
let smaller_shape = [2, 2, 4]; // 16 elements instead of 24
// These are the contiguous strides for shape [2, 3, 4] but NOT for [2, 2, 4]
// For shape [2, 2, 4], contiguous strides would be [8, 4, 1]
let non_contiguous_strides = [12, 4, 1];
let non_contiguous = TensorView::<_, 3>::with_strides(
&mock_tensor,
smaller_shape,
non_contiguous_strides,
0,
4,
)
.unwrap();
// It should NOT be contiguous since the strides don't match the shape
assert!(!non_contiguous.is_contiguous());
assert_eq!(non_contiguous.strides(), &[12, 4, 1]);
assert_eq!(non_contiguous.byte_strides(), &[48, 16, 4]);
// Test accessing the last element to confirm it's within bounds
let last_index = [1, 1, 3];
let byte_offset = non_contiguous.byte_offset(&last_index).unwrap();
assert_eq!(byte_offset, (12 + 4 + 3) * 4);
assert!(
byte_offset < mock_tensor.storage_size(),
"Byte offset {} should be less than storage size {}",
byte_offset,
mock_tensor.storage_size()
);
// CASE 3: Non-contiguous that exceeds bounds
// Using strides that will exceed the tensor's storage
// 1*16 + 2*4 + 3*1 = 16 + 8 + 3 = 27 elements, which is beyond our 24 elements
let invalid_custom_strides = [16, 4, 1];
let result =
TensorView::<_, 3>::with_strides(&mock_tensor, shape, invalid_custom_strides, 0, 4);
// Verify we get the expected error
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("would access up to byte offset 108"),
"Expected error about exceeding storage, got: {}",
error_msg
);
}
#[test]
fn test_tensor_view_with_offset() {
// Error shows: "View would access up to byte offset 108, but storage size is only 40 bytes"
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 120, // Increase from 40 to 120 bytes
};
// Shape is smaller than the full tensor to allow for offset
let shape = [2, 3, 4]; // 24 elements total
// Create a view with an offset of 4 elements (16 bytes)
let offset_view =
TensorView::<_, 3>::with_strides(&mock_tensor, shape, [12, 4, 1], 16, 4).unwrap();
// The view should still be contiguous
assert!(offset_view.is_contiguous());
// Check offset is preserved
assert_eq!(offset_view.offset, 16);
// Test accessing the first element
let first_byte_offset = offset_view.byte_offset(&[0, 0, 0]).unwrap();
assert_eq!(first_byte_offset, 16); // Should be at the offset
// Test accessing the last element
let last_byte_offset = offset_view.byte_offset(&[1, 2, 3]).unwrap();
assert_eq!(last_byte_offset, 16 + (12 + 2 * 4 + 3) * 4);
// Creating a view with an offset that would exceed the tensor size should fail
let result = TensorView::<_, 3>::with_strides(
&mock_tensor,
shape,
[12, 4, 1],
80, // 40 elements - 24 + a bit more
4,
);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("would access up to byte offset"),
"Expected error about exceeding storage, got: {}",
error_msg
);
}
#[test]
fn test_in_bounds_method() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // 24 elements * 4 bytes
};
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
// Test valid indices
assert!(view.in_bounds(&[0, 0, 0]));
assert!(view.in_bounds(&[1, 2, 3]));
// Test out-of-bounds indices
assert!(!view.in_bounds(&[2, 0, 0])); // First dimension too large
assert!(!view.in_bounds(&[0, 3, 0])); // Second dimension too large
assert!(!view.in_bounds(&[0, 0, 4])); // Third dimension too large
assert!(!view.in_bounds(&[2, 3, 4])); // All dimensions too large
}
#[test]
fn test_validate_indices() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 96, // 24 elements * 4 bytes
};
let shape = [2, 3, 4];
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
// Test valid indices
assert!(view.validate_indices(&[0, 0, 0]).is_ok());
assert!(view.validate_indices(&[1, 2, 3]).is_ok());
// Test out-of-bounds indices
assert!(view.validate_indices(&[2, 0, 0]).is_err());
assert!(view.validate_indices(&[0, 3, 0]).is_err());
assert!(view.validate_indices(&[0, 0, 4]).is_err());
}
#[test]
fn test_indices_iter() {
let mock_tensor = MockTensor {
data_ptr: 0x1000,
storage_size_bytes: 24, // 6 elements * 4 bytes
};
// Create a 2x3 tensor
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&mock_tensor, shape, 4).unwrap();
// Collect all indices from the iterator
let indices: Vec<[usize; 2]> = view.indices_iter().collect();
// Expected indices in row-major order
let expected_indices = vec![[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]];
assert_eq!(indices, expected_indices);
}
/// Real memory test for get_element and set_element
#[test]
fn test_get_set_element() {
use std::sync::{Arc, Mutex};
// Create a real memory tensor
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
// Create a 2x3 tensor with f32 elements
let real_tensor = RealDataMock::new(24); // 6 elements * 4 bytes
let shape = [2, 3];
let mut view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
// Set some values using set_element
view.set_element::<f32>(&[0, 0], 1.0).unwrap();
view.set_element::<f32>(&[0, 1], 2.0).unwrap();
view.set_element::<f32>(&[1, 2], 6.0).unwrap();
// Read them back with get_element
assert_eq!(view.get_element::<f32>(&[0, 0]).unwrap(), 1.0);
assert_eq!(view.get_element::<f32>(&[0, 1]).unwrap(), 2.0);
assert_eq!(view.get_element::<f32>(&[1, 2]).unwrap(), 6.0);
// Default values should be 0.0
assert_eq!(view.get_element::<f32>(&[0, 2]).unwrap(), 0.0);
assert_eq!(view.get_element::<f32>(&[1, 0]).unwrap(), 0.0);
assert_eq!(view.get_element::<f32>(&[1, 1]).unwrap(), 0.0);
}
#[test]
fn test_fill_method() {
use std::sync::{Arc, Mutex};
// Create a real memory tensor
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
// Create a 2x3 tensor with f32 elements
let real_tensor = RealDataMock::new(24); // 6 elements * 4 bytes
let shape = [2, 3];
let mut view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
// Fill with value 42.5
view.fill::<f32>(42.5).unwrap();
// Check all elements
for i in 0..2 {
for j in 0..3 {
assert_eq!(view.get_element::<f32>(&[i, j]).unwrap(), 42.5);
}
}
}
#[test]
fn test_map_elements() {
use std::sync::{Arc, Mutex};
// Create a real memory tensor
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
fn set_f32_values(&self, values: &[f32]) {
let mut data = self.data.lock().unwrap();
for (i, val) in values.iter().enumerate() {
let bytes = val.to_ne_bytes();
data[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
// Create a 2x3 tensor with f32 elements
let real_tensor = RealDataMock::new(24); // 6 elements * 4 bytes
// Set up some initial values
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
real_tensor.set_f32_values(&values);
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
// Apply a function to map each element
let doubled: Vec<f32> = view.map_elements::<f32, f32, _>(|x| x * 2.0).unwrap();
// Check results
let expected = [2.0, 4.0, 6.0, 8.0, 10.0, 12.0];
assert_eq!(doubled, expected);
// Map to a different type
let as_ints: Vec<i32> = view.map_elements::<f32, i32, _>(|x| x as i32).unwrap();
let expected_ints = [1, 2, 3, 4, 5, 6];
assert_eq!(as_ints, expected_ints);
}
#[test]
fn test_as_slice() {
use std::sync::{Arc, Mutex};
// Create a real memory tensor
#[derive(Debug)]
struct RealDataMock {
data: Arc<Mutex<Vec<u8>>>,
}
impl RealDataMock {
fn new(size_bytes: usize) -> Self {
Self {
data: Arc::new(Mutex::new(vec![0u8; size_bytes])),
}
}
fn set_f32_values(&self, values: &[f32]) {
let mut data = self.data.lock().unwrap();
for (i, val) in values.iter().enumerate() {
let bytes = val.to_ne_bytes();
data[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
// Create a 2x3 tensor with f32 elements
let real_tensor = RealDataMock::new(24); // 6 elements * 4 bytes
// Set up some initial values
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
real_tensor.set_f32_values(&values);
let shape = [2, 3];
let view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
// Get a slice and verify contents
let slice = view.as_slice::<f32>().unwrap();
assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
// Get a mutable view
let mut mut_view = TensorView::<_, 2>::new(&real_tensor, shape, 4).unwrap();
// Get a mutable slice and modify contents
{
let mut_slice = mut_view.as_slice_mut::<f32>().unwrap();
mut_slice[0] = 10.0;
mut_slice[5] = 60.0;
}
// Verify changes through the original view
assert_eq!(mut_view.get_element::<f32>(&[0, 0]).unwrap(), 10.0);
assert_eq!(mut_view.get_element::<f32>(&[1, 2]).unwrap(), 60.0);
}
#[test]
fn test_ndarray_view_with_real_data() {
use std::sync::{Arc, Mutex};
// Create a mock tensor with real memory for testing
#[derive(Debug)]
struct RealDataMock {
// Using Vec<u8> to store the raw bytes
data: Arc<Mutex<Vec<u8>>>,
element_size_bytes: usize,
}
impl RealDataMock {
// Add the missing new method
fn new(num_elements: usize, element_size: usize) -> Self {
// Create a zeroed buffer of the right size
let buffer = vec![0u8; num_elements * element_size];
Self {
data: Arc::new(Mutex::new(buffer)),
element_size_bytes: element_size,
}
}
// Helper to set a specific element's value
fn set_element_value(&self, index: usize, value: u32) {
let mut data = self.data.lock().unwrap();
let bytes = value.to_ne_bytes();
let offset = index * self.element_size_bytes;
for i in 0..std::mem::size_of::<u32>() {
data[offset + i] = bytes[i];
}
}
}
impl Storage for RealDataMock {
fn get_pointer(&self) -> u64 {
// Get the raw pointer to the start of our vector's data
self.data.lock().unwrap().as_ptr() as u64
}
fn storage_size(&self) -> usize {
self.data.lock().unwrap().len()
}
fn storage_type(&self) -> StorageType {
StorageType::System
}
}
// Create a mock tensor with u32 elements (4 bytes each)
let shape = [2, 3, 4]; // 24 elements total
let num_elements = shape.iter().product();
let element_size = std::mem::size_of::<u32>();
let mock_tensor = RealDataMock::new(num_elements, element_size);
// Create a tensor view
let view = TensorView::<_, 3>::new(&mock_tensor, shape, 4).unwrap();
// Create an ndarray view
let ndarray_view = view.as_ndarray_view::<u32>().unwrap();
// Verify all elements are zero (initial state)
for &value in ndarray_view.iter() {
assert_eq!(value, 0, "Expected all initial values to be 0");
}
// Create a mutable clone of the data to work with
let data_arc = mock_tensor.data.clone();
// Set all elements to 42 by modifying the underlying storage
{
let mut data = data_arc.lock().unwrap();
for i in 0..num_elements {
let offset = i * element_size;
let bytes = 42u32.to_ne_bytes();
for j in 0..element_size {
data[offset + j] = bytes[j];
}
}
}
// Create another ndarray view to see the changes
let updated_view = view.as_ndarray_view::<u32>().unwrap();
// Verify all elements are now 42
for &value in updated_view.iter() {
assert_eq!(value, 42, "Expected all values to be 42 after update");
}
// Change just the first element back to 0
mock_tensor.set_element_value(0, 0);
// Create another ndarray view to see the effect of our change
let final_view = view.as_ndarray_view::<u32>().unwrap();
// The first element should be 0, others should remain 42
assert_eq!(final_view[[0, 0, 0]], 0, "First element should be 0");
assert_eq!(updated_view[[0, 0, 0]], 0, "First element should be 0");
// Check some of the other elements to ensure they're still 42
assert_eq!(
final_view[[0, 0, 1]],
42,
"Element [0,0,1] should still be 42"
);
assert_eq!(final_view[[1, 2, 3]], 42, "Last element should still be 42");
// Count the number of zeros (should be exactly 1)
let zero_count = final_view.iter().filter(|&&x| x == 0).count();
assert_eq!(zero_count, 1, "There should be exactly one zero element");
// Count the number of 42s (should be num_elements - 1)
let forty_two_count = final_view.iter().filter(|&&x| x == 42).count();
assert_eq!(
forty_two_count,
num_elements - 1,
"All other elements should be 42"
);
}
#[test]
fn test_host_device_transfers() {
use cudarc::driver::CudaContext;
// Initialize CUDA
let context = CudaContext::new(0).unwrap();
let stream = context.default_stream();
// Create a host tensor with f32 elements (6 elements)
let pinned_storage = OwnedStorage::create_pinned_array(6 * 4).unwrap();
// Create a host tensor view
let shape = [2, 3];
let mut host_view = TensorView::<_, 2>::new(&pinned_storage, shape, 4).unwrap();
// Set some values
let values = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
for i in 0..2 {
for j in 0..3 {
host_view
.set_element::<f32>(&[i, j], values[i * 3 + j])
.unwrap();
}
}
// Create a device tensor
let device_storage = OwnedStorage::create_device_array(6 * 4, context.clone()).unwrap();
let mut device_view = TensorView::<_, 2>::new(&device_storage, shape, 4).unwrap();
// Copy from host to device using h2d method
host_view.h2d(&mut device_view, &stream).unwrap();
// Create another host tensor for receiving data back
let pinned_storage2 = OwnedStorage::create_pinned_array(6 * 4).unwrap();
let mut host_view2 = TensorView::<_, 2>::new(&pinned_storage2, shape, 4).unwrap();
// Copy from device to host using d2h method
device_view.d2h(&mut host_view2, &stream).unwrap();
stream.synchronize().unwrap();
// Verify the data was correctly transferred
for i in 0..2 {
for j in 0..3 {
assert_eq!(
host_view2.get_element::<f32>(&[i, j]).unwrap(),
values[i * 3 + j]
);
}
}
// Test with new values
let new_values = [10.0f32, 20.0, 30.0, 40.0, 50.0, 60.0];
// Fill host view with new values
for i in 0..2 {
for j in 0..3 {
host_view
.set_element::<f32>(&[i, j], new_values[i * 3 + j])
.unwrap();
}
}
// Upload to device
host_view.h2d(&mut device_view, &stream).unwrap();
// Download to host view and check values
device_view.d2h(&mut host_view2, &stream).unwrap();
stream.synchronize().unwrap();
// Verify the data
for i in 0..2 {
for j in 0..3 {
assert_eq!(
host_view2.get_element::<f32>(&[i, j]).unwrap(),
new_values[i * 3 + j]
);
}
}
}
}
...@@ -84,6 +84,10 @@ pub type WorkerId = i64; ...@@ -84,6 +84,10 @@ pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block. /// Compute the hash of a local block.
/// ///
/// ### Arguments /// ### Arguments
...@@ -94,7 +98,7 @@ type SharedRadixBlock = Rc<RefCell<RadixBlock>>; ...@@ -94,7 +98,7 @@ type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
/// ///
/// A `LocalBlockHash` representing the computed hash. /// A `LocalBlockHash` representing the computed hash.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(xxh3::xxh3_64_with_seed(data, XXH3_SEED)) LocalBlockHash(compute_hash(data))
} }
// /// Updated version of the `compute_block_hash` function that included the lora_id // /// Updated version of the `compute_block_hash` function that included the lora_id
......
...@@ -29,4 +29,8 @@ pub mod model_type; ...@@ -29,4 +29,8 @@ pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
pub mod protocols; pub mod protocols;
pub mod tokenizers; pub mod tokenizers;
pub mod tokens;
pub mod types; pub mod types;
#[cfg(feature = "cuda_kv")]
pub mod kv;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::indexer::compute_hash;
use bytemuck::cast_slice;
use derive_getters::{Dissolve, Getters};
use rayon::prelude::*;
pub type Token = u32;
/// A hash of the only the tokens within a block computed from [compute_hash].
pub type BlockHash = u64;
/// A sequence aware hash that combines the previous block's sequence hash with the current block's hash.
pub type SequenceHash = u64;
#[derive(Debug, Clone, Dissolve, Default)]
pub struct Tokens(Vec<Token>);
impl AsRef<[Token]> for Tokens {
fn as_ref(&self) -> &[Token] {
&self.0
}
}
impl std::ops::Deref for Tokens {
type Target = [Token];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::borrow::Borrow<[Token]> for Tokens {
fn borrow(&self) -> &[Token] {
&self.0
}
}
impl From<Vec<Token>> for Tokens {
fn from(tokens: Vec<Token>) -> Self {
Tokens(tokens)
}
}
impl From<&[Token]> for Tokens {
fn from(tokens: &[Token]) -> Self {
Tokens(tokens.to_vec())
}
}
impl From<Vec<i32>> for Tokens {
fn from(tokens: Vec<i32>) -> Self {
Tokens(tokens.into_iter().map(|t| t as u32).collect())
}
}
impl From<&[i32]> for Tokens {
fn from(tokens: &[i32]) -> Self {
Tokens(tokens.iter().map(|&t| t as u32).collect())
}
}
impl From<Tokens> for Vec<Token> {
fn from(tokens: Tokens) -> Self {
tokens.0
}
}
impl Tokens {
pub fn into_sequence(self, block_size: usize) -> TokenSequence {
TokenSequence::new(self, block_size)
}
}
pub struct PartialTokenBlock {
tokens: Tokens,
block_size: usize,
parent_sequence_hash: Option<SequenceHash>,
}
impl PartialTokenBlock {
/// Push a token onto the block, if the block is full, return a new [TokenBlock]
/// and reset the incomplete block
pub fn push_token(&mut self, token: Token) -> Option<TokenBlock> {
self.tokens.0.push(token);
if self.tokens.0.len() == self.block_size {
let block = std::mem::take(&mut self.tokens);
let block_hash = compute_hash(cast_slice(&block));
let sequence_hash = compute_hash(bytemuck::cast_slice(&[
self.parent_sequence_hash.unwrap_or_default(),
block_hash,
]));
Some(TokenBlock {
tokens: block,
sequence_hash,
block_hash,
parent_sequence_hash: self.parent_sequence_hash,
})
} else {
None
}
}
pub fn tokens(&self) -> &Tokens {
&self.tokens
}
}
impl std::ops::Deref for PartialTokenBlock {
type Target = Tokens;
fn deref(&self) -> &Self::Target {
&self.tokens
}
}
#[derive(Debug, Clone, Getters, Default)]
pub struct TokenBlock {
tokens: Tokens,
#[getter(copy)]
block_hash: BlockHash,
#[getter(copy)]
sequence_hash: SequenceHash,
#[getter(copy)]
parent_sequence_hash: Option<SequenceHash>,
}
pub struct TokenSequence {
blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock,
}
impl TokenSequence {
pub fn new(tokens: Tokens, block_size: usize) -> Self {
let (blocks, current_block) = Self::split_tokens(tokens, block_size);
Self {
blocks,
current_block,
}
}
pub fn push_token(&mut self, token: Token) -> Option<&TokenBlock> {
if let Some(block) = self.current_block.push_token(token) {
self.blocks.push(block);
self.blocks.last()
} else {
None
}
}
pub fn blocks(&self) -> &[TokenBlock] {
&self.blocks
}
pub fn current_block(&self) -> &PartialTokenBlock {
&self.current_block
}
pub fn into_parts(self) -> (Vec<TokenBlock>, PartialTokenBlock) {
(self.blocks, self.current_block)
}
pub fn split_tokens(tokens: Tokens, block_size: usize) -> (Vec<TokenBlock>, PartialTokenBlock) {
// Use rayon's parallel iterator to process chunks in parallel
let mut blocks: Vec<TokenBlock> = tokens
.par_chunks_exact(block_size)
.map(|chunk| TokenBlock {
tokens: chunk.to_vec().into(),
sequence_hash: 0,
block_hash: compute_hash(cast_slice(chunk)),
parent_sequence_hash: None,
})
.collect();
blocks[0].sequence_hash = blocks[0].block_hash;
// compute the sequence hash for each block
// this is the sequence hash of the previous block with the current block's hash
for i in 1..blocks.len() {
let previous_block = &blocks[i - 1];
let parent_sequence_hash = previous_block.sequence_hash;
let vals = &[parent_sequence_hash, blocks[i].block_hash];
blocks[i].sequence_hash = compute_hash(bytemuck::cast_slice(vals));
blocks[i].parent_sequence_hash = Some(parent_sequence_hash);
}
let remainder = tokens.chunks_exact(block_size).remainder();
let next_block = PartialTokenBlock {
tokens: remainder.into(),
block_size,
parent_sequence_hash: blocks.last().map(|b| b.sequence_hash),
};
(blocks, next_block)
}
}
impl PartialEq<Vec<Token>> for Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
impl PartialEq<Tokens> for Vec<Token> {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0
}
}
impl PartialEq<[Token]> for Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
impl PartialEq<Tokens> for &[Token] {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0.as_slice()
}
}
impl PartialEq<Vec<Token>> for &Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
impl<'a> PartialEq<&'a Tokens> for Vec<Token> {
fn eq(&self, other: &&'a Tokens) -> bool {
*self == other.0
}
}
impl PartialEq<[Token]> for &Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
impl<'a> PartialEq<&'a [Token]> for Tokens {
fn eq(&self, other: &&'a [Token]) -> bool {
self.0.as_slice() == *other
}
}
impl PartialEq for Tokens {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Tokens {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokens_slice_operations() {
let tokens = Tokens(vec![1, 2, 3, 4, 5]);
// Test AsRef<[Token]>
let slice: &[Token] = tokens.as_ref();
assert_eq!(slice, &[1, 2, 3, 4, 5]);
// Test Deref
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0], 1);
assert_eq!(tokens[4], 5);
// Test iteration
let sum: u32 = tokens.iter().sum();
assert_eq!(sum, 15);
// Test slicing
let slice = &tokens[1..4];
assert_eq!(slice, &[2, 3, 4]);
// Test Borrow
let borrowed: &[Token] = std::borrow::Borrow::borrow(&tokens);
assert_eq!(borrowed, &[1, 2, 3, 4, 5]);
// Test with functions that accept &[Token]
fn takes_slice(slice: &[Token]) -> usize {
slice.len()
}
assert_eq!(takes_slice(&tokens), 5);
}
#[test]
fn test_tokens_conversions() {
// Test From<Vec<Token>> for Tokens
let vec = vec![1, 2, 3, 4, 5];
let tokens: Tokens = vec.clone().into();
assert_eq!(tokens.0, vec);
// Test Into<Vec<Token>> for Tokens
let tokens = Tokens(vec![6, 7, 8, 9, 10]);
let vec: Vec<Token> = tokens.into();
assert_eq!(vec, vec![6, 7, 8, 9, 10]);
// Test From<&[Token]> for Tokens
let slice: &[Token] = &[11, 12, 13];
let tokens: Tokens = slice.into();
assert_eq!(tokens.0, vec![11, 12, 13]);
// Test From<Vec<i32>> for Tokens
let i32_values = vec![100_i32, 200_i32, 300_i32];
let tokens: Tokens = i32_values.into();
assert_eq!(tokens.0, vec![100, 200, 300]);
// Test From<&[i32]> for Tokens
let i32_slice: &[i32] = &[400_i32, 500_i32, 600_i32];
let tokens: Tokens = i32_slice.into();
assert_eq!(tokens.0, vec![400, 500, 600]);
}
#[test]
fn test_tokens_blocks() {
let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let sequence = TokenSequence::new(tokens, 4);
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().len(), 2);
assert_eq!(sequence.blocks()[0].tokens(), vec![1, 2, 3, 4]);
assert_eq!(sequence.blocks()[0].block_hash(), 14643705804678351452);
assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
println!("blocks[0]: {:?}", sequence.blocks()[0]);
assert_eq!(sequence.blocks()[1].tokens(), vec![5, 6, 7, 8]);
assert_eq!(sequence.blocks()[1].block_hash(), 16777012769546811212);
assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
println!("blocks[1]: {:?}", sequence.blocks()[1]);
assert_eq!(sequence.current_block().tokens(), vec![9, 10]);
let mut sequence = sequence;
let new_block = sequence.push_token(11);
assert!(new_block.is_none());
assert_eq!(sequence.blocks().len(), 2);
let new_block = sequence.push_token(12);
assert!(new_block.is_some());
assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.current_block().tokens().len(), 0);
println!("blocks[2]: {:?}", sequence.blocks()[2]);
let (blocks, mut current_block) = sequence.into_parts();
let new_block = current_block.push_token(13);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 1);
let new_block = current_block.push_token(14);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 2);
let new_block = current_block.push_token(15);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 3);
let new_block = current_block.push_token(16);
assert!(new_block.is_some());
assert_eq!(blocks.len(), 3);
assert_eq!(current_block.tokens().len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// // SPDX-License-Identifier: Apache-2.0
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //! Prototype KV Manager
// //!
// //! The KV Manager will be linked to three components:
// //! - ForwardPassTask / Scheduler
// //! - On each forward pass, any slot that has completed a block will:
// //! - Add the block to the Persistence Engine
// //! - Acquire a new block to continue generating
// //! - Persistence Engine
// //! - Will perform copies from GPU memory to CPU memory and possibly CPU memory
// //! to some global flash storage
// //! - Prefill Descriptor Manager
// //! - New request that require prefill offload, will acquire leases on any shared
// //! blocks and any "net new" blocks that need to be populated from the prefill
// //! instance.
// //!
// use async_trait::async_trait;
// use bytemuck::cast_slice;
// use rayon::prelude::*;
// use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, VecDeque};
// use std::sync::Arc;
// use tokio::{
// sync::{Mutex, Notify},
// time::Instant,
// };
// use triton_distributed_llm::kv_router::indexer::compute_block_hash;
// use triton_distributed_llm::kv_router::protocols::LocalBlockHash;
// use dynamo_runtime::utils::pool::{
// Pool, PoolExt, PoolItem, PoolValue, Returnable, SharedPoolItem,
// };
// pub trait Storage {}
// pub type BlockHash = u64;
// pub type SequenceHash = u64;
// pub type Token = u32;
// pub struct Tokens(Vec<Token>);
// pub struct TokenBlock {
// tokens: Tokens,
// sequence_hash: SequenceHash,
// block_hash: LocalBlockHash,
// sequence_position: u32,
// priority: Option<u8>,
// reserved_deadline: Option<Instant>,
// }
// impl Tokens {
// pub fn blocks(&self, block_size: usize) -> Vec<TokenBlock> {
// // split the tokens into blocks of the given size
// // todo: determine how and when to parallelize the block creation
// // we can hash the local chunks in parallel
// // Use rayon's parallel iterator to process chunks in parallel
// self.0
// .chunks_exact(block_size)
// .par_iter()
// .map(|chunk| TokenBlock {
// tokens: Tokens(chunk.to_vec()),
// sequence_hash: 0,
// block_hash: compute_block_hash(cast_slice(chunk)),
// sequence_position: 0,
// priority: None,
// reserved_deadline: None,
// })
// .collect()
// }
// }
// pub struct KvBlock<T: Storage> {
// sequence_hash: SequenceHash,
// block_hash: BlockHash,
// sequence_position: u32,
// reserved_deadline: Option<Instant>,
// storage: Arc<T>,
// }
// pub struct SampleKvStorage {}
// impl Storage for SampleKvStorage {}
// pub type Block = KvBlock<SampleKvStorage>;
// impl Returnable for Block {}
// pub type UniqueBlock = PoolItem<Block, Pool<Block>>;
// pub type SharedBlock = SharedPoolItem<Block, Pool<Block>>;
// /// A wrapper around a time-critical item that will determine the amount of elapsed/walltime
// /// since the item was created. The `deadline` is optional and if not set, the item will be
// /// considered to have no time constraints. If the `deadline` is set, the item will be will
// /// increment a [prometheus::Counter] if the deadline is exceeded.
// ///
// /// In this manner, we can monitor the time-criticality of the item and take action if it is
// /// taking too long to process.
// // pub struct TimeCritical<T> {
// // // pub timestamp: Instant,
// // // pub item: T,
// // // pub deadline: Option<Instant>,
// // }
// pub struct Sequence {
// tokens: Vec<u32>,
// shared_blocks: Vec<SharedBlock>,
// current_block: UniqueBlock,
// }
// /// Adapt the KvIndexer to hold Block information
// pub struct DeviceRadixTree {}
// /// Adapt the KvIndexer to hold Block information
// pub struct HostRadixTree {}
// /// Owner of the radix trees and the block pool
// pub struct KvBlockManager {}
// /// The [Scheduler] is responsible for determining which [Sequence] objects should be
// /// scheduled for the next forward pass.
// ///
// /// The [Scheduler] will prepare a [Sequence] object for each request and pass that [Sequence]
// /// to either the [ForwardPassEngine] or the [PrefillHandler] depending the size of the
// /// ISL and "net-new" tokens that need to be prefilled to the [Sequence].
// ///
// /// The [Scheduler] has have multiple [Sequences][Sequence] offloaded to the [PrefillHandler];
// /// however, some care needs to be taken that that value is not "too large" as the blocks
// /// held by the [Sequence] can not be reused or repurposed by eviction.
// pub struct Scheduler {
// // slots: BTreeMap<u64, Sequence>,
// // pending: VecDeque<Sequencd>,
// }
// /// The [ForwardPassEngine] is responsible for scheduling the forward pass of the model.
// /// It will receive requests from the scheduler that will have the set of SharedBlocks that
// /// associated with the current request tied to a Sequence object.
// ///
// /// The [ForwardPassEngine] appends new tokens to the current block of the [Sequence]. When
// /// the current block is full, it is converted to an immutable [SharedBlock] and a copy/clone
// /// is passed to the [PersistenceEngine] via an mpsc::Sender<TimeCritical<SharedBlock>>.
// ///
// /// The [ForwardPassEngine] should spawn async tasks per forward pass to evaluate the potential
// /// of each [Sequence] and determine how many blocks it could return to the [FreePool] if it was
// /// evicted.
// ///
// /// We only want to evict a [Sequence] if it can free enough blocks to be worth the overhead of
// /// evicting it and most critically, that we have persisted all evicted blocks in host memory.
// /// This will avoid the need to re-prefill the blocks when the sequence is rescheduled.
// ///
// /// The [ForwardPassEngine] should also evaluate the potential of each [Sequence] to be
// /// prefilled and if so, it will return a [PrefillHandler] to the caller.
// pub struct ForwardPassEngine {
// // scheduler: Scheduler,
// // kv_manager: KvBlockManager,
// // persistence_engine: PersistenceEngine,
// }
// /// The [PersistenceEngine] is responsible for copying blocks from GPU memory to
// /// to either host memory or some form of persistent storage.
// ///
// /// The [PersistenceEngine] will have a mpsc receiver of SharedBlock. Each block can
// /// be handled independently and freed after the copy is complete.
// ///
// /// We must time each SharedBlock as it enters the channel, so perhaps we wrap the incoming
// /// SharedBlock in a timestamped context.
// ///
// /// Holding SharedBlocks forbids their reuse, so we need to carefully and accurately monitor
// /// the state of this engine so it is not starving the ForwardPass [Scheduler] of free blocks.
// pub struct PersistenceEngine {}
// /// The [PrefillHandler] is responsible for acquiring blocks from the [KvBlockManager] for a
// /// given request. The input sequence length will be evaluated and two sets of blocks will be
// /// returned to the caller:
// /// - Vec<SharedBlock>
// /// - Vec<UniqueBlock>
// ///
// /// The `Vec<SharedBlock>` are the blocks that matched inflight radix tree. By acquiring a
// /// [SharedBlock], this ensure that the blocks cannot be returned to the [FreePool].
// ///
// /// The `Vec<UniqueBlock>` are the new blocks that are not present in the inflight radix tree
// /// which need to be prefilled. The decision to prefill locally via chunking of to offload to
// /// dedicated prefill workers can be made once the target destinations for the KV are determined.
// pub struct PrefillHandler {}
// /// The [MigrationEngine] is responsible for migrating blocks from one physical machine to another.
// /// In an ideal world, this transfer is over NVLink or ConnectX InfiniBand; however, any reasonably
// /// fast transfer will suffice.
// ///
// /// The [MigrationEngine] spawns tasks that operate in two paradigms:
// /// - RDMA Passive Source: The task will acquire [SharedBlocks][SharedBlock] from the [KvBlockManager]
// /// and hold them until a RDMA GET COMPLETION notification is received. Essentially, the task which
// /// holds the [SharedBlocks][SharedBlock] is simply responsible for ensuring the memory is pinned
// /// and not returned to the [FreePool] over the duration of the RDMA GET.
// /// - RDMA Active Puller: The task will receive a set of [SharedBlocks][SharedBlock]. The block list
// /// is a set of block_ids and a remote target. The task will initiate the RDMA GETs via the NIXL
// /// library and then wait for completion. Upon completion, and event or active message event will
// /// be triggered on each RDMA Passive Source to trigger task completion and resource dereferencing.
// ///
// pub struct MigrationEngine {}
// // when in a hashset, PriorityBlockReference must be unique by block_id and sorted by:
// // - priority (lowest to highest)
// // - sequence_id (highest to lowest)
// //
// // - all lower priority items must be evicted before higher priority items
// // - all items with the same priority must be evicted in sequence_id order with the highest sequence
// // position evicted first
// //
// // when a sequences must have priorities that are ordered, you can not have a block with a lower sequence_id
// // and a lower priority. the same is true for deadlines.
// #[derive(Debug, Clone, Eq)]
// struct PriorityBlockReference {
// block_id: SequenceHash,
// sequence_position: u32,
// priority: u8,
// }
// struct TimeAwareBlockReference {
// block_id: SequenceHash,
// sequence_position: u32,
// evict_deadline: Instant,
// priority: u8,
// }
// impl PartialEq for PriorityBlockReference {
// fn eq(&self, other: &Self) -> bool {
// self.block_id == other.block_id
// }
// }
// impl std::hash::Hash for PriorityBlockReference {
// fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// self.block_id.hash(state);
// }
// }
// // Example usage:
// // let priority_set: HashSet<PriorityBlockReference> = HashSet::new();
// //
// // // To get items in sequence_id order:
// // let mut sorted_refs: Vec<&PriorityBlockReference> = priority_set.iter().collect();
// // sorted_refs.sort_by(|a, b| a.sequence_id.cmp(&b.sequence_id));
// // A key that defines the ordering
// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
// struct PriorityKey {
// // For PriorityReference
// priority: u8,
// sequence_position: u32,
// // Unique identifier to break ties and ensure uniqueness
// block_hash: BlockHash,
// }
// impl PriorityKey {
// fn new_priority(block: &Block, priority: u8) -> Self {
// Self {
// priority,
// sequence_position: block.sequence_position,
// block_hash: block.block_hash,
// }
// }
// }
// // A key that defines deadline-based ordering
// //
// // Sort by deadline, then priority, then sequence_position, then sequence_hash
// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
// struct DeadlineKey {
// deadline: Instant,
// priority: u8,
// sequence_position: u32,
// sequence_hash: SequenceHash,
// }
// impl DeadlineKey {
// fn new_deadline(block: &Block, priority: u8) -> Self {
// Self {
// deadline: block
// .reserved_deadline
// .unwrap_or_else(|| Instant::now() + std::time::Duration::from_secs(u64::MAX)),
// priority,
// sequence_position: block.sequence_position,
// blocksequence_hash_hash: block.sequence_hash,
// }
// }
// }
// // Define a struct that combines ordered access with direct lookup
// #[derive(Default)]
// pub struct OrderedLookupSet {
// // Direct lookup by sequence_hash
// lookup_map: HashMap<SequenceHash, PoolValue<Block>>,
// // Ordered by priority
// priority_set: BTreeMap<PriorityKey, SequenceHash>,
// // Ordered by deadline
// deadline_set: BTreeMap<DeadlineKey, SequenceHash>,
// }
// impl<T> OrderedLookupSet {
// // Insert an item with a given key and sequence_hash
// pub fn insert(&mut self, key: OrderKey, sequence_hash: SequenceHash, item: T) {
// // Add to the ordered set
// self.ordered_set.insert(key.clone(), item);
// // Add to the lookup map
// self.lookup_map.entry(sequence_hash).or_default().push(key);
// }
// // Remove an item by its key
// pub fn remove_by_key(&mut self, key: &OrderKey) -> Option<T> {
// self.ordered_set.remove(key)
// }
// // Remove an item by sequence_hash and block_hash
// pub fn remove_by_hash(
// &mut self,
// sequence_hash: SequenceHash,
// block_hash: BlockHash,
// ) -> Option<T> {
// // Find the key in the lookup map
// if let Some(keys) = self.lookup_map.get_mut(&sequence_hash) {
// // Find the key with the matching block_hash
// if let Some(pos) = keys.iter().position(|k| k.block_hash == block_hash) {
// // Remove the key from the lookup map
// let key = keys.remove(pos);
// // If this was the last key for this sequence_hash, remove the entry
// if keys.is_empty() {
// self.lookup_map.remove(&sequence_hash);
// }
// // Remove and return the item from the ordered set
// return self.ordered_set.remove(&key);
// }
// }
// None
// }
// // Pop the highest priority item (first in order)
// pub fn pop_first(&mut self) -> Option<(OrderKey, T)> {
// if let Some((key, item)) = self.ordered_set.first_key_value() {
// let key_clone = key.clone();
// let sequence_hash = self.get_sequence_hash(&key_clone)?;
// // Remove from the ordered set
// let item = self.ordered_set.remove(&key_clone)?;
// // Remove from the lookup map
// if let Some(keys) = self.lookup_map.get_mut(&sequence_hash) {
// if let Some(pos) = keys.iter().position(|k| k == &key_clone) {
// keys.remove(pos);
// // If this was the last key for this sequence_hash, remove the entry
// if keys.is_empty() {
// self.lookup_map.remove(&sequence_hash);
// }
// }
// }
// Some((key_clone, item))
// } else {
// None
// }
// }
// // Helper method to find the sequence_hash for a key
// fn get_sequence_hash(&self, key: &OrderKey) -> Option<SequenceHash> {
// for (hash, keys) in &self.lookup_map {
// if keys.iter().any(|k| k == key) {
// return Some(*hash);
// }
// }
// None
// }
// // Get all items for a sequence_hash
// pub fn get_by_sequence_hash(&self, sequence_hash: SequenceHash) -> Vec<&T> {
// if let Some(keys) = self.lookup_map.get(&sequence_hash) {
// keys.iter()
// .filter_map(|key| self.ordered_set.get(key))
// .collect()
// } else {
// Vec::new()
// }
// }
// }
// // Now update the AvailableBlocks implementation
// #[derive(Debug, Clone, Default)]
// pub struct AvailableBlocks {
// // Map from sequence_hash to blocks
// sequence_map: BTreeMap<SequenceHash, Vec<UniqueBlock>>,
// // Ordered by priority with lookup by sequence_hash
// priority_set: OrderedLookupSet<UniqueBlock>,
// // Ordered by deadline with lookup by sequence_hash
// deadline_set: OrderedLookupSet<UniqueBlock>,
// }
// impl AvailableBlocks {
// // Add a block to the available blocks
// pub fn add_block(&mut self, block: UniqueBlock) {
// let block_ref = &*block; // Deref to get the Block
// let sequence_hash = block_ref.sequence_hash;
// let priority = calculate_priority(block_ref);
// // Create keys for our sets
// let priority_key = OrderKey::new_priority(block_ref, priority);
// let deadline_key = DeadlineKey::new_deadline(block_ref, priority);
// // Add to the sequence map
// self.sequence_map
// .entry(sequence_hash)
// .or_default()
// .push(block.clone());
// // Add to our sets
// self.priority_set
// .insert(priority_key, sequence_hash, block.clone());
// // For deadline_set, we'd need a similar implementation with DeadlineKey
// // self.deadline_set.insert(deadline_key, sequence_hash, block);
// }
// // Get the highest priority block
// pub fn pop_highest_priority(&mut self) -> Option<UniqueBlock> {
// if let Some((key, block)) = self.priority_set.pop_first() {
// // Remove from sequence map
// if let Some(blocks) = self.sequence_map.get_mut(&block.sequence_hash) {
// if let Some(pos) = blocks.iter().position(|b| b.block_hash == key.block_hash) {
// blocks.remove(pos);
// }
// }
// // Remove from deadline set
// // self.deadline_set.remove_by_hash(block.sequence_hash, key.block_hash);
// Some(block)
// } else {
// None
// }
// }
// // Get all blocks for a sequence
// pub fn get_blocks_by_sequence(&self, sequence_hash: SequenceHash) -> Vec<&UniqueBlock> {
// self.priority_set.get_by_sequence_hash(sequence_hash)
// }
// }
// // Helper function to calculate priority based on block details
// fn calculate_priority(block: &Block) -> u8 {
// // Implement your priority calculation logic here
// 0 // Default priority
// }
// async fn available_block_progress_engine(
// request_rx: Receiver<BlockRequest>,
// return_rx: Receiver<B>,
// ) {
// let available_blocks = AvailableBlocks::default();
// }
...@@ -239,9 +239,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -239,9 +239,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.1" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
...@@ -1010,9 +1010,9 @@ dependencies = [ ...@@ -1010,9 +1010,9 @@ dependencies = [
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
...@@ -1031,12 +1031,12 @@ dependencies = [ ...@@ -1031,12 +1031,12 @@ dependencies = [
[[package]] [[package]]
name = "http-body-util" name = "http-body-util"
version = "0.1.2" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-util", "futures-core",
"http", "http",
"http-body", "http-body",
"pin-project-lite", "pin-project-lite",
...@@ -1056,9 +1056,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" ...@@ -1056,9 +1056,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]] [[package]]
name = "humantime" name = "humantime"
version = "2.1.0" version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
[[package]] [[package]]
name = "hyper" name = "hyper"
...@@ -1634,9 +1634,9 @@ dependencies = [ ...@@ -1634,9 +1634,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "openssl-probe" name = "openssl-probe"
...@@ -1792,9 +1792,9 @@ dependencies = [ ...@@ -1792,9 +1792,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -1919,9 +1919,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" ...@@ -1919,9 +1919,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.39" version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
...@@ -2031,9 +2031,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" ...@@ -2031,9 +2031,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.13" version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [ dependencies = [
"cc", "cc",
"cfg-if 1.0.0", "cfg-if 1.0.0",
...@@ -2563,9 +2563,9 @@ dependencies = [ ...@@ -2563,9 +2563,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.44.0" version = "1.44.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9975ea0f48b5aa3972bf2d888c238182458437cc2a19374b81b25cdf1023fb3a" checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
...@@ -2613,9 +2613,9 @@ dependencies = [ ...@@ -2613,9 +2613,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.13" version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
...@@ -3203,9 +3203,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" ...@@ -3203,9 +3203,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.7.3" version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
......
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
pub use tokio::time::{Duration, Instant}; pub use tokio::time::{Duration, Instant};
pub mod pool;
pub mod stream; pub mod stream;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::VecDeque;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::Notify;
/// Trait for items that can be returned to a pool
pub trait Returnable: Send + Sync + 'static {
/// Called when an item is returned to the pool
fn on_return(&mut self) {}
}
pub trait ReturnHandle<T: Returnable>: Send + Sync + 'static {
fn return_to_pool(&self, value: PoolValue<T>);
}
/// Enum to hold either a Box<T> or T directly
pub enum PoolValue<T: Returnable> {
Boxed(Box<T>),
Direct(T),
}
impl<T: Returnable> PoolValue<T> {
/// Create a new PoolValue from a boxed item
pub fn from_boxed(value: Box<T>) -> Self {
PoolValue::Boxed(value)
}
/// Create a new PoolValue from a direct item
pub fn from_direct(value: T) -> Self {
PoolValue::Direct(value)
}
/// Get a reference to the underlying item
pub fn get(&self) -> &T {
match self {
PoolValue::Boxed(boxed) => boxed.as_ref(),
PoolValue::Direct(direct) => direct,
}
}
/// Get a mutable reference to the underlying item
pub fn get_mut(&mut self) -> &mut T {
match self {
PoolValue::Boxed(boxed) => boxed.as_mut(),
PoolValue::Direct(direct) => direct,
}
}
/// Call on_return on the underlying item
pub fn on_return(&mut self) {
self.get_mut().on_return();
}
}
impl<T: Returnable> Deref for PoolValue<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.get()
}
}
impl<T: Returnable> DerefMut for PoolValue<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.get_mut()
}
}
// Private module to restrict access to PoolItem constructor
mod private {
// This type can only be constructed within this module
#[derive(Clone, Copy)]
pub struct PoolItemToken(());
impl PoolItemToken {
pub(super) fn new() -> Self {
PoolItemToken(())
}
}
}
/// Core trait defining pool operations
pub trait PoolExt<T: Returnable>: Send + Sync + 'static {
/// Create a new PoolItem (only available to implementors)
fn create_pool_item(
&self,
value: PoolValue<T>,
handle: Arc<dyn ReturnHandle<T>>,
) -> PoolItem<T> {
PoolItem::new(value, handle)
}
}
/// An item borrowed from a pool
pub struct PoolItem<T: Returnable> {
value: Option<PoolValue<T>>,
handle: Arc<dyn ReturnHandle<T>>,
_token: private::PoolItemToken,
}
impl<T: Returnable> PoolItem<T> {
/// Create a new PoolItem (only available within this module)
fn new(value: PoolValue<T>, handle: Arc<dyn ReturnHandle<T>>) -> Self {
Self {
value: Some(value),
handle,
_token: private::PoolItemToken::new(),
}
}
/// Convert this unique PoolItem into a shared reference
pub fn into_shared(self) -> SharedPoolItem<T> {
SharedPoolItem {
inner: Arc::new(self),
}
}
/// Check if this item still contains a value
pub fn has_value(&self) -> bool {
self.value.is_some()
}
}
impl<T: Returnable> Deref for PoolItem<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value.as_ref().unwrap().get()
}
}
impl<T: Returnable> DerefMut for PoolItem<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value.as_mut().unwrap().get_mut()
}
}
impl<T: Returnable> Drop for PoolItem<T> {
fn drop(&mut self) {
if let Some(mut value) = self.value.take() {
value.on_return();
// Use blocking version for drop
self.handle.return_to_pool(value);
}
}
}
/// A shared reference to a pooled item
pub struct SharedPoolItem<T: Returnable> {
inner: Arc<PoolItem<T>>,
}
impl<T: Returnable> Clone for SharedPoolItem<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T: Returnable> SharedPoolItem<T> {
/// Get a reference to the underlying item
pub fn get(&self) -> &T {
self.inner.value.as_ref().unwrap().get()
}
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
}
impl<T: Returnable> Deref for SharedPoolItem<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.value.as_ref().unwrap().get()
}
}
/// Standard pool implementation
pub struct Pool<T: Returnable> {
state: Arc<PoolState<T>>,
capacity: usize,
}
struct PoolState<T: Returnable> {
pool: Arc<Mutex<VecDeque<PoolValue<T>>>>,
available: Arc<Notify>,
}
impl<T: Returnable> ReturnHandle<T> for PoolState<T> {
fn return_to_pool(&self, value: PoolValue<T>) {
let mut pool = self.pool.lock().unwrap();
pool.push_back(value);
self.available.notify_one();
}
}
impl<T: Returnable> Pool<T> {
/// Create a new pool with the given initial elements
pub fn new(initial_elements: Vec<PoolValue<T>>) -> Self {
let capacity = initial_elements.len();
let pool = initial_elements
.into_iter()
.collect::<VecDeque<PoolValue<T>>>();
let state = Arc::new(PoolState {
pool: Arc::new(Mutex::new(pool)),
available: Arc::new(Notify::new()),
});
Self { state, capacity }
}
/// Create a new pool with initial boxed elements
pub fn new_boxed(initial_elements: Vec<Box<T>>) -> Self {
let initial_values = initial_elements
.into_iter()
.map(PoolValue::from_boxed)
.collect();
Self::new(initial_values)
}
/// Create a new pool with initial direct elements
pub fn new_direct(initial_elements: Vec<T>) -> Self {
let initial_values = initial_elements
.into_iter()
.map(PoolValue::from_direct)
.collect();
Self::new(initial_values)
}
async fn try_acquire(&self) -> Option<PoolItem<T>> {
let mut pool = self.state.pool.lock().unwrap();
pool.pop_front()
.map(|value| PoolItem::new(value, self.state.clone()))
}
async fn acquire(&self) -> PoolItem<T> {
loop {
if let Some(guard) = self.try_acquire().await {
return guard;
}
self.state.available.notified().await;
}
}
fn notify_return(&self) {
self.state.available.notify_one();
}
fn capacity(&self) -> usize {
self.capacity
}
}
impl<T: Returnable> PoolExt<T> for Pool<T> {}
impl<T: Returnable> Clone for Pool<T> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
capacity: self.capacity,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{timeout, Duration};
// Implement Returnable for u32 just for testing
impl Returnable for u32 {
fn on_return(&mut self) {
*self = 0;
tracing::debug!("Resetting u32 to 0");
}
}
#[tokio::test]
async fn test_acquire_release() {
let initial_elements = vec![
PoolValue::Direct(1),
PoolValue::Direct(2),
PoolValue::Direct(3),
PoolValue::Direct(4),
PoolValue::Direct(5),
];
let pool = Pool::new(initial_elements);
// Acquire an element from the pool
if let Some(mut item) = pool.try_acquire().await {
assert_eq!(*item, 1); // It should be the first element we put in
// Modify the value
*item += 10;
assert_eq!(*item, 11);
// The item will be dropped at the end of this scope,
// and the value will be returned to the pool
}
// Acquire all remaining elements and the one we returned
let mut values = Vec::new();
let mut items = Vec::new();
while let Some(item) = pool.try_acquire().await {
values.push(*item);
items.push(item);
}
// The last element in `values` should be the one we returned, and it should be on_return to 0
assert_eq!(values, vec![2, 3, 4, 5, 0]);
// Test the awaitable acquire
let pool_clone = pool.clone();
let task = tokio::spawn(async move {
let first_acquired = pool_clone.acquire().await;
assert_eq!(*first_acquired, 0);
});
timeout(Duration::from_secs(1), task)
.await
.expect_err("Expected timeout");
// Drop the guards to return the PoolItems to the pool.
items.clear();
let pool_clone = pool.clone();
let task = tokio::spawn(async move {
let first_acquired = pool_clone.acquire().await;
assert_eq!(*first_acquired, 0);
});
// Now the task should be able to finish.
timeout(Duration::from_secs(1), task)
.await
.expect("Task did not complete in time")
.unwrap();
}
#[tokio::test]
async fn test_shared_items() {
let initial_elements = vec![
PoolValue::Direct(1),
// PoolValue::Direct(2),
// PoolValue::Direct(3),
];
let pool = Pool::new(initial_elements);
// Acquire and convert to shared
let mut item = pool.acquire().await;
*item += 10; // Modify before sharing
let shared = item.into_shared();
assert_eq!(*shared, 11);
// Create a clone of the shared item
let shared_clone = shared.clone();
assert_eq!(*shared_clone, 11);
// Drop the original shared item
drop(shared);
// Clone should still be valid
assert_eq!(*shared_clone, 11);
// Drop the clone
drop(shared_clone);
// Now we should be able to acquire the item again
let item = pool.acquire().await;
assert_eq!(*item, 0); // Value should be on_return
}
#[tokio::test]
async fn test_boxed_values() {
let initial_elements = vec![
PoolValue::Boxed(Box::new(1)),
// PoolValue::Boxed(Box::new(2)),
// PoolValue::Boxed(Box::new(3)),
];
let pool = Pool::new(initial_elements);
// Acquire an element from the pool
let mut item = pool.acquire().await;
assert_eq!(*item, 1);
// Modify and return to pool
*item += 10;
drop(item);
// Should get on_return value when acquired again
let item = pool.acquire().await;
assert_eq!(*item, 0);
}
#[tokio::test]
async fn test_pool_item_creation() {
let pool = Pool::new(vec![PoolValue::Direct(1)]);
// This works - acquiring from the pool
let item = pool.acquire().await;
assert_eq!(*item, 1);
// This would not compile - can't create PoolItem directly
// let invalid_item = PoolItem {
// value: Some(PoolValue::Direct(2)),
// pool: pool.clone(),
// _token: /* can't create this */
// };
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment