Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
......@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fetch_model, m)?)?;
m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::replay::run_mocker_trace_replay, m)?)?;
m.add_function(wrap_pyfunction!(
llm::entrypoint::run_mocker_trace_replay,
llm::replay::run_mocker_synthetic_trace_replay,
m
)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
......@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::EngineType>()?;
m.add_class::<llm::entrypoint::RouterConfig>()?;
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
......@@ -31,3 +31,4 @@ pub mod local_model;
pub mod lora;
pub mod model_card;
pub mod preprocessor;
pub mod replay;
......@@ -9,7 +9,6 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use pythonize::pythonize;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
......@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback;
use dynamo_mocker::common::protocols::MockEngineArgs;
use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
use dynamo_runtime::protocols::EndpointId;
......@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>,
router_event_threads: u32,
router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str,
remote_indexer_component: Option<String>,
) -> Self {
......@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads,
router_enable_cache_control,
skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}")
}),
......@@ -106,6 +108,13 @@ impl KvRouterConfig {
},
}
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
serde_json::from_str::<RsKvRouterConfig>(config_json)
.map(|inner| KvRouterConfig { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
}
}
#[pyclass]
......@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>,
namespace_prefix: Option<String>,
......@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
......@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>,
namespace_prefix: Option<String>,
......@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path,
tls_key_path,
extra_engine_args,
mocker_engine_args,
runtime_config,
namespace,
namespace_prefix,
......@@ -419,8 +431,10 @@ async fn select_engine(
}
}
EngineType::Mocker => {
let mut mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
let mut mocker_args = if let Some(mocker_engine_args) = args.mocker_engine_args {
mocker_engine_args.inner()
} else if let Some(extra_args_path) = args.extra_engine_args {
RsMockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
anyhow::anyhow!(
"Failed to load mocker args from {:?}: {}",
extra_args_path,
......@@ -431,7 +445,7 @@ async fn select_engine(
tracing::warn!(
"No extra_engine_args specified for mocker engine. Using default mocker args."
);
MockEngineArgs::default()
RsMockEngineArgs::default()
};
// If aic_backend is set, create Python AIC callback and override perf_model
......@@ -503,84 +517,6 @@ pub fn run_input<'p>(
})
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, num_workers=1, replay_concurrency=None))]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<PathBuf>,
num_workers: usize,
replay_concurrency: Option<isize>,
) -> PyResult<PyObject> {
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let mut args = if let Some(ref extra_args_path) = extra_engine_args {
MockEngineArgs::from_json_file(extra_args_path).map_err(|e| {
PyException::new_err(format!(
"Failed to load mocker args from {:?}: {}",
extra_args_path, e
))
})?
} else {
MockEngineArgs::default()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
let report = py.allow_threads(move || {
let replay_concurrency = replay_concurrency
.map(usize::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("replay_concurrency must be at least 1"))?;
if let Some(max_in_flight) = replay_concurrency {
dynamo_mocker::simulation::simulate_concurrency_file(
args,
&trace_file,
max_in_flight,
num_workers,
)
} else {
dynamo_mocker::simulation::simulate_trace_file(args, &trace_file, num_workers)
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
pub fn to_pyerr<E>(err: E) -> PyErr
where
E: Display,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::path::PathBuf;
use std::sync::Arc;
use dynamo_mocker::common::perf_model::PerfModel;
use dynamo_mocker::common::protocols::{
DirectRequest, EngineType as RsMockerEngineType, MockEngineArgs as RsMockEngineArgs,
PreemptionMode as RsPreemptionMode, ReasoningConfig as RsReasoningConfig,
SglangArgs as RsSglangArgs, WorkerType as RsWorkerType,
};
use pyo3::{exceptions::PyException, prelude::*};
use pythonize::pythonize;
use uuid::Uuid;
use super::aic_callback::create_aic_callback;
use super::entrypoint::{KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type {
"vllm" => Ok(RsMockerEngineType::Vllm),
"sglang" => Ok(RsMockerEngineType::Sglang),
other => Err(PyException::new_err(format!(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn parse_worker_type(worker_type: &str) -> PyResult<RsWorkerType> {
match worker_type {
"aggregated" => Ok(RsWorkerType::Aggregated),
"prefill" => Ok(RsWorkerType::Prefill),
"decode" => Ok(RsWorkerType::Decode),
other => Err(PyException::new_err(format!(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn parse_preemption_mode(preemption_mode: &str) -> PyResult<RsPreemptionMode> {
match preemption_mode {
"lifo" => Ok(RsPreemptionMode::Lifo),
"fifo" => Ok(RsPreemptionMode::Fifo),
other => Err(PyException::new_err(format!(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct ReasoningConfig {
inner: RsReasoningConfig,
}
impl ReasoningConfig {
pub fn inner(&self) -> RsReasoningConfig {
self.inner.clone()
}
}
#[pymethods]
impl ReasoningConfig {
#[new]
fn new(
start_thinking_token_id: u32,
end_thinking_token_id: u32,
thinking_ratio: f64,
) -> PyResult<Self> {
let inner = RsReasoningConfig {
start_thinking_token_id,
end_thinking_token_id,
thinking_ratio,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct SglangArgs {
inner: RsSglangArgs,
}
impl SglangArgs {
pub fn inner(&self) -> RsSglangArgs {
self.inner.clone()
}
}
#[pymethods]
impl SglangArgs {
#[new]
#[pyo3(signature = (schedule_policy=None, page_size=None, max_prefill_tokens=None, chunked_prefill_size=None, clip_max_new_tokens=None, schedule_conservativeness=None))]
fn new(
schedule_policy: Option<String>,
page_size: Option<usize>,
max_prefill_tokens: Option<usize>,
chunked_prefill_size: Option<usize>,
clip_max_new_tokens: Option<usize>,
schedule_conservativeness: Option<f64>,
) -> PyResult<Self> {
let inner = RsSglangArgs {
schedule_policy,
page_size,
max_prefill_tokens,
chunked_prefill_size,
clip_max_new_tokens,
schedule_conservativeness,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct MockEngineArgs {
inner: RsMockEngineArgs,
}
impl MockEngineArgs {
pub fn inner(&self) -> RsMockEngineArgs {
self.inner.clone()
}
}
#[pymethods]
impl MockEngineArgs {
#[new]
#[pyo3(signature = (engine_type="vllm", num_gpu_blocks=16384, block_size=0, max_num_seqs=Some(256), max_num_batched_tokens=Some(8192), enable_prefix_caching=true, enable_chunked_prefill=true, speedup_ratio=1.0, decode_speedup_ratio=1.0, dp_size=1, startup_time=None, worker_type="aggregated", aic_backend=None, aic_system=None, aic_backend_version=None, aic_tp_size=None, aic_model_path=None, enable_local_indexer=false, bootstrap_port=None, kv_bytes_per_token=None, kv_transfer_bandwidth=None, reasoning=None, zmq_kv_events_port=None, zmq_replay_port=None, preemption_mode="lifo", router_queue_policy=None, sglang=None))]
#[allow(clippy::too_many_arguments)]
fn new(
engine_type: &str,
num_gpu_blocks: usize,
block_size: usize,
max_num_seqs: Option<usize>,
max_num_batched_tokens: Option<usize>,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
speedup_ratio: f64,
decode_speedup_ratio: f64,
dp_size: u32,
startup_time: Option<f64>,
worker_type: &str,
aic_backend: Option<String>,
aic_system: Option<String>,
aic_backend_version: Option<String>,
aic_tp_size: Option<usize>,
aic_model_path: Option<String>,
enable_local_indexer: bool,
bootstrap_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
kv_transfer_bandwidth: Option<f64>,
reasoning: Option<ReasoningConfig>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
preemption_mode: &str,
router_queue_policy: Option<&str>,
sglang: Option<SglangArgs>,
) -> PyResult<Self> {
let engine_type = parse_mocker_engine_type(engine_type)?;
let worker_type = parse_worker_type(worker_type)?;
let preemption_mode = parse_preemption_mode(preemption_mode)?;
let router_queue_policy = router_queue_policy
.map(|value| {
value.parse().map_err(|e: String| {
PyException::new_err(format!("invalid router_queue_policy {value:?}: {e}"))
})
})
.transpose()?;
let inner = RsMockEngineArgs::builder()
.engine_type(engine_type)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.max_num_seqs(max_num_seqs)
.max_num_batched_tokens(max_num_batched_tokens)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(decode_speedup_ratio)
.dp_size(dp_size)
.startup_time(startup_time)
.worker_type(worker_type)
.aic_backend(aic_backend)
.aic_system(aic_system)
.aic_backend_version(aic_backend_version)
.aic_tp_size(aic_tp_size)
.aic_model_path(aic_model_path)
.enable_local_indexer(enable_local_indexer)
.bootstrap_port(bootstrap_port)
.kv_bytes_per_token(kv_bytes_per_token)
.kv_transfer_bandwidth(kv_transfer_bandwidth)
.reasoning(reasoning.map(|config| config.inner()))
.zmq_kv_events_port(zmq_kv_events_port)
.zmq_replay_port(zmq_replay_port)
.preemption_mode(preemption_mode)
.router_queue_policy(router_queue_policy)
.sglang(sglang.map(|config| config.inner()))
.build()
.map_err(|e| PyException::new_err(format!("Failed to build MockEngineArgs: {e}")))?
.normalized()
.map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs: {e}"))
})?;
Ok(Self { inner })
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
RsMockEngineArgs::from_json_str(config_json)
.map(|inner| Self { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse MockEngineArgs JSON: {e}")))
}
#[getter]
fn block_size(&self) -> usize {
self.inner.block_size
}
#[getter]
fn num_gpu_blocks(&self) -> usize {
self.inner.num_gpu_blocks
}
#[getter]
fn max_num_seqs(&self) -> Option<usize> {
self.inner.max_num_seqs
}
#[getter]
fn max_num_batched_tokens(&self) -> Option<usize> {
self.inner.max_num_batched_tokens
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn dp_size(&self) -> u32 {
self.inner.dp_size
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner.bootstrap_port
}
fn is_prefill(&self) -> bool {
self.inner.is_prefill()
}
fn is_decode(&self) -> bool {
self.inner.is_decode()
}
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None))]
fn with_overrides(
&self,
bootstrap_port: Option<u16>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
) -> PyResult<Self> {
let mut inner = self.inner.clone();
if let Some(port) = bootstrap_port {
inner.bootstrap_port = Some(port);
}
if let Some(port) = zmq_kv_events_port {
inner.zmq_kv_events_port = Some(port);
}
if let Some(port) = zmq_replay_port {
inner.zmq_replay_port = Some(port);
}
if let Some(bytes_per_token) = kv_bytes_per_token {
inner.kv_bytes_per_token = Some(bytes_per_token);
}
inner.normalized().map(|inner| Self { inner }).map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}"))
})
}
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
#[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>,
input_tokens: usize,
output_tokens: usize,
request_count: usize,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
arrival_interval_ms: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
let requests = build_synthetic_requests(
input_tokens,
output_tokens,
request_count,
arrival_interval_ms,
replay_concurrency.is_none(),
)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => {
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
fn load_replay_mocker_args(
py: Python<'_>,
extra_engine_args: Option<MockEngineArgs>,
) -> PyResult<RsMockEngineArgs> {
let mut args = match extra_engine_args {
Some(extra_args) => extra_args.inner(),
None => RsMockEngineArgs::default(),
};
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
Ok(args)
}
fn load_replay_router_config(
router_config: Option<KvRouterConfig>,
) -> Option<dynamo_kv_router::config::KvRouterConfig> {
router_config.map(|config| config.inner())
}
fn parse_replay_router_mode(
router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
match router_mode {
"round_robin" => Ok(dynamo_mocker::replay::ReplayRouterMode::RoundRobin),
"kv_router" => Ok(dynamo_mocker::replay::ReplayRouterMode::KvRouter),
other => Err(PyException::new_err(format!(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'",
other
))),
}
}
fn parse_replay_concurrency(replay_concurrency: Option<isize>) -> anyhow::Result<Option<usize>> {
match replay_concurrency {
Some(value) if value < 1 => anyhow::bail!("replay_concurrency must be at least 1"),
Some(value) => Ok(Some(value as usize)),
None => Ok(None),
}
}
fn build_synthetic_requests(
input_tokens: usize,
output_tokens: usize,
request_count: usize,
arrival_interval_ms: f64,
include_arrival_timestamps: bool,
) -> anyhow::Result<Vec<DirectRequest>> {
if input_tokens == 0 {
anyhow::bail!("input_tokens must be at least 1");
}
if output_tokens == 0 {
anyhow::bail!("output_tokens must be at least 1");
}
if request_count == 0 {
anyhow::bail!("request_count must be at least 1");
}
if !arrival_interval_ms.is_finite() || arrival_interval_ms < 0.0 {
anyhow::bail!(
"arrival_interval_ms must be a finite non-negative number, got {}",
arrival_interval_ms
);
}
let mut requests = Vec::with_capacity(request_count);
for request_idx in 0..request_count {
let tokens = (0..input_tokens)
.map(|token_idx| synthetic_token_id(request_idx, token_idx))
.collect();
requests.push(DirectRequest {
tokens,
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128((request_idx as u128) + 1)),
dp_rank: 0,
arrival_timestamp_ms: include_arrival_timestamps
.then_some(request_idx as f64 * arrival_interval_ms),
});
}
Ok(requests)
}
fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let mut value =
(((request_idx as u64) << 32) ^ (token_idx as u64)).wrapping_add(0x9E37_79B9_7F4A_7C15);
value ^= value >> 30;
value = value.wrapping_mul(0xBF58_476D_1CE4_E5B9);
value ^= value >> 27;
value = value.wrapping_mul(0x94D0_49BB_1331_11EB);
value ^= value >> 31;
let token = value as u32;
if token == 0 { 1 } else { token }
}
......@@ -3,7 +3,17 @@
import asyncio
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
......@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = 2.0,
router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4,
router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs",
) -> None:
"""
......@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 2.0).
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 4.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler).
......@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
"""
...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model(
model_input: ModelInput,
model_type: ModelType,
......@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay(
trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report."""
"""Replay a synthetic mocker workload without requiring a trace file."""
...
class Layer:
......@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None,
......@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix
......
......@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import MockEngineArgs as MockEngineArgs
from dynamo._core import ModelCardInstanceId as ModelCardInstanceId
from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
......@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig
from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode
from dynamo._core import SglangArgs as SglangArgs
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model
......@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from dynamo._core import register_model as register_model
from dynamo._core import run_input
from dynamo._core import run_kv_indexer as run_kv_indexer
from dynamo._core import run_mocker_trace_replay
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
from dynamo._core import unregister_model as unregister_model
from .exceptions import HttpError
......@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm = fetch_model
register_llm = register_model
unregister_llm = unregister_model
def run_mocker_trace_replay(
trace_file,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode="offline",
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.api import run_synthetic_trace_replay, run_trace_replay
__all__ = ["run_synthetic_trace_replay", "run_trace_replay"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.main import main
if __name__ == "__main__":
raise SystemExit(main())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo._core import (
run_mocker_synthetic_trace_replay as _run_mocker_synthetic_trace_replay,
)
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
def run_trace_replay(
trace_file,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
def run_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
arrival_interval_ms=1.0,
):
return _run_mocker_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
arrival_interval_ms=arrival_interval_ms,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import json
import os
import sys
from collections.abc import Sequence
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
parser.add_argument("--extra-engine-args")
parser.add_argument("--router-config")
parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int)
parser.add_argument("--request-count", type=int)
parser.add_argument("--arrival-interval-ms", type=float, default=1.0)
parser.add_argument("--num-workers", type=int, default=1)
parser.add_argument("--replay-concurrency", type=int)
parser.add_argument(
"--replay-mode",
choices=("offline", "online"),
default="offline",
)
parser.add_argument(
"--router-mode",
choices=("round_robin", "kv_router"),
default="round_robin",
)
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
synthetic_args = (args.input_tokens, args.output_tokens, args.request_count)
using_synthetic = any(value is not None for value in synthetic_args)
if using_trace_file == using_synthetic:
parser.error(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if using_synthetic and not all(value is not None for value in synthetic_args):
parser.error(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args = (
MockEngineArgs.from_json(args.extra_engine_args)
if args.extra_engine_args is not None
else None
)
router_config = (
KvRouterConfig.from_json(args.router_config)
if args.router_config is not None
else None
)
if using_trace_file:
report = run_trace_replay(
args.trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
)
else:
report = run_synthetic_trace_replay(
args.input_tokens,
args.output_tokens,
args.request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
arrival_interval_ms=args.arrival_interval_ms,
)
json.dump(report, sys.stdout, indent=2, sort_keys=True)
sys.stdout.write("\n")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import pytest
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.pre_merge,
]
MOONCAKE_TRACE_FIRST20 = """{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def _write_trace_and_args(tmp_path):
trace_path = tmp_path / "trace.jsonl"
records = [
{
"timestamp": 1000.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
{
"timestamp": 1005.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
]
trace_path.write_text(
"\n".join(json.dumps(record) for record in records) + "\n",
encoding="utf-8",
)
return trace_path
def _write_vllm_args(tmp_path):
args_path = tmp_path / "args.json"
args_path.write_text(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
),
encoding="utf-8",
)
return args_path
def _vllm_args():
return MockEngineArgs.from_json(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
)
)
def _write_sglang_args(tmp_path):
args_path = tmp_path / "sglang_args.json"
args_path.write_text(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
),
encoding="utf-8",
)
return args_path
def _sglang_args():
return MockEngineArgs.from_json(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
)
)
def _write_router_config(tmp_path):
config_path = tmp_path / "router_config.json"
config_path.write_text(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
),
encoding="utf-8",
)
return config_path
def _router_config():
return KvRouterConfig.from_json(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
)
)
def _partial_router_config():
return KvRouterConfig(
router_queue_threshold=1.25,
router_event_threads=1,
router_queue_policy="wspt",
)
def _assert_basic_report_counts(report, *, num_requests, input_tokens, output_tokens):
assert report["num_requests"] == num_requests
assert report["completed_requests"] == num_requests
assert report["total_input_tokens"] == num_requests * input_tokens
assert report["total_output_tokens"] == num_requests * output_tokens
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_trace_replay_smoke_matrix(tmp_path, engine_type, replay_mode, router_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_invariant_counts_match(tmp_path, engine_type, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
)
multi_round_robin = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
)
multi_kv_router = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_synthetic_trace_replay_smoke_matrix(
tmp_path, engine_type, replay_mode, router_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_interval_ms=5.0,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_trace_replay_invariant_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
arrival_interval_ms=5.0,
)
multi_round_robin = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
arrival_interval_ms=5.0,
)
multi_kv_router = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
arrival_interval_ms=5.0,
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_concurrency_replay_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
report = run_synthetic_trace_replay(
64,
2,
3,
extra_engine_args=args_path,
num_workers=2,
replay_mode=replay_mode,
replay_concurrency=2,
)
_assert_basic_report_counts(
report,
num_requests=3,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_router_config(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
router_config_path = _router_config()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=router_config_path,
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_router_config_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=_partial_router_config(),
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_extra_engine_args_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
report = run_trace_replay(
trace_path,
extra_engine_args=MockEngineArgs(block_size=64, speedup_ratio=1000.0),
num_workers=1,
replay_mode=replay_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
......@@ -245,1392 +245,1412 @@ async fn flush_and_settle(index: &dyn KvIndexerInterface) {
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_store_and_find(variant: &str) {
let index = make_indexer(variant);
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
mod interface_tests {
use super::*;
use rstest_reuse::apply;
#[tokio::test]
#[apply(indexer_template)]
async fn test_store_and_find(variant: &str) {
let index = make_indexer(variant);
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
// Find matches using local hashes
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
let index = make_indexer(variant);
// Find matches using local hashes
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
// Store [1, 2, 3] for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Store [1, 2, 3] for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Find matches for [1, 2, 999] - should match first 2 then stop
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(999),
])
.await
.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove(variant: &str) {
let index = make_indexer(variant);
// Find matches for [1, 2, 999] - should match first 2 then stop
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(999),
])
.await
.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
// Store sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove(variant: &str) {
let index = make_indexer(variant);
// Remove all blocks
index.apply_event(make_remove_event(0, &[1, 2, 3])).await;
// Store sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
// Remove all blocks
index.apply_event(make_remove_event(0, &[1, 2, 3])).await;
// Find should return nothing
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_workers_shared_prefix(variant: &str) {
let index = make_indexer(variant);
// Find should return nothing
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).await;
index.apply_event(make_store_event(1, &[1, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_workers_shared_prefix(variant: &str) {
let index = make_indexer(variant);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).await;
index.apply_event(make_store_event(1, &[1, 3])).await;
flush_and_settle(index.as_ref()).await;
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)]).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_worker(variant: &str) {
let index = make_indexer(variant);
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)]).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
// Allow time for async event processing
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_worker(variant: &str) {
let index = make_indexer(variant);
index.remove_worker(0).await;
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Allow time for async remove_worker processing
flush_and_settle(index.as_ref()).await;
// Allow time for async event processing
flush_and_settle(index.as_ref()).await;
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
index.remove_worker(0).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_large_stores(variant: &str) {
let index = make_indexer(variant);
// Allow time for async remove_worker processing
flush_and_settle(index.as_ref()).await;
// Test sequences of increasing sizes
for i in 0..10u64 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
flush_and_settle(index.as_ref()).await;
// Verify we can find matches for the last stored sequence
let last_seq: Vec<LocalBlockHash> = (1..=512u64)
.map(|x| LocalBlockHash(x + (9 * 10000)))
.collect();
let scores = index.find_matches(last_seq).await.unwrap();
assert!(!scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_dump_and_restore(variant: &str) {
let index = make_indexer(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 4])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_large_stores(variant: &str) {
let index = make_indexer(variant);
// Allow background worker threads to process events.
flush_and_settle(index.as_ref()).await;
// Test sequences of increasing sizes
for i in 0..10u64 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
}
// Dump the tree as events and replay into a new index
let events = index.dump_events().await.unwrap();
assert!(!events.is_empty());
flush_and_settle(index.as_ref()).await;
let restored = make_indexer(variant);
for event in events {
restored.apply_event(event).await;
// Verify we can find matches for the last stored sequence
let last_seq: Vec<LocalBlockHash> = (1..=512u64)
.map(|x| LocalBlockHash(x + (9 * 10000)))
.collect();
let scores = index.find_matches(last_seq).await.unwrap();
assert!(!scores.scores.is_empty());
}
flush_and_settle(restored.as_ref()).await;
assert_eq!(
snapshot_tree(index.as_ref()).await,
snapshot_tree(restored.as_ref()).await
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_dump_and_restore(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_all_blocks(variant: &str) {
let index = make_indexer(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 4])).await;
// Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Allow background worker threads to process events.
flush_and_settle(index.as_ref()).await;
// Clear worker 0's blocks using the Cleared event
index.apply_event(make_clear_event(0)).await;
// Dump the tree as events and replay into a new index
let events = index.dump_events().await.unwrap();
assert!(!events.is_empty());
flush_and_settle(index.as_ref()).await;
let restored = make_indexer(variant);
for event in events {
restored.apply_event(event).await;
}
// Worker 0's blocks should be gone, worker 1's remain
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
flush_and_settle(restored.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_empty_query(variant: &str) {
let index = make_indexer(variant);
assert_eq!(
snapshot_tree(index.as_ref()).await,
snapshot_tree(restored.as_ref()).await
);
}
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_all_blocks(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Empty query should return empty scores
let scores = index.find_matches(vec![]).await.unwrap();
assert!(scores.scores.is_empty());
}
// Clear worker 0's blocks using the Cleared event
index.apply_event(make_clear_event(0)).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_miss_query(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Worker 0's blocks should be gone, worker 1's remain
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_empty_query(variant: &str) {
let index = make_indexer(variant);
// Query for non-existent blocks
let scores = index
.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(variant: &str) {
let index = make_indexer(variant);
index.shutdown();
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown_idempotent(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
index.shutdown();
index.shutdown();
}
// Empty query should return empty scores
let scores = index.find_matches(vec![]).await.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) {
let index = make_indexer(variant);
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_miss_query(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_process_routing_decision(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Create tokens with hashes
let tokens = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, 32);
flush_and_settle(index.as_ref()).await;
let worker = WorkerWithDpRank::new(0, 0);
// Query for non-existent blocks
let scores = index
.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
// Process routing decision - should not error
let result = index
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await;
assert!(result.is_ok());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(variant: &str) {
let index = make_indexer(variant);
index.shutdown();
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_parent_hash_chains(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown_idempotent(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
index.shutdown();
index.shutdown();
}
// Store initial sequence [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) {
let index = make_indexer(variant);
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
}
// Store continuation [4, 5] with parent pointing to block 3
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_process_routing_decision(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Create tokens with hashes
let tokens = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, 32);
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let full_seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(full_seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
let worker = WorkerWithDpRank::new(0, 0);
// Query for just [1, 2, 3] should match 3 blocks
let prefix_seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
// Process routing decision - should not error
let result = index
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await;
assert!(result.is_ok());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_dp_ranks(variant: &str) {
let index = make_indexer(variant);
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 2))
.await;
flush_and_settle(index.as_ref()).await;
// Query should return all 3 dp_ranks as separate entries
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 1)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 2)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_parent_hash_chains(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_block_removal(variant: &str) {
let index = make_indexer(variant);
// Store initial sequence [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Store continuation [4, 5] with parent pointing to block 3
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
flush_and_settle(index.as_ref()).await;
flush_and_settle(index.as_ref()).await;
// Verify all 3 blocks match
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let full_seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(full_seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Remove only the last block (block 3)
// To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let full_hashes: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
// Query for just [1, 2, 3] should match 3 blocks
let prefix_seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_dp_ranks(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 2))
.await;
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
flush_and_settle(index.as_ref()).await;
// Query [1, 2] - should still match 2 blocks
let partial_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
// Query should return all 3 dp_ranks as separate entries
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_mid_chain_block(variant: &str) {
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if variant == "flat" {
return;
assert_eq!(scores.scores.len(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 1)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 2)).unwrap(), 3);
}
let index = make_indexer(variant);
// Store [1, 2, 3, 4, 5]
index
.apply_event(make_store_event(0, &[1, 2, 3, 4, 5]))
.await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_partial_block_removal(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Verify all 5 blocks match
let seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
flush_and_settle(index.as_ref()).await;
// Remove only block 3 (index 2) — the middle of the chain
let full_hashes: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
// Verify all 3 blocks match
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
// Remove only the last block (block 3)
// To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let full_hashes: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
flush_and_settle(index.as_ref()).await;
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
// Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5)
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
flush_and_settle(index.as_ref()).await;
// Query [1, 2] — prefix before the gap is still intact
let prefix_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Re-store block 3 as a continuation of [1, 2]
index
.apply_event(make_store_event_with_parent(0, &[1, 2], &[3]))
.await;
// Query [1, 2] - should still match 2 blocks
let partial_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_mid_chain_block(variant: &str) {
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if variant == "flat" {
return;
}
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_worker(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3, 4, 5]
index
.apply_event(make_store_event(0, &[1, 2, 3, 4, 5]))
.await;
// Store data for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
flush_and_settle(index.as_ref()).await;
// Verify all 5 blocks match
let seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Remove non-existent worker 999 - should not error or affect worker 0
index.remove_worker(999).await;
// Remove only block 3 (index 2) — the middle of the chain
let full_hashes: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
// Worker 0's data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_blocks(variant: &str) {
let index = make_indexer(variant);
// Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5)
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Query [1, 2] — prefix before the gap is still intact
let prefix_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Try to remove blocks [999, 998] that don't exist - should not error
index.apply_event(make_remove_event(0, &[999, 998])).await;
// Re-store block 3 as a continuation of [1, 2]
index
.apply_event(make_store_event_with_parent(0, &[1, 2], &[3]))
.await;
flush_and_settle(index.as_ref()).await;
flush_and_settle(index.as_ref()).await;
// Original data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_then_reuse(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_worker(variant: &str) {
let index = make_indexer(variant);
// Store initial data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Store data for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Clear the worker
index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await;
flush_and_settle(index.as_ref()).await;
// Remove non-existent worker 999 - should not error or affect worker 0
index.remove_worker(999).await;
// Verify data is gone
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
// Store new data for the same worker
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Worker 0's data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_blocks(variant: &str) {
let index = make_indexer(variant);
// Verify new data is accessible
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
// Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_sequences_per_worker(variant: &str) {
let index = make_indexer(variant);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event(make_store_event(0, &[100, 101, 102]))
.await;
flush_and_settle(index.as_ref()).await;
// Query first sequence
let seq1: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq1).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query second sequence
let seq2: Vec<LocalBlockHash> = (100..=102).map(LocalBlockHash).collect();
let scores = index.find_matches(seq2).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query a mix that doesn't exist as a sequence - should only match first block
let mixed: Vec<LocalBlockHash> = vec![LocalBlockHash(1), LocalBlockHash(100)];
let scores = index.find_matches(mixed).await.unwrap();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
}
// Try to remove blocks [999, 998] that don't exist - should not error
index.apply_event(make_remove_event(0, &[999, 998])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_clears_all_dp_ranks(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Store same sequence for different dp_ranks
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
// Original data should still be there
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_then_reuse(variant: &str) {
let index = make_indexer(variant);
// Verify both dp_ranks are present
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(scores.scores.len(), 2);
// Store initial data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Clear event clears ALL blocks for the worker_id, regardless of dp_rank
index.apply_event(make_clear_event_with_dp_rank(0, 0)).await;
// Clear the worker
index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await;
flush_and_settle(index.as_ref()).await;
// Both dp_ranks should be cleared
let scores = index.find_matches(seq).await.unwrap();
assert!(
scores.scores.is_empty(),
"Cleared event should clear all dp_ranks for a worker"
);
}
// Verify data is gone
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// ============================================================================
// LoRA isolation tests
// ============================================================================
// Store new data for the same worker
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
flush_and_settle(index.as_ref()).await;
// Same token sequence for both base model and LoRA adapter
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// Verify new data is accessible
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
#[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_sequences_per_worker(variant: &str) {
let index = make_indexer(variant);
// Hashes must differ despite identical tokens
assert_ne!(
base_hashes, lora_hashes,
"Base and LoRA hashes must differ for the same tokens"
);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event(make_store_event(0, &[100, 101, 102]))
.await;
let base_seq = compute_seq_hash_for_block(&base_hashes);
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
flush_and_settle(index.as_ref()).await;
// Store base-model blocks on worker 0
let base_event = router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}),
);
index.apply_event(base_event).await;
// Query first sequence
let seq1: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq1).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Store LoRA blocks on worker 1
let lora_event = router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}),
);
index.apply_event(lora_event).await;
// Query second sequence
let seq2: Vec<LocalBlockHash> = (100..=102).map(LocalBlockHash).collect();
let scores = index.find_matches(seq2).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
flush_and_settle(index.as_ref()).await;
// Query a mix that doesn't exist as a sequence - should only match first block
let mixed: Vec<LocalBlockHash> = vec![LocalBlockHash(1), LocalBlockHash(100)];
let scores = index.find_matches(mixed).await.unwrap();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
}
// Query with base-model hashes → only worker 0
let base_scores = index.find_matches(base_hashes.clone()).await.unwrap();
assert_eq!(
base_scores.scores.len(),
1,
"Only base-model worker should match"
);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
#[tokio::test]
#[apply(indexer_template)]
async fn test_clear_clears_all_dp_ranks(variant: &str) {
let index = make_indexer(variant);
// Query with LoRA hashes → only worker 1
let lora_scores = index.find_matches(lora_hashes.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1, "Only LoRA worker should match");
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
// Store same sequence for different dp_ranks
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
assert_ne!(
base_local, lora_local,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
flush_and_settle(index.as_ref()).await;
// Simulate what vLLM does: same tokens, different rolling seq hashes
// because the engine accounts for the adapter internally.
let base_seq = compute_seq_hash_for_block(&base_local);
let lora_seq = compute_seq_hash_for_block(&lora_local);
// Verify both dp_ranks are present
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(scores.scores.len(), 2);
// Worker 0: base model
index
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
}),
))
.await;
// Clear event clears ALL blocks for the worker_id, regardless of dp_rank
index.apply_event(make_clear_event_with_dp_rank(0, 0)).await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
// Base query finds only worker 0
let base_scores = index.find_matches(base_local.clone()).await.unwrap();
assert_eq!(base_scores.scores.len(), 1);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
flush_and_settle(index.as_ref()).await;
// LoRA query finds only worker 1
let lora_scores = index.find_matches(lora_local.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1);
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
// Both dp_ranks should be cleared
let scores = index.find_matches(seq).await.unwrap();
assert!(
scores.scores.is_empty(),
"Cleared event should clear all dp_ranks for a worker"
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
// ============================================================================
// LoRA isolation tests
// ============================================================================
let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
mod lora_tests {
use super::*;
use rstest_reuse::apply;
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a"));
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b"));
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
assert_ne!(
hashes_a, hashes_b,
"Different adapters must produce different hashes"
);
// Same token sequence for both base model and LoRA adapter
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
let seq_a = compute_seq_hash_for_block(&hashes_a);
let seq_b = compute_seq_hash_for_block(&hashes_b);
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_hashes =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
// Store adapter-a blocks on worker 0
index
.apply_event(router_event(
// Hashes must differ despite identical tokens
assert_ne!(
base_hashes, lora_hashes,
"Base and LoRA hashes must differ for the same tokens"
);
let base_seq = compute_seq_hash_for_block(&base_hashes);
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0
let base_event = router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}),
))
.await;
);
index.apply_event(base_event).await;
// Store adapter-b blocks on worker 1
index
.apply_event(router_event(
// Store LoRA blocks on worker 1
let lora_event = router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
// Query adapter-a → only worker 0
let scores_a = index.find_matches(hashes_a.clone()).await.unwrap();
assert_eq!(scores_a.scores.len(), 1);
assert!(scores_a.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
assert!(!scores_a.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
// Query adapter-b → only worker 1
let scores_b = index.find_matches(hashes_b.clone()).await.unwrap();
assert_eq!(scores_b.scores.len(), 1);
assert!(scores_b.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
assert!(!scores_b.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
);
index.apply_event(lora_event).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_single_store(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence (128 blocks) in a single event
let seq_len = 128;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Query full sequence - should match all blocks
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
flush_and_settle(index.as_ref()).await;
// Query prefix (first 64 blocks)
let prefix_query: Vec<LocalBlockHash> = (1..=64).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
64
);
// Query with base-model hashes → only worker 0
let base_scores = index.find_matches(base_hashes.clone()).await.unwrap();
assert_eq!(
base_scores.scores.len(),
1,
"Only base-model worker should match"
);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// Query with divergence at position 50
let mut divergent_query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
divergent_query[49] = LocalBlockHash(99999); // Position 49 (0-indexed) diverges
let scores = index.find_matches(divergent_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
49
);
}
// Query with LoRA hashes → only worker 1
let lora_scores = index.find_matches(lora_hashes.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1, "Only LoRA worker should match");
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_continuations(variant: &str) {
let index = make_indexer(variant);
// Build a long sequence through multiple continuations
// First store: blocks 1-50
let first_chunk: Vec<u64> = (1..=50).collect();
index.apply_event(make_store_event(0, &first_chunk)).await;
// Second store: blocks 51-100 (continuation of first)
let second_chunk: Vec<u64> = (51..=100).collect();
index
.apply_event(make_store_event_with_parent(0, &first_chunk, &second_chunk))
.await;
// Third store: blocks 101-150 (continuation of second)
let prefix_1_2: Vec<u64> = (1..=100).collect();
let third_chunk: Vec<u64> = (101..=150).collect();
index
.apply_event(make_store_event_with_parent(0, &prefix_1_2, &third_chunk))
.await;
flush_and_settle(index.as_ref()).await;
// Query full sequence - should match all 150 blocks
let full_query: Vec<LocalBlockHash> = (1..=150).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
150
);
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
assert_ne!(
base_local, lora_local,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
// Query crossing continuation boundaries
let cross_boundary_query: Vec<LocalBlockHash> = (45..=105).map(LocalBlockHash).collect();
let scores = index.find_matches(cross_boundary_query).await.unwrap();
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert!(scores.scores.is_empty() || !scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
// Simulate what vLLM does: same tokens, different rolling seq hashes
// because the engine accounts for the adapter internally.
let base_seq = compute_seq_hash_for_block(&base_local);
let lora_seq = compute_seq_hash_for_block(&lora_local);
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_branching_continuations(variant: &str) {
let index = make_indexer(variant);
// Common prefix: blocks 1-30
let common_prefix: Vec<u64> = (1..=30).collect();
index.apply_event(make_store_event(0, &common_prefix)).await;
// Branch A: blocks 31-60 on worker 0
let branch_a: Vec<u64> = (31..=60).collect();
index
.apply_event(make_store_event_with_parent(0, &common_prefix, &branch_a))
.await;
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index.apply_event(make_store_event(1, &common_prefix)).await;
let branch_b: Vec<u64> = (131..=160).collect();
index
.apply_event(make_store_event_with_parent(1, &common_prefix, &branch_b))
.await;
flush_and_settle(index.as_ref()).await;
// Query common prefix - both workers should match
let prefix_query: Vec<LocalBlockHash> = (1..=30).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
30
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
// Worker 0: base model
index
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
}),
))
.await;
// Query branch A path - only worker 0 should match fully
let branch_a_query: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(branch_a_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
}
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
}),
))
.await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_partial_removal(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Base query finds only worker 0
let base_scores = index.find_matches(base_local.clone()).await.unwrap();
assert_eq!(base_scores.scores.len(), 1);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
flush_and_settle(index.as_ref()).await;
// LoRA query finds only worker 1
let lora_scores = index.find_matches(lora_local.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1);
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
// Verify full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
#[tokio::test]
#[apply(indexer_template)]
async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
// Remove blocks 80-100 (the tail)
let tail_hashes: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&tail_hashes);
let remove_hashes: Vec<ExternalSequenceBlockHash> = seq_hashes[79..100]
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
let remove_event = remove_event(0, 0, 0, remove_hashes);
index.apply_event(remove_event).await;
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a"));
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b"));
flush_and_settle(index.as_ref()).await;
assert_ne!(
hashes_a, hashes_b,
"Different adapters must produce different hashes"
);
// Query should now only match first 79 blocks
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
79
);
}
let seq_a = compute_seq_hash_for_block(&hashes_a);
let seq_b = compute_seq_hash_for_block(&hashes_b);
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_interleaved_workers(variant: &str) {
let index = make_indexer(variant);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let seq_100: Vec<u64> = (1..=100).collect();
let seq_75: Vec<u64> = (1..=75).collect();
let seq_50: Vec<u64> = (1..=50).collect();
let seq_25: Vec<u64> = (1..=25).collect();
index.apply_event(make_store_event(0, &seq_100)).await;
index.apply_event(make_store_event(1, &seq_75)).await;
index.apply_event(make_store_event(2, &seq_50)).await;
index.apply_event(make_store_event(3, &seq_25)).await;
flush_and_settle(index.as_ref()).await;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let query_60: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(query_60).await.unwrap();
assert_eq!(scores.scores.len(), 4);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
50
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
25
);
}
// Store adapter-a blocks on worker 0
index
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
}),
))
.await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) {
let index = make_indexer(variant);
// Store adapter-b blocks on worker 1
index
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
}),
))
.await;
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
// This tests edge cases in the jump search algorithm
flush_and_settle(index.as_ref()).await;
// Store sequence of exactly 32 blocks
let seq_32: Vec<u64> = (1..=32).collect();
index.apply_event(make_store_event(0, &seq_32)).await;
// Query adapter-a → only worker 0
let scores_a = index.find_matches(hashes_a.clone()).await.unwrap();
assert_eq!(scores_a.scores.len(), 1);
assert!(scores_a.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
assert!(!scores_a.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
// Store sequence of exactly 64 blocks (2x jump_size)
let seq_64: Vec<u64> = (1001..=1064).collect();
index.apply_event(make_store_event(1, &seq_64)).await;
// Query adapter-b → only worker 1
let scores_b = index.find_matches(hashes_b.clone()).await.unwrap();
assert_eq!(scores_b.scores.len(), 1);
assert!(scores_b.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
assert!(!scores_b.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
}
// Store sequence of exactly 96 blocks (3x jump_size)
let seq_96: Vec<u64> = (2001..=2096).collect();
index.apply_event(make_store_event(2, &seq_96)).await;
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
flush_and_settle(index.as_ref()).await;
mod long_sequence_tests {
use super::*;
use rstest_reuse::apply;
// Verify all sequences match correctly
let query_32: Vec<LocalBlockHash> = seq_32.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_32).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
32
);
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_single_store(variant: &str) {
let index = make_indexer(variant);
let query_64: Vec<LocalBlockHash> = seq_64.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_64).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
64
);
// Store a long sequence (128 blocks) in a single event
let seq_len = 128;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
let query_96: Vec<LocalBlockHash> = seq_96.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_96).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
96
);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let seq_31: Vec<u64> = (1..=31).collect();
let seq_33: Vec<u64> = (101..=133).collect();
let seq_63: Vec<u64> = (201..=263).collect();
let seq_65: Vec<u64> = (301..=365).collect();
index.apply_event(make_store_event(0, &seq_31)).await;
index.apply_event(make_store_event(1, &seq_33)).await;
index.apply_event(make_store_event(2, &seq_63)).await;
index.apply_event(make_store_event(3, &seq_65)).await;
flush_and_settle(index.as_ref()).await;
// Verify all sequences match correctly
let query_31: Vec<LocalBlockHash> = seq_31.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_31).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
31
);
// Query full sequence - should match all blocks
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
let query_33: Vec<LocalBlockHash> = seq_33.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_33).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
33
);
// Query prefix (first 64 blocks)
let prefix_query: Vec<LocalBlockHash> = (1..=64).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
64
);
let query_63: Vec<LocalBlockHash> = seq_63.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_63).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
63
);
// Query with divergence at position 50
let mut divergent_query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
divergent_query[49] = LocalBlockHash(99999); // Position 49 (0-indexed) diverges
let scores = index.find_matches(divergent_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
49
);
}
let query_65: Vec<LocalBlockHash> = seq_65.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_65).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
65
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_continuations(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
// Build a long sequence through multiple continuations
// First store: blocks 1-50
let first_chunk: Vec<u64> = (1..=50).collect();
index.apply_event(make_store_event(0, &first_chunk)).await;
// Store a long sequence
let sequence: Vec<u64> = (1..=128).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Second store: blocks 51-100 (continuation of first)
let second_chunk: Vec<u64> = (51..=100).collect();
index
.apply_event(make_store_event_with_parent(0, &first_chunk, &second_chunk))
.await;
flush_and_settle(index.as_ref()).await;
// Third store: blocks 101-150 (continuation of second)
let prefix_1_2: Vec<u64> = (1..=100).collect();
let third_chunk: Vec<u64> = (101..=150).collect();
index
.apply_event(make_store_event_with_parent(0, &prefix_1_2, &third_chunk))
.await;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for diverge_pos in [31usize, 32, 33, 63, 64, 65, 95, 96, 97] {
let mut query: Vec<LocalBlockHash> = (1..=128).map(LocalBlockHash).collect();
query[diverge_pos] = LocalBlockHash(99999);
flush_and_settle(index.as_ref()).await;
let scores = index.find_matches(query).await.unwrap();
// Query full sequence - should match all 150 blocks
let full_query: Vec<LocalBlockHash> = (1..=150).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
diverge_pos as u32,
"Divergence at position {} should match {} blocks",
diverge_pos,
diverge_pos
150
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_deep_continuation_chain(variant: &str) {
let index = make_indexer(variant);
// Query crossing continuation boundaries
let cross_boundary_query: Vec<LocalBlockHash> = (45..=105).map(LocalBlockHash).collect();
let scores = index.find_matches(cross_boundary_query).await.unwrap();
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert!(
scores.scores.is_empty() || !scores.scores.contains_key(&WorkerWithDpRank::new(0, 0))
);
}
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let chunk_size = 10;
let num_chunks = 20; // Total 200 blocks
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_branching_continuations(variant: &str) {
let index = make_indexer(variant);
let mut full_prefix: Vec<u64> = Vec::new();
// Common prefix: blocks 1-30
let common_prefix: Vec<u64> = (1..=30).collect();
index.apply_event(make_store_event(0, &common_prefix)).await;
for chunk_idx in 0..num_chunks {
let chunk_start = chunk_idx * chunk_size + 1;
let chunk: Vec<u64> = (chunk_start..chunk_start + chunk_size)
.map(|x| x as u64)
.collect();
// Branch A: blocks 31-60 on worker 0
let branch_a: Vec<u64> = (31..=60).collect();
index
.apply_event(make_store_event_with_parent(0, &common_prefix, &branch_a))
.await;
if chunk_idx == 0 {
index.apply_event(make_store_event(0, &chunk)).await;
} else {
index
.apply_event(make_store_event_with_parent(0, &full_prefix, &chunk))
.await;
}
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index.apply_event(make_store_event(1, &common_prefix)).await;
let branch_b: Vec<u64> = (131..=160).collect();
index
.apply_event(make_store_event_with_parent(1, &common_prefix, &branch_b))
.await;
flush_and_settle(index.as_ref()).await;
full_prefix.extend(&chunk);
// Query common prefix - both workers should match
let prefix_query: Vec<LocalBlockHash> = (1..=30).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
30
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
// Query branch A path - only worker 0 should match fully
let branch_a_query: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(branch_a_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_partial_removal(variant: &str) {
let index = make_indexer(variant);
// Query full sequence
let full_query: Vec<LocalBlockHash> = (1..=200).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
200
);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Query partial prefix crossing multiple chunk boundaries
let partial_query: Vec<LocalBlockHash> = (1..=75).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
75
);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_clear_and_rebuild(variant: &str) {
let index = make_indexer(variant);
// Verify full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Remove blocks 80-100 (the tail)
let tail_hashes: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&tail_hashes);
let remove_hashes: Vec<ExternalSequenceBlockHash> = seq_hashes[79..100]
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
flush_and_settle(index.as_ref()).await;
let remove_event = remove_event(0, 0, 0, remove_hashes);
index.apply_event(remove_event).await;
// Verify it's stored
let query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
flush_and_settle(index.as_ref()).await;
// Clear the worker
index.apply_event(make_clear_event(0)).await;
// Query should now only match first 79 blocks
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
79
);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_interleaved_workers(variant: &str) {
let index = make_indexer(variant);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let seq_100: Vec<u64> = (1..=100).collect();
let seq_75: Vec<u64> = (1..=75).collect();
let seq_50: Vec<u64> = (1..=50).collect();
let seq_25: Vec<u64> = (1..=25).collect();
index.apply_event(make_store_event(0, &seq_100)).await;
index.apply_event(make_store_event(1, &seq_75)).await;
index.apply_event(make_store_event(2, &seq_50)).await;
index.apply_event(make_store_event(3, &seq_25)).await;
flush_and_settle(index.as_ref()).await;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let query_60: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(query_60).await.unwrap();
assert_eq!(scores.scores.len(), 4);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
50
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
25
);
}
// Verify it's cleared
let scores = index.find_matches(query.clone()).await.unwrap();
assert!(scores.scores.is_empty());
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) {
let index = make_indexer(variant);
// Rebuild with a different sequence
let new_sequence: Vec<u64> = (1001..=1100).collect();
index.apply_event(make_store_event(0, &new_sequence)).await;
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
// This tests edge cases in the jump search algorithm
flush_and_settle(index.as_ref()).await;
// Store sequence of exactly 32 blocks
let seq_32: Vec<u64> = (1..=32).collect();
index.apply_event(make_store_event(0, &seq_32)).await;
// Verify new sequence works
let new_query: Vec<LocalBlockHash> = new_sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(new_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Store sequence of exactly 64 blocks (2x jump_size)
let seq_64: Vec<u64> = (1001..=1064).collect();
index.apply_event(make_store_event(1, &seq_64)).await;
// Verify old sequence no longer matches
let scores = index.find_matches(query).await.unwrap();
assert!(scores.scores.is_empty());
}
// Store sequence of exactly 96 blocks (3x jump_size)
let seq_96: Vec<u64> = (2001..=2096).collect();
index.apply_event(make_store_event(2, &seq_96)).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// Verify all sequences match correctly
let query_32: Vec<LocalBlockHash> = seq_32.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_32).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
32
);
// All workers share prefix 1-40
let shared_prefix: Vec<u64> = (1..=40).collect();
let query_64: Vec<LocalBlockHash> = seq_64.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_64).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
64
);
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let worker_0_full: Vec<u64> = (1..=100).collect();
let query_96: Vec<LocalBlockHash> = seq_96.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_96).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
96
);
}
// Worker 1: prefix + 141-180 (diverges at block 41)
let worker_1_suffix: Vec<u64> = (141..=180).collect();
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
// Worker 2: prefix + 241-300 (diverges at block 41)
let worker_2_suffix: Vec<u64> = (241..=300).collect();
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let seq_31: Vec<u64> = (1..=31).collect();
let seq_33: Vec<u64> = (101..=133).collect();
let seq_63: Vec<u64> = (201..=263).collect();
let seq_65: Vec<u64> = (301..=365).collect();
// Store for all workers
index.apply_event(make_store_event(0, &worker_0_full)).await;
index.apply_event(make_store_event(0, &seq_31)).await;
index.apply_event(make_store_event(1, &seq_33)).await;
index.apply_event(make_store_event(2, &seq_63)).await;
index.apply_event(make_store_event(3, &seq_65)).await;
index.apply_event(make_store_event(1, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
1,
&shared_prefix,
&worker_1_suffix,
))
.await;
index.apply_event(make_store_event(2, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
2,
&shared_prefix,
&worker_2_suffix,
))
.await;
flush_and_settle(index.as_ref()).await;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let query: Vec<LocalBlockHash> = worker_0_full.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
40
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
40
);
}
flush_and_settle(index.as_ref()).await;
// Verify all sequences match correctly
let query_31: Vec<LocalBlockHash> = seq_31.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_31).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
31
);
let query_33: Vec<LocalBlockHash> = seq_33.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_33).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
33
);
let query_63: Vec<LocalBlockHash> = seq_63.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_63).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
63
);
let query_65: Vec<LocalBlockHash> = seq_65.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_65).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
65
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence
let sequence: Vec<u64> = (1..=128).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for diverge_pos in [31usize, 32, 33, 63, 64, 65, 95, 96, 97] {
let mut query: Vec<LocalBlockHash> = (1..=128).map(LocalBlockHash).collect();
query[diverge_pos] = LocalBlockHash(99999);
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
diverge_pos as u32,
"Divergence at position {} should match {} blocks",
diverge_pos,
diverge_pos
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_deep_continuation_chain(variant: &str) {
let index = make_indexer(variant);
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let chunk_size = 10;
let num_chunks = 20; // Total 200 blocks
let mut full_prefix: Vec<u64> = Vec::new();
for chunk_idx in 0..num_chunks {
let chunk_start = chunk_idx * chunk_size + 1;
let chunk: Vec<u64> = (chunk_start..chunk_start + chunk_size)
.map(|x| x as u64)
.collect();
if chunk_idx == 0 {
index.apply_event(make_store_event(0, &chunk)).await;
} else {
index
.apply_event(make_store_event_with_parent(0, &full_prefix, &chunk))
.await;
}
full_prefix.extend(&chunk);
}
flush_and_settle(index.as_ref()).await;
// Query full sequence
let full_query: Vec<LocalBlockHash> = (1..=200).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
200
);
// Query partial prefix crossing multiple chunk boundaries
let partial_query: Vec<LocalBlockHash> = (1..=75).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
75
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_clear_and_rebuild(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Verify it's stored
let query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Clear the worker
index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await;
// Verify it's cleared
let scores = index.find_matches(query.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Rebuild with a different sequence
let new_sequence: Vec<u64> = (1001..=1100).collect();
index.apply_event(make_store_event(0, &new_sequence)).await;
flush_and_settle(index.as_ref()).await;
// Verify new sequence works
let new_query: Vec<LocalBlockHash> =
new_sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(new_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Verify old sequence no longer matches
let scores = index.find_matches(query).await.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
let index = make_indexer(variant);
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// All workers share prefix 1-40
let shared_prefix: Vec<u64> = (1..=40).collect();
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let worker_0_full: Vec<u64> = (1..=100).collect();
// Worker 1: prefix + 141-180 (diverges at block 41)
let worker_1_suffix: Vec<u64> = (141..=180).collect();
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_staggered_lengths(variant: &str) {
let index = make_indexer(variant);
// Worker 2: prefix + 241-300 (diverges at block 41)
let worker_2_suffix: Vec<u64> = (241..=300).collect();
// Workers with sequences of staggered lengths to test drain tracking
// Worker 0: 10 blocks
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
// Store for all workers
index.apply_event(make_store_event(0, &worker_0_full)).await;
for (worker_id, len) in [(0, 10), (1, 20), (2, 35), (3, 64), (4, 100)] {
let sequence: Vec<u64> = (1..=len).collect();
index.apply_event(make_store_event(1, &shared_prefix)).await;
index
.apply_event(make_store_event(worker_id, &sequence))
.apply_event(make_store_event_with_parent(
1,
&shared_prefix,
&worker_1_suffix,
))
.await;
index.apply_event(make_store_event(2, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
2,
&shared_prefix,
&worker_2_suffix,
))
.await;
flush_and_settle(index.as_ref()).await;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let query: Vec<LocalBlockHash> = worker_0_full.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
40
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
40
);
}
flush_and_settle(index.as_ref()).await;
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_staggered_lengths(variant: &str) {
let index = make_indexer(variant);
// Query for 100 blocks - each worker should match their stored length
let query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let scores = index.find_matches(query).await.unwrap();
// Workers with sequences of staggered lengths to test drain tracking
// Worker 0: 10 blocks
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
10
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
20
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
35
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
64
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(4, 0)).unwrap(),
100
);
}
for (worker_id, len) in [(0, 10), (1, 20), (2, 35), (3, 64), (4, 100)] {
let sequence: Vec<u64> = (1..=len).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_very_long_sequence(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await;
// Test with a very long sequence (1000 blocks)
let seq_len = 1000u64;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Query for 100 blocks - each worker should match their stored length
let query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let scores = index.find_matches(query).await.unwrap();
flush_and_settle(index.as_ref()).await;
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
10
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
20
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
35
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
64
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(4, 0)).unwrap(),
100
);
}
// Full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
#[tokio::test]
#[apply(indexer_template)]
async fn test_very_long_sequence(variant: &str) {
let index = make_indexer(variant);
// Partial match (first 500)
let partial_query: Vec<LocalBlockHash> = (1..=500).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
500
);
// Test with a very long sequence (1000 blocks)
let seq_len = 1000u64;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
// Divergence in the middle
let mut mid_diverge: Vec<LocalBlockHash> = (1..=1000).map(LocalBlockHash).collect();
mid_diverge[499] = LocalBlockHash(99999);
let scores = index.find_matches(mid_diverge).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
499
);
flush_and_settle(index.as_ref()).await;
// Full match
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
// Partial match (first 500)
let partial_query: Vec<LocalBlockHash> = (1..=500).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
500
);
// Divergence in the middle
let mut mid_diverge: Vec<LocalBlockHash> = (1..=1000).map(LocalBlockHash).collect();
mid_diverge[499] = LocalBlockHash(99999);
let scores = index.find_matches(mid_diverge).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
499
);
}
}
// ============================================================================
......@@ -1670,129 +1690,146 @@ fn make_tree_indexer_with_frequency(
}
}
#[tokio::test]
#[apply(tree_indexer_template)]
async fn test_frequency(variant: &str) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
let expiration = Duration::from_millis(50);
let kv_indexer = make_tree_indexer_with_frequency(variant, expiration);
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Should be no cached blocks yet"
);
mod tree_specific_tests {
use super::*;
use rstest_reuse::apply;
// Blocks go in cache
let event = make_store_event(0, &[1, 2, 3, 4]);
kv_indexer.apply_event(event).await;
// First access - poll briefly since store event is applied async
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
);
#[tokio::test]
#[apply(tree_indexer_template)]
async fn test_frequency(variant: &str) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
// Second access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
);
let expiration = Duration::from_millis(50);
let kv_indexer = make_tree_indexer_with_frequency(variant, expiration);
// Let those two accesses expire
time::sleep(expiration + Duration::from_millis(10)).await;
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
// New first access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
);
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Should be no cached blocks yet"
);
// Blocks go in cache
let event = make_store_event(0, &[1, 2, 3, 4]);
kv_indexer.apply_event(event).await;
// First access - poll briefly since store event is applied async
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Second access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
);
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// Let those two accesses expire
time::sleep(expiration + Duration::from_millis(10)).await;
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
// New first access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
}
// ============================================================================
// KvIndexerMetrics tests
// ============================================================================
#[cfg(feature = "metrics")]
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
mod metrics_tests {
#[cfg(feature = "metrics")]
use super::*;
metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
.unwrap()
.get(),
1
);
#[cfg(feature = "metrics")]
#[test]
fn test_increment_event_applied() {
let metrics = KvIndexerMetrics::new_unregistered();
metrics.increment_event_applied(
METRIC_EVENT_STORED,
Err(KvCacheEventError::ParentBlockNotFound),
);
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_PARENT_NOT_FOUND])
.unwrap()
.get(),
1
);
metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
.unwrap()
.get(),
1
);
metrics.increment_event_applied(
METRIC_EVENT_STORED,
Err(KvCacheEventError::ParentBlockNotFound),
);
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_STORED,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap()
.get(),
1
);
metrics.increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[METRIC_EVENT_REMOVED, METRIC_STATUS_BLOCK_NOT_FOUND])
.unwrap()
.get(),
1
);
.increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_REMOVED,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap()
.get(),
1
);
}
}
// ============================================================================
......@@ -1822,363 +1859,368 @@ fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
indexer
}
#[tokio::test]
async fn test_local_indexer_slice_within_range() {
let indexer = make_local_indexer_with_events(&[1, 2, 3, 4, 5]);
mod local_indexer_tests {
use super::*;
use rstest_reuse::apply;
#[tokio::test]
async fn test_local_indexer_slice_within_range() {
let indexer = make_local_indexer_with_events(&[1, 2, 3, 4, 5]);
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type"),
}
};
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![3]); // single element when start == end
// Invalid range: end < start
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
}
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type"),
#[tokio::test]
async fn test_local_indexer_get_events_in_id_range_all_cases() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
};
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Wait for events to be processed
indexer.flush().await;
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Verify buffer state
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(get_ids(buffer_events), vec![10, 11, 12, 13, 14]);
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![3]); // single element when start == end
// Buffer path tests
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert_eq!(get_ids(extract_events(result)), vec![11, 12, 13, 14]);
// Invalid range: end < start
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
}
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert_eq!(get_ids(extract_events(result)), vec![10, 11, 12, 13, 14]);
#[tokio::test]
async fn test_local_indexer_get_events_in_id_range_all_cases() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
// Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
assert_eq!(extract_events(result).len(), 10);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0,
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
#[tokio::test]
async fn test_tree_dump_includes_last_event_id() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Empty buffer case: create a fresh indexer with no events
let empty_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 0,
"empty buffer should return last_event_id=0"
);
assert!(events.is_empty(), "empty indexer should have no events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64;
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token, 4, metrics, 100));
let test_event = RouterEvent::new(
worker_id,
KvCacheEvent {
event_id: id,
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
);
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
local_indexer
.apply_event_with_buffer(test_event)
.await
.unwrap();
}
// Wait for events to be processed
indexer.flush().await;
local_indexer.flush().await;
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
let buffered_events = local_indexer.get_all_events_in_buffer();
assert_eq!(buffered_events.len(), 1);
assert_eq!(buffered_events[0].worker_id, worker_id);
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Test serialization round-trip
let response = WorkerKvQueryResponse::Events(buffered_events);
let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
// Verify buffer state
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(get_ids(buffer_events), vec![10, 11, 12, 13, 14]);
// Buffer path tests
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert_eq!(get_ids(extract_events(result)), vec![11, 12, 13, 14]);
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert_eq!(get_ids(extract_events(result)), vec![10, 11, 12, 13, 14]);
// Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
assert_eq!(extract_events(result).len(), 10);
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
let events = match deserialized {
WorkerKvQueryResponse::Events(e) => e,
_ => panic!("Expected Events variant"),
};
assert_eq!(events.len(), 1);
assert_eq!(events[0].worker_id, worker_id);
}
#[tokio::test]
async fn test_tree_dump_includes_last_event_id() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
#[tokio::test]
async fn test_local_indexer_does_not_buffer_failed_send() {
let local_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let make_event = |id: u64| {
RouterEvent::new(
0,
let test_event = RouterEvent::new(
7,
KvCacheEvent {
event_id: id,
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
);
// Empty buffer case: create a fresh indexer with no events
let empty_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 0,
"empty buffer should return last_event_id=0"
);
assert!(events.is_empty(), "empty indexer should have no events");
let event_tx = local_indexer.event_sender();
local_indexer.shutdown();
event_tx.closed().await;
let result = local_indexer.apply_event_with_buffer(test_event).await;
assert!(matches!(result, Err(KvRouterError::IndexerOffline)));
assert_eq!(local_indexer.buffer_len(), 0);
match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
} => {
assert!(events.is_empty());
assert_eq!(last_event_id, 0);
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64;
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token, 4, metrics, 100));
let test_event = RouterEvent::new(
worker_id,
KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
);
local_indexer
.apply_event_with_buffer(test_event)
.await
.unwrap();
local_indexer.flush().await;
let buffered_events = local_indexer.get_all_events_in_buffer();
assert_eq!(buffered_events.len(), 1);
assert_eq!(buffered_events[0].worker_id, worker_id);
// Test serialization round-trip
let response = WorkerKvQueryResponse::Events(buffered_events);
let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
let events = match deserialized {
WorkerKvQueryResponse::Events(e) => e,
_ => panic!("Expected Events variant"),
};
assert_eq!(events.len(), 1);
assert_eq!(events[0].worker_id, worker_id);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_events_idempotent(variant: &str) {
let index = make_indexer(variant);
#[tokio::test]
async fn test_local_indexer_does_not_buffer_failed_send() {
let local_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let test_event = RouterEvent::new(
7,
KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
);
// Setup: build initial tree
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[4, 5, 6])).await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[7, 8]))
.await;
flush_and_settle(index.as_ref()).await;
let s0 = snapshot_tree(index.as_ref()).await;
// Mutation events: each add paired with its remove
let adds = [
make_store_event(2, &[1, 2, 9]),
make_store_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
let removes = [
make_remove_event(2, &[1, 2, 9]),
make_remove_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
// Phase 1: interleaved add/remove
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s1 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s0, s1,
"Phase 1: interleaved add/remove should restore tree"
);
let event_tx = local_indexer.event_sender();
local_indexer.shutdown();
event_tx.closed().await;
let result = local_indexer.apply_event_with_buffer(test_event).await;
assert!(matches!(result, Err(KvRouterError::IndexerOffline)));
assert_eq!(local_indexer.buffer_len(), 0);
match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump {
events,
last_event_id,
} => {
assert!(events.is_empty());
assert_eq!(last_event_id, 0);
}
other => panic!("Expected TreeDump, got: {other:?}"),
// Phase 2: same interleaved again (idempotence of the full cycle)
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s2 = snapshot_tree(index.as_ref()).await;
assert_eq!(s1, s2, "Phase 2: repeated cycle should be idempotent");
// Phase 3: non-interleaved (all adds then all removes)
index.apply_event(adds[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s3 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s2, s3,
"Phase 3: non-interleaved ordering should restore tree"
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_events_idempotent(variant: &str) {
let index = make_indexer(variant);
// Setup: build initial tree
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[4, 5, 6])).await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[7, 8]))
.await;
flush_and_settle(index.as_ref()).await;
let s0 = snapshot_tree(index.as_ref()).await;
// Mutation events: each add paired with its remove
let adds = [
make_store_event(2, &[1, 2, 9]),
make_store_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
let removes = [
make_remove_event(2, &[1, 2, 9]),
make_remove_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
// Phase 1: interleaved add/remove
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s1 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s0, s1,
"Phase 1: interleaved add/remove should restore tree"
);
// Phase 2: same interleaved again (idempotence of the full cycle)
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s2 = snapshot_tree(index.as_ref()).await;
assert_eq!(s1, s2, "Phase 2: repeated cycle should be idempotent");
// Phase 3: non-interleaved (all adds then all removes)
index.apply_event(adds[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s3 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s2, s3,
"Phase 3: non-interleaved ordering should restore tree"
);
}
......@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod event_sink;
pub mod indexer;
pub mod protocols;
pub mod scheduling;
......@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId,
compute_block_hash_for_seq,
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
WorkerConfigLike, WorkerId, compute_block_hash_for_seq,
};
pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
......@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn total_kv_blocks(&self) -> Option<u64>;
}
/// Transport abstraction for publishing batched router-visible KV cache events.
pub trait RouterEventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
/// A worker identifier.
pub type WorkerId = u64;
......
......@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize {
1
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
#[default]
Fcfs,
Lcfs,
Wspt,
}
......@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fcfs => f.write_str("fcfs"),
Self::Lcfs => f.write_str("lcfs"),
Self::Wspt => f.write_str("wspt"),
}
}
......@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"fcfs" => Ok(Self::Fcfs),
"lcfs" => Ok(Self::Lcfs),
"wspt" => Ok(Self::Wspt),
_ => Err(format!(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'"
"unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
)),
}
}
......@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(default)]
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
......@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
......@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
router_queue_threshold: Some(2.0),
router_queue_threshold: Some(4.0),
router_event_threads: 4,
router_enable_cache_control: false,
skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None,
}
......@@ -237,3 +253,39 @@ impl KvRouterConfig {
self.use_kv_events && self.overlap_score_weight > 0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
{
request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>,
worker_type: &'static str,
}
impl<P, C, S, Sel> LocalScheduler<P, C, S, Sel>
where
P: SequencePublisher + 'static,
C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static,
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Sel,
policy: S,
cancellation_token: CancellationToken,
worker_type: &'static str,
monitor_worker_configs: bool,
) -> Self {
if monitor_worker_configs {
let slots_monitor = Arc::clone(&slots);
let mut monitor_rx = workers_with_configs.clone();
let mut last_workers = monitor_rx.borrow().clone();
let monitor_cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler workers monitoring task started");
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("LocalScheduler worker config watch dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers == last_workers {
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
});
}
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
workers_with_configs,
threshold_frac,
block_size,
selector,
policy,
));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL);
tracing::trace!("LocalScheduler background task started");
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::trace!("LocalScheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("LocalScheduler request channel closed");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = recheck_interval.tick() => {
queue_clone.update().await;
}
}
}
});
Self {
request_tx,
slots,
queue,
worker_type,
}
}
#[expect(clippy::too_many_arguments)]
pub async fn schedule(
&self,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?
}
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
}
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
}
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
}
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
let mut loads = Vec::with_capacity(workers.len());
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::*;
use crate::protocols::OverlapScores;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
#[allow(clippy::type_complexity)]
fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>,
monitor_worker_configs: bool,
) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
watch::Sender<HashMap<WorkerId, SimpleWorkerConfig>>,
CancellationToken,
) {
let dp_range = workers
.iter()
.map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size)))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
64,
dp_range,
false,
0,
"test",
));
let (cfg_tx, cfg_rx) = watch::channel(workers);
let cancel_token = CancellationToken::new();
let scheduler = Arc::new(LocalScheduler::new(
Arc::clone(&slots),
cfg_rx,
threshold_frac,
64,
DefaultWorkerSelector::new(None, "test"),
FcfsPolicy,
cancel_token.clone(),
"test",
monitor_worker_configs,
));
(scheduler, slots, cfg_tx, cancel_token)
}
#[tokio::test]
async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(scheduler.pending_count(), 1);
scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_free_updates_active_state() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
scheduler.free("req-1").await.unwrap();
assert!(scheduler.get_active_lora_counts().is_empty());
cancel_token.cancel();
}
#[tokio::test]
async fn test_get_potential_loads_matches_slots() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
workers.insert(
1,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128),
potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0),
})
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
for (actual, expected) in actual.iter().zip(expected.iter()) {
assert_eq!(actual.worker_id, expected.worker_id);
assert_eq!(actual.dp_rank, expected.dp_rank);
assert_eq!(
actual.potential_prefill_tokens,
expected.potential_prefill_tokens
);
assert_eq!(
actual.potential_decode_blocks,
expected.potential_decode_blocks
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default());
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
assert_eq!(loads[0].dp_rank, 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true);
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len(),
1
);
let mut updated_workers = HashMap::new();
updated_workers.insert(
0,
SimpleWorkerConfig {
data_parallel_size: 2,
..Default::default()
},
);
updated_workers.insert(1, SimpleWorkerConfig::default());
cfg_tx.send(updated_workers).unwrap();
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len()
== 3
{
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
cancel_token.cancel();
}
}
......@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
pub mod config;
mod local;
pub mod policy;
pub mod queue;
pub mod selector;
mod types;
pub use local::LocalScheduler;
pub use types::*;
......@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
}
}
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub struct LcfsPolicy;
impl SchedulingPolicy for LcfsPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
OrderedFloat(request.priority_jump.max(0.0) + arrival_offset.as_secs_f64())
}
}
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL.
......@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy {
Fcfs(FcfsPolicy),
Lcfs(LcfsPolicy),
Wspt(WsptPolicy),
}
......@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
}
}
......@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
match self {
Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Lcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
}
}
......@@ -178,6 +196,42 @@ mod tests {
assert!(key_b > key_a);
}
#[test]
fn lcfs_later_arrival_scheduled_first() {
let policy = LcfsPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let early = policy.enqueue_key(Duration::from_secs(1), &req);
let late = policy.enqueue_key(Duration::from_secs(10), &req);
assert!(late > early, "later arrival should have higher key");
}
#[test]
fn lcfs_priority_jump_promotes() {
let policy = LcfsPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 100.0, OverlapScores::default());
let t = Duration::from_secs(10);
let key_normal = policy.enqueue_key(t, &normal);
let key_boosted = policy.enqueue_key(t, &boosted);
assert!(
key_boosted > key_normal,
"priority_jump should produce a higher key"
);
}
#[test]
fn router_scheduling_policy_matches_fcfs_and_lcfs_ordering() {
let req = request_with(512, 0.0, OverlapScores::default());
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
// ---- WSPT policy tests ----
#[test]
......
......@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use tokio::sync::watch;
use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::selector::WorkerSelector;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
......@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy = FcfsPolicy,
Sel: WorkerSelector<C> = DefaultWorkerSelector,
> {
pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>,
/// Number of requests currently parked in the pending queue.
......@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets.
start_time: Instant,
block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>,
selector: Sel,
policy: S,
}
impl<P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy>
SchedulerQueue<P, C, S>
impl<
P: SequencePublisher + 'static,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
> SchedulerQueue<P, C, S, Sel>
{
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>,
selector: Sel,
policy: S,
) -> Self {
if let Some(frac) = threshold_frac {
......@@ -341,7 +346,7 @@ mod tests {
}
let (cfg_tx, cfg_rx) = watch::channel(configs);
let selector = Box::new(DefaultWorkerSelector::new(None, "test"));
let selector = DefaultWorkerSelector::new(None, "test");
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
cfg_rx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment