Commit f04359cf authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: global kv block manager (#45)

parent 530a6be0
...@@ -193,7 +193,7 @@ dependencies = [ ...@@ -193,7 +193,7 @@ dependencies = [
"once_cell", "once_cell",
"pin-project", "pin-project",
"portable-atomic", "portable-atomic",
"rand", "rand 0.8.5",
"regex", "regex",
"ring", "ring",
"rustls-native-certs 0.7.3", "rustls-native-certs 0.7.3",
...@@ -232,7 +232,7 @@ dependencies = [ ...@@ -232,7 +232,7 @@ dependencies = [
"derive_builder", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"reqwest-eventsource", "reqwest-eventsource",
"secrecy", "secrecy",
...@@ -496,7 +496,7 @@ dependencies = [ ...@@ -496,7 +496,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"instant", "instant",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"tokio", "tokio",
] ]
...@@ -535,9 +535,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -535,9 +535,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.1" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bindgen" name = "bindgen"
...@@ -747,12 +747,12 @@ dependencies = [ ...@@ -747,12 +747,12 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
"candle-metal-kernels", "candle-metal-kernels",
"cudarc", "cudarc 0.13.9 (registry+https://github.com/rust-lang/crates.io-index)",
"float8", "float8",
"gemm", "gemm",
"half", "half",
...@@ -760,7 +760,7 @@ dependencies = [ ...@@ -760,7 +760,7 @@ dependencies = [
"metal", "metal",
"num-traits", "num-traits",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
"rayon", "rayon",
"safetensors", "safetensors",
...@@ -781,7 +781,7 @@ dependencies = [ ...@@ -781,7 +781,7 @@ dependencies = [
"indicatif", "indicatif",
"log", "log",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"rustls 0.23.23", "rustls 0.23.23",
"serde", "serde",
...@@ -794,7 +794,7 @@ dependencies = [ ...@@ -794,7 +794,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -802,7 +802,7 @@ dependencies = [ ...@@ -802,7 +802,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -813,7 +813,7 @@ dependencies = [ ...@@ -813,7 +813,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f" source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -1212,6 +1212,14 @@ dependencies = [ ...@@ -1212,6 +1212,14 @@ dependencies = [
"libloading", "libloading",
] ]
[[package]]
name = "cudarc"
version = "0.13.9"
source = "git+https://github.com/coreylowman/cudarc.git?rev=8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe#8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe"
dependencies = [
"libloading",
]
[[package]] [[package]]
name = "curve25519-dalek" name = "curve25519-dalek"
version = "4.1.3" version = "4.1.3"
...@@ -1494,12 +1502,6 @@ dependencies = [ ...@@ -1494,12 +1502,6 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]] [[package]]
name = "dunce" name = "dunce"
version = "1.0.5" version = "1.0.5"
...@@ -1535,9 +1537,12 @@ dependencies = [ ...@@ -1535,9 +1537,12 @@ dependencies = [
"bindgen 0.70.1", "bindgen 0.70.1",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck",
"bytes", "bytes",
"chrono", "chrono",
"cmake", "cmake",
"cudarc 0.13.9 (git+https://github.com/coreylowman/cudarc.git?rev=8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe)",
"derive-getters",
"derive_builder", "derive_builder",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
...@@ -1553,11 +1558,14 @@ dependencies = [ ...@@ -1553,11 +1558,14 @@ dependencies = [
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
"ndarray",
"prometheus", "prometheus",
"proptest", "proptest",
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rand 0.9.0",
"rayon",
"regex", "regex",
"reqwest 0.12.14", "reqwest 0.12.14",
"rstest 0.18.2", "rstest 0.18.2",
...@@ -1639,7 +1647,7 @@ dependencies = [ ...@@ -1639,7 +1647,7 @@ dependencies = [
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"regex", "regex",
"rstest 0.23.0", "rstest 0.23.0",
"serde", "serde",
...@@ -1963,10 +1971,10 @@ version = "0.1.3" ...@@ -1963,10 +1971,10 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a14d1c3d88fdab81b5886c34b2a424b3fba5123564ea415ff98b160cf44c432f" checksum = "a14d1c3d88fdab81b5886c34b2a424b3fba5123564ea415ff98b160cf44c432f"
dependencies = [ dependencies = [
"cudarc", "cudarc 0.13.9 (registry+https://github.com/rust-lang/crates.io-index)",
"half", "half",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
] ]
...@@ -2401,7 +2409,7 @@ dependencies = [ ...@@ -2401,7 +2409,7 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"crunchy", "crunchy",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_distr", "rand_distr",
] ]
...@@ -2439,13 +2447,6 @@ version = "0.5.0" ...@@ -2439,13 +2447,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hello_world"
version = "0.1.0"
dependencies = [
"dynamo-runtime",
]
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.9" version = "0.3.9"
...@@ -2466,7 +2467,7 @@ dependencies = [ ...@@ -2466,7 +2467,7 @@ dependencies = [
"log", "log",
"native-tls", "native-tls",
"num_cpus", "num_cpus",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"serde", "serde",
"serde_json", "serde_json",
...@@ -2983,19 +2984,6 @@ dependencies = [ ...@@ -2983,19 +2984,6 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "inventory" name = "inventory"
version = "0.3.20" version = "0.3.20"
...@@ -3394,6 +3382,16 @@ version = "0.8.4" ...@@ -3394,6 +3382,16 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
...@@ -3453,7 +3451,7 @@ dependencies = [ ...@@ -3453,7 +3451,7 @@ dependencies = [
"opentelemetry", "opentelemetry",
"opentelemetry-prometheus", "opentelemetry-prometheus",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"reqwest 0.11.27", "reqwest 0.11.27",
"serde", "serde",
"serde_json", "serde_json",
...@@ -3573,7 +3571,7 @@ dependencies = [ ...@@ -3573,7 +3571,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core", "candle-core",
...@@ -3584,7 +3582,7 @@ dependencies = [ ...@@ -3584,7 +3582,7 @@ dependencies = [
"image", "image",
"indexmap 2.8.0", "indexmap 2.8.0",
"mistralrs-core", "mistralrs-core",
"rand", "rand 0.8.5",
"reqwest 0.12.14", "reqwest 0.12.14",
"serde", "serde",
"serde_json", "serde_json",
...@@ -3594,7 +3592,7 @@ dependencies = [ ...@@ -3594,7 +3592,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3622,7 +3620,6 @@ dependencies = [ ...@@ -3622,7 +3620,6 @@ dependencies = [
"image", "image",
"indexmap 2.8.0", "indexmap 2.8.0",
"indicatif", "indicatif",
"interprocess",
"itertools 0.13.0", "itertools 0.13.0",
"llguidance", "llguidance",
"lrtable", "lrtable",
...@@ -3635,7 +3632,7 @@ dependencies = [ ...@@ -3635,7 +3632,7 @@ dependencies = [
"objc", "objc",
"once_cell", "once_cell",
"radix_trie", "radix_trie",
"rand", "rand 0.8.5",
"rand_isaac", "rand_isaac",
"rayon", "rayon",
"regex", "regex",
...@@ -3645,7 +3642,6 @@ dependencies = [ ...@@ -3645,7 +3642,6 @@ dependencies = [
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
"serde-big-array",
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
...@@ -3668,7 +3664,7 @@ dependencies = [ ...@@ -3668,7 +3664,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3683,7 +3679,7 @@ dependencies = [ ...@@ -3683,7 +3679,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3709,7 +3705,7 @@ dependencies = [ ...@@ -3709,7 +3705,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"image", "image",
...@@ -3759,6 +3755,21 @@ dependencies = [ ...@@ -3759,6 +3755,21 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]] [[package]]
name = "neli" name = "neli"
version = "0.6.5" version = "0.6.5"
...@@ -3874,7 +3885,7 @@ version = "3.0.0" ...@@ -3874,7 +3885,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3" checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -3913,7 +3924,7 @@ dependencies = [ ...@@ -3913,7 +3924,7 @@ dependencies = [
"ed25519-dalek", "ed25519-dalek",
"getrandom 0.2.15", "getrandom 0.2.15",
"log", "log",
"rand", "rand 0.8.5",
"signatory", "signatory",
] ]
...@@ -3952,7 +3963,7 @@ version = "0.5.0" ...@@ -3952,7 +3963,7 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83" checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -4087,9 +4098,9 @@ dependencies = [ ...@@ -4087,9 +4098,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -4211,7 +4222,7 @@ dependencies = [ ...@@ -4211,7 +4222,7 @@ dependencies = [
"opentelemetry_api", "opentelemetry_api",
"ordered-float", "ordered-float",
"percent-encoding", "percent-encoding",
"rand", "rand 0.8.5",
"regex", "regex",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -4479,9 +4490,9 @@ dependencies = [ ...@@ -4479,9 +4490,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -4566,8 +4577,8 @@ dependencies = [ ...@@ -4566,8 +4577,8 @@ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"lazy_static", "lazy_static",
"num-traits", "num-traits",
"rand", "rand 0.8.5",
"rand_chacha", "rand_chacha 0.3.1",
"rand_xorshift", "rand_xorshift",
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
"rusty-fork", "rusty-fork",
...@@ -4816,7 +4827,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" ...@@ -4816,7 +4827,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [ dependencies = [
"bytes", "bytes",
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand 0.8.5",
"ring", "ring",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"rustls 0.23.23", "rustls 0.23.23",
...@@ -4868,8 +4879,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4868,8 +4879,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc", "libc",
"rand_chacha", "rand_chacha 0.3.1",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy 0.8.23",
] ]
[[package]] [[package]]
...@@ -4879,7 +4901,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4879,7 +4901,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
] ]
[[package]] [[package]]
...@@ -4891,6 +4923,15 @@ dependencies = [ ...@@ -4891,6 +4923,15 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
] ]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.1",
]
[[package]] [[package]]
name = "rand_distr" name = "rand_distr"
version = "0.4.3" version = "0.4.3"
...@@ -4898,7 +4939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -4898,7 +4939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [ dependencies = [
"num-traits", "num-traits",
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -4907,7 +4948,7 @@ version = "0.3.0" ...@@ -4907,7 +4948,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28" checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28"
dependencies = [ dependencies = [
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -4916,7 +4957,7 @@ version = "0.3.0" ...@@ -4916,7 +4957,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f"
dependencies = [ dependencies = [
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -4928,6 +4969,12 @@ dependencies = [ ...@@ -4928,6 +4969,12 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.10.0" version = "1.10.0"
...@@ -4965,12 +5012,6 @@ version = "0.5.5" ...@@ -4965,12 +5012,6 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.10" version = "0.5.10"
...@@ -5268,7 +5309,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5268,7 +5309,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14"
dependencies = [ dependencies = [
"quote", "quote",
"rand", "rand 0.8.5",
"syn 2.0.100", "syn 2.0.100",
] ]
...@@ -5639,15 +5680,6 @@ dependencies = [ ...@@ -5639,15 +5680,6 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-pickle" name = "serde-pickle"
version = "1.2.0" version = "1.2.0"
...@@ -5769,17 +5801,6 @@ dependencies = [ ...@@ -5769,17 +5801,6 @@ dependencies = [
"unsafe-libyaml", "unsafe-libyaml",
] ]
[[package]]
name = "service_metrics"
version = "0.1.0"
dependencies = [
"dynamo-runtime",
"futures",
"serde",
"serde_json",
"tokio",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
...@@ -5849,7 +5870,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5849,7 +5870,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31" checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31"
dependencies = [ dependencies = [
"pkcs8", "pkcs8",
"rand_core", "rand_core 0.6.4",
"signature", "signature",
"zeroize", "zeroize",
] ]
...@@ -5861,7 +5882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -5861,7 +5882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [ dependencies = [
"digest", "digest",
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -6346,7 +6367,7 @@ dependencies = [ ...@@ -6346,7 +6367,7 @@ dependencies = [
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand", "rand 0.8.5",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
...@@ -6465,7 +6486,7 @@ dependencies = [ ...@@ -6465,7 +6486,7 @@ dependencies = [
"futures-sink", "futures-sink",
"http 1.3.1", "http 1.3.1",
"httparse", "httparse",
"rand", "rand 0.8.5",
"ring", "ring",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pki-types", "rustls-pki-types",
...@@ -6617,7 +6638,7 @@ dependencies = [ ...@@ -6617,7 +6638,7 @@ dependencies = [
"indexmap 1.9.3", "indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
...@@ -7187,12 +7208,6 @@ dependencies = [ ...@@ -7187,12 +7208,6 @@ dependencies = [
"rustix 0.38.44", "rustix 0.38.44",
] ]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -136,7 +136,7 @@ dependencies = [ ...@@ -136,7 +136,7 @@ dependencies = [
"once_cell", "once_cell",
"pin-project", "pin-project",
"portable-atomic", "portable-atomic",
"rand", "rand 0.8.5",
"regex", "regex",
"ring", "ring",
"rustls-native-certs 0.7.3", "rustls-native-certs 0.7.3",
...@@ -175,7 +175,7 @@ dependencies = [ ...@@ -175,7 +175,7 @@ dependencies = [
"derive_builder", "derive_builder",
"eventsource-stream", "eventsource-stream",
"futures", "futures",
"rand", "rand 0.8.5",
"reqwest", "reqwest",
"reqwest-eventsource", "reqwest-eventsource",
"secrecy", "secrecy",
...@@ -367,7 +367,7 @@ dependencies = [ ...@@ -367,7 +367,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"instant", "instant",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"tokio", "tokio",
] ]
...@@ -400,9 +400,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -400,9 +400,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.2" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8faa168b8c4ffca39c2699e772943af41ec2b75fb1683dda07b28a6d285c53dc" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bindgen" name = "bindgen"
...@@ -966,9 +966,11 @@ dependencies = [ ...@@ -966,9 +966,11 @@ dependencies = [
"bindgen", "bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytemuck",
"bytes", "bytes",
"chrono", "chrono",
"cmake", "cmake",
"derive-getters",
"derive_builder", "derive_builder",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
...@@ -984,6 +986,8 @@ dependencies = [ ...@@ -984,6 +986,8 @@ dependencies = [
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rand 0.9.0",
"rayon",
"regex", "regex",
"semver", "semver",
"serde", "serde",
...@@ -1053,7 +1057,7 @@ dependencies = [ ...@@ -1053,7 +1057,7 @@ dependencies = [
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
"rand", "rand 0.8.5",
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
...@@ -1497,7 +1501,7 @@ dependencies = [ ...@@ -1497,7 +1501,7 @@ dependencies = [
"indicatif", "indicatif",
"libc", "libc",
"log", "log",
"rand", "rand 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.12", "thiserror 2.0.12",
...@@ -2251,7 +2255,7 @@ version = "3.0.0" ...@@ -2251,7 +2255,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3" checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
...@@ -2279,7 +2283,7 @@ dependencies = [ ...@@ -2279,7 +2283,7 @@ dependencies = [
"ed25519-dalek", "ed25519-dalek",
"getrandom 0.2.15", "getrandom 0.2.15",
"log", "log",
"rand", "rand 0.8.5",
"signatory", "signatory",
] ]
...@@ -2309,7 +2313,7 @@ version = "0.5.0" ...@@ -2309,7 +2313,7 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83" checksum = "fc895af95856f929163a0aa20c26a78d26bfdc839f51b9d5aa7a5b79e52b7e83"
dependencies = [ dependencies = [
"rand", "rand 0.8.5",
] ]
[[package]] [[package]]
...@@ -2372,9 +2376,9 @@ dependencies = [ ...@@ -2372,9 +2376,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -2570,9 +2574,9 @@ dependencies = [ ...@@ -2570,9 +2574,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -2822,7 +2826,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" ...@@ -2822,7 +2826,7 @@ checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [ dependencies = [
"bytes", "bytes",
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand 0.8.5",
"ring", "ring",
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"rustls", "rustls",
...@@ -2864,8 +2868,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2864,8 +2868,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc", "libc",
"rand_chacha", "rand_chacha 0.3.1",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy",
] ]
[[package]] [[package]]
...@@ -2875,7 +2890,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2875,7 +2890,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
] ]
[[package]] [[package]]
...@@ -2887,6 +2912,15 @@ dependencies = [ ...@@ -2887,6 +2912,15 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
] ]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.1",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.10.0" version = "1.10.0"
...@@ -3402,7 +3436,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3402,7 +3436,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31" checksum = "c1e303f8205714074f6068773f0e29527e0453937fe837c9717d066635b65f31"
dependencies = [ dependencies = [
"pkcs8", "pkcs8",
"rand_core", "rand_core 0.6.4",
"signature", "signature",
"zeroize", "zeroize",
] ]
...@@ -3414,7 +3448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3414,7 +3448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [ dependencies = [
"digest", "digest",
"rand_core", "rand_core 0.6.4",
] ]
[[package]] [[package]]
...@@ -3717,7 +3751,7 @@ dependencies = [ ...@@ -3717,7 +3751,7 @@ dependencies = [
"monostate", "monostate",
"onig", "onig",
"paste", "paste",
"rand", "rand 0.8.5",
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
...@@ -3806,7 +3840,7 @@ dependencies = [ ...@@ -3806,7 +3840,7 @@ dependencies = [
"futures-sink", "futures-sink",
"http", "http",
"httparse", "httparse",
"rand", "rand 0.8.5",
"ring", "ring",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pki-types", "rustls-pki-types",
...@@ -3931,7 +3965,7 @@ dependencies = [ ...@@ -3931,7 +3965,7 @@ dependencies = [
"indexmap 1.9.3", "indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
......
...@@ -22,6 +22,7 @@ license.workspace = true ...@@ -22,6 +22,7 @@ license.workspace = true
homepage.workspace = true homepage.workspace = true
[features] [features]
default = []
mistralrs = ["dep:mistralrs"] mistralrs = ["dep:mistralrs"]
llamacpp = ["dep:llama-cpp-2"] llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"] sglang = ["dep:async_zmq"]
...@@ -29,6 +30,7 @@ sentencepiece = ["dep:sentencepiece"] ...@@ -29,6 +30,7 @@ sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"] vllm = ["dep:async_zmq"]
python = ["dep:pyo3-async-runtimes", "dep:pythonize"] python = ["dep:pyo3-async-runtimes", "dep:pythonize"]
trtllm = [] trtllm = []
cuda_kv = ["dep:cudarc", "dep:ndarray"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"] cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"] metal = ["mistralrs/metal", "llama-cpp-2/metal"]
...@@ -45,6 +47,7 @@ async-trait = { workspace = true } ...@@ -45,6 +47,7 @@ async-trait = { workspace = true }
bytes = { workspace = true } bytes = { workspace = true }
derive_builder = {workspace = true } derive_builder = {workspace = true }
futures = { workspace = true } futures = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
...@@ -58,7 +61,18 @@ strum = { workspace = true } ...@@ -58,7 +61,18 @@ strum = { workspace = true }
async-openai = "0.27.2" async-openai = "0.27.2"
blake3 = "1" blake3 = "1"
bytemuck = "1.22"
derive-getters = "0.5"
rand = "0.9"
regex = "1" regex = "1"
rayon = "1"
# kv_cuda
cudarc = { git = "https://github.com/coreylowman/cudarc.git", rev = "8c52e735b55bf8e979e1a16bd85e3dfe4f87c9fe", features = ["cuda-12040"], optional = true }
ndarray = { version = "0.16", optional = true }
# candle-core = { version = "0.8.3", features = ["cuda"], optional = true }
# half = "2.4.1"
pyo3 = { version = "0.23.3", default-features = false, features = [ pyo3 = { version = "0.23.3", default-features = false, features = [
"macros", "macros",
"experimental-async", "experimental-async",
...@@ -84,7 +98,7 @@ prometheus = { version = "0.13" } ...@@ -84,7 +98,7 @@ prometheus = { version = "0.13" }
# mistralrs # mistralrs
either = { version = "1.13" } either = { version = "1.13" }
indexmap = { version = "2.6" } indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "a691154bb", optional = true } mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true }
# sglang # sglang
async_zmq = { version = "0.4.0", optional = true } async_zmq = { version = "0.4.0", optional = true }
......
...@@ -13,9 +13,125 @@ ...@@ -13,9 +13,125 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(not(feature = "trtllm"))] #[cfg(not(feature = "cuda_kv"))]
fn main() {} fn main() {}
#[cfg(feature = "cuda_kv")]
fn main() {
use std::{path::PathBuf, process::Command};
println!("cargo:rerun-if-changed=src/kernels/block_copy.cu");
// first do a which nvcc, if it is in the path
// if so, we don't need to set the cuda_lib
let nvcc = Command::new("which").arg("nvcc").output().unwrap();
let cuda_lib = if nvcc.status.success() {
println!("cargo:info=nvcc found in path");
// Extract the path from nvcc location by removing "bin/nvcc"
let nvcc_path = String::from_utf8_lossy(&nvcc.stdout).trim().to_string();
let path = PathBuf::from(nvcc_path);
if let Some(parent) = path.parent() {
// Remove "nvcc"
if let Some(cuda_root) = parent.parent() {
// Remove "bin"
cuda_root.to_string_lossy().to_string()
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
// Fallback to CUDA_ROOT or default if path extraction fails
get_cuda_root_or_default()
}
} else {
println!("cargo:warning=nvcc not found in path");
get_cuda_root_or_default()
};
println!("cargo:info=Using CUDA installation at: {}", cuda_lib);
let cuda_lib_path = PathBuf::from(&cuda_lib).join("lib64");
println!("cargo:info=Using CUDA libs: {}", cuda_lib_path.display());
println!("cargo:rustc-link-search=native={}", cuda_lib_path.display());
// Link against multiple CUDA libraries
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
// Make sure the CUDA libraries are found before other system libraries
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
cuda_lib_path.display()
);
// Create kernels directory for output if it doesn't exist
std::fs::create_dir_all("src/kernels").unwrap_or_else(|_| {
println!("Kernels directory already exists");
});
// Compile CUDA code
let output = Command::new("nvcc")
.arg("src/kernels/block_copy.cu")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-o")
.arg("src/kernels/libblock_copy.o")
.arg("-c")
.output()
.expect("Failed to compile CUDA code");
if !output.status.success() {
panic!(
"Failed to compile CUDA kernel: {}",
String::from_utf8_lossy(&output.stderr)
);
}
// Create static library
#[cfg(target_os = "windows")]
{
Command::new("lib")
.arg("/OUT:src/kernels/block_copy.lib")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
}
#[cfg(not(target_os = "windows"))]
{
Command::new("ar")
.arg("rcs")
.arg("src/kernels/libblock_copy.a")
.arg("src/kernels/libblock_copy.o")
.output()
.expect("Failed to create static library");
println!("cargo:rustc-link-search=native=src/kernels");
println!("cargo:rustc-link-lib=static=block_copy");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}
#[cfg(feature = "cuda_kv")]
fn get_cuda_root_or_default() -> String {
match std::env::var("CUDA_ROOT") {
Ok(path) => path,
Err(_) => {
// Default locations based on OS
if cfg!(target_os = "windows") {
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8".to_string()
} else {
"/usr/local/cuda".to_string()
}
}
}
}
#[cfg(feature = "trtllm")] #[cfg(feature = "trtllm")]
fn main() { fn main() {
extern crate bindgen; extern crate bindgen;
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod layer;
pub mod manager;
pub mod reserved;
pub mod reuse;
pub mod sequence;
pub mod storage;
// #[cfg(feature = "cuda_kv")]
// pub mod storage;
use reserved::*;
use std::{
collections::{BTreeMap, HashMap, VecDeque},
sync::{atomic::AtomicU64, Arc, RwLock},
};
use async_trait::async_trait;
use derive_getters::Dissolve;
use dynamo_runtime::{
raise,
utils::pool::{PoolExt, PoolItem, PoolValue, Returnable, SharedPoolItem},
Result,
};
use crate::tokens::{PartialTokenBlock, SequenceHash, TokenBlock, Tokens};
use tracing as log;
pub type UniqueBlock = PoolItem<KvBlock>;
pub type SharedBlock = SharedPoolItem<KvBlock>;
#[derive(Default)]
pub struct KvBlock {
token_block: TokenBlock,
priority: u32,
return_tick: u64,
}
// pub struct KvStorage {
// data: u64,
// size: usize,
// layer_idx: usize,
// block_idx: usize,
// /// The layout of the tensor
// layout: layer::KvLayer,
// }
impl KvBlock {
/// Creates a new KvBlock with the given token block
pub fn new(token_block: TokenBlock) -> Self {
Self {
token_block,
priority: 0,
return_tick: 0,
// storage: None,
}
}
/// Updates the token block
pub fn update_token_block(&mut self, token_block: TokenBlock) {
self.token_block = token_block;
}
/// Resets the block to its initial state
pub(crate) fn reset(&mut self) {
self.token_block = TokenBlock::default();
self.priority = 0;
self.return_tick = 0;
// self.storage = None;
// self.storage_state = StorageState::Absent;
}
}
impl Returnable for KvBlock {
fn on_return(&mut self) {}
}
pub struct KvBlockConfig {}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use reuse::AvailableBlocks;
/// Manages the reservation and priority reuse of kv blocks for a single storage type,
/// e.g. a GPU, host memory.
pub struct KvStorageManager {
available_blocks: AvailableBlocks,
inflight_blocks: ReservedBlocks,
block_size: usize,
}
impl KvStorageManager {
pub async fn new(block_size: usize) -> Self {
Self {
available_blocks: AvailableBlocks::new().await,
inflight_blocks: ReservedBlocks::new(block_size),
block_size,
}
}
pub async fn prepare_prefill_sequence(&mut self, tokens: Tokens) -> Result<PrefillMatched> {
log::debug!("adding request with {} tokens", tokens.len());
let seq = tokens.into_sequence(self.block_size);
let (blocks, tail_block) = seq.into_parts();
log::debug!(
"request translates to {} blocks; remaining tokens: {}",
blocks.len(),
tail_block.tokens().len()
);
// first match blocks to inflight blocks
let mut inflight_blocks = self.inflight_blocks.match_token_blocks(&blocks)?;
log::debug!("matched {} inflight blocks", inflight_blocks.len());
// shift the blocks to the left by the number of inflight blocks
let unmatched_blocks = &blocks[inflight_blocks.len()..];
let unmatched_hashes = unmatched_blocks
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
// match the remaining blocks to freed gpu blocks (available_blocks)
let unregistered_blocks = self.available_blocks.match_blocks(unmatched_hashes).await?;
log::debug!("matched {} freed blocks", unregistered_blocks.len());
// the blocks from the freed blocks pool must be registered as inflight blocks
// todo - we might have to register the list of unregistered blocks as a single transaction
for block in unregistered_blocks {
inflight_blocks.push(self.inflight_blocks.register(block)?);
}
// the remaining blocks are the unmatched blocks
let remaining_blocks = blocks.into_iter().skip(inflight_blocks.len()).collect();
Ok(PrefillMatched {
inflight_blocks,
remaining_blocks,
tail_block,
})
}
pub async fn prepare_prefill_offload(
&mut self,
matched: PrefillMatched,
) -> Result<PrefillOffload> {
let (inflight_blocks, remaining_blocks, tail_block) = matched.dissolve();
let mut blocks_to_reuse = self
.available_blocks
.take_blocks(remaining_blocks.len() as u32 + 1)
.await?;
if blocks_to_reuse.len() != remaining_blocks.len() + 1 {
raise!(
"expected {} blocks, got {}",
remaining_blocks.len() + 1,
blocks_to_reuse.len()
);
}
// update the blocks_to_reuse with the token block from remaining_blocks
let complete_prefill_blocks: Vec<UniqueBlock> = remaining_blocks
.into_iter()
.map(|b| {
let mut block = blocks_to_reuse.pop().unwrap();
block.update_token_block(b);
block
})
.collect();
assert_eq!(blocks_to_reuse.len(), 1);
let tail_kv_block = blocks_to_reuse.pop().unwrap();
let tail_prefill_block = PartialKvBlock {
token_block: tail_block,
kv_block: tail_kv_block,
};
Ok(PrefillOffload {
inflight_blocks,
complete_prefill_blocks,
tail_prefill_block,
})
}
}
#[derive(Dissolve)]
pub struct PartialKvBlock {
token_block: PartialTokenBlock,
kv_block: UniqueBlock,
}
#[derive(Dissolve)]
pub struct PrefillMatched {
inflight_blocks: Vec<ReservedBlock>,
remaining_blocks: Vec<TokenBlock>,
tail_block: PartialTokenBlock,
}
#[derive(Dissolve)]
pub struct PrefillOffload {
inflight_blocks: Vec<ReservedBlock>,
complete_prefill_blocks: Vec<UniqueBlock>,
tail_prefill_block: PartialKvBlock,
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use dynamo_runtime::logging::init;
// #[tokio::test]
// async fn test() {
// init();
// let mut manager = KvStorageManager::new(2);
// for _ in 0..100 {
// manager.available_blocks.insert(KvBlock::default());
// }
// let tokens = Tokens::from([0_i32, 1, 2, 3, 4, 5, 6, 7, 8].as_ref());
// // this is good for the scheduler to make a local decision as it now knows how many
// // net-new blocks need to be prefilled
// let sequence = manager.prepare_prefill_sequence(tokens).unwrap();
// assert_eq!(sequence.inflight_blocks.len(), 0);
// assert_eq!(sequence.remaining_blocks.len(), 4);
// }
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Weak;
use super::*;
type ReservedBlockMap = Arc<RwLock<HashMap<SequenceHash, Weak<ReservedBlockInner>>>>;
#[derive(Clone)]
pub struct ReservedBlock {
inner: Arc<ReservedBlockInner>,
}
impl ReservedBlock {
fn new(inner: Arc<ReservedBlockInner>) -> Self {
Self { inner }
}
pub fn inflight_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
}
impl std::ops::Deref for ReservedBlock {
type Target = SharedBlock;
fn deref(&self) -> &Self::Target {
&self.inner.block
}
}
struct ReservedBlockInner {
block: SharedBlock,
map: ReservedBlockMap,
}
impl Drop for ReservedBlockInner {
fn drop(&mut self) {
let sequence_hash = self.block.token_block.sequence_hash();
let mut map = self.map.write().unwrap();
let val = map.remove(&sequence_hash);
if let Some(inner) = val {
if inner.strong_count() > 0 {
// this was not the weak pointer we were looking for
map.insert(sequence_hash, inner);
}
}
}
}
/// [ReservedBlocks] is a collection of inflight blocks that are actively being used
pub struct ReservedBlocks {
block_size: usize,
blocks: ReservedBlockMap,
}
impl ReservedBlocks {
pub fn new(block_size: usize) -> Self {
Self {
block_size,
blocks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn match_sequence_hashes(
&self,
sequence_hashes: &[SequenceHash],
) -> Result<Vec<ReservedBlock>> {
let mut inflight_blocks = Vec::new();
let map = self.blocks.read().unwrap();
for sequence_hash in sequence_hashes {
if let Some(inner) = map.get(sequence_hash) {
if let Some(inner) = inner.upgrade() {
inflight_blocks.push(ReservedBlock::new(inner.clone()));
} else {
break;
}
} else {
break;
}
}
Ok(inflight_blocks)
}
/// Match the list of blocks to inflight blocks
///
/// This will return a [Vec<ReservedBlock>] that match the sequence hashes
/// in the order of the token blocks.
///
/// The matching is done in order, with the first block in the list being the first
/// block in the token blocks list.
///
/// If a block is not found, the function will return the list of matched blocks
/// and the remaining blocks will not be included.
pub fn match_token_blocks(&self, token_blocks: &[TokenBlock]) -> Result<Vec<ReservedBlock>> {
let mut inflight_blocks = Vec::new();
let map = self.blocks.read().unwrap();
for token_block in token_blocks {
let sequence_hash = token_block.sequence_hash();
if let Some(inner) = map.get(&sequence_hash) {
if let Some(inner) = inner.upgrade() {
inflight_blocks.push(ReservedBlock::new(inner.clone()));
} else {
break;
}
} else {
break;
}
}
Ok(inflight_blocks)
}
pub fn register(&mut self, block: UniqueBlock) -> Result<ReservedBlock> {
let sequence_hash = block.token_block.sequence_hash();
let shared = block.into_shared();
if shared.token_block.tokens().len() != self.block_size {
raise!("Block size mismatch");
}
// if the block already exists, we drop the block the user passed in and return the existing block
// this should return the passed in block to the free pool
let mut map = self.blocks.write().unwrap();
if let Some(existing_block) = map.get(&sequence_hash) {
// return an ReservedBlock with the existing block
// the passed in block will be dropped and returned to the pool
// this could happen if two sequences are building the same block at the same time,
// the first sequence to finish and register the block will insert it into the map
if let Some(inner) = existing_block.upgrade() {
return Ok(ReservedBlock::new(inner.clone()));
}
}
// Insert the new block and create an ReservedBlock from it
let inner = Arc::new(ReservedBlockInner {
block: shared,
map: self.blocks.clone(),
});
map.insert(sequence_hash, Arc::downgrade(&inner));
Ok(ReservedBlock::new(inner))
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::reuse::tests::{create_blocks, create_token_sequence};
use super::reuse::AvailableBlocks;
#[tokio::test]
async fn test_reserved_blocks() {
let available_blocks = AvailableBlocks::new().await;
let mut reserved_blocks = ReservedBlocks::new(2);
// Create two sequences with different priorities
let seq1 = create_token_sequence(&[1, 2, 3, 4]);
let seq2 = create_token_sequence(&[5, 6, 7, 8]);
// This is creating new KvBlock; this is will be done when the block manager is initialized
// but since we are not using the block manager in this test, we need to create them manually
let blocks1 = create_blocks(seq1, 2);
let blocks2 = create_blocks(seq2, 2);
// Insert Sequence 2
for block in blocks2.into_iter().rev() {
available_blocks.insert(block).await.unwrap();
}
// Insert Sequence 1
for block in blocks1.into_iter().rev() {
available_blocks.insert(block).await.unwrap();
}
available_blocks.fence().await.unwrap();
assert_eq!(available_blocks.total_blocks(), 4);
assert_eq!(available_blocks.available_blocks(), 4);
// Initialize of the KvBlocks is complete - there are 4 blocks with state in the available pool
// Mimic a request for 2 tokens and test the block matching sequence
// This pattern will be used in the KvBlockManager
let req1 = create_token_sequence(&[1, 2]);
let seq1 = req1.into_sequence(2);
let (blocks, tail_block) = seq1.into_parts();
assert_eq!(blocks.len(), 1);
assert_eq!(tail_block.tokens().len(), 0);
let matched = reserved_blocks.match_token_blocks(&blocks).unwrap();
assert_eq!(matched.len(), 0);
let matched = available_blocks.match_token_blocks(&blocks).await.unwrap();
assert_eq!(matched.len(), 1);
// possible update the api to take a vec of unique blocks and return a vec of reserved blocks
let reserved: Vec<ReservedBlock> = matched
.into_iter()
.map(|unique_block| reserved_blocks.register(unique_block).unwrap())
.collect();
assert_eq!(reserved.len(), 1);
assert_eq!(reserved[0].inflight_count(), 1);
assert_eq!(available_blocks.available_blocks(), 3);
// request 2
// reuse blocks
// match blocks to the reserved blocks get a new reserved block which should have a ref count of 2
let reserved2 = reserved_blocks.match_token_blocks(&blocks).unwrap();
assert_eq!(reserved2.len(), 1);
assert_eq!(reserved2[0].inflight_count(), 2);
assert_eq!(available_blocks.available_blocks(), 3);
drop(reserved2);
available_blocks.fence().await.unwrap();
assert_eq!(reserved[0].inflight_count(), 1);
assert_eq!(available_blocks.available_blocks(), 3);
drop(reserved);
available_blocks.fence().await.unwrap();
assert_eq!(available_blocks.available_blocks(), 4);
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
This diff is collapsed.
...@@ -84,6 +84,10 @@ pub type WorkerId = i64; ...@@ -84,6 +84,10 @@ pub type WorkerId = i64;
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block. /// Compute the hash of a local block.
/// ///
/// ### Arguments /// ### Arguments
...@@ -94,7 +98,7 @@ type SharedRadixBlock = Rc<RefCell<RadixBlock>>; ...@@ -94,7 +98,7 @@ type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
/// ///
/// A `LocalBlockHash` representing the computed hash. /// A `LocalBlockHash` representing the computed hash.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(xxh3::xxh3_64_with_seed(data, XXH3_SEED)) LocalBlockHash(compute_hash(data))
} }
// /// Updated version of the `compute_block_hash` function that included the lora_id // /// Updated version of the `compute_block_hash` function that included the lora_id
......
...@@ -29,4 +29,8 @@ pub mod model_type; ...@@ -29,4 +29,8 @@ pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
pub mod protocols; pub mod protocols;
pub mod tokenizers; pub mod tokenizers;
pub mod tokens;
pub mod types; pub mod types;
#[cfg(feature = "cuda_kv")]
pub mod kv;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::indexer::compute_hash;
use bytemuck::cast_slice;
use derive_getters::{Dissolve, Getters};
use rayon::prelude::*;
pub type Token = u32;
/// A hash of the only the tokens within a block computed from [compute_hash].
pub type BlockHash = u64;
/// A sequence aware hash that combines the previous block's sequence hash with the current block's hash.
pub type SequenceHash = u64;
#[derive(Debug, Clone, Dissolve, Default)]
pub struct Tokens(Vec<Token>);
impl AsRef<[Token]> for Tokens {
fn as_ref(&self) -> &[Token] {
&self.0
}
}
impl std::ops::Deref for Tokens {
type Target = [Token];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::borrow::Borrow<[Token]> for Tokens {
fn borrow(&self) -> &[Token] {
&self.0
}
}
impl From<Vec<Token>> for Tokens {
fn from(tokens: Vec<Token>) -> Self {
Tokens(tokens)
}
}
impl From<&[Token]> for Tokens {
fn from(tokens: &[Token]) -> Self {
Tokens(tokens.to_vec())
}
}
impl From<Vec<i32>> for Tokens {
fn from(tokens: Vec<i32>) -> Self {
Tokens(tokens.into_iter().map(|t| t as u32).collect())
}
}
impl From<&[i32]> for Tokens {
fn from(tokens: &[i32]) -> Self {
Tokens(tokens.iter().map(|&t| t as u32).collect())
}
}
impl From<Tokens> for Vec<Token> {
fn from(tokens: Tokens) -> Self {
tokens.0
}
}
impl Tokens {
pub fn into_sequence(self, block_size: usize) -> TokenSequence {
TokenSequence::new(self, block_size)
}
}
pub struct PartialTokenBlock {
tokens: Tokens,
block_size: usize,
parent_sequence_hash: Option<SequenceHash>,
}
impl PartialTokenBlock {
/// Push a token onto the block, if the block is full, return a new [TokenBlock]
/// and reset the incomplete block
pub fn push_token(&mut self, token: Token) -> Option<TokenBlock> {
self.tokens.0.push(token);
if self.tokens.0.len() == self.block_size {
let block = std::mem::take(&mut self.tokens);
let block_hash = compute_hash(cast_slice(&block));
let sequence_hash = compute_hash(bytemuck::cast_slice(&[
self.parent_sequence_hash.unwrap_or_default(),
block_hash,
]));
Some(TokenBlock {
tokens: block,
sequence_hash,
block_hash,
parent_sequence_hash: self.parent_sequence_hash,
})
} else {
None
}
}
pub fn tokens(&self) -> &Tokens {
&self.tokens
}
}
impl std::ops::Deref for PartialTokenBlock {
type Target = Tokens;
fn deref(&self) -> &Self::Target {
&self.tokens
}
}
#[derive(Debug, Clone, Getters, Default)]
pub struct TokenBlock {
tokens: Tokens,
#[getter(copy)]
block_hash: BlockHash,
#[getter(copy)]
sequence_hash: SequenceHash,
#[getter(copy)]
parent_sequence_hash: Option<SequenceHash>,
}
pub struct TokenSequence {
blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock,
}
impl TokenSequence {
pub fn new(tokens: Tokens, block_size: usize) -> Self {
let (blocks, current_block) = Self::split_tokens(tokens, block_size);
Self {
blocks,
current_block,
}
}
pub fn push_token(&mut self, token: Token) -> Option<&TokenBlock> {
if let Some(block) = self.current_block.push_token(token) {
self.blocks.push(block);
self.blocks.last()
} else {
None
}
}
pub fn blocks(&self) -> &[TokenBlock] {
&self.blocks
}
pub fn current_block(&self) -> &PartialTokenBlock {
&self.current_block
}
pub fn into_parts(self) -> (Vec<TokenBlock>, PartialTokenBlock) {
(self.blocks, self.current_block)
}
pub fn split_tokens(tokens: Tokens, block_size: usize) -> (Vec<TokenBlock>, PartialTokenBlock) {
// Use rayon's parallel iterator to process chunks in parallel
let mut blocks: Vec<TokenBlock> = tokens
.par_chunks_exact(block_size)
.map(|chunk| TokenBlock {
tokens: chunk.to_vec().into(),
sequence_hash: 0,
block_hash: compute_hash(cast_slice(chunk)),
parent_sequence_hash: None,
})
.collect();
blocks[0].sequence_hash = blocks[0].block_hash;
// compute the sequence hash for each block
// this is the sequence hash of the previous block with the current block's hash
for i in 1..blocks.len() {
let previous_block = &blocks[i - 1];
let parent_sequence_hash = previous_block.sequence_hash;
let vals = &[parent_sequence_hash, blocks[i].block_hash];
blocks[i].sequence_hash = compute_hash(bytemuck::cast_slice(vals));
blocks[i].parent_sequence_hash = Some(parent_sequence_hash);
}
let remainder = tokens.chunks_exact(block_size).remainder();
let next_block = PartialTokenBlock {
tokens: remainder.into(),
block_size,
parent_sequence_hash: blocks.last().map(|b| b.sequence_hash),
};
(blocks, next_block)
}
}
impl PartialEq<Vec<Token>> for Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
impl PartialEq<Tokens> for Vec<Token> {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0
}
}
impl PartialEq<[Token]> for Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
impl PartialEq<Tokens> for &[Token] {
fn eq(&self, other: &Tokens) -> bool {
*self == other.0.as_slice()
}
}
impl PartialEq<Vec<Token>> for &Tokens {
fn eq(&self, other: &Vec<Token>) -> bool {
self.0 == *other
}
}
impl<'a> PartialEq<&'a Tokens> for Vec<Token> {
fn eq(&self, other: &&'a Tokens) -> bool {
*self == other.0
}
}
impl PartialEq<[Token]> for &Tokens {
fn eq(&self, other: &[Token]) -> bool {
self.0.as_slice() == other
}
}
impl<'a> PartialEq<&'a [Token]> for Tokens {
fn eq(&self, other: &&'a [Token]) -> bool {
self.0.as_slice() == *other
}
}
impl PartialEq for Tokens {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Tokens {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokens_slice_operations() {
let tokens = Tokens(vec![1, 2, 3, 4, 5]);
// Test AsRef<[Token]>
let slice: &[Token] = tokens.as_ref();
assert_eq!(slice, &[1, 2, 3, 4, 5]);
// Test Deref
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0], 1);
assert_eq!(tokens[4], 5);
// Test iteration
let sum: u32 = tokens.iter().sum();
assert_eq!(sum, 15);
// Test slicing
let slice = &tokens[1..4];
assert_eq!(slice, &[2, 3, 4]);
// Test Borrow
let borrowed: &[Token] = std::borrow::Borrow::borrow(&tokens);
assert_eq!(borrowed, &[1, 2, 3, 4, 5]);
// Test with functions that accept &[Token]
fn takes_slice(slice: &[Token]) -> usize {
slice.len()
}
assert_eq!(takes_slice(&tokens), 5);
}
#[test]
fn test_tokens_conversions() {
// Test From<Vec<Token>> for Tokens
let vec = vec![1, 2, 3, 4, 5];
let tokens: Tokens = vec.clone().into();
assert_eq!(tokens.0, vec);
// Test Into<Vec<Token>> for Tokens
let tokens = Tokens(vec![6, 7, 8, 9, 10]);
let vec: Vec<Token> = tokens.into();
assert_eq!(vec, vec![6, 7, 8, 9, 10]);
// Test From<&[Token]> for Tokens
let slice: &[Token] = &[11, 12, 13];
let tokens: Tokens = slice.into();
assert_eq!(tokens.0, vec![11, 12, 13]);
// Test From<Vec<i32>> for Tokens
let i32_values = vec![100_i32, 200_i32, 300_i32];
let tokens: Tokens = i32_values.into();
assert_eq!(tokens.0, vec![100, 200, 300]);
// Test From<&[i32]> for Tokens
let i32_slice: &[i32] = &[400_i32, 500_i32, 600_i32];
let tokens: Tokens = i32_slice.into();
assert_eq!(tokens.0, vec![400, 500, 600]);
}
#[test]
fn test_tokens_blocks() {
let tokens = Tokens(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let sequence = TokenSequence::new(tokens, 4);
assert_eq!(sequence.blocks().len(), 2);
assert_eq!(sequence.current_block().len(), 2);
assert_eq!(sequence.blocks()[0].tokens(), vec![1, 2, 3, 4]);
assert_eq!(sequence.blocks()[0].block_hash(), 14643705804678351452);
assert_eq!(sequence.blocks()[0].sequence_hash(), 14643705804678351452);
println!("blocks[0]: {:?}", sequence.blocks()[0]);
assert_eq!(sequence.blocks()[1].tokens(), vec![5, 6, 7, 8]);
assert_eq!(sequence.blocks()[1].block_hash(), 16777012769546811212);
assert_eq!(sequence.blocks()[1].sequence_hash(), 4945711292740353085);
println!("blocks[1]: {:?}", sequence.blocks()[1]);
assert_eq!(sequence.current_block().tokens(), vec![9, 10]);
let mut sequence = sequence;
let new_block = sequence.push_token(11);
assert!(new_block.is_none());
assert_eq!(sequence.blocks().len(), 2);
let new_block = sequence.push_token(12);
assert!(new_block.is_some());
assert_eq!(sequence.blocks().len(), 3);
assert_eq!(sequence.current_block().tokens().len(), 0);
println!("blocks[2]: {:?}", sequence.blocks()[2]);
let (blocks, mut current_block) = sequence.into_parts();
let new_block = current_block.push_token(13);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 1);
let new_block = current_block.push_token(14);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 2);
let new_block = current_block.push_token(15);
assert!(new_block.is_none());
assert_eq!(current_block.tokens().len(), 3);
let new_block = current_block.push_token(16);
assert!(new_block.is_some());
assert_eq!(blocks.len(), 3);
assert_eq!(current_block.tokens().len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
This diff is collapsed.
...@@ -239,9 +239,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" ...@@ -239,9 +239,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "base64ct" name = "base64ct"
version = "1.7.1" version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
...@@ -1010,9 +1010,9 @@ dependencies = [ ...@@ -1010,9 +1010,9 @@ dependencies = [
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
...@@ -1031,12 +1031,12 @@ dependencies = [ ...@@ -1031,12 +1031,12 @@ dependencies = [
[[package]] [[package]]
name = "http-body-util" name = "http-body-util"
version = "0.1.2" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-util", "futures-core",
"http", "http",
"http-body", "http-body",
"pin-project-lite", "pin-project-lite",
...@@ -1056,9 +1056,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" ...@@ -1056,9 +1056,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]] [[package]]
name = "humantime" name = "humantime"
version = "2.1.0" version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
[[package]] [[package]]
name = "hyper" name = "hyper"
...@@ -1634,9 +1634,9 @@ dependencies = [ ...@@ -1634,9 +1634,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.0" version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]] [[package]]
name = "openssl-probe" name = "openssl-probe"
...@@ -1792,9 +1792,9 @@ dependencies = [ ...@@ -1792,9 +1792,9 @@ dependencies = [
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.30" version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.100", "syn 2.0.100",
...@@ -1919,9 +1919,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" ...@@ -1919,9 +1919,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.39" version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
...@@ -2031,9 +2031,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" ...@@ -2031,9 +2031,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.13" version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [ dependencies = [
"cc", "cc",
"cfg-if 1.0.0", "cfg-if 1.0.0",
...@@ -2563,9 +2563,9 @@ dependencies = [ ...@@ -2563,9 +2563,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.44.0" version = "1.44.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9975ea0f48b5aa3972bf2d888c238182458437cc2a19374b81b25cdf1023fb3a" checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
...@@ -2613,9 +2613,9 @@ dependencies = [ ...@@ -2613,9 +2613,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.13" version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
...@@ -3203,9 +3203,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" ...@@ -3203,9 +3203,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.7.3" version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
......
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
pub use tokio::time::{Duration, Instant}; pub use tokio::time::{Duration, Instant};
pub mod pool;
pub mod stream; pub mod stream;
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment