Unverified Commit 3f99cf21 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: move ActiveSequences to kv-router and add unit tests (#6600)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 4c648b11
...@@ -1880,6 +1880,7 @@ dependencies = [ ...@@ -1880,6 +1880,7 @@ dependencies = [
"async-trait", "async-trait",
"clap 4.5.60", "clap 4.5.60",
"dashmap 6.1.0", "dashmap 6.1.0",
"derive-getters",
"dynamo-bench", "dynamo-bench",
"dynamo-mocker", "dynamo-mocker",
"dynamo-runtime", "dynamo-runtime",
......
...@@ -788,9 +788,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" ...@@ -788,9 +788,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.43" version = "0.4.44"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
dependencies = [ dependencies = [
"iana-time-zone", "iana-time-zone",
"js-sys", "js-sys",
...@@ -1106,9 +1106,9 @@ dependencies = [ ...@@ -1106,9 +1106,9 @@ dependencies = [
[[package]] [[package]]
name = "cudarc" name = "cudarc"
version = "0.19.2" version = "0.19.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aed81f178e780f3d5d354d12b4c5c5a484c4a9c329ecd037ac57f2a0e0648397" checksum = "6468cb7fa330840f3ebcd8df51edc0e7bf5c18df524792ce6004c6821851cdf3"
dependencies = [ dependencies = [
"libloading 0.9.0", "libloading 0.9.0",
] ]
...@@ -1270,9 +1270,9 @@ dependencies = [ ...@@ -1270,9 +1270,9 @@ dependencies = [
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.6" version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc3dc5ad92c2e2d1c193bbbbdf2ea477cb81331de4f3103f267ca18368b988c4" checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c"
dependencies = [ dependencies = [
"powerfmt", "powerfmt",
"serde_core", "serde_core",
...@@ -1520,6 +1520,7 @@ dependencies = [ ...@@ -1520,6 +1520,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"dashmap 6.1.0", "dashmap 6.1.0",
"derive-getters",
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
"flume", "flume",
...@@ -1532,6 +1533,7 @@ dependencies = [ ...@@ -1532,6 +1533,7 @@ dependencies = [
"tokio", "tokio",
"tokio-util", "tokio-util",
"tracing", "tracing",
"uuid",
"xxhash-rust", "xxhash-rust",
] ]
...@@ -2960,9 +2962,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" ...@@ -2960,9 +2962,9 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]] [[package]]
name = "jiff" name = "jiff"
version = "0.2.20" version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c867c356cc096b33f4981825ab281ecba3db0acefe60329f044c1789d94c6543" checksum = "b3e3d65f018c6ae946ab16e80944b97096ed73c35b221d1c478a6c81d8f57940"
dependencies = [ dependencies = [
"jiff-static", "jiff-static",
"jiff-tzdb-platform", "jiff-tzdb-platform",
...@@ -2975,9 +2977,9 @@ dependencies = [ ...@@ -2975,9 +2977,9 @@ dependencies = [
[[package]] [[package]]
name = "jiff-static" name = "jiff-static"
version = "0.2.20" version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7946b4325269738f270bb55b3c19ab5c5040525f83fd625259422a9d25d9be5" checksum = "a17c2b211d863c7fde02cbea8a3c1a439b98e109286554f2860bdded7ff83818"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -3011,9 +3013,9 @@ dependencies = [ ...@@ -3011,9 +3013,9 @@ dependencies = [
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.87" version = "0.3.90"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21" checksum = "14dc6f6450b3f6d4ed5b16327f38fed626d375a886159ca555bd7822c0c3a5a6"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"wasm-bindgen", "wasm-bindgen",
...@@ -3331,7 +3333,7 @@ checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" ...@@ -3331,7 +3333,7 @@ checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"libc", "libc",
"redox_syscall 0.7.1", "redox_syscall 0.7.2",
] ]
[[package]] [[package]]
...@@ -3342,9 +3344,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" ...@@ -3342,9 +3344,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.11.0" version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53"
[[package]] [[package]]
name = "litemap" name = "litemap"
...@@ -4964,9 +4966,9 @@ dependencies = [ ...@@ -4964,9 +4966,9 @@ dependencies = [
[[package]] [[package]]
name = "pulldown-cmark" name = "pulldown-cmark"
version = "0.13.0" version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"memchr", "memchr",
...@@ -5347,9 +5349,9 @@ dependencies = [ ...@@ -5347,9 +5349,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.7.1" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35985aa610addc02e24fc232012c86fd11f14111180f902b67e2d5331f8ebf2b" checksum = "6d94dd2f7cd932d4dc02cc8b2b50dfd38bd079a4e5d79198b99743d7fcf9a4b4"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
] ]
...@@ -5410,9 +5412,9 @@ dependencies = [ ...@@ -5410,9 +5412,9 @@ dependencies = [
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.8.9" version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
...@@ -5479,9 +5481,9 @@ dependencies = [ ...@@ -5479,9 +5481,9 @@ dependencies = [
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.52" version = "0.8.53"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce" checksum = "47b34b781b31e5d73e9fbc8689c70551fd1ade9a19e3e28cfec8580a79290cc4"
[[package]] [[package]]
name = "ring" name = "ring"
...@@ -5610,22 +5612,22 @@ dependencies = [ ...@@ -5610,22 +5612,22 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "1.1.3" version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
dependencies = [ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.11.0", "linux-raw-sys 0.12.1",
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.36" version = "0.23.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4"
dependencies = [ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"log", "log",
...@@ -6031,9 +6033,9 @@ dependencies = [ ...@@ -6031,9 +6033,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_with" name = "serde_with"
version = "3.16.1" version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fa237f2807440d238e0364a218270b98f767a00d3dada77b1c53ae88940e2e7" checksum = "381b283ce7bc6b476d903296fb59d0d36633652b633b27f64db4fb46dcbfc3b9"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"chrono", "chrono",
...@@ -6050,9 +6052,9 @@ dependencies = [ ...@@ -6050,9 +6052,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_with_macros" name = "serde_with_macros"
version = "3.16.1" version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0"
dependencies = [ dependencies = [
"darling 0.21.3", "darling 0.21.3",
"proc-macro2", "proc-macro2",
...@@ -6361,14 +6363,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" ...@@ -6361,14 +6363,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.25.0" version = "3.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0"
dependencies = [ dependencies = [
"fastrand", "fastrand",
"getrandom 0.4.1", "getrandom 0.4.1",
"once_cell", "once_cell",
"rustix 1.1.3", "rustix 1.1.4",
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
...@@ -7418,9 +7420,9 @@ dependencies = [ ...@@ -7418,9 +7420,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen" name = "wasm-bindgen"
version = "0.2.110" version = "0.2.113"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866" checksum = "60722a937f594b7fde9adb894d7c092fc1bb6612897c46368d18e7a20208eff2"
dependencies = [ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"once_cell", "once_cell",
...@@ -7431,9 +7433,9 @@ dependencies = [ ...@@ -7431,9 +7433,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-futures" name = "wasm-bindgen-futures"
version = "0.4.60" version = "0.4.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da" checksum = "8a89f4650b770e4521aa6573724e2aed4704372151bd0de9d16a3bbabb87441a"
dependencies = [ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"futures-util", "futures-util",
...@@ -7445,9 +7447,9 @@ dependencies = [ ...@@ -7445,9 +7447,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro" name = "wasm-bindgen-macro"
version = "0.2.110" version = "0.2.113"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52" checksum = "0fac8c6395094b6b91c4af293f4c79371c163f9a6f56184d2c9a85f5a95f3950"
dependencies = [ dependencies = [
"quote", "quote",
"wasm-bindgen-macro-support", "wasm-bindgen-macro-support",
...@@ -7455,9 +7457,9 @@ dependencies = [ ...@@ -7455,9 +7457,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro-support" name = "wasm-bindgen-macro-support"
version = "0.2.110" version = "0.2.113"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309" checksum = "ab3fabce6159dc20728033842636887e4877688ae94382766e00b180abac9d60"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"proc-macro2", "proc-macro2",
...@@ -7468,9 +7470,9 @@ dependencies = [ ...@@ -7468,9 +7470,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-shared" name = "wasm-bindgen-shared"
version = "0.2.110" version = "0.2.113"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53" checksum = "de0e091bdb824da87dc01d967388880d017a0a9bc4f3bdc0d86ee9f9336e3bb5"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
...@@ -7524,9 +7526,9 @@ dependencies = [ ...@@ -7524,9 +7526,9 @@ dependencies = [
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.87" version = "0.3.90"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff" checksum = "705eceb4ce901230f8625bd1d665128056ccbe4b7408faa625eec1ba80f59a97"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"wasm-bindgen", "wasm-bindgen",
......
...@@ -1528,6 +1528,7 @@ dependencies = [ ...@@ -1528,6 +1528,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"dashmap 6.1.0", "dashmap 6.1.0",
"derive-getters",
"dynamo-runtime", "dynamo-runtime",
"dynamo-tokens", "dynamo-tokens",
"flume", "flume",
...@@ -1540,6 +1541,7 @@ dependencies = [ ...@@ -1540,6 +1541,7 @@ dependencies = [
"tokio", "tokio",
"tokio-util", "tokio-util",
"tracing", "tracing",
"uuid",
"xxhash-rust", "xxhash-rust",
] ]
......
...@@ -13,7 +13,7 @@ repository.workspace = true ...@@ -13,7 +13,7 @@ repository.workspace = true
[features] [features]
default = [] default = []
metrics = ["dep:dynamo-runtime"] metrics = ["dep:dynamo-runtime"]
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dynamo-runtime/integration", "dep:uuid", "dep:plotters"] bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dynamo-runtime/integration", "dep:plotters"]
[dependencies] [dependencies]
# repo # repo
...@@ -24,6 +24,7 @@ dynamo-tokens = { workspace = true } ...@@ -24,6 +24,7 @@ dynamo-tokens = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
dashmap = { workspace = true } dashmap = { workspace = true }
derive-getters = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
...@@ -32,6 +33,7 @@ thiserror = { workspace = true } ...@@ -32,6 +33,7 @@ thiserror = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tokio-util = { workspace = true } tokio-util = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
uuid = { workspace = true }
xxhash-rust = { workspace = true } xxhash-rust = { workspace = true }
# dependencies # dependencies
...@@ -41,7 +43,6 @@ parking_lot = { workspace = true } ...@@ -41,7 +43,6 @@ parking_lot = { workspace = true }
# bench (optional) # bench (optional)
clap = { version = "4.5", features = ["derive"], optional = true } clap = { version = "4.5", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true } indicatif = { version = "0.18.0", optional = true }
uuid = { workspace = true, optional = true }
plotters = { version = "0.3", optional = true, default-features = false, features = ["svg_backend", "line_series", "point_series", "full_palette"] } plotters = { version = "0.3", optional = true, default-features = false, features = ["svg_backend", "line_series", "point_series", "full_palette"] }
rustc-hash = "2.1.1" rustc-hash = "2.1.1"
......
...@@ -36,12 +36,12 @@ use std::time::Instant; ...@@ -36,12 +36,12 @@ use std::time::Instant;
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use dynamo_runtime::error::DynamoError;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
pub use dynamo_runtime::protocols::maybe_error::MaybeError; pub use dynamo_runtime::protocols::maybe_error::MaybeError;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
error::DynamoError,
metrics::{MetricsHierarchy, prometheus_names::kvrouter}, metrics::{MetricsHierarchy, prometheus_names::kvrouter},
}; };
use prometheus::{IntCounterVec, Opts}; use prometheus::{IntCounterVec, Opts};
...@@ -54,7 +54,7 @@ pub trait MaybeError { ...@@ -54,7 +54,7 @@ pub trait MaybeError {
/// Construct an instance from an error. /// Construct an instance from an error.
fn from_err(err: impl std::error::Error + 'static) -> Self; fn from_err(err: impl std::error::Error + 'static) -> Self;
/// Convert to an error instance if this represents an error. /// Convert to an error instance if this represents an error.
fn err(&self) -> Option<DynamoError>; fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>>;
} }
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
......
...@@ -11,11 +11,13 @@ pub mod approx; ...@@ -11,11 +11,13 @@ pub mod approx;
pub mod bench_utils; pub mod bench_utils;
pub mod concurrent_radix_tree; pub mod concurrent_radix_tree;
pub mod indexer; pub mod indexer;
pub mod multi_worker_sequence;
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
pub mod naive_indexers; pub mod naive_indexers;
pub mod nested_map; pub mod nested_map;
pub mod protocols; pub mod protocols;
pub mod radix_tree; pub mod radix_tree;
pub mod sequence;
#[cfg(test)] #[cfg(test)]
pub(crate) mod test_utils; pub(crate) mod test_utils;
...@@ -23,6 +25,10 @@ pub(crate) mod test_utils; ...@@ -23,6 +25,10 @@ pub(crate) mod test_utils;
// Re-export key types for convenience // Re-export key types for convenience
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use multi_worker_sequence::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
SequenceSubscriber,
};
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
pub use naive_indexers::{InvertedIndex, NaiveNestedMap}; pub use naive_indexers::{InvertedIndex, NaiveNestedMap};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
...@@ -31,3 +37,4 @@ pub use protocols::{ ...@@ -31,3 +37,4 @@ pub use protocols::{
compute_block_hash_for_seq, compute_block_hash_for_seq,
}; };
pub use radix_tree::RadixTree; pub use radix_tree::RadixTree;
pub use sequence::{ActiveSequences, RequestId};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Multi-worker extension of [`ActiveSequences`] using shared DashMap for lock-free concurrent
//! access, with pluggable event publishing and metric observation via traits.
//!
//! The two traits [`SequencePublisher`] and [`SequenceSubscriber`] abstract the runtime-specific
//! transport (e.g., NATS EventPublisher, Prometheus gauges) so that all business logic lives in
//! this crate while the runtime glue stays in `lib/llm`.
use dashmap::DashMap;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank,
};
use crate::sequence::{ActiveSequences, RequestId};
// ---------------------------------------------------------------------------
// Traits
// ---------------------------------------------------------------------------
/// Abstraction over event publishing and metrics observation.
///
/// Implementations provide the runtime-specific transport (e.g., NATS EventPublisher,
/// Prometheus gauges) while the business logic in [`ActiveSequencesMultiWorker`] stays
/// runtime-agnostic.
pub trait SequencePublisher: Send + Sync {
/// Publish a replica-sync event to peer routers.
fn publish_event(
&self,
event: &ActiveSequenceEvent,
) -> impl Future<Output = anyhow::Result<()>> + Send;
/// Fire-and-forget publish of an [`ActiveLoad`] metric payload.
fn publish_load(&self, load: ActiveLoad);
/// Record per-worker load in Prometheus gauges.
fn observe_load(
&self,
worker: &WorkerWithDpRank,
worker_type: &str,
blocks: usize,
tokens: usize,
);
}
/// Abstraction over event subscription for replica sync.
pub trait SequenceSubscriber: Send {
/// Receive the next replica-sync event, or `None` if the stream is closed.
fn next_event(
&mut self,
) -> impl Future<Output = Option<anyhow::Result<ActiveSequenceEvent>>> + Send;
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
/// Errors that can occur during sequence management operations.
#[derive(Debug, thiserror::Error)]
pub enum SequenceError {
#[error("Worker {worker:?} not found")]
WorkerNotFound { worker: WorkerWithDpRank },
#[error("Request {request_id} already exists (assigned to worker {worker:?})")]
DuplicateRequest {
request_id: String,
worker: WorkerWithDpRank,
},
#[error("Request {request_id} not found")]
RequestNotFound { request_id: String },
#[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error),
}
/// Bundled parameters for adding a request to the sequence tracker.
pub struct SequenceRequest {
pub request_id: RequestId,
pub token_sequence: Option<Vec<SequenceHash>>,
pub isl: usize,
pub overlap: u32,
pub expected_output_tokens: Option<u32>,
pub worker: WorkerWithDpRank,
pub lora_name: Option<String>,
}
// ---------------------------------------------------------------------------
// ActiveSequencesMultiWorker
// ---------------------------------------------------------------------------
/// Multi-worker extension of [`ActiveSequences`] using shared DashMap for lock-free concurrent
/// access.
///
/// Generic over `P: SequencePublisher` to decouple from runtime-specific event transport
/// and metrics infrastructure.
pub struct ActiveSequencesMultiWorker<P: SequencePublisher> {
workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
block_size: usize,
router_id: u64,
publisher: P,
replica_sync: bool,
worker_type: &'static str,
}
impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Create a new multi-worker sequence tracker.
///
/// `dp_sizes` maps worker IDs to their data-parallel size (number of dp_ranks).
pub fn new(
publisher: P,
block_size: usize,
dp_sizes: HashMap<u64, u32>,
replica_sync: bool,
router_id: u64,
worker_type: &'static str,
) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");
let workers = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new());
let request_to_lora = Arc::new(DashMap::new());
for (worker_id, dp_size) in dp_sizes {
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
workers.insert(worker, ActiveSequences::new(block_size));
}
}
Self {
workers,
request_to_worker,
request_to_lora,
block_size,
router_id,
publisher,
replica_sync,
worker_type,
}
}
/// Spawn a background task that subscribes to replica-sync events from peer routers
/// and applies them to the local state.
pub fn start_replica_sync<S: SequenceSubscriber + 'static>(
self: &Arc<Self>,
subscriber: S,
cancel_token: CancellationToken,
) {
let this = Arc::clone(self);
tokio::spawn(async move {
if let Err(e) = this.run_replica_sync(subscriber, cancel_token).await {
tracing::error!("Error in active sequences events subscription: {}", e);
}
});
}
async fn run_replica_sync<S: SequenceSubscriber>(
&self,
mut subscriber: S,
cancel_token: CancellationToken,
) -> anyhow::Result<()> {
loop {
tokio::select! {
result = subscriber.next_event() => {
let Some(result) = result else {
break;
};
let Ok(event) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
);
continue;
};
if event.router_id == self.router_id {
continue;
}
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
isl,
overlap,
expected_output_tokens,
} => {
self.request_to_worker
.insert(event.request_id.clone(), event.worker);
if let Some(ref lora_name) = event.lora_name {
self.request_to_lora
.insert(event.request_id.clone(), lora_name.clone());
}
if let Some(mut entry) = self.workers.get_mut(&event.worker) {
entry.add_request(
event.request_id.clone(),
token_sequence.clone(),
*isl,
*overlap,
*expected_output_tokens,
);
} else {
tracing::warn!(
"Worker {:?} not found, cannot process AddRequest",
event.worker
);
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker)) =
self.request_to_worker.remove(&event.request_id)
&& let Some(mut entry) = self.workers.get_mut(&worker)
{
entry.free(&event.request_id);
}
self.request_to_lora.remove(&event.request_id);
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker) =
self.request_to_worker.get(&event.request_id)
&& let Some(mut entry) = self.workers.get_mut(&*worker)
{
entry.mark_prefill_completed(&event.request_id);
}
}
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled");
break;
}
}
}
Ok(())
}
/// Update the set of workers, adding and removing as needed.
///
/// `new_dp_sizes` maps worker IDs to their data-parallel size.
pub fn update_workers(&self, new_dp_sizes: HashMap<u64, u32>) {
let current_workers: HashSet<WorkerWithDpRank> =
self.workers.iter().map(|entry| *entry.key()).collect();
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, dp_size) in &new_dp_sizes {
for dp_rank in 0..*dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
}
}
let workers_to_remove: Vec<WorkerWithDpRank> =
current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect();
for worker in &workers_to_remove {
tracing::warn!("Removing worker {:?}", worker);
self.workers.remove(worker);
let requests_to_remove: Vec<RequestId> = self
.request_to_worker
.iter()
.filter(|entry| entry.value() == worker)
.map(|entry| entry.key().clone())
.collect();
self.request_to_worker
.retain(|_request_id, mapped_worker| mapped_worker != worker);
for request_id in requests_to_remove {
self.request_to_lora.remove(&request_id);
}
}
for worker in &workers_to_add {
tracing::warn!("Adding worker {:?}", worker);
self.workers
.insert(*worker, ActiveSequences::new(self.block_size));
}
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
worker,
lora_name,
} = req;
if !self.workers.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest {
request_id,
worker: *existing_worker,
});
}
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: token_sequence.clone(),
isl,
overlap,
expected_output_tokens,
},
router_id: self.router_id,
lora_name: lora_name.clone(),
};
self.publisher.publish_event(&event).await?;
}
self.request_to_worker.insert(request_id.clone(), worker);
if let Some(lora) = lora_name {
self.request_to_lora.insert(request_id.clone(), lora);
}
let removed_requests = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
)
};
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
async fn mutate_request_worker(
&self,
request_id: &RequestId,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
if self.replica_sync {
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
};
self.publisher.publish_event(&event).await?;
}
{
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
mutate_fn(&mut entry, request_id);
}
if remove_mapping {
self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Free all blocks associated with a request.
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::Free,
|seqs, rid| {
seqs.free(rid);
},
true,
)
.await
}
/// Mark prefill as completed for a request.
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
pub async fn mark_prefill_completed(
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::MarkPrefillCompleted,
|seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false,
)
.await
}
/// Add an output block with optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
// TODO: output blocks are not replicated via replica_sync — add an
// ActiveSequenceEventData variant if cross-instance accuracy matters.
pub fn add_output_block(
&self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
let success = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_output_block(request_id, decay_fraction)
};
if !success {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(),
});
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let (active_blocks, active_tokens) = {
let Some(entry) = self.workers.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
(entry.active_blocks(), entry.active_tokens())
};
self.publisher
.observe_load(&worker, self.worker_type, active_blocks, active_tokens);
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
};
self.publisher.publish_load(active_load);
}
/// Get the number of workers.
pub fn num_workers(&self) -> usize {
self.workers.len()
}
/// Get the worker type for this router ("prefill" or "decode").
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
/// Query all workers for the number of new blocks that would be added by a token sequence.
pub fn new_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().new_blocks(&token_sequence));
}
results
}
/// Query all workers for the total number of blocks (new + active) that would be used.
pub fn potential_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(
*entry.key(),
entry.value().potential_blocks(&token_sequence),
);
}
results
}
/// Query all workers for the potential blocks and tokens.
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlaps: OverlapScores,
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
#[cfg(feature = "bench")]
let start = tokio::time::Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.workers.len();
let mut potential_blocks = HashMap::with_capacity(self.workers.len());
let mut potential_tokens = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
let worker = *entry.key();
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let (blocks, tokens) =
entry
.value()
.potential_blocks_and_tokens(token_sequence.as_deref(), isl, overlap);
potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker, tokens);
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
);
}
(potential_blocks, potential_tokens)
}
/// Query all workers for their current number of active blocks.
pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_blocks());
}
results
}
/// Query all workers for their current number of active tokens.
pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_tokens());
}
results
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in self.request_to_lora.iter() {
let lora_name = entry.value().clone();
*counts.entry(lora_name).or_insert(0) += 1;
}
counts
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use derive_getters::Getters;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
/// Duration after which stale requests are forcibly expired (5 minutes)
const EXPIRY_DURATION: Duration = Duration::from_secs(300);
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// it contributes the fractional value instead of 1 to active_blocks()
fractional_blocks: HashMap<SequenceHash, f64>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_tokens: usize,
/// Timer for when to force expiry of stale requests
expiry_timer: Instant,
/// Set of request IDs to check for expiry
expiry_requests: HashSet<RequestId>,
}
impl ActiveSequences {
/// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");
Self {
active_seqs: HashMap::new(),
prefill_tokens: HashMap::new(),
expected_output_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
block_size,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
}
}
fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc
}
fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
}
pub fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value
count = count - 1.0 + frac;
}
}
count.round() as usize
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return;
};
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
}
}
}
/// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
) -> HashSet<RequestId> {
// Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) {
tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
return HashSet::new();
}
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry();
let prefill_tokens = self.new_tokens(isl, overlap);
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
}
removed_requests
}
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
if let Some(tokens) = self.prefill_tokens.remove(request_id) {
self.active_tokens = self
.active_tokens
.checked_sub(tokens)
.expect("active_tokens underflow");
}
}
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * self.block_size;
isl.checked_sub(cached_tokens)
.unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
self.block_size
);
0
})
}
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks()
} else {
self.active_blocks()
};
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
(potential_blocks, potential_tokens)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks()
}
/// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id);
self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
}
};
// Drop each Rc reference, then clean up the corresponding weak reference
for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
}
self.active_blocks()
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block(
&mut self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> bool {
// Check if request exists first (immutable borrow)
if !self.active_seqs.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block");
return false;
}
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
// Touch the block (adds to unique_blocks)
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
.get_mut(request_id)
.unwrap()
.push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac);
}
true
}
/// Force expiry of stale requests if the timer has elapsed
/// Returns the set of expired request IDs that were removed
pub fn force_expiry(&mut self) -> HashSet<RequestId> {
let now = Instant::now();
// Early return if timer hasn't expired yet
if now < self.expiry_timer {
return HashSet::new();
}
// Process expired requests - drain to avoid clone
let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
for request_id in &expired_requests {
tracing::warn!("Force expiring stale request: {}", request_id);
self.free(request_id);
}
self.expiry_timer = now + EXPIRY_DURATION;
self.expiry_requests = self.active_seqs.keys().cloned().collect();
expired_requests
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 16);
seq_manager.free(&"request_2".to_string());
assert_eq!(seq_manager.active_blocks(), 4);
assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.free(&"request_3".to_string());
assert_eq!(seq_manager.active_blocks(), 3);
assert_eq!(seq_manager.active_tokens(), 12);
seq_manager.free(&"request_1".to_string());
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
}
#[test]
fn test_output_blocks_with_fractional_decay() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
// Add request with 3 prefill blocks
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_blocks(), 3);
// Add output block with 0.5 decay fraction.
// This adds a random block and sets all single-ref blocks to 0.5.
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
// 4 unique blocks, all single-ref → all fractional at 0.5
// active_blocks = 4 - 4 + 4*0.5 = 2
assert_eq!(seq_manager.active_blocks(), 2);
// Add second request sharing prefix [1, 2]
seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None);
// Blocks 1,2 now have strong_count=2 but still have fractional 0.5 from before
// No new unique blocks → active_blocks = 4 - 4 + 2.0 = 2
assert_eq!(seq_manager.active_blocks(), 2);
// Add another output block with 0.0 decay for r1.
// set_single_ref_blocks_as_fractional updates only single-ref blocks:
// blocks 1,2: strong_count=2, NOT updated (remain 0.5)
// block 3, old output, new output: strong_count=1, set to 0.0
// active_blocks = 5 - 5 + (0.5+0.5+0.0+0.0+0.0) = 1
assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
assert_eq!(seq_manager.active_blocks(), 1);
// Free both requests, verify clean state
seq_manager.free(&"r2".to_string());
seq_manager.free(&"r1".to_string());
assert_eq!(seq_manager.active_blocks(), 0);
assert_eq!(seq_manager.active_tokens(), 0);
}
#[test]
fn test_mark_prefill_completed() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
// Add request with isl=12, overlap=0 → active_tokens=12
seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
assert_eq!(seq_manager.active_tokens(), 12);
// Mark prefill completed → active_tokens drops to 0
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
// Double-mark: no panic, still 0
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
// Add second request with isl=8
seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None);
assert_eq!(seq_manager.active_tokens(), 8);
// Free it (internally calls mark_prefill_completed) → active_tokens=0
seq_manager.free(&"r2".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
}
#[tokio::test(start_paused = true)]
async fn test_force_expiry() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
// Add two requests
seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
assert_eq!(seq_manager.active_blocks(), 4);
// First expiry cycle: advance past EXPIRY_DURATION.
// This populates expiry_requests with {r1, r2} but doesn't expire anything
// since expiry_requests started empty.
tokio::time::advance(Duration::from_secs(301)).await;
let expired = seq_manager.force_expiry();
assert!(expired.is_empty());
// Second expiry cycle: advance again so the timer expires.
// Adding r3 triggers force_expiry which drains {r1, r2}.
tokio::time::advance(Duration::from_secs(301)).await;
let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
// Only r3's block remains
assert_eq!(seq_manager.active_blocks(), 1);
assert_eq!(seq_manager.active_tokens(), 4);
}
}
...@@ -11,7 +11,7 @@ use tokio::sync::Mutex; ...@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use super::WorkerSelector; use super::WorkerSelector;
use super::protocols::WorkerWithDpRank; use super::protocols::WorkerWithDpRank;
use super::scheduler::{SchedulingRequest, SchedulingResponse}; use super::scheduler::{SchedulingRequest, SchedulingResponse};
use super::sequence::{ActiveSequencesMultiWorker, SequenceRequest}; use super::sequence::{ActiveSequencesMulti, SequenceRequest};
use crate::discovery::RuntimeConfigWatch; use crate::discovery::RuntimeConfigWatch;
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
...@@ -51,7 +51,7 @@ impl PartialOrd for QueueEntry { ...@@ -51,7 +51,7 @@ impl PartialOrd for QueueEntry {
/// If queueing is disabled (threshold_frac is None), requests are scheduled immediately. /// If queueing is disabled (threshold_frac is None), requests are scheduled immediately.
pub struct SchedulerQueue { pub struct SchedulerQueue {
pending: Mutex<BinaryHeap<QueueEntry>>, pending: Mutex<BinaryHeap<QueueEntry>>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMulti>,
workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
/// Cached threshold fraction; None means queueing is disabled. /// Cached threshold fraction; None means queueing is disabled.
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
...@@ -63,7 +63,7 @@ pub struct SchedulerQueue { ...@@ -63,7 +63,7 @@ pub struct SchedulerQueue {
impl SchedulerQueue { impl SchedulerQueue {
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMulti>,
workers_with_configs: RuntimeConfigWatch, workers_with_configs: RuntimeConfigWatch,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
block_size: u32, block_size: u32,
......
...@@ -6,7 +6,9 @@ use super::RouterConfigOverride; ...@@ -6,7 +6,9 @@ use super::RouterConfigOverride;
use super::WorkerSelector; use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank}; use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue; use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest}; use super::sequence::{
ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
use crate::discovery::RuntimeConfigWatch; use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result; use anyhow::Result;
...@@ -82,7 +84,7 @@ impl SchedulingRequest { ...@@ -82,7 +84,7 @@ impl SchedulingRequest {
pub struct KvScheduler { pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>, request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMulti>,
queue: Arc<SchedulerQueue>, queue: Arc<SchedulerQueue>,
} }
...@@ -103,8 +105,7 @@ impl KvScheduler { ...@@ -103,8 +105,7 @@ impl KvScheduler {
workers_with_configs.borrow().clone(); workers_with_configs.borrow().clone();
let router_id = component.drt().discovery().instance_id(); let router_id = component.drt().discovery().instance_id();
let slots = Arc::new( let slots = create_multi_worker_sequences(
ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size as usize, block_size as usize,
initial_workers, initial_workers,
...@@ -113,8 +114,7 @@ impl KvScheduler { ...@@ -113,8 +114,7 @@ impl KvScheduler {
worker_type, worker_type,
) )
.await .await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?, .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
);
// Spawn background task to sync slots when the watch value changes. // Spawn background task to sync slots when the watch value changes.
let slots_monitor = slots.clone(); let slots_monitor = slots.clone();
...@@ -141,7 +141,11 @@ impl KvScheduler { ...@@ -141,7 +141,11 @@ impl KvScheduler {
let current_workers = monitor_rx.borrow_and_update().clone(); let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers { if current_workers != last_workers {
slots_monitor.update_workers(current_workers.clone()); let dp_sizes: HashMap<u64, u32> = current_workers
.iter()
.map(|(&id, c)| (id, c.data_parallel_size))
.collect();
slots_monitor.update_workers(dp_sizes);
last_workers = current_workers; last_workers = current_workers;
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! KV Cache Sequence Management for LLM Inference //! Runtime-specific glue for [`ActiveSequencesMultiWorker`].
//! //!
//! This module provides efficient management of token sequences and their associated KV cache blocks //! This module provides the concrete [`SequencePublisher`] and [`SequenceSubscriber`]
//! for distributed LLM inference. It implements a shared block system where multiple requests can //! implementations that wire the runtime-agnostic business logic (in `dynamo_kv_router`)
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage. //! to NATS event transport and Prometheus metrics.
//!
//! # Key Components pub use dynamo_kv_router::multi_worker_sequence::{
//! ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their SequenceSubscriber,
//! token sequences, managing shared KV cache blocks efficiently. };
//! pub use dynamo_kv_router::sequence::{ActiveSequences, RequestId};
//! - [`ActiveSequencesMultiWorker`]: Multi-worker extension that stores per-worker
//! `ActiveSequences` in a shared `DashMap` for lock-free concurrent access.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use crate::kv_router::protocols::OverlapScores;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap;
use derive_getters::Getters;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber}; use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
use dynamo_tokens::SequenceHash; use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
use super::metrics::WORKER_LOAD_METRICS; use super::metrics::WORKER_LOAD_METRICS;
use super::protocols::{ use super::protocols::{ActiveLoad, ActiveSequenceEvent, WorkerWithDpRank};
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank,
};
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT}; use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::CancellationToken;
/// Errors that can occur during sequence management operations
#[derive(Debug, thiserror::Error)]
pub enum SequenceError {
#[error("Worker {worker:?} not found")]
WorkerNotFound { worker: WorkerWithDpRank },
#[error("Request {request_id} already exists (assigned to worker {worker:?})")]
DuplicateRequest {
request_id: String,
worker: WorkerWithDpRank,
},
#[error("Request {request_id} not found")]
RequestNotFound { request_id: String },
#[error("Failed to publish event: {0}")]
PublishFailed(#[from] anyhow::Error),
}
/// Duration after which stale requests are forcibly expired (5 minutes)
const EXPIRY_DURATION: Duration = Duration::from_secs(300);
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;
/// Bundled parameters for adding a request to the sequence tracker.
pub struct SequenceRequest {
pub request_id: RequestId,
pub token_sequence: Option<Vec<SequenceHash>>,
pub isl: usize,
pub overlap: u32,
pub expected_output_tokens: Option<u32>,
pub worker: WorkerWithDpRank,
pub lora_name: Option<String>,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
prefill_tokens: HashMap<RequestId, usize>,
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens: HashMap<RequestId, u32>,
unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
/// Fractional block counts for blocks that are partially cached /// Concrete [`SequencePublisher`] backed by NATS [`EventPublisher`] and Prometheus gauges.
/// When a block is in both unique_blocks and fractional_blocks, pub struct RuntimeSequencePublisher {
/// it contributes the fractional value instead of 1 to active_blocks() event_publisher: EventPublisher,
fractional_blocks: HashMap<SequenceHash, f64>, metrics_publisher: Arc<EventPublisher>,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
active_tokens: usize,
/// Timer for when to force expiry of stale requests
expiry_timer: Instant,
/// Set of request IDs to check for expiry
expiry_requests: HashSet<RequestId>,
} }
impl ActiveSequences { impl SequencePublisher for RuntimeSequencePublisher {
/// Create a new SharedSequenceManager instance async fn publish_event(&self, event: &ActiveSequenceEvent) -> anyhow::Result<()> {
pub fn new(block_size: usize) -> Self { self.event_publisher.publish(event).await
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");
Self {
active_seqs: HashMap::new(),
prefill_tokens: HashMap::new(),
expected_output_tokens: HashMap::new(),
unique_blocks: HashMap::new(),
fractional_blocks: HashMap::new(),
block_size,
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
}
}
fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
} }
let rc = Arc::new(()); fn publish_load(&self, load: ActiveLoad) {
self.unique_blocks.insert(*block, Arc::downgrade(&rc)); let publisher = self.metrics_publisher.clone();
rc tokio::spawn(async move {
} if let Err(e) = publisher.publish(&load).await {
tracing::trace!(
fn try_remove_block(&mut self, block: &SequenceHash) { "Failed to publish ActiveLoad to NATS for worker (id={}, dp_rank={}): {e:?}",
if let Some(weak) = self.unique_blocks.get(block) load.worker_id,
&& weak.strong_count() == 0 load.dp_rank
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
}
pub fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
// Subtract 1 (the full block) and add the fractional value
count = count - 1.0 + frac;
}
}
count.round() as usize
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
let Some(blocks) = self.active_seqs.get(request_id) else {
tracing::warn!(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
); );
return;
};
for (hash, rc) in blocks {
// A block with strong_count == 1 means only this request holds a reference
if Arc::strong_count(rc) == 1 {
self.fractional_blocks.insert(*hash, fraction);
}
}
}
/// Add a new request with its initial tokens
/// Returns the set of expired request IDs that were removed during cleanup
pub fn add_request(
&mut self,
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
) -> HashSet<RequestId> {
// Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) {
tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
return HashSet::new();
}
// Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry();
let prefill_tokens = self.new_tokens(isl, overlap);
self.prefill_tokens
.insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens;
// Store expected output tokens if provided
if let Some(tokens) = expected_output_tokens {
self.expected_output_tokens
.insert(request_id.clone(), tokens);
}
if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
self.active_seqs
.insert(request_id.clone(), sequence_with_refs);
} else {
// dummy empty sequence
self.active_seqs.insert(request_id.clone(), Vec::new());
} }
});
removed_requests
}
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking
pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
if let Some(tokens) = self.prefill_tokens.remove(request_id) {
self.active_tokens = self
.active_tokens
.checked_sub(tokens)
.expect("active_tokens underflow");
}
}
pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
let cached_tokens = (overlap as usize) * self.block_size;
isl.checked_sub(cached_tokens)
.unwrap_or_else(|| {
tracing::error!(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
self.block_size
);
0
})
} }
pub fn potential_blocks_and_tokens( fn observe_load(
&self, &self,
token_sequence: Option<&[SequenceHash]>, worker: &WorkerWithDpRank,
isl: usize, worker_type: &str,
overlap: u32, blocks: usize,
) -> (usize, usize) { tokens: usize,
let potential_blocks = if let Some(token_seq) = token_sequence { ) {
self.new_blocks(token_seq) + self.active_blocks() WORKER_LOAD_METRICS.observe(
} else { worker.worker_id,
self.active_blocks() worker.dp_rank,
}; worker_type,
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens; blocks,
(potential_blocks, potential_tokens) tokens,
} );
/// Match a request against existing blocks and return the number of new blocks that would be added
pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
token_sequence
.iter()
.filter(|block| !self.unique_blocks.contains_key(block))
.count()
}
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
self.new_blocks(token_sequence) + self.active_blocks()
}
/// Free all blocks associated with a request
pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id);
self.expiry_requests.remove(request_id);
// Remove expected output tokens tracking
self.expected_output_tokens.remove(request_id);
// Remove from active_seqs and get the token sequence
let token_seq = match self.active_seqs.remove(request_id) {
Some(seq) => seq,
None => {
tracing::warn!("Trying to free non-existent request {request_id}");
return self.active_blocks();
}
};
// Drop each Rc reference, then clean up the corresponding weak reference
for (block_hash, rc) in token_seq {
drop(rc);
self.try_remove_block(&block_hash);
}
self.active_blocks()
}
/// Add an output block with a random hash and optional fractional decay weight.
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction (if provided) represents how "temporary" the block is:
/// - 1.0 means fully counted (early in generation)
/// - 0.0 means not counted (near end of expected output)
/// - Computed as: 1 - (current_osl / expected_output_tokens)
///
/// Returns true if the block was added, false if the request was not found.
pub fn add_output_block(
&mut self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> bool {
// Check if request exists first (immutable borrow)
if !self.active_seqs.contains_key(request_id) {
tracing::warn!("Request {request_id} not found for add_output_block");
return false;
}
// Generate a random block hash using UUID
let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
// Touch the block (adds to unique_blocks)
let rc = self.touch_block(&random_hash);
// Now we can safely get_mut and push
self.active_seqs
.get_mut(request_id)
.unwrap()
.push((random_hash, rc));
// Apply fractional decay to all single-ref blocks in this request if provided
if let Some(frac) = decay_fraction {
self.set_single_ref_blocks_as_fractional(request_id, frac);
}
true
} }
}
/// Force expiry of stale requests if the timer has elapsed /// Concrete [`SequenceSubscriber`] backed by NATS typed event stream.
/// Returns the set of expired request IDs that were removed pub struct RuntimeSequenceSubscriber {
pub fn force_expiry(&mut self) -> HashSet<RequestId> { inner: dynamo_runtime::transports::event_plane::TypedEventSubscriber<ActiveSequenceEvent>,
let now = Instant::now(); }
// Early return if timer hasn't expired yet
if now < self.expiry_timer {
return HashSet::new();
}
// Process expired requests - drain to avoid clone impl SequenceSubscriber for RuntimeSequenceSubscriber {
let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect(); async fn next_event(&mut self) -> Option<anyhow::Result<ActiveSequenceEvent>> {
for request_id in &expired_requests { match self.inner.next().await? {
tracing::warn!("Force expiring stale request: {}", request_id); Ok((_envelope, event)) => Some(Ok(event)),
self.free(request_id); Err(e) => Some(Err(e)),
} }
self.expiry_timer = now + EXPIRY_DURATION;
self.expiry_requests = self.active_seqs.keys().cloned().collect();
expired_requests
} }
} }
/// Multi-worker extension of ActiveSequences using shared DashMap for lock-free concurrent access /// Type alias for the runtime-wired multi-worker sequence tracker.
pub struct ActiveSequencesMultiWorker { pub type ActiveSequencesMulti = ActiveSequencesMultiWorker<RuntimeSequencePublisher>;
workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
block_size: usize,
router_id: u64,
event_publisher: EventPublisher,
metrics_publisher: Arc<EventPublisher>,
replica_sync: bool,
worker_type: &'static str,
}
impl ActiveSequencesMultiWorker { /// Convenience async constructor that creates the NATS publishers/subscribers
pub async fn new( /// and returns an `Arc<ActiveSequencesMulti>` with replica sync already running.
pub async fn create_multi_worker_sequences(
component: Component, component: Component,
block_size: usize, block_size: usize,
workers_with_configs: HashMap<u64, ModelRuntimeConfig>, workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
replica_sync: bool, replica_sync: bool,
router_id: u64, router_id: u64,
worker_type: &'static str, worker_type: &'static str,
) -> Result<Self> { ) -> Result<Arc<ActiveSequencesMulti>> {
assert!(block_size > 1, "block_size must be greater than 1");
let workers = Arc::new(DashMap::new());
let request_to_worker = Arc::new(DashMap::new());
let request_to_lora = Arc::new(DashMap::new());
for (worker_id, config) in workers_with_configs {
let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
workers.insert(worker, ActiveSequences::new(block_size));
}
}
let event_publisher = let event_publisher =
EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?; EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
let metrics_publisher = Arc::new( let metrics_publisher =
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?, Arc::new(EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?);
);
let multi_worker = Self { let publisher = RuntimeSequencePublisher {
workers: workers.clone(),
request_to_worker: request_to_worker.clone(),
request_to_lora: request_to_lora.clone(),
block_size,
event_publisher, event_publisher,
metrics_publisher, metrics_publisher,
router_id,
replica_sync,
worker_type,
};
if replica_sync {
let workers_clone = workers.clone();
let request_to_worker_clone = request_to_worker.clone();
let request_to_lora_clone = request_to_lora.clone();
let component_clone = component.clone();
let router_id_clone = router_id;
let cancel_token = component.drt().runtime().child_token();
tokio::spawn(async move {
if let Err(e) = Self::subscribe_to_events(
workers_clone,
request_to_worker_clone,
request_to_lora_clone,
component_clone,
router_id_clone,
cancel_token,
)
.await
{
tracing::error!("Error in active sequences events subscription: {}", e);
}
});
}
Ok(multi_worker)
}
/// Background task to subscribe to active sequence events and update all workers
async fn subscribe_to_events(
workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
request_to_lora: Arc<DashMap<RequestId, String>>,
component: Component,
router_id: u64,
cancel_token: CancellationToken,
) -> Result<()> {
let mut subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
.await?
.typed::<ActiveSequenceEvent>();
loop {
tokio::select! {
result = subscriber.next() => {
let Some(result) = result else {
break;
};
let Ok((_envelope, event)) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
);
continue;
}; };
if event.router_id == router_id { let dp_sizes: HashMap<u64, u32> = workers_with_configs
continue; .into_iter()
} .map(|(id, config)| (id, config.data_parallel_size))
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
isl,
overlap,
expected_output_tokens,
} => {
request_to_worker.insert(event.request_id.clone(), event.worker);
if let Some(ref lora_name) = event.lora_name {
request_to_lora.insert(event.request_id.clone(), lora_name.clone());
}
if let Some(mut entry) = workers.get_mut(&event.worker) {
entry.add_request(
event.request_id.clone(),
token_sequence.clone(),
*isl,
*overlap,
*expected_output_tokens,
);
} else {
tracing::warn!(
"Worker {:?} not found, cannot process AddRequest",
event.worker
);
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
&& let Some(mut entry) = workers.get_mut(&worker)
{
entry.free(&event.request_id);
}
request_to_lora.remove(&event.request_id);
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker) = request_to_worker.get(&event.request_id)
&& let Some(mut entry) = workers.get_mut(&*worker)
{
entry.mark_prefill_completed(&event.request_id);
}
}
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Subscription task cancelled");
break;
}
}
}
Ok(())
}
/// Update the set of workers, adding and removing as needed
pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
let current_workers: HashSet<WorkerWithDpRank> =
self.workers.iter().map(|entry| *entry.key()).collect();
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
}
}
let workers_to_remove: Vec<WorkerWithDpRank> =
current_workers.difference(&new_workers).copied().collect();
let workers_to_add: Vec<WorkerWithDpRank> =
new_workers.difference(&current_workers).copied().collect();
for worker in &workers_to_remove {
tracing::warn!("Removing worker {:?}", worker);
self.workers.remove(worker);
let requests_to_remove: Vec<RequestId> = self
.request_to_worker
.iter()
.filter(|entry| entry.value() == worker)
.map(|entry| entry.key().clone())
.collect(); .collect();
self.request_to_worker let multi_worker = ActiveSequencesMultiWorker::new(
.retain(|_request_id, mapped_worker| mapped_worker != worker); publisher,
block_size,
for request_id in requests_to_remove { dp_sizes,
self.request_to_lora.remove(&request_id); replica_sync,
} router_id,
} worker_type,
for worker in &workers_to_add {
tracing::warn!("Adding worker {:?}", worker);
self.workers
.insert(*worker, ActiveSequences::new(self.block_size));
}
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
worker,
lora_name,
} = req;
if !self.workers.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
return Err(SequenceError::DuplicateRequest {
request_id,
worker: *existing_worker,
});
}
if self.replica_sync {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: ActiveSequenceEventData::AddRequest {
token_sequence: token_sequence.clone(),
isl,
overlap,
expected_output_tokens,
},
router_id: self.router_id,
lora_name: lora_name.clone(),
};
self.event_publisher.publish(&event).await?;
}
self.request_to_worker.insert(request_id.clone(), worker);
if let Some(lora) = lora_name {
self.request_to_lora.insert(request_id.clone(), lora);
}
let removed_requests = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_request(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
)
};
for expired_id in &removed_requests {
self.request_to_worker.remove(expired_id);
self.request_to_lora.remove(expired_id);
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
async fn mutate_request_worker(
&self,
request_id: &RequestId,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
if self.replica_sync {
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: event_data,
router_id: self.router_id,
lora_name,
};
self.event_publisher.publish(&event).await?;
}
{
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
mutate_fn(&mut entry, request_id);
}
if remove_mapping {
self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Free all blocks associated with a request
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::Free,
|seqs, rid| {
seqs.free(rid);
},
true,
)
.await
}
/// Mark prefill as completed for a request
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
pub async fn mark_prefill_completed(
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
self.mutate_request_worker(
request_id,
ActiveSequenceEventData::MarkPrefillCompleted,
|seqs, rid| {
seqs.mark_prefill_completed(rid);
},
false,
)
.await
}
/// Add an output block with optional fractional decay weight
///
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
// TODO: output blocks are not replicated via replica_sync — add an
// ActiveSequenceEventData variant if cross-instance accuracy matters.
pub fn add_output_block(
&self,
request_id: &RequestId,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
let success = {
let mut entry = self
.workers
.get_mut(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
entry.add_output_block(request_id, decay_fraction)
};
if !success {
return Err(SequenceError::RequestNotFound {
request_id: request_id.clone(),
});
}
self.publish_active_load_for_worker(worker);
Ok(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
/// The NATS publish is spawned as a background task to avoid blocking the caller.
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let (active_blocks, active_tokens) = {
let Some(entry) = self.workers.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
(entry.active_blocks(), entry.active_tokens())
};
WORKER_LOAD_METRICS.observe(
worker.worker_id,
worker.dp_rank,
self.worker_type,
active_blocks,
active_tokens,
);
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
};
let publisher = self.metrics_publisher.clone();
tokio::spawn(async move {
if let Err(e) = publisher.publish(&active_load).await {
tracing::trace!(
"Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
);
}
});
}
/// Get the number of workers
pub fn num_workers(&self) -> usize {
self.workers.len()
}
/// Get the worker type for this router ("prefill" or "decode").
/// Used for Prometheus metric labeling.
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub fn new_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().new_blocks(&token_sequence));
}
results
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub fn potential_blocks(
&self,
token_sequence: Vec<SequenceHash>,
) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(
*entry.key(),
entry.value().potential_blocks(&token_sequence),
);
}
results
}
/// Query all workers for the potential blocks and tokens
pub fn potential_blocks_and_tokens(
&self,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlaps: OverlapScores,
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
#[cfg(feature = "bench")]
let start = Instant::now();
#[cfg(feature = "bench")]
let num_workers = self.workers.len();
let mut potential_blocks = HashMap::with_capacity(self.workers.len());
let mut potential_tokens = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
let worker = *entry.key();
let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);
let (blocks, tokens) =
entry
.value()
.potential_blocks_and_tokens(token_sequence.as_deref(), isl, overlap);
potential_blocks.insert(worker, blocks);
potential_tokens.insert(worker, tokens);
}
#[cfg(feature = "bench")]
{
let total_elapsed = start.elapsed();
tracing::info!(
num_workers,
total_us = total_elapsed.as_micros() as u64,
"potential_blocks_and_tokens completed"
); );
}
(potential_blocks, potential_tokens) let arc = Arc::new(multi_worker);
}
/// Query all workers for their current number of active blocks if replica_sync {
pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> { let subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
let mut results = HashMap::with_capacity(self.workers.len()); .await?
for entry in self.workers.iter() { .typed::<ActiveSequenceEvent>();
results.insert(*entry.key(), entry.value().active_blocks()); let subscriber = RuntimeSequenceSubscriber { inner: subscriber };
} let cancel_token = component.drt().runtime().child_token();
results arc.start_replica_sync(subscriber, cancel_token);
}
/// Query all workers for their current number of active tokens
pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
let mut results = HashMap::with_capacity(self.workers.len());
for entry in self.workers.iter() {
results.insert(*entry.key(), entry.value().active_tokens());
}
results
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> { Ok(arc)
let mut counts: HashMap<String, usize> = HashMap::new();
for entry in self.request_to_lora.iter() {
let lora_name = entry.value().clone();
*counts.entry(lora_name).or_insert(0) += 1;
}
counts
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
use std::sync::Arc;
#[test] #[test]
fn test_active_sequences_shared_blocks() { fn test_active_sequences_shared_blocks() {
...@@ -961,37 +169,26 @@ mod tests { ...@@ -961,37 +169,26 @@ mod tests {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_multi_worker_cross_instance_sync() -> Result<()> { async fn test_multi_worker_cross_instance_sync() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let block_size = 4; // arbitrary block size let block_size = 4;
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?; let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and shared component for both seq_managers
let namespace = distributed.namespace("test_cross_instance_sync")?; let namespace = distributed.namespace("test_cross_instance_sync")?;
let component = namespace.component("sequences")?; let component = namespace.component("sequences")?;
// Create multi-worker sequence managers with:
// - Worker 0 with dp_size=2 (dp_ranks 0 and 1)
// - Worker 1 with dp_size=1 (dp_rank 0)
// This gives us 3 effective workers total to test dp_rank effect
// Both seq_managers use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new(); let mut workers_with_configs = HashMap::new();
// Create runtime config for worker 0 with dp_size=2
let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new(); let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
config_worker_0.data_parallel_size = 2; config_worker_0.data_parallel_size = 2;
workers_with_configs.insert(0, config_worker_0); workers_with_configs.insert(0, config_worker_0);
// Create runtime config for worker 1 with dp_size=1 (default)
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new(); let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, config_worker_1); workers_with_configs.insert(1, config_worker_1);
let seq_manager_1 = Arc::new( let seq_manager_1 = create_multi_worker_sequences(
ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size, block_size,
workers_with_configs.clone(), workers_with_configs.clone(),
...@@ -999,10 +196,8 @@ mod tests { ...@@ -999,10 +196,8 @@ mod tests {
1, 1,
crate::discovery::WORKER_TYPE_DECODE, crate::discovery::WORKER_TYPE_DECODE,
) )
.await?, .await?;
); let seq_manager_2 = create_multi_worker_sequences(
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(
component, component,
block_size, block_size,
workers_with_configs, workers_with_configs,
...@@ -1010,15 +205,10 @@ mod tests { ...@@ -1010,15 +205,10 @@ mod tests {
2, 2,
crate::discovery::WORKER_TYPE_DECODE, crate::discovery::WORKER_TYPE_DECODE,
) )
.await?, .await?;
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// PHASE 1: Add requests using both seq_manager_1 and seq_manager_2
// Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2]
seq_manager_1 seq_manager_1
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_0".to_string(), request_id: "request_0".to_string(),
...@@ -1031,7 +221,6 @@ mod tests { ...@@ -1031,7 +221,6 @@ mod tests {
}) })
.await?; .await?;
// Add request_1 to worker 0, dp_rank 1: sequence [3, 4]
seq_manager_1 seq_manager_1
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_1".to_string(), request_id: "request_1".to_string(),
...@@ -1044,7 +233,6 @@ mod tests { ...@@ -1044,7 +233,6 @@ mod tests {
}) })
.await?; .await?;
// Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2
seq_manager_2 seq_manager_2
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_2".to_string(), request_id: "request_2".to_string(),
...@@ -1057,18 +245,11 @@ mod tests { ...@@ -1057,18 +245,11 @@ mod tests {
}) })
.await?; .await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let blocks_phase1 = seq_manager_1.active_blocks(); let blocks_phase1 = seq_manager_1.active_blocks();
let tokens_phase1 = seq_manager_1.active_tokens(); let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have:
// - Worker 0, dp_rank 0: request_0
// - Worker 0, dp_rank 1: request_1
// - Worker 1, dp_rank 0: request_2
let worker_0_dp0 = WorkerWithDpRank::new(0, 0); let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
let worker_0_dp1 = WorkerWithDpRank::new(0, 1); let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
let worker_1_dp0 = WorkerWithDpRank::new(1, 0); let worker_1_dp0 = WorkerWithDpRank::new(1, 0);
...@@ -1098,23 +279,16 @@ mod tests { ...@@ -1098,23 +279,16 @@ mod tests {
"Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
// Free request_2 (which was added by seq_manager_2) using seq_manager_1
seq_manager_1.free(&"request_2".to_string()).await?; seq_manager_1.free(&"request_2".to_string()).await?;
// Free request_0 and request_1 (which were added by seq_manager_1) using seq_manager_2
seq_manager_2.free(&"request_0".to_string()).await?; seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?; seq_manager_2.free(&"request_1".to_string()).await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let blocks_phase2 = seq_manager_2.active_blocks(); let blocks_phase2 = seq_manager_2.active_blocks();
let tokens_phase2 = seq_manager_2.active_tokens(); let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty for all 3 workers
let all_workers = vec![ let all_workers = vec![
WorkerWithDpRank::new(0, 0), WorkerWithDpRank::new(0, 0),
WorkerWithDpRank::new(0, 1), WorkerWithDpRank::new(0, 1),
...@@ -1140,21 +314,16 @@ mod tests { ...@@ -1140,21 +314,16 @@ mod tests {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_multi_worker_no_token_sequence_sync() -> Result<()> { async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
// Initialize logging once
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let block_size = 4; // arbitrary block size let block_size = 4;
// Create runtime and distributed runtime
let runtime = Runtime::from_current()?; let runtime = Runtime::from_current()?;
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// Create namespace and shared component for both seq_managers
let namespace = distributed.namespace("test_no_token_seq_sync")?; let namespace = distributed.namespace("test_no_token_seq_sync")?;
let component = namespace.component("sequences")?; let component = namespace.component("sequences")?;
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new(); let mut workers_with_configs = HashMap::new();
workers_with_configs.insert( workers_with_configs.insert(
0, 0,
...@@ -1169,8 +338,7 @@ mod tests { ...@@ -1169,8 +338,7 @@ mod tests {
crate::local_model::runtime_config::ModelRuntimeConfig::new(), crate::local_model::runtime_config::ModelRuntimeConfig::new(),
); );
let seq_manager_1 = Arc::new( let seq_manager_1 = create_multi_worker_sequences(
ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size, block_size,
workers_with_configs.clone(), workers_with_configs.clone(),
...@@ -1178,10 +346,8 @@ mod tests { ...@@ -1178,10 +346,8 @@ mod tests {
1, 1,
crate::discovery::WORKER_TYPE_DECODE, crate::discovery::WORKER_TYPE_DECODE,
) )
.await?, .await?;
); let seq_manager_2 = create_multi_worker_sequences(
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(
component, component,
block_size, block_size,
workers_with_configs, workers_with_configs,
...@@ -1189,15 +355,10 @@ mod tests { ...@@ -1189,15 +355,10 @@ mod tests {
2, 2,
crate::discovery::WORKER_TYPE_DECODE, crate::discovery::WORKER_TYPE_DECODE,
) )
.await?, .await?;
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// PHASE 1: Add requests (without token sequences) using both seq_managers
// Add request_0 to worker 0 with no token sequence
seq_manager_1 seq_manager_1
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_0".to_string(), request_id: "request_0".to_string(),
...@@ -1210,7 +371,6 @@ mod tests { ...@@ -1210,7 +371,6 @@ mod tests {
}) })
.await?; .await?;
// Add request_1 to worker 1 with no token sequence
seq_manager_1 seq_manager_1
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_1".to_string(), request_id: "request_1".to_string(),
...@@ -1223,7 +383,6 @@ mod tests { ...@@ -1223,7 +383,6 @@ mod tests {
}) })
.await?; .await?;
// Add request_2 to worker 2 with no token sequence using seq_manager_2
seq_manager_2 seq_manager_2
.add_request(SequenceRequest { .add_request(SequenceRequest {
request_id: "request_2".to_string(), request_id: "request_2".to_string(),
...@@ -1236,13 +395,10 @@ mod tests { ...@@ -1236,13 +395,10 @@ mod tests {
}) })
.await?; .await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let tokens_phase1 = seq_manager_1.active_tokens(); let tokens_phase1 = seq_manager_1.active_tokens();
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
let worker_0 = WorkerWithDpRank::from_worker_id(0); let worker_0 = WorkerWithDpRank::from_worker_id(0);
let worker_1 = WorkerWithDpRank::from_worker_id(1); let worker_1 = WorkerWithDpRank::from_worker_id(1);
let worker_2 = WorkerWithDpRank::from_worker_id(2); let worker_2 = WorkerWithDpRank::from_worker_id(2);
...@@ -1260,15 +416,11 @@ mod tests { ...@@ -1260,15 +416,11 @@ mod tests {
"Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)" "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
); );
// PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2
// Mark prefill completed and free request_2 (which was added by seq_manager_2) using seq_manager_1
seq_manager_1 seq_manager_1
.mark_prefill_completed(&"request_2".to_string()) .mark_prefill_completed(&"request_2".to_string())
.await?; .await?;
seq_manager_1.free(&"request_2".to_string()).await?; seq_manager_1.free(&"request_2".to_string()).await?;
// Mark prefill completed and free requests 0 and 1 (which were added by seq_manager_1) using seq_manager_2
seq_manager_2 seq_manager_2
.mark_prefill_completed(&"request_0".to_string()) .mark_prefill_completed(&"request_0".to_string())
.await?; .await?;
...@@ -1278,13 +430,10 @@ mod tests { ...@@ -1278,13 +430,10 @@ mod tests {
seq_manager_2.free(&"request_0".to_string()).await?; seq_manager_2.free(&"request_0".to_string()).await?;
seq_manager_2.free(&"request_1".to_string()).await?; seq_manager_2.free(&"request_1".to_string()).await?;
// Give some time for synchronization
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
// Query seq_manager_2 to verify everything is empty
let tokens_phase2 = seq_manager_2.active_tokens(); let tokens_phase2 = seq_manager_2.active_tokens();
// Verify phase 2 results - everything should be empty
for worker_id in 0..=2 { for worker_id in 0..=2 {
let worker = WorkerWithDpRank::from_worker_id(worker_id); let worker = WorkerWithDpRank::from_worker_id(worker_id);
assert_eq!( assert_eq!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment