"lib/llm/src/vscode:/vscode.git/clone" did not exist on "713c96d26d1654ea0900fd25d11bdb544b42c0bc"
Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
...@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fetch_model, m)?)?; m.add_function(wrap_pyfunction!(fetch_model, m)?)?;
m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?; m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::replay::run_mocker_trace_replay, m)?)?;
m.add_function(wrap_pyfunction!( m.add_function(wrap_pyfunction!(
llm::entrypoint::run_mocker_trace_replay, llm::replay::run_mocker_synthetic_trace_replay,
m m
)?)?; )?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
...@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::EngineType>()?; m.add_class::<llm::entrypoint::EngineType>()?;
m.add_class::<llm::entrypoint::RouterConfig>()?; m.add_class::<llm::entrypoint::RouterConfig>()?;
m.add_class::<llm::entrypoint::KvRouterConfig>()?; m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?; m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
...@@ -31,3 +31,4 @@ pub mod local_model; ...@@ -31,3 +31,4 @@ pub mod local_model;
pub mod lora; pub mod lora;
pub mod model_card; pub mod model_card;
pub mod preprocessor; pub mod preprocessor;
pub mod replay;
...@@ -9,7 +9,6 @@ use std::sync::Arc; ...@@ -9,7 +9,6 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals; use pyo3_async_runtimes::TaskLocals;
use pythonize::pythonize;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig; use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig; use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
...@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE ...@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use dynamo_mocker::common::perf_model::PerfModel; use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback; use super::aic_callback::create_aic_callback;
use dynamo_mocker::common::protocols::MockEngineArgs; use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId; use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
...@@ -58,7 +58,7 @@ impl KvRouterConfig { ...@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -77,6 +77,7 @@ impl KvRouterConfig { ...@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>, router_queue_threshold: Option<f64>,
router_event_threads: u32, router_event_threads: u32,
router_enable_cache_control: bool, router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str, router_queue_policy: &str,
remote_indexer_component: Option<String>, remote_indexer_component: Option<String>,
) -> Self { ) -> Self {
...@@ -99,6 +100,7 @@ impl KvRouterConfig { ...@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads, router_event_threads,
router_enable_cache_control, router_enable_cache_control,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| { router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}") panic!("invalid router_queue_policy: {router_queue_policy:?}")
}), }),
...@@ -106,6 +108,13 @@ impl KvRouterConfig { ...@@ -106,6 +108,13 @@ impl KvRouterConfig {
}, },
} }
} }
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
serde_json::from_str::<RsKvRouterConfig>(config_json)
.map(|inner| KvRouterConfig { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
}
} }
#[pyclass] #[pyclass]
...@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs { ...@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
...@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs { ...@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -225,6 +235,7 @@ impl EntrypointArgs { ...@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
...@@ -272,6 +283,7 @@ impl EntrypointArgs { ...@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path, tls_cert_path,
tls_key_path, tls_key_path,
extra_engine_args, extra_engine_args,
mocker_engine_args,
runtime_config, runtime_config,
namespace, namespace,
namespace_prefix, namespace_prefix,
...@@ -419,8 +431,10 @@ async fn select_engine( ...@@ -419,8 +431,10 @@ async fn select_engine(
} }
} }
EngineType::Mocker => { EngineType::Mocker => {
let mut mocker_args = if let Some(extra_args_path) = args.extra_engine_args { let mut mocker_args = if let Some(mocker_engine_args) = args.mocker_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| { mocker_engine_args.inner()
} else if let Some(extra_args_path) = args.extra_engine_args {
RsMockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
anyhow::anyhow!( anyhow::anyhow!(
"Failed to load mocker args from {:?}: {}", "Failed to load mocker args from {:?}: {}",
extra_args_path, extra_args_path,
...@@ -431,7 +445,7 @@ async fn select_engine( ...@@ -431,7 +445,7 @@ async fn select_engine(
tracing::warn!( tracing::warn!(
"No extra_engine_args specified for mocker engine. Using default mocker args." "No extra_engine_args specified for mocker engine. Using default mocker args."
); );
MockEngineArgs::default() RsMockEngineArgs::default()
}; };
// If aic_backend is set, create Python AIC callback and override perf_model // If aic_backend is set, create Python AIC callback and override perf_model
...@@ -503,84 +517,6 @@ pub fn run_input<'p>( ...@@ -503,84 +517,6 @@ pub fn run_input<'p>(
}) })
} }
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, num_workers=1, replay_concurrency=None))]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<PathBuf>,
num_workers: usize,
replay_concurrency: Option<isize>,
) -> PyResult<PyObject> {
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let mut args = if let Some(ref extra_args_path) = extra_engine_args {
MockEngineArgs::from_json_file(extra_args_path).map_err(|e| {
PyException::new_err(format!(
"Failed to load mocker args from {:?}: {}",
extra_args_path, e
))
})?
} else {
MockEngineArgs::default()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
let report = py.allow_threads(move || {
let replay_concurrency = replay_concurrency
.map(usize::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("replay_concurrency must be at least 1"))?;
if let Some(max_in_flight) = replay_concurrency {
dynamo_mocker::simulation::simulate_concurrency_file(
args,
&trace_file,
max_in_flight,
num_workers,
)
} else {
dynamo_mocker::simulation::simulate_trace_file(args, &trace_file, num_workers)
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
pub fn to_pyerr<E>(err: E) -> PyErr pub fn to_pyerr<E>(err: E) -> PyErr
where where
E: Display, E: Display,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::path::PathBuf;
use std::sync::Arc;
use dynamo_mocker::common::perf_model::PerfModel;
use dynamo_mocker::common::protocols::{
DirectRequest, EngineType as RsMockerEngineType, MockEngineArgs as RsMockEngineArgs,
PreemptionMode as RsPreemptionMode, ReasoningConfig as RsReasoningConfig,
SglangArgs as RsSglangArgs, WorkerType as RsWorkerType,
};
use pyo3::{exceptions::PyException, prelude::*};
use pythonize::pythonize;
use uuid::Uuid;
use super::aic_callback::create_aic_callback;
use super::entrypoint::{KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type {
"vllm" => Ok(RsMockerEngineType::Vllm),
"sglang" => Ok(RsMockerEngineType::Sglang),
other => Err(PyException::new_err(format!(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn parse_worker_type(worker_type: &str) -> PyResult<RsWorkerType> {
match worker_type {
"aggregated" => Ok(RsWorkerType::Aggregated),
"prefill" => Ok(RsWorkerType::Prefill),
"decode" => Ok(RsWorkerType::Decode),
other => Err(PyException::new_err(format!(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn parse_preemption_mode(preemption_mode: &str) -> PyResult<RsPreemptionMode> {
match preemption_mode {
"lifo" => Ok(RsPreemptionMode::Lifo),
"fifo" => Ok(RsPreemptionMode::Fifo),
other => Err(PyException::new_err(format!(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct ReasoningConfig {
inner: RsReasoningConfig,
}
impl ReasoningConfig {
pub fn inner(&self) -> RsReasoningConfig {
self.inner.clone()
}
}
#[pymethods]
impl ReasoningConfig {
#[new]
fn new(
start_thinking_token_id: u32,
end_thinking_token_id: u32,
thinking_ratio: f64,
) -> PyResult<Self> {
let inner = RsReasoningConfig {
start_thinking_token_id,
end_thinking_token_id,
thinking_ratio,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct SglangArgs {
inner: RsSglangArgs,
}
impl SglangArgs {
pub fn inner(&self) -> RsSglangArgs {
self.inner.clone()
}
}
#[pymethods]
impl SglangArgs {
#[new]
#[pyo3(signature = (schedule_policy=None, page_size=None, max_prefill_tokens=None, chunked_prefill_size=None, clip_max_new_tokens=None, schedule_conservativeness=None))]
fn new(
schedule_policy: Option<String>,
page_size: Option<usize>,
max_prefill_tokens: Option<usize>,
chunked_prefill_size: Option<usize>,
clip_max_new_tokens: Option<usize>,
schedule_conservativeness: Option<f64>,
) -> PyResult<Self> {
let inner = RsSglangArgs {
schedule_policy,
page_size,
max_prefill_tokens,
chunked_prefill_size,
clip_max_new_tokens,
schedule_conservativeness,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct MockEngineArgs {
inner: RsMockEngineArgs,
}
impl MockEngineArgs {
pub fn inner(&self) -> RsMockEngineArgs {
self.inner.clone()
}
}
#[pymethods]
impl MockEngineArgs {
#[new]
#[pyo3(signature = (engine_type="vllm", num_gpu_blocks=16384, block_size=0, max_num_seqs=Some(256), max_num_batched_tokens=Some(8192), enable_prefix_caching=true, enable_chunked_prefill=true, speedup_ratio=1.0, decode_speedup_ratio=1.0, dp_size=1, startup_time=None, worker_type="aggregated", aic_backend=None, aic_system=None, aic_backend_version=None, aic_tp_size=None, aic_model_path=None, enable_local_indexer=false, bootstrap_port=None, kv_bytes_per_token=None, kv_transfer_bandwidth=None, reasoning=None, zmq_kv_events_port=None, zmq_replay_port=None, preemption_mode="lifo", router_queue_policy=None, sglang=None))]
#[allow(clippy::too_many_arguments)]
fn new(
engine_type: &str,
num_gpu_blocks: usize,
block_size: usize,
max_num_seqs: Option<usize>,
max_num_batched_tokens: Option<usize>,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
speedup_ratio: f64,
decode_speedup_ratio: f64,
dp_size: u32,
startup_time: Option<f64>,
worker_type: &str,
aic_backend: Option<String>,
aic_system: Option<String>,
aic_backend_version: Option<String>,
aic_tp_size: Option<usize>,
aic_model_path: Option<String>,
enable_local_indexer: bool,
bootstrap_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
kv_transfer_bandwidth: Option<f64>,
reasoning: Option<ReasoningConfig>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
preemption_mode: &str,
router_queue_policy: Option<&str>,
sglang: Option<SglangArgs>,
) -> PyResult<Self> {
let engine_type = parse_mocker_engine_type(engine_type)?;
let worker_type = parse_worker_type(worker_type)?;
let preemption_mode = parse_preemption_mode(preemption_mode)?;
let router_queue_policy = router_queue_policy
.map(|value| {
value.parse().map_err(|e: String| {
PyException::new_err(format!("invalid router_queue_policy {value:?}: {e}"))
})
})
.transpose()?;
let inner = RsMockEngineArgs::builder()
.engine_type(engine_type)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.max_num_seqs(max_num_seqs)
.max_num_batched_tokens(max_num_batched_tokens)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(decode_speedup_ratio)
.dp_size(dp_size)
.startup_time(startup_time)
.worker_type(worker_type)
.aic_backend(aic_backend)
.aic_system(aic_system)
.aic_backend_version(aic_backend_version)
.aic_tp_size(aic_tp_size)
.aic_model_path(aic_model_path)
.enable_local_indexer(enable_local_indexer)
.bootstrap_port(bootstrap_port)
.kv_bytes_per_token(kv_bytes_per_token)
.kv_transfer_bandwidth(kv_transfer_bandwidth)
.reasoning(reasoning.map(|config| config.inner()))
.zmq_kv_events_port(zmq_kv_events_port)
.zmq_replay_port(zmq_replay_port)
.preemption_mode(preemption_mode)
.router_queue_policy(router_queue_policy)
.sglang(sglang.map(|config| config.inner()))
.build()
.map_err(|e| PyException::new_err(format!("Failed to build MockEngineArgs: {e}")))?
.normalized()
.map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs: {e}"))
})?;
Ok(Self { inner })
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
RsMockEngineArgs::from_json_str(config_json)
.map(|inner| Self { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse MockEngineArgs JSON: {e}")))
}
#[getter]
fn block_size(&self) -> usize {
self.inner.block_size
}
#[getter]
fn num_gpu_blocks(&self) -> usize {
self.inner.num_gpu_blocks
}
#[getter]
fn max_num_seqs(&self) -> Option<usize> {
self.inner.max_num_seqs
}
#[getter]
fn max_num_batched_tokens(&self) -> Option<usize> {
self.inner.max_num_batched_tokens
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn dp_size(&self) -> u32 {
self.inner.dp_size
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner.bootstrap_port
}
fn is_prefill(&self) -> bool {
self.inner.is_prefill()
}
fn is_decode(&self) -> bool {
self.inner.is_decode()
}
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None))]
fn with_overrides(
&self,
bootstrap_port: Option<u16>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
) -> PyResult<Self> {
let mut inner = self.inner.clone();
if let Some(port) = bootstrap_port {
inner.bootstrap_port = Some(port);
}
if let Some(port) = zmq_kv_events_port {
inner.zmq_kv_events_port = Some(port);
}
if let Some(port) = zmq_replay_port {
inner.zmq_replay_port = Some(port);
}
if let Some(bytes_per_token) = kv_bytes_per_token {
inner.kv_bytes_per_token = Some(bytes_per_token);
}
inner.normalized().map(|inner| Self { inner }).map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}"))
})
}
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
#[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>,
input_tokens: usize,
output_tokens: usize,
request_count: usize,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
arrival_interval_ms: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
let requests = build_synthetic_requests(
input_tokens,
output_tokens,
request_count,
arrival_interval_ms,
replay_concurrency.is_none(),
)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => {
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
fn load_replay_mocker_args(
py: Python<'_>,
extra_engine_args: Option<MockEngineArgs>,
) -> PyResult<RsMockEngineArgs> {
let mut args = match extra_engine_args {
Some(extra_args) => extra_args.inner(),
None => RsMockEngineArgs::default(),
};
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
Ok(args)
}
fn load_replay_router_config(
router_config: Option<KvRouterConfig>,
) -> Option<dynamo_kv_router::config::KvRouterConfig> {
router_config.map(|config| config.inner())
}
fn parse_replay_router_mode(
router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
match router_mode {
"round_robin" => Ok(dynamo_mocker::replay::ReplayRouterMode::RoundRobin),
"kv_router" => Ok(dynamo_mocker::replay::ReplayRouterMode::KvRouter),
other => Err(PyException::new_err(format!(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'",
other
))),
}
}
fn parse_replay_concurrency(replay_concurrency: Option<isize>) -> anyhow::Result<Option<usize>> {
match replay_concurrency {
Some(value) if value < 1 => anyhow::bail!("replay_concurrency must be at least 1"),
Some(value) => Ok(Some(value as usize)),
None => Ok(None),
}
}
fn build_synthetic_requests(
input_tokens: usize,
output_tokens: usize,
request_count: usize,
arrival_interval_ms: f64,
include_arrival_timestamps: bool,
) -> anyhow::Result<Vec<DirectRequest>> {
if input_tokens == 0 {
anyhow::bail!("input_tokens must be at least 1");
}
if output_tokens == 0 {
anyhow::bail!("output_tokens must be at least 1");
}
if request_count == 0 {
anyhow::bail!("request_count must be at least 1");
}
if !arrival_interval_ms.is_finite() || arrival_interval_ms < 0.0 {
anyhow::bail!(
"arrival_interval_ms must be a finite non-negative number, got {}",
arrival_interval_ms
);
}
let mut requests = Vec::with_capacity(request_count);
for request_idx in 0..request_count {
let tokens = (0..input_tokens)
.map(|token_idx| synthetic_token_id(request_idx, token_idx))
.collect();
requests.push(DirectRequest {
tokens,
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128((request_idx as u128) + 1)),
dp_rank: 0,
arrival_timestamp_ms: include_arrival_timestamps
.then_some(request_idx as f64 * arrival_interval_ms),
});
}
Ok(requests)
}
fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let mut value =
(((request_idx as u64) << 32) ^ (token_idx as u64)).wrapping_add(0x9E37_79B9_7F4A_7C15);
value ^= value >> 30;
value = value.wrapping_mul(0xBF58_476D_1CE4_E5B9);
value ^= value >> 27;
value = value.wrapping_mul(0x94D0_49BB_1331_11EB);
value ^= value >> 31;
let token = value as u32;
if token == 0 { 1 } else { token }
}
...@@ -3,7 +3,17 @@ ...@@ -3,7 +3,17 @@
import asyncio import asyncio
import os import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules # Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
...@@ -1104,9 +1114,10 @@ class KvRouterConfig: ...@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = 2.0, router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4, router_event_threads: int = 4,
router_enable_cache_control: bool = False, router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs", router_queue_policy: str = "fcfs",
) -> None: ) -> None:
""" """
...@@ -1132,7 +1143,7 @@ class KvRouterConfig: ...@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 2.0). router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 4.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens. Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints. Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler). Set to None to disable queueing (all requests go directly to the scheduler).
...@@ -1140,12 +1151,111 @@ class KvRouterConfig: ...@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False). cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs"). router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT. "fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT. "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
""" """
... ...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model( async def register_model(
model_input: ModelInput, model_input: ModelInput,
model_type: ModelType, model_type: ModelType,
...@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi ...@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay( def run_mocker_trace_replay(
trace_file: str | os.PathLike[str], trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None, extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1, num_workers: int = 1,
replay_concurrency: Optional[int] = None, replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report.""" """Replay a synthetic mocker workload without requiring a trace file."""
... ...
class Layer: class Layer:
...@@ -1687,6 +1817,7 @@ class EntrypointArgs: ...@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None, tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None, tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None, extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None, runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None, namespace_prefix: Optional[str] = None,
...@@ -1711,7 +1842,8 @@ class EntrypointArgs: ...@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service) http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format) tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format) tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix namespace_prefix: Optional namespace prefix
......
...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig ...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import MockEngineArgs as MockEngineArgs
from dynamo._core import ModelCardInstanceId as ModelCardInstanceId from dynamo._core import ModelCardInstanceId as ModelCardInstanceId
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
...@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType ...@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig
from dynamo._core import RouterConfig as RouterConfig from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode from dynamo._core import RouterMode as RouterMode
from dynamo._core import SglangArgs as SglangArgs
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model from dynamo._core import fetch_model as fetch_model
...@@ -35,7 +38,7 @@ from dynamo._core import make_engine ...@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from dynamo._core import register_model as register_model from dynamo._core import register_model as register_model
from dynamo._core import run_input from dynamo._core import run_input
from dynamo._core import run_kv_indexer as run_kv_indexer from dynamo._core import run_kv_indexer as run_kv_indexer
from dynamo._core import run_mocker_trace_replay from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
from dynamo._core import unregister_model as unregister_model from dynamo._core import unregister_model as unregister_model
from .exceptions import HttpError from .exceptions import HttpError
...@@ -44,3 +47,24 @@ from .exceptions import HttpError ...@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm = fetch_model fetch_llm = fetch_model
register_llm = register_model register_llm = register_model
unregister_llm = unregister_model unregister_llm = unregister_model
def run_mocker_trace_replay(
trace_file,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode="offline",
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.api import run_synthetic_trace_replay, run_trace_replay
__all__ = ["run_synthetic_trace_replay", "run_trace_replay"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.main import main
if __name__ == "__main__":
raise SystemExit(main())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo._core import (
run_mocker_synthetic_trace_replay as _run_mocker_synthetic_trace_replay,
)
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
def run_trace_replay(
trace_file,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
def run_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
arrival_interval_ms=1.0,
):
return _run_mocker_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
arrival_interval_ms=arrival_interval_ms,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import json
import os
import sys
from collections.abc import Sequence
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
parser.add_argument("--extra-engine-args")
parser.add_argument("--router-config")
parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int)
parser.add_argument("--request-count", type=int)
parser.add_argument("--arrival-interval-ms", type=float, default=1.0)
parser.add_argument("--num-workers", type=int, default=1)
parser.add_argument("--replay-concurrency", type=int)
parser.add_argument(
"--replay-mode",
choices=("offline", "online"),
default="offline",
)
parser.add_argument(
"--router-mode",
choices=("round_robin", "kv_router"),
default="round_robin",
)
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
synthetic_args = (args.input_tokens, args.output_tokens, args.request_count)
using_synthetic = any(value is not None for value in synthetic_args)
if using_trace_file == using_synthetic:
parser.error(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if using_synthetic and not all(value is not None for value in synthetic_args):
parser.error(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args = (
MockEngineArgs.from_json(args.extra_engine_args)
if args.extra_engine_args is not None
else None
)
router_config = (
KvRouterConfig.from_json(args.router_config)
if args.router_config is not None
else None
)
if using_trace_file:
report = run_trace_replay(
args.trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
)
else:
report = run_synthetic_trace_replay(
args.input_tokens,
args.output_tokens,
args.request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
arrival_interval_ms=args.arrival_interval_ms,
)
json.dump(report, sys.stdout, indent=2, sort_keys=True)
sys.stdout.write("\n")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import pytest
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.pre_merge,
]
MOONCAKE_TRACE_FIRST20 = """{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def _write_trace_and_args(tmp_path):
trace_path = tmp_path / "trace.jsonl"
records = [
{
"timestamp": 1000.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
{
"timestamp": 1005.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
]
trace_path.write_text(
"\n".join(json.dumps(record) for record in records) + "\n",
encoding="utf-8",
)
return trace_path
def _write_vllm_args(tmp_path):
args_path = tmp_path / "args.json"
args_path.write_text(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
),
encoding="utf-8",
)
return args_path
def _vllm_args():
return MockEngineArgs.from_json(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
)
)
def _write_sglang_args(tmp_path):
args_path = tmp_path / "sglang_args.json"
args_path.write_text(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
),
encoding="utf-8",
)
return args_path
def _sglang_args():
return MockEngineArgs.from_json(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
)
)
def _write_router_config(tmp_path):
config_path = tmp_path / "router_config.json"
config_path.write_text(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
),
encoding="utf-8",
)
return config_path
def _router_config():
return KvRouterConfig.from_json(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
)
)
def _partial_router_config():
return KvRouterConfig(
router_queue_threshold=1.25,
router_event_threads=1,
router_queue_policy="wspt",
)
def _assert_basic_report_counts(report, *, num_requests, input_tokens, output_tokens):
assert report["num_requests"] == num_requests
assert report["completed_requests"] == num_requests
assert report["total_input_tokens"] == num_requests * input_tokens
assert report["total_output_tokens"] == num_requests * output_tokens
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_trace_replay_smoke_matrix(tmp_path, engine_type, replay_mode, router_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_invariant_counts_match(tmp_path, engine_type, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
)
multi_round_robin = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
)
multi_kv_router = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_synthetic_trace_replay_smoke_matrix(
tmp_path, engine_type, replay_mode, router_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_interval_ms=5.0,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_trace_replay_invariant_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
arrival_interval_ms=5.0,
)
multi_round_robin = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
arrival_interval_ms=5.0,
)
multi_kv_router = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
arrival_interval_ms=5.0,
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_concurrency_replay_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
report = run_synthetic_trace_replay(
64,
2,
3,
extra_engine_args=args_path,
num_workers=2,
replay_mode=replay_mode,
replay_concurrency=2,
)
_assert_basic_report_counts(
report,
num_requests=3,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_router_config(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
router_config_path = _router_config()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=router_config_path,
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_router_config_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=_partial_router_config(),
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_extra_engine_args_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
report = run_trace_replay(
trace_path,
extra_engine_args=MockEngineArgs(block_size=64, speedup_ratio=1000.0),
num_workers=1,
replay_mode=replay_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
...@@ -245,1392 +245,1412 @@ async fn flush_and_settle(index: &dyn KvIndexerInterface) { ...@@ -245,1392 +245,1412 @@ async fn flush_and_settle(index: &dyn KvIndexerInterface) {
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
} }
#[tokio::test] mod interface_tests {
#[apply(indexer_template)] use super::*;
async fn test_store_and_find(variant: &str) { use rstest_reuse::apply;
let index = make_indexer(variant);
#[tokio::test]
// Store a sequence for worker 0 #[apply(indexer_template)]
index.apply_event(make_store_event(0, &[1, 2, 3])).await; async fn test_store_and_find(variant: &str) {
let index = make_indexer(variant);
// Store a sequence for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
// Find matches using local hashes
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
let index = make_indexer(variant);
// Find matches using local hashes // Store [1, 2, 3] for worker 0
let scores = index index.apply_event(make_store_event(0, &[1, 2, 3])).await;
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_partial_match(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3] for worker 0 // Find matches for [1, 2, 999] - should match first 2 then stop
index.apply_event(make_store_event(0, &[1, 2, 3])).await; let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(999),
])
.await
.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_remove(variant: &str) {
let index = make_indexer(variant);
// Find matches for [1, 2, 999] - should match first 2 then stop // Store sequence for worker 0
let scores = index index.apply_event(make_store_event(0, &[1, 2, 3])).await;
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(999),
])
.await
.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[tokio::test] // Remove all blocks
#[apply(indexer_template)] index.apply_event(make_remove_event(0, &[1, 2, 3])).await;
async fn test_remove(variant: &str) {
let index = make_indexer(variant);
// Store sequence for worker 0 flush_and_settle(index.as_ref()).await;
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Remove all blocks // Find should return nothing
index.apply_event(make_remove_event(0, &[1, 2, 3])).await; let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_multiple_workers_shared_prefix(variant: &str) {
let index = make_indexer(variant);
// Find should return nothing // Worker 0 has [1, 2], Worker 1 has [1, 3]
let scores = index // Since sequence hashes are cumulative, [1] has same hash for both,
.find_matches(vec![ // but [1, 2] and [1, 3] have different hashes.
LocalBlockHash(1), index.apply_event(make_store_event(0, &[1, 2])).await;
LocalBlockHash(2), index.apply_event(make_store_event(1, &[1, 3])).await;
LocalBlockHash(3),
])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_multiple_workers_shared_prefix(variant: &str) {
let index = make_indexer(variant);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index.apply_event(make_store_event(0, &[1, 2])).await;
index.apply_event(make_store_event(1, &[1, 3])).await;
flush_and_settle(index.as_ref()).await;
// Query [1] - both workers should match
let scores = index.find_matches(vec![LocalBlockHash(1)]).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
#[tokio::test] // Query [1] - both workers should match
#[apply(indexer_template)] let scores = index.find_matches(vec![LocalBlockHash(1)]).await.unwrap();
async fn test_remove_worker(variant: &str) { assert_eq!(scores.scores.len(), 2);
let index = make_indexer(variant); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
index.apply_event(make_store_event(0, &[1, 2, 3])).await; // Query [1, 2] - worker 0 matches both, worker 1 matches only first block
index.apply_event(make_store_event(1, &[1, 2, 3])).await; let scores = index
.find_matches(vec![LocalBlockHash(1), LocalBlockHash(2)])
.await
.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), 1);
}
// Allow time for async event processing #[tokio::test]
flush_and_settle(index.as_ref()).await; #[apply(indexer_template)]
async fn test_remove_worker(variant: &str) {
let index = make_indexer(variant);
index.remove_worker(0).await; index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Allow time for async remove_worker processing // Allow time for async event processing
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
let scores = index index.remove_worker(0).await;
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[tokio::test] // Allow time for async remove_worker processing
#[apply(indexer_template)] flush_and_settle(index.as_ref()).await;
async fn test_large_stores(variant: &str) {
let index = make_indexer(variant);
// Test sequences of increasing sizes let scores = index
for i in 0..10u64 { .find_matches(vec![
let len = 1 << i; // 1, 2, 4, 8, ..., 512 LocalBlockHash(1),
let worker_id = i; LocalBlockHash(2),
let sequence: Vec<u64> = (1..=len).map(|x| x + (i * 10000)).collect(); LocalBlockHash(3),
index ])
.apply_event(make_store_event(worker_id, &sequence)) .await
.await; .unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
} }
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
// Verify we can find matches for the last stored sequence async fn test_large_stores(variant: &str) {
let last_seq: Vec<LocalBlockHash> = (1..=512u64) let index = make_indexer(variant);
.map(|x| LocalBlockHash(x + (9 * 10000)))
.collect();
let scores = index.find_matches(last_seq).await.unwrap();
assert!(!scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_dump_and_restore(variant: &str) {
let index = make_indexer(variant);
// Store some data
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 4])).await;
// Allow background worker threads to process events. // Test sequences of increasing sizes
flush_and_settle(index.as_ref()).await; for i in 0..10u64 {
let len = 1 << i; // 1, 2, 4, 8, ..., 512
let worker_id = i;
let sequence: Vec<u64> = (1..=len).map(|x| x + (i * 10000)).collect();
index
.apply_event(make_store_event(worker_id, &sequence))
.await;
}
// Dump the tree as events and replay into a new index flush_and_settle(index.as_ref()).await;
let events = index.dump_events().await.unwrap();
assert!(!events.is_empty());
let restored = make_indexer(variant); // Verify we can find matches for the last stored sequence
for event in events { let last_seq: Vec<LocalBlockHash> = (1..=512u64)
restored.apply_event(event).await; .map(|x| LocalBlockHash(x + (9 * 10000)))
.collect();
let scores = index.find_matches(last_seq).await.unwrap();
assert!(!scores.scores.is_empty());
} }
flush_and_settle(restored.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
assert_eq!( async fn test_dump_and_restore(variant: &str) {
snapshot_tree(index.as_ref()).await, let index = make_indexer(variant);
snapshot_tree(restored.as_ref()).await
);
}
#[tokio::test] // Store some data
#[apply(indexer_template)] index.apply_event(make_store_event(0, &[1, 2, 3])).await;
async fn test_clear_all_blocks(variant: &str) { index.apply_event(make_store_event(1, &[1, 2, 4])).await;
let index = make_indexer(variant);
// Store some data for two workers // Allow background worker threads to process events.
index.apply_event(make_store_event(0, &[1, 2, 3])).await; flush_and_settle(index.as_ref()).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Clear worker 0's blocks using the Cleared event // Dump the tree as events and replay into a new index
index.apply_event(make_clear_event(0)).await; let events = index.dump_events().await.unwrap();
assert!(!events.is_empty());
flush_and_settle(index.as_ref()).await; let restored = make_indexer(variant);
for event in events {
restored.apply_event(event).await;
}
// Worker 0's blocks should be gone, worker 1's remain flush_and_settle(restored.as_ref()).await;
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
#[tokio::test] assert_eq!(
#[apply(indexer_template)] snapshot_tree(index.as_ref()).await,
async fn test_empty_query(variant: &str) { snapshot_tree(restored.as_ref()).await
let index = make_indexer(variant); );
}
index.apply_event(make_store_event(0, &[1, 2, 3])).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_clear_all_blocks(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await; // Store some data for two workers
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[1, 2, 3])).await;
// Empty query should return empty scores // Clear worker 0's blocks using the Cleared event
let scores = index.find_matches(vec![]).await.unwrap(); index.apply_event(make_clear_event(0)).await;
assert!(scores.scores.is_empty());
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_miss_query(variant: &str) {
let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await; // Worker 0's blocks should be gone, worker 1's remain
let scores = index
.find_matches(vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
])
.await
.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_empty_query(variant: &str) {
let index = make_indexer(variant);
// Query for non-existent blocks index.apply_event(make_store_event(0, &[1, 2, 3])).await;
let scores = index
.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_shutdown(variant: &str) {
let index = make_indexer(variant);
index.shutdown();
}
#[tokio::test] // Empty query should return empty scores
#[apply(indexer_template)] let scores = index.find_matches(vec![]).await.unwrap();
async fn test_shutdown_idempotent(variant: &str) { assert!(scores.scores.is_empty());
let index = make_indexer(variant); }
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
index.shutdown();
index.shutdown();
}
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) { async fn test_miss_query(variant: &str) {
let index = make_indexer(variant); let index = make_indexer(variant);
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
}
#[tokio::test] index.apply_event(make_store_event(0, &[1, 2, 3])).await;
#[apply(indexer_template)]
async fn test_process_routing_decision(variant: &str) {
let index = make_indexer(variant);
// Create tokens with hashes flush_and_settle(index.as_ref()).await;
let tokens = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, 32);
let worker = WorkerWithDpRank::new(0, 0); // Query for non-existent blocks
let scores = index
.find_matches(vec![LocalBlockHash(999), LocalBlockHash(998)])
.await
.unwrap();
assert!(scores.scores.is_empty());
}
// Process routing decision - should not error #[tokio::test]
let result = index #[apply(indexer_template)]
.process_routing_decision_for_request(&mut tokens_with_hashes, worker) async fn test_shutdown(variant: &str) {
.await; let index = make_indexer(variant);
assert!(result.is_ok()); index.shutdown();
} }
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_parent_hash_chains(variant: &str) { async fn test_shutdown_idempotent(variant: &str) {
let index = make_indexer(variant); let index = make_indexer(variant);
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await;
index.shutdown();
index.shutdown();
}
// Store initial sequence [1, 2, 3] #[tokio::test]
index.apply_event(make_store_event(0, &[1, 2, 3])).await; #[apply(indexer_template)]
async fn test_find_matches_for_request(variant: &str) {
let index = make_indexer(variant);
// Empty index should return no matches
let tokens = vec![1, 2, 3, 4];
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
assert!(scores.scores.is_empty());
// Store some data and verify we can find it via tokens
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Allow time for async processing
flush_and_settle(index.as_ref()).await;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let scores = index.find_matches_for_request(&tokens, None).await.unwrap();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert!(scores.scores.is_empty() || !scores.scores.is_empty());
}
// Store continuation [4, 5] with parent pointing to block 3 #[tokio::test]
index #[apply(indexer_template)]
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5])) async fn test_process_routing_decision(variant: &str) {
.await; let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await; // Create tokens with hashes
let tokens = vec![1u32, 2, 3, 4, 5, 6, 7, 8];
let mut tokens_with_hashes = TokensWithHashes::new(tokens, 32);
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks let worker = WorkerWithDpRank::new(0, 0);
let full_seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(full_seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Query for just [1, 2, 3] should match 3 blocks // Process routing decision - should not error
let prefix_seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); let result = index
let scores = index.find_matches(prefix_seq).await.unwrap(); .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3); .await;
} assert!(result.is_ok());
}
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_multiple_dp_ranks(variant: &str) { async fn test_parent_hash_chains(variant: &str) {
let index = make_indexer(variant); let index = make_indexer(variant);
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 2))
.await;
flush_and_settle(index.as_ref()).await;
// Query should return all 3 dp_ranks as separate entries
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 1)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 2)).unwrap(), 3);
}
#[tokio::test] // Store initial sequence [1, 2, 3]
#[apply(indexer_template)] index.apply_event(make_store_event(0, &[1, 2, 3])).await;
async fn test_partial_block_removal(variant: &str) {
let index = make_indexer(variant);
// Store [1, 2, 3] // Store continuation [4, 5] with parent pointing to block 3
index.apply_event(make_store_event(0, &[1, 2, 3])).await; index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[4, 5]))
.await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
// Verify all 3 blocks match // Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); let full_seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap(); let scores = index.find_matches(full_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3); assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Remove only the last block (block 3) // Query for just [1, 2, 3] should match 3 blocks
// To do this correctly, we need to compute the seq_hash for block 3 specifically, let prefix_seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
// which requires the full sequence context [1,2,3]. let scores = index.find_matches(prefix_seq).await.unwrap();
let full_hashes: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
let seq_hashes = compute_seq_hash_for_block(&full_hashes); }
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]); #[tokio::test]
index.apply_event(remove_event).await; #[apply(indexer_template)]
async fn test_multiple_dp_ranks(variant: &str) {
let index = make_indexer(variant);
flush_and_settle(index.as_ref()).await; // Same worker_id but different dp_ranks should be tracked separately
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
index
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 2))
.await;
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed) flush_and_settle(index.as_ref()).await;
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Query [1, 2] - should still match 2 blocks // Query should return all 3 dp_ranks as separate entries
let partial_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect(); let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_seq).await.unwrap(); let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
#[tokio::test] assert_eq!(scores.scores.len(), 3);
#[apply(indexer_template)] assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
async fn test_remove_mid_chain_block(variant: &str) { assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 1)).unwrap(), 3);
// TODO: positional indexer has no parent-child structure, so mid-chain removal assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 2)).unwrap(), 3);
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if variant == "flat" {
return;
} }
let index = make_indexer(variant); #[tokio::test]
#[apply(indexer_template)]
// Store [1, 2, 3, 4, 5] async fn test_partial_block_removal(variant: &str) {
index let index = make_indexer(variant);
.apply_event(make_store_event(0, &[1, 2, 3, 4, 5]))
.await;
flush_and_settle(index.as_ref()).await; // Store [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Verify all 5 blocks match flush_and_settle(index.as_ref()).await;
let seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Remove only block 3 (index 2) — the middle of the chain // Verify all 3 blocks match
let full_hashes: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect(); let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes); let scores = index.find_matches(seq.clone()).await.unwrap();
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]); // Remove only the last block (block 3)
index.apply_event(remove_event).await; // To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let full_hashes: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]); // Last block's hash
flush_and_settle(index.as_ref()).await; let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
index.apply_event(remove_event).await;
// Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5) flush_and_settle(index.as_ref()).await;
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Query [1, 2] — prefix before the gap is still intact // Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let prefix_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect(); let scores = index.find_matches(seq).await.unwrap();
let scores = index.find_matches(prefix_seq).await.unwrap(); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Re-store block 3 as a continuation of [1, 2] // Query [1, 2] - should still match 2 blocks
index let partial_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
.apply_event(make_store_event_with_parent(0, &[1, 2], &[3])) let scores = index.find_matches(partial_seq).await.unwrap();
.await; assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_remove_mid_chain_block(variant: &str) {
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if variant == "flat" {
return;
}
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3 let index = make_indexer(variant);
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test] // Store [1, 2, 3, 4, 5]
#[apply(indexer_template)] index
async fn test_remove_nonexistent_worker(variant: &str) { .apply_event(make_store_event(0, &[1, 2, 3, 4, 5]))
let index = make_indexer(variant); .await;
// Store data for worker 0 flush_and_settle(index.as_ref()).await;
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
flush_and_settle(index.as_ref()).await; // Verify all 5 blocks match
let seq: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 5);
// Remove non-existent worker 999 - should not error or affect worker 0 // Remove only block 3 (index 2) — the middle of the chain
index.remove_worker(999).await; let full_hashes: Vec<LocalBlockHash> = (1..=5).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&full_hashes);
let block_3_seq_hash = ExternalSequenceBlockHash(seq_hashes[2]);
// Allow time for async processing let remove_event = remove_event(0, 0, 0, vec![block_3_seq_hash]);
flush_and_settle(index.as_ref()).await; index.apply_event(remove_event).await;
// Worker 0's data should still be there flush_and_settle(index.as_ref()).await;
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
#[tokio::test] // Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5)
#[apply(indexer_template)] let scores = index.find_matches(seq.clone()).await.unwrap();
async fn test_remove_nonexistent_blocks(variant: &str) { assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
let index = make_indexer(variant);
// Store [1, 2, 3] // Query [1, 2] — prefix before the gap is still intact
index.apply_event(make_store_event(0, &[1, 2, 3])).await; let prefix_seq: Vec<LocalBlockHash> = (1..=2).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_seq).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 2);
// Try to remove blocks [999, 998] that don't exist - should not error // Re-store block 3 as a continuation of [1, 2]
index.apply_event(make_remove_event(0, &[999, 998])).await; index
.apply_event(make_store_event_with_parent(0, &[1, 2], &[3]))
.await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
// Original data should still be there // Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); let scores = index.find_matches(seq).await.unwrap();
let scores = index.find_matches(seq).await.unwrap(); assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3); }
}
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_clear_then_reuse(variant: &str) { async fn test_remove_nonexistent_worker(variant: &str) {
let index = make_indexer(variant); let index = make_indexer(variant);
// Store initial data // Store data for worker 0
index.apply_event(make_store_event(0, &[1, 2, 3])).await; index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Clear the worker flush_and_settle(index.as_ref()).await;
index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await; // Remove non-existent worker 999 - should not error or affect worker 0
index.remove_worker(999).await;
// Verify data is gone // Allow time for async processing
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); flush_and_settle(index.as_ref()).await;
let scores = index.find_matches(seq.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Store new data for the same worker // Worker 0's data should still be there
index.apply_event(make_store_event(0, &[1, 2, 3])).await; let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert!(scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_remove_nonexistent_blocks(variant: &str) {
let index = make_indexer(variant);
// Verify new data is accessible // Store [1, 2, 3]
let scores = index.find_matches(seq).await.unwrap(); index.apply_event(make_store_event(0, &[1, 2, 3])).await;
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
#[tokio::test] // Try to remove blocks [999, 998] that don't exist - should not error
#[apply(indexer_template)] index.apply_event(make_remove_event(0, &[999, 998])).await;
async fn test_multiple_sequences_per_worker(variant: &str) {
let index = make_indexer(variant);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event(make_store_event(0, &[100, 101, 102]))
.await;
flush_and_settle(index.as_ref()).await;
// Query first sequence
let seq1: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let scores = index.find_matches(seq1).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query second sequence
let seq2: Vec<LocalBlockHash> = (100..=102).map(LocalBlockHash).collect();
let scores = index.find_matches(seq2).await.unwrap();
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
// Query a mix that doesn't exist as a sequence - should only match first block
let mixed: Vec<LocalBlockHash> = vec![LocalBlockHash(1), LocalBlockHash(100)];
let scores = index.find_matches(mixed).await.unwrap();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_clear_clears_all_dp_ranks(variant: &str) {
let index = make_indexer(variant);
// Store same sequence for different dp_ranks // Original data should still be there
index let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0)) let scores = index.find_matches(seq).await.unwrap();
.await; assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
index }
.apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.await;
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_clear_then_reuse(variant: &str) {
let index = make_indexer(variant);
// Verify both dp_ranks are present // Store initial data
let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect(); index.apply_event(make_store_event(0, &[1, 2, 3])).await;
let scores = index.find_matches(seq.clone()).await.unwrap();
assert_eq!(scores.scores.len(), 2);
// Clear event clears ALL blocks for the worker_id, regardless of dp_rank // Clear the worker
index.apply_event(make_clear_event_with_dp_rank(0, 0)).await; index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await; flush_and_settle(index.as_ref()).await;
// Both dp_ranks should be cleared // Verify data is gone
let scores = index.find_matches(seq).await.unwrap(); let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
assert!( let scores = index.find_matches(seq.clone()).await.unwrap();
scores.scores.is_empty(), assert!(scores.scores.is_empty());
"Cleared event should clear all dp_ranks for a worker"
);
}
// ============================================================================ // Store new data for the same worker
// LoRA isolation tests index.apply_event(make_store_event(0, &[1, 2, 3])).await;
// ============================================================================
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
// Same token sequence for both base model and LoRA adapter // Verify new data is accessible
let tokens: Vec<u32> = (0..kv_block_size * 3).collect(); let scores = index.find_matches(seq).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
}
let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None); #[tokio::test]
let lora_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter")); #[apply(indexer_template)]
async fn test_multiple_sequences_per_worker(variant: &str) {
let index = make_indexer(variant);
// Hashes must differ despite identical tokens // Store two disjoint sequences for the same worker
assert_ne!( // Sequence 1: [1, 2, 3]
base_hashes, lora_hashes, index.apply_event(make_store_event(0, &[1, 2, 3])).await;
"Base and LoRA hashes must differ for the same tokens" // Sequence 2: [100, 101, 102] (completely different, no parent)
); index
.apply_event(make_store_event(0, &[100, 101, 102]))
.await;
let base_seq = compute_seq_hash_for_block(&base_hashes); flush_and_settle(index.as_ref()).await;
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0 // Query first sequence
let base_event = router_event( let seq1: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
0, let scores = index.find_matches(seq1).await.unwrap();
0, assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}),
);
index.apply_event(base_event).await;
// Store LoRA blocks on worker 1 // Query second sequence
let lora_event = router_event( let seq2: Vec<LocalBlockHash> = (100..=102).map(LocalBlockHash).collect();
1, let scores = index.find_matches(seq2).await.unwrap();
0, assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 3);
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}),
);
index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await; // Query a mix that doesn't exist as a sequence - should only match first block
let mixed: Vec<LocalBlockHash> = vec![LocalBlockHash(1), LocalBlockHash(100)];
let scores = index.find_matches(mixed).await.unwrap();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!(*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1);
}
// Query with base-model hashes → only worker 0 #[tokio::test]
let base_scores = index.find_matches(base_hashes.clone()).await.unwrap(); #[apply(indexer_template)]
assert_eq!( async fn test_clear_clears_all_dp_ranks(variant: &str) {
base_scores.scores.len(), let index = make_indexer(variant);
1,
"Only base-model worker should match"
);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// Query with LoRA hashes → only worker 1 // Store same sequence for different dp_ranks
let lora_scores = index.find_matches(lora_hashes.clone()).await.unwrap(); index
assert_eq!(lora_scores.scores.len(), 1, "Only LoRA worker should match"); .apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 0))
assert_eq!( .await;
*lora_scores index
.scores .apply_event(make_store_event_with_dp_rank(0, &[1, 2, 3], 1))
.get(&WorkerWithDpRank::new(1, 0)) .await;
.unwrap(),
3
);
}
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform flush_and_settle(index.as_ref()).await;
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
assert_ne!(
base_local, lora_local,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
// Simulate what vLLM does: same tokens, different rolling seq hashes // Verify both dp_ranks are present
// because the engine accounts for the adapter internally. let seq: Vec<LocalBlockHash> = (1..=3).map(LocalBlockHash).collect();
let base_seq = compute_seq_hash_for_block(&base_local); let scores = index.find_matches(seq.clone()).await.unwrap();
let lora_seq = compute_seq_hash_for_block(&lora_local); assert_eq!(scores.scores.len(), 2);
// Worker 0: base model // Clear event clears ALL blocks for the worker_id, regardless of dp_rank
index index.apply_event(make_clear_event_with_dp_rank(0, 0)).await;
.apply_event(router_event(
0,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
}),
))
.await;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to flush_and_settle(index.as_ref()).await;
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event(router_event(
1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
}),
))
.await;
flush_and_settle(index.as_ref()).await;
// Base query finds only worker 0
let base_scores = index.find_matches(base_local.clone()).await.unwrap();
assert_eq!(base_scores.scores.len(), 1);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// LoRA query finds only worker 1 // Both dp_ranks should be cleared
let lora_scores = index.find_matches(lora_local.clone()).await.unwrap(); let scores = index.find_matches(seq).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1); assert!(
assert_eq!( scores.scores.is_empty(),
*lora_scores "Cleared event should clear all dp_ranks for a worker"
.scores );
.get(&WorkerWithDpRank::new(1, 0)) }
.unwrap(),
3
);
} }
#[tokio::test] // ============================================================================
#[apply(indexer_template)] // LoRA isolation tests
async fn test_different_lora_adapters_do_not_conflict(variant: &str) { // ============================================================================
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
let tokens: Vec<u32> = (0..kv_block_size * 2).collect(); mod lora_tests {
use super::*;
use rstest_reuse::apply;
let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a")); #[tokio::test]
let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b")); #[apply(indexer_template)]
async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let index = make_indexer(variant);
let kv_block_size: u32 = 32;
assert_ne!( // Same token sequence for both base model and LoRA adapter
hashes_a, hashes_b, let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
"Different adapters must produce different hashes"
);
let seq_a = compute_seq_hash_for_block(&hashes_a); let base_hashes = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let seq_b = compute_seq_hash_for_block(&hashes_b); let lora_hashes =
compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
// Store adapter-a blocks on worker 0 // Hashes must differ despite identical tokens
index assert_ne!(
.apply_event(router_event( base_hashes, lora_hashes,
"Base and LoRA hashes must differ for the same tokens"
);
let base_seq = compute_seq_hash_for_block(&base_hashes);
let lora_seq = compute_seq_hash_for_block(&lora_hashes);
// Store base-model blocks on worker 0
let base_event = router_event(
0, 0,
0, 0,
0, 0,
KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a), blocks: stored_blocks_with_sequence_hashes(&base_hashes, &base_seq),
}), }),
)) );
.await; index.apply_event(base_event).await;
// Store adapter-b blocks on worker 1 // Store LoRA blocks on worker 1
index let lora_event = router_event(
.apply_event(router_event(
1, 1,
0, 0,
0, 0,
KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b), blocks: stored_blocks_with_sequence_hashes(&lora_hashes, &lora_seq),
}), }),
)) );
.await; index.apply_event(lora_event).await;
flush_and_settle(index.as_ref()).await;
// Query adapter-a → only worker 0
let scores_a = index.find_matches(hashes_a.clone()).await.unwrap();
assert_eq!(scores_a.scores.len(), 1);
assert!(scores_a.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
assert!(!scores_a.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
// Query adapter-b → only worker 1
let scores_b = index.find_matches(hashes_b.clone()).await.unwrap();
assert_eq!(scores_b.scores.len(), 1);
assert!(scores_b.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
assert!(!scores_b.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_long_sequence_single_store(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence (128 blocks) in a single event
let seq_len = 128;
let sequence: Vec<u64> = (1..=seq_len).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Query full sequence - should match all blocks
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
// Query prefix (first 64 blocks) // Query with base-model hashes → only worker 0
let prefix_query: Vec<LocalBlockHash> = (1..=64).map(LocalBlockHash).collect(); let base_scores = index.find_matches(base_hashes.clone()).await.unwrap();
let scores = index.find_matches(prefix_query).await.unwrap(); assert_eq!(
assert_eq!( base_scores.scores.len(),
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1,
64 "Only base-model worker should match"
); );
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
// Query with divergence at position 50 // Query with LoRA hashes → only worker 1
let mut divergent_query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect(); let lora_scores = index.find_matches(lora_hashes.clone()).await.unwrap();
divergent_query[49] = LocalBlockHash(99999); // Position 49 (0-indexed) diverges assert_eq!(lora_scores.scores.len(), 1, "Only LoRA worker should match");
let scores = index.find_matches(divergent_query).await.unwrap(); assert_eq!(
assert_eq!( *lora_scores
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), .scores
49 .get(&WorkerWithDpRank::new(1, 0))
); .unwrap(),
} 3
);
}
#[tokio::test] /// Reproduces the "block_hash mismatch: sequence hashes should be uniform
#[apply(indexer_template)] /// across workers" warning seen when the same prompt is sent to both a base
async fn test_long_sequence_multiple_continuations(variant: &str) { /// model worker and a LoRA worker.
let index = make_indexer(variant); ///
/// On main (without LoRA-aware hashing), both workers compute the same
// Build a long sequence through multiple continuations /// LocalBlockHash for identical tokens. But vLLM's engine includes the
// First store: blocks 1-50 /// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
let first_chunk: Vec<u64> = (1..=50).collect(); /// sees conflicting sequence hashes at the same tree node.
index.apply_event(make_store_event(0, &first_chunk)).await; ///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
// Second store: blocks 51-100 (continuation of first) /// LocalBlockHash values for different adapters, so the blocks land on
let second_chunk: Vec<u64> = (51..=100).collect(); /// separate tree paths and no mismatch occurs.
index #[tokio::test]
.apply_event(make_store_event_with_parent(0, &first_chunk, &second_chunk)) #[apply(indexer_template)]
.await; async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
let index = make_indexer(variant);
// Third store: blocks 101-150 (continuation of second) let kv_block_size: u32 = 32;
let prefix_1_2: Vec<u64> = (1..=100).collect();
let third_chunk: Vec<u64> = (101..=150).collect(); let tokens: Vec<u32> = (0..kv_block_size * 3).collect();
index
.apply_event(make_store_event_with_parent(0, &prefix_1_2, &third_chunk)) // With LoRA-aware hashing, base and adapter produce different LocalBlockHash
.await; let base_local = compute_block_hash_for_seq(&tokens, kv_block_size, None, None);
let lora_local =
flush_and_settle(index.as_ref()).await; compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("my-adapter"));
// Query full sequence - should match all 150 blocks assert_ne!(
let full_query: Vec<LocalBlockHash> = (1..=150).map(LocalBlockHash).collect(); base_local, lora_local,
let scores = index.find_matches(full_query).await.unwrap(); "LoRA-aware hashing must produce different LocalBlockHash values"
assert_eq!(scores.scores.len(), 1); );
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
150
);
// Query crossing continuation boundaries // Simulate what vLLM does: same tokens, different rolling seq hashes
let cross_boundary_query: Vec<LocalBlockHash> = (45..=105).map(LocalBlockHash).collect(); // because the engine accounts for the adapter internally.
let scores = index.find_matches(cross_boundary_query).await.unwrap(); let base_seq = compute_seq_hash_for_block(&base_local);
// Query starts at block 45, but stored sequence starts at 1, so this won't match let lora_seq = compute_seq_hash_for_block(&lora_local);
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert!(scores.scores.is_empty() || !scores.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
#[tokio::test] // Worker 0: base model
#[apply(indexer_template)] index
async fn test_long_sequence_branching_continuations(variant: &str) { .apply_event(router_event(
let index = make_indexer(variant); 0,
0,
// Common prefix: blocks 1-30 0,
let common_prefix: Vec<u64> = (1..=30).collect(); KvCacheEventData::Stored(KvCacheStoreData {
index.apply_event(make_store_event(0, &common_prefix)).await; parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&base_local, &base_seq),
// Branch A: blocks 31-60 on worker 0 }),
let branch_a: Vec<u64> = (31..=60).collect(); ))
index .await;
.apply_event(make_store_event_with_parent(0, &common_prefix, &branch_a))
.await;
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index.apply_event(make_store_event(1, &common_prefix)).await;
let branch_b: Vec<u64> = (131..=160).collect();
index
.apply_event(make_store_event_with_parent(1, &common_prefix, &branch_b))
.await;
flush_and_settle(index.as_ref()).await;
// Query common prefix - both workers should match
let prefix_query: Vec<LocalBlockHash> = (1..=30).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
30
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
// Query branch A path - only worker 0 should match fully // Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
let branch_a_query: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect(); // a separate tree path instead of colliding with worker 0's node.
let scores = index.find_matches(branch_a_query).await.unwrap(); index
assert_eq!( .apply_event(router_event(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), 1,
60 0,
); 0,
assert_eq!( KvCacheEventData::Stored(KvCacheStoreData {
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), parent_hash: None,
30 blocks: stored_blocks_with_sequence_hashes(&lora_local, &lora_seq),
); }),
} ))
.await;
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_long_sequence_partial_removal(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence // Base query finds only worker 0
let sequence: Vec<u64> = (1..=100).collect(); let base_scores = index.find_matches(base_local.clone()).await.unwrap();
index.apply_event(make_store_event(0, &sequence)).await; assert_eq!(base_scores.scores.len(), 1);
assert_eq!(
*base_scores
.scores
.get(&WorkerWithDpRank::new(0, 0))
.unwrap(),
3
);
flush_and_settle(index.as_ref()).await; // LoRA query finds only worker 1
let lora_scores = index.find_matches(lora_local.clone()).await.unwrap();
assert_eq!(lora_scores.scores.len(), 1);
assert_eq!(
*lora_scores
.scores
.get(&WorkerWithDpRank::new(1, 0))
.unwrap(),
3
);
}
// Verify full match #[tokio::test]
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect(); #[apply(indexer_template)]
let scores = index.find_matches(full_query.clone()).await.unwrap(); async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
assert_eq!( let index = make_indexer(variant);
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), let kv_block_size: u32 = 32;
100
);
// Remove blocks 80-100 (the tail) let tokens: Vec<u32> = (0..kv_block_size * 2).collect();
let tail_hashes: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let seq_hashes = compute_seq_hash_for_block(&tail_hashes);
let remove_hashes: Vec<ExternalSequenceBlockHash> = seq_hashes[79..100]
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
let remove_event = remove_event(0, 0, 0, remove_hashes); let hashes_a = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-a"));
index.apply_event(remove_event).await; let hashes_b = compute_block_hash_for_seq(&tokens, kv_block_size, None, Some("adapter-b"));
flush_and_settle(index.as_ref()).await; assert_ne!(
hashes_a, hashes_b,
"Different adapters must produce different hashes"
);
// Query should now only match first 79 blocks let seq_a = compute_seq_hash_for_block(&hashes_a);
let scores = index.find_matches(full_query).await.unwrap(); let seq_b = compute_seq_hash_for_block(&hashes_b);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
79
);
}
#[tokio::test] // Store adapter-a blocks on worker 0
#[apply(indexer_template)] index
async fn test_long_sequence_interleaved_workers(variant: &str) { .apply_event(router_event(
let index = make_indexer(variant); 0,
0,
// Multiple workers storing overlapping long sequences concurrently 0,
// Worker 0: blocks 1-100 KvCacheEventData::Stored(KvCacheStoreData {
// Worker 1: blocks 1-75 parent_hash: None,
// Worker 2: blocks 1-50 blocks: stored_blocks_with_sequence_hashes(&hashes_a, &seq_a),
// Worker 3: blocks 1-25 }),
))
let seq_100: Vec<u64> = (1..=100).collect(); .await;
let seq_75: Vec<u64> = (1..=75).collect();
let seq_50: Vec<u64> = (1..=50).collect();
let seq_25: Vec<u64> = (1..=25).collect();
index.apply_event(make_store_event(0, &seq_100)).await;
index.apply_event(make_store_event(1, &seq_75)).await;
index.apply_event(make_store_event(2, &seq_50)).await;
index.apply_event(make_store_event(3, &seq_25)).await;
flush_and_settle(index.as_ref()).await;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let query_60: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(query_60).await.unwrap();
assert_eq!(scores.scores.len(), 4);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
50
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
25
);
}
#[tokio::test] // Store adapter-b blocks on worker 1
#[apply(indexer_template)] index
async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) { .apply_event(router_event(
let index = make_indexer(variant); 1,
0,
0,
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: stored_blocks_with_sequence_hashes(&hashes_b, &seq_b),
}),
))
.await;
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer) flush_and_settle(index.as_ref()).await;
// This tests edge cases in the jump search algorithm
// Store sequence of exactly 32 blocks // Query adapter-a → only worker 0
let seq_32: Vec<u64> = (1..=32).collect(); let scores_a = index.find_matches(hashes_a.clone()).await.unwrap();
index.apply_event(make_store_event(0, &seq_32)).await; assert_eq!(scores_a.scores.len(), 1);
assert!(scores_a.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
assert!(!scores_a.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
// Store sequence of exactly 64 blocks (2x jump_size) // Query adapter-b → only worker 1
let seq_64: Vec<u64> = (1001..=1064).collect(); let scores_b = index.find_matches(hashes_b.clone()).await.unwrap();
index.apply_event(make_store_event(1, &seq_64)).await; assert_eq!(scores_b.scores.len(), 1);
assert!(scores_b.scores.contains_key(&WorkerWithDpRank::new(1, 0)));
assert!(!scores_b.scores.contains_key(&WorkerWithDpRank::new(0, 0)));
}
}
// Store sequence of exactly 96 blocks (3x jump_size) // ============================================================================
let seq_96: Vec<u64> = (2001..=2096).collect(); // Long sequence tests - especially important for NestedMap/PositionalIndexer
index.apply_event(make_store_event(2, &seq_96)).await; // ============================================================================
flush_and_settle(index.as_ref()).await; mod long_sequence_tests {
use super::*;
use rstest_reuse::apply;
// Verify all sequences match correctly #[tokio::test]
let query_32: Vec<LocalBlockHash> = seq_32.iter().map(|&i| LocalBlockHash(i)).collect(); #[apply(indexer_template)]
let scores = index.find_matches(query_32).await.unwrap(); async fn test_long_sequence_single_store(variant: &str) {
assert_eq!( let index = make_indexer(variant);
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
32
);
let query_64: Vec<LocalBlockHash> = seq_64.iter().map(|&i| LocalBlockHash(i)).collect(); // Store a long sequence (128 blocks) in a single event
let scores = index.find_matches(query_64).await.unwrap(); let seq_len = 128;
assert_eq!( let sequence: Vec<u64> = (1..=seq_len).collect();
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), index.apply_event(make_store_event(0, &sequence)).await;
64
);
let query_96: Vec<LocalBlockHash> = seq_96.iter().map(|&i| LocalBlockHash(i)).collect(); flush_and_settle(index.as_ref()).await;
let scores = index.find_matches(query_96).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
96
);
}
#[tokio::test] // Query full sequence - should match all blocks
#[apply(indexer_template)] let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) { let scores = index.find_matches(full_query).await.unwrap();
let index = make_indexer(variant); assert_eq!(scores.scores.len(), 1);
assert_eq!(
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
let seq_31: Vec<u64> = (1..=31).collect(); seq_len as u32
let seq_33: Vec<u64> = (101..=133).collect(); );
let seq_63: Vec<u64> = (201..=263).collect();
let seq_65: Vec<u64> = (301..=365).collect();
index.apply_event(make_store_event(0, &seq_31)).await;
index.apply_event(make_store_event(1, &seq_33)).await;
index.apply_event(make_store_event(2, &seq_63)).await;
index.apply_event(make_store_event(3, &seq_65)).await;
flush_and_settle(index.as_ref()).await;
// Verify all sequences match correctly
let query_31: Vec<LocalBlockHash> = seq_31.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_31).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
31
);
let query_33: Vec<LocalBlockHash> = seq_33.iter().map(|&i| LocalBlockHash(i)).collect(); // Query prefix (first 64 blocks)
let scores = index.find_matches(query_33).await.unwrap(); let prefix_query: Vec<LocalBlockHash> = (1..=64).map(LocalBlockHash).collect();
assert_eq!( let scores = index.find_matches(prefix_query).await.unwrap();
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), assert_eq!(
33 *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
); 64
);
let query_63: Vec<LocalBlockHash> = seq_63.iter().map(|&i| LocalBlockHash(i)).collect(); // Query with divergence at position 50
let scores = index.find_matches(query_63).await.unwrap(); let mut divergent_query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
assert_eq!( divergent_query[49] = LocalBlockHash(99999); // Position 49 (0-indexed) diverges
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(), let scores = index.find_matches(divergent_query).await.unwrap();
63 assert_eq!(
); *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
49
);
}
let query_65: Vec<LocalBlockHash> = seq_65.iter().map(|&i| LocalBlockHash(i)).collect(); #[tokio::test]
let scores = index.find_matches(query_65).await.unwrap(); #[apply(indexer_template)]
assert_eq!( async fn test_long_sequence_multiple_continuations(variant: &str) {
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(), let index = make_indexer(variant);
65
);
}
#[tokio::test] // Build a long sequence through multiple continuations
#[apply(indexer_template)] // First store: blocks 1-50
async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) { let first_chunk: Vec<u64> = (1..=50).collect();
let index = make_indexer(variant); index.apply_event(make_store_event(0, &first_chunk)).await;
// Store a long sequence // Second store: blocks 51-100 (continuation of first)
let sequence: Vec<u64> = (1..=128).collect(); let second_chunk: Vec<u64> = (51..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await; index
.apply_event(make_store_event_with_parent(0, &first_chunk, &second_chunk))
.await;
flush_and_settle(index.as_ref()).await; // Third store: blocks 101-150 (continuation of second)
let prefix_1_2: Vec<u64> = (1..=100).collect();
let third_chunk: Vec<u64> = (101..=150).collect();
index
.apply_event(make_store_event_with_parent(0, &prefix_1_2, &third_chunk))
.await;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65) flush_and_settle(index.as_ref()).await;
for diverge_pos in [31usize, 32, 33, 63, 64, 65, 95, 96, 97] {
let mut query: Vec<LocalBlockHash> = (1..=128).map(LocalBlockHash).collect();
query[diverge_pos] = LocalBlockHash(99999);
let scores = index.find_matches(query).await.unwrap(); // Query full sequence - should match all 150 blocks
let full_query: Vec<LocalBlockHash> = (1..=150).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(scores.scores.len(), 1);
assert_eq!( assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
diverge_pos as u32, 150
"Divergence at position {} should match {} blocks",
diverge_pos,
diverge_pos
); );
}
}
#[tokio::test] // Query crossing continuation boundaries
#[apply(indexer_template)] let cross_boundary_query: Vec<LocalBlockHash> = (45..=105).map(LocalBlockHash).collect();
async fn test_long_sequence_deep_continuation_chain(variant: &str) { let scores = index.find_matches(cross_boundary_query).await.unwrap();
let index = make_indexer(variant); // Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert!(
scores.scores.is_empty() || !scores.scores.contains_key(&WorkerWithDpRank::new(0, 0))
);
}
// Build a very long sequence through many small continuations #[tokio::test]
// This tests the parent_hash chain handling #[apply(indexer_template)]
let chunk_size = 10; async fn test_long_sequence_branching_continuations(variant: &str) {
let num_chunks = 20; // Total 200 blocks let index = make_indexer(variant);
let mut full_prefix: Vec<u64> = Vec::new(); // Common prefix: blocks 1-30
let common_prefix: Vec<u64> = (1..=30).collect();
index.apply_event(make_store_event(0, &common_prefix)).await;
for chunk_idx in 0..num_chunks { // Branch A: blocks 31-60 on worker 0
let chunk_start = chunk_idx * chunk_size + 1; let branch_a: Vec<u64> = (31..=60).collect();
let chunk: Vec<u64> = (chunk_start..chunk_start + chunk_size) index
.map(|x| x as u64) .apply_event(make_store_event_with_parent(0, &common_prefix, &branch_a))
.collect(); .await;
if chunk_idx == 0 { // Branch B: blocks 131-160 (different content) on worker 1
index.apply_event(make_store_event(0, &chunk)).await; // First store the common prefix for worker 1
} else { index.apply_event(make_store_event(1, &common_prefix)).await;
index let branch_b: Vec<u64> = (131..=160).collect();
.apply_event(make_store_event_with_parent(0, &full_prefix, &chunk)) index
.await; .apply_event(make_store_event_with_parent(1, &common_prefix, &branch_b))
} .await;
flush_and_settle(index.as_ref()).await;
full_prefix.extend(&chunk); // Query common prefix - both workers should match
let prefix_query: Vec<LocalBlockHash> = (1..=30).map(LocalBlockHash).collect();
let scores = index.find_matches(prefix_query).await.unwrap();
assert_eq!(scores.scores.len(), 2);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
30
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
// Query branch A path - only worker 0 should match fully
let branch_a_query: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(branch_a_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
30
);
} }
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_partial_removal(variant: &str) {
let index = make_indexer(variant);
// Query full sequence // Store a long sequence
let full_query: Vec<LocalBlockHash> = (1..=200).map(LocalBlockHash).collect(); let sequence: Vec<u64> = (1..=100).collect();
let scores = index.find_matches(full_query).await.unwrap(); index.apply_event(make_store_event(0, &sequence)).await;
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
200
);
// Query partial prefix crossing multiple chunk boundaries flush_and_settle(index.as_ref()).await;
let partial_query: Vec<LocalBlockHash> = (1..=75).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
75
);
}
#[tokio::test] // Verify full match
#[apply(indexer_template)] let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
async fn test_long_sequence_clear_and_rebuild(variant: &str) { let scores = index.find_matches(full_query.clone()).await.unwrap();
let index = make_indexer(variant); assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Store a long sequence // Remove blocks 80-100 (the tail)
let sequence: Vec<u64> = (1..=100).collect(); let tail_hashes: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
index.apply_event(make_store_event(0, &sequence)).await; let seq_hashes = compute_seq_hash_for_block(&tail_hashes);
let remove_hashes: Vec<ExternalSequenceBlockHash> = seq_hashes[79..100]
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect();
flush_and_settle(index.as_ref()).await; let remove_event = remove_event(0, 0, 0, remove_hashes);
index.apply_event(remove_event).await;
// Verify it's stored flush_and_settle(index.as_ref()).await;
let query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Clear the worker // Query should now only match first 79 blocks
index.apply_event(make_clear_event(0)).await; let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
79
);
}
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_interleaved_workers(variant: &str) {
let index = make_indexer(variant);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let seq_100: Vec<u64> = (1..=100).collect();
let seq_75: Vec<u64> = (1..=75).collect();
let seq_50: Vec<u64> = (1..=50).collect();
let seq_25: Vec<u64> = (1..=25).collect();
index.apply_event(make_store_event(0, &seq_100)).await;
index.apply_event(make_store_event(1, &seq_75)).await;
index.apply_event(make_store_event(2, &seq_50)).await;
index.apply_event(make_store_event(3, &seq_25)).await;
flush_and_settle(index.as_ref()).await;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let query_60: Vec<LocalBlockHash> = (1..=60).map(LocalBlockHash).collect();
let scores = index.find_matches(query_60).await.unwrap();
assert_eq!(scores.scores.len(), 4);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
60
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
50
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
25
);
}
// Verify it's cleared #[tokio::test]
let scores = index.find_matches(query.clone()).await.unwrap(); #[apply(indexer_template)]
assert!(scores.scores.is_empty()); async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) {
let index = make_indexer(variant);
// Rebuild with a different sequence // Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
let new_sequence: Vec<u64> = (1001..=1100).collect(); // This tests edge cases in the jump search algorithm
index.apply_event(make_store_event(0, &new_sequence)).await;
flush_and_settle(index.as_ref()).await; // Store sequence of exactly 32 blocks
let seq_32: Vec<u64> = (1..=32).collect();
index.apply_event(make_store_event(0, &seq_32)).await;
// Verify new sequence works // Store sequence of exactly 64 blocks (2x jump_size)
let new_query: Vec<LocalBlockHash> = new_sequence.iter().map(|&i| LocalBlockHash(i)).collect(); let seq_64: Vec<u64> = (1001..=1064).collect();
let scores = index.find_matches(new_query).await.unwrap(); index.apply_event(make_store_event(1, &seq_64)).await;
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Verify old sequence no longer matches // Store sequence of exactly 96 blocks (3x jump_size)
let scores = index.find_matches(query).await.unwrap(); let seq_96: Vec<u64> = (2001..=2096).collect();
assert!(scores.scores.is_empty()); index.apply_event(make_store_event(2, &seq_96)).await;
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
let index = make_indexer(variant);
// Multiple workers with long sequences that share a prefix then diverge // Verify all sequences match correctly
// This tests precise drain point tracking across workers let query_32: Vec<LocalBlockHash> = seq_32.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query_32).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
32
);
// All workers share prefix 1-40 let query_64: Vec<LocalBlockHash> = seq_64.iter().map(|&i| LocalBlockHash(i)).collect();
let shared_prefix: Vec<u64> = (1..=40).collect(); let scores = index.find_matches(query_64).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
64
);
// Worker 0: prefix + 41-100 (stores full sequence 1-100) let query_96: Vec<LocalBlockHash> = seq_96.iter().map(|&i| LocalBlockHash(i)).collect();
let worker_0_full: Vec<u64> = (1..=100).collect(); let scores = index.find_matches(query_96).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
96
);
}
// Worker 1: prefix + 141-180 (diverges at block 41) #[tokio::test]
let worker_1_suffix: Vec<u64> = (141..=180).collect(); #[apply(indexer_template)]
async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) {
let index = make_indexer(variant);
// Worker 2: prefix + 241-300 (diverges at block 41) // Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let worker_2_suffix: Vec<u64> = (241..=300).collect(); let seq_31: Vec<u64> = (1..=31).collect();
let seq_33: Vec<u64> = (101..=133).collect();
let seq_63: Vec<u64> = (201..=263).collect();
let seq_65: Vec<u64> = (301..=365).collect();
// Store for all workers index.apply_event(make_store_event(0, &seq_31)).await;
index.apply_event(make_store_event(0, &worker_0_full)).await; index.apply_event(make_store_event(1, &seq_33)).await;
index.apply_event(make_store_event(2, &seq_63)).await;
index.apply_event(make_store_event(3, &seq_65)).await;
index.apply_event(make_store_event(1, &shared_prefix)).await; flush_and_settle(index.as_ref()).await;
index
.apply_event(make_store_event_with_parent( // Verify all sequences match correctly
1, let query_31: Vec<LocalBlockHash> = seq_31.iter().map(|&i| LocalBlockHash(i)).collect();
&shared_prefix, let scores = index.find_matches(query_31).await.unwrap();
&worker_1_suffix, assert_eq!(
)) *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
.await; 31
);
index.apply_event(make_store_event(2, &shared_prefix)).await;
index let query_33: Vec<LocalBlockHash> = seq_33.iter().map(|&i| LocalBlockHash(i)).collect();
.apply_event(make_store_event_with_parent( let scores = index.find_matches(query_33).await.unwrap();
2, assert_eq!(
&shared_prefix, *scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
&worker_2_suffix, 33
)) );
.await;
let query_63: Vec<LocalBlockHash> = seq_63.iter().map(|&i| LocalBlockHash(i)).collect();
flush_and_settle(index.as_ref()).await; let scores = index.find_matches(query_63).await.unwrap();
assert_eq!(
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40 *scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
let query: Vec<LocalBlockHash> = worker_0_full.iter().map(|&i| LocalBlockHash(i)).collect(); 63
let scores = index.find_matches(query).await.unwrap(); );
assert_eq!( let query_65: Vec<LocalBlockHash> = seq_65.iter().map(|&i| LocalBlockHash(i)).collect();
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), let scores = index.find_matches(query_65).await.unwrap();
100 assert_eq!(
); *scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
assert_eq!( 65
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), );
40 }
);
assert_eq!( #[tokio::test]
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(), #[apply(indexer_template)]
40 async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) {
); let index = make_indexer(variant);
}
// Store a long sequence
let sequence: Vec<u64> = (1..=128).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for diverge_pos in [31usize, 32, 33, 63, 64, 65, 95, 96, 97] {
let mut query: Vec<LocalBlockHash> = (1..=128).map(LocalBlockHash).collect();
query[diverge_pos] = LocalBlockHash(99999);
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
diverge_pos as u32,
"Divergence at position {} should match {} blocks",
diverge_pos,
diverge_pos
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_deep_continuation_chain(variant: &str) {
let index = make_indexer(variant);
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let chunk_size = 10;
let num_chunks = 20; // Total 200 blocks
let mut full_prefix: Vec<u64> = Vec::new();
for chunk_idx in 0..num_chunks {
let chunk_start = chunk_idx * chunk_size + 1;
let chunk: Vec<u64> = (chunk_start..chunk_start + chunk_size)
.map(|x| x as u64)
.collect();
if chunk_idx == 0 {
index.apply_event(make_store_event(0, &chunk)).await;
} else {
index
.apply_event(make_store_event_with_parent(0, &full_prefix, &chunk))
.await;
}
full_prefix.extend(&chunk);
}
flush_and_settle(index.as_ref()).await;
// Query full sequence
let full_query: Vec<LocalBlockHash> = (1..=200).map(LocalBlockHash).collect();
let scores = index.find_matches(full_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
200
);
// Query partial prefix crossing multiple chunk boundaries
let partial_query: Vec<LocalBlockHash> = (1..=75).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
75
);
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_clear_and_rebuild(variant: &str) {
let index = make_indexer(variant);
// Store a long sequence
let sequence: Vec<u64> = (1..=100).collect();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await;
// Verify it's stored
let query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query.clone()).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Clear the worker
index.apply_event(make_clear_event(0)).await;
flush_and_settle(index.as_ref()).await;
// Verify it's cleared
let scores = index.find_matches(query.clone()).await.unwrap();
assert!(scores.scores.is_empty());
// Rebuild with a different sequence
let new_sequence: Vec<u64> = (1001..=1100).collect();
index.apply_event(make_store_event(0, &new_sequence)).await;
flush_and_settle(index.as_ref()).await;
// Verify new sequence works
let new_query: Vec<LocalBlockHash> =
new_sequence.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(new_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
// Verify old sequence no longer matches
let scores = index.find_matches(query).await.unwrap();
assert!(scores.scores.is_empty());
}
#[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
let index = make_indexer(variant);
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// All workers share prefix 1-40
let shared_prefix: Vec<u64> = (1..=40).collect();
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let worker_0_full: Vec<u64> = (1..=100).collect();
// Worker 1: prefix + 141-180 (diverges at block 41)
let worker_1_suffix: Vec<u64> = (141..=180).collect();
#[tokio::test] // Worker 2: prefix + 241-300 (diverges at block 41)
#[apply(indexer_template)] let worker_2_suffix: Vec<u64> = (241..=300).collect();
async fn test_long_sequence_staggered_lengths(variant: &str) {
let index = make_indexer(variant);
// Workers with sequences of staggered lengths to test drain tracking // Store for all workers
// Worker 0: 10 blocks index.apply_event(make_store_event(0, &worker_0_full)).await;
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
for (worker_id, len) in [(0, 10), (1, 20), (2, 35), (3, 64), (4, 100)] { index.apply_event(make_store_event(1, &shared_prefix)).await;
let sequence: Vec<u64> = (1..=len).collect();
index index
.apply_event(make_store_event(worker_id, &sequence)) .apply_event(make_store_event_with_parent(
1,
&shared_prefix,
&worker_1_suffix,
))
.await; .await;
index.apply_event(make_store_event(2, &shared_prefix)).await;
index
.apply_event(make_store_event_with_parent(
2,
&shared_prefix,
&worker_2_suffix,
))
.await;
flush_and_settle(index.as_ref()).await;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let query: Vec<LocalBlockHash> = worker_0_full.iter().map(|&i| LocalBlockHash(i)).collect();
let scores = index.find_matches(query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
100
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
40
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
40
);
} }
flush_and_settle(index.as_ref()).await; #[tokio::test]
#[apply(indexer_template)]
async fn test_long_sequence_staggered_lengths(variant: &str) {
let index = make_indexer(variant);
// Query for 100 blocks - each worker should match their stored length // Workers with sequences of staggered lengths to test drain tracking
let query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect(); // Worker 0: 10 blocks
let scores = index.find_matches(query).await.unwrap(); // Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
assert_eq!( for (worker_id, len) in [(0, 10), (1, 20), (2, 35), (3, 64), (4, 100)] {
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), let sequence: Vec<u64> = (1..=len).collect();
10 index
); .apply_event(make_store_event(worker_id, &sequence))
assert_eq!( .await;
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(), }
20
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
35
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
64
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(4, 0)).unwrap(),
100
);
}
#[tokio::test] flush_and_settle(index.as_ref()).await;
#[apply(indexer_template)]
async fn test_very_long_sequence(variant: &str) {
let index = make_indexer(variant);
// Test with a very long sequence (1000 blocks) // Query for 100 blocks - each worker should match their stored length
let seq_len = 1000u64; let query: Vec<LocalBlockHash> = (1..=100).map(LocalBlockHash).collect();
let sequence: Vec<u64> = (1..=seq_len).collect(); let scores = index.find_matches(query).await.unwrap();
index.apply_event(make_store_event(0, &sequence)).await;
flush_and_settle(index.as_ref()).await; assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
10
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(1, 0)).unwrap(),
20
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(2, 0)).unwrap(),
35
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(3, 0)).unwrap(),
64
);
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(4, 0)).unwrap(),
100
);
}
// Full match #[tokio::test]
let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect(); #[apply(indexer_template)]
let scores = index.find_matches(full_query).await.unwrap(); async fn test_very_long_sequence(variant: &str) {
assert_eq!( let index = make_indexer(variant);
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
seq_len as u32
);
// Partial match (first 500) // Test with a very long sequence (1000 blocks)
let partial_query: Vec<LocalBlockHash> = (1..=500).map(LocalBlockHash).collect(); let seq_len = 1000u64;
let scores = index.find_matches(partial_query).await.unwrap(); let sequence: Vec<u64> = (1..=seq_len).collect();
assert_eq!( index.apply_event(make_store_event(0, &sequence)).await;
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
500
);
// Divergence in the middle flush_and_settle(index.as_ref()).await;
let mut mid_diverge: Vec<LocalBlockHash> = (1..=1000).map(LocalBlockHash).collect();
mid_diverge[499] = LocalBlockHash(99999); // Full match
let scores = index.find_matches(mid_diverge).await.unwrap(); let full_query: Vec<LocalBlockHash> = sequence.iter().map(|&i| LocalBlockHash(i)).collect();
assert_eq!( let scores = index.find_matches(full_query).await.unwrap();
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(), assert_eq!(
499 *scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
); seq_len as u32
);
// Partial match (first 500)
let partial_query: Vec<LocalBlockHash> = (1..=500).map(LocalBlockHash).collect();
let scores = index.find_matches(partial_query).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
500
);
// Divergence in the middle
let mut mid_diverge: Vec<LocalBlockHash> = (1..=1000).map(LocalBlockHash).collect();
mid_diverge[499] = LocalBlockHash(99999);
let scores = index.find_matches(mid_diverge).await.unwrap();
assert_eq!(
*scores.scores.get(&WorkerWithDpRank::new(0, 0)).unwrap(),
499
);
}
} }
// ============================================================================ // ============================================================================
...@@ -1670,129 +1690,146 @@ fn make_tree_indexer_with_frequency( ...@@ -1670,129 +1690,146 @@ fn make_tree_indexer_with_frequency(
} }
} }
#[tokio::test] mod tree_specific_tests {
#[apply(tree_indexer_template)] use super::*;
async fn test_frequency(variant: &str) { use rstest_reuse::apply;
const ONE_MILLIS: Duration = Duration::from_millis(1);
let expiration = Duration::from_millis(50);
let kv_indexer = make_tree_indexer_with_frequency(variant, expiration);
// The blocks
let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(
overlap.frequencies.len(),
0,
"Should be no cached blocks yet"
);
// Blocks go in cache #[tokio::test]
let event = make_store_event(0, &[1, 2, 3, 4]); #[apply(tree_indexer_template)]
kv_indexer.apply_event(event).await; async fn test_frequency(variant: &str) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
// First access - poll briefly since store event is applied async
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
);
// Second access let expiration = Duration::from_millis(50);
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); let kv_indexer = make_tree_indexer_with_frequency(variant, expiration);
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
);
// Let those two accesses expire // The blocks
time::sleep(expiration + Duration::from_millis(10)).await; let block_hashes = vec![
LocalBlockHash(1),
LocalBlockHash(2),
LocalBlockHash(3),
LocalBlockHash(4),
];
// New first access let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); assert_eq!(
assert_eq!( overlap.frequencies.len(),
overlap.frequencies.len(), 0,
0, "Should be no cached blocks yet"
"Blocks were accessed too long ago" );
);
// Blocks go in cache
let event = make_store_event(0, &[1, 2, 3, 4]);
kv_indexer.apply_event(event).await;
// First access - poll briefly since store event is applied async
let mut overlap = OverlapScores::default();
let timeout = Duration::from_millis(10);
let start = Instant::now();
while overlap.scores.is_empty() && Instant::now().duration_since(start) < timeout {
time::sleep(ONE_MILLIS).await;
overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
}
assert_eq!(
overlap.scores.len(),
1,
"One worker has these blocks cached"
);
assert_eq!(
overlap.frequencies.len(),
0,
"Blocks have not previously been accessed"
);
// New second access // Second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.scores.len(), 1, "Still one worker matches");
assert_eq!(
overlap.frequencies,
vec![1, 1, 1, 1],
"We should see the first access now"
);
// Access only the first three blocks // Let those two accesses expire
let overlap = kv_indexer time::sleep(expiration + Duration::from_millis(10)).await;
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// The third access did not touch the last block // New first access
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap(); let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]); assert_eq!(
overlap.frequencies.len(),
0,
"Blocks were accessed too long ago"
);
// New second access
let _ = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
// Access only the first three blocks
let overlap = kv_indexer
.find_matches(block_hashes[0..3].to_vec())
.await
.unwrap();
// We see the previous two new accesses
assert_eq!(overlap.frequencies, vec![2, 2, 2]);
// The third access did not touch the last block
let overlap = kv_indexer.find_matches(block_hashes.clone()).await.unwrap();
assert_eq!(overlap.frequencies, vec![3, 3, 3, 2]);
}
} }
// ============================================================================ // ============================================================================
// KvIndexerMetrics tests // KvIndexerMetrics tests
// ============================================================================ // ============================================================================
#[cfg(feature = "metrics")] mod metrics_tests {
#[test] #[cfg(feature = "metrics")]
fn test_increment_event_applied() { use super::*;
let metrics = KvIndexerMetrics::new_unregistered();
metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(())); #[cfg(feature = "metrics")]
assert_eq!( #[test]
metrics fn test_increment_event_applied() {
.kv_cache_events_applied let metrics = KvIndexerMetrics::new_unregistered();
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
.unwrap()
.get(),
1
);
metrics.increment_event_applied( metrics.increment_event_applied(METRIC_EVENT_STORED, Ok(()));
METRIC_EVENT_STORED, assert_eq!(
Err(KvCacheEventError::ParentBlockNotFound), metrics
); .kv_cache_events_applied
assert_eq!( .get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_OK])
metrics .unwrap()
.kv_cache_events_applied .get(),
.get_metric_with_label_values(&[METRIC_EVENT_STORED, METRIC_STATUS_PARENT_NOT_FOUND]) 1
.unwrap() );
.get(),
1 metrics.increment_event_applied(
); METRIC_EVENT_STORED,
Err(KvCacheEventError::ParentBlockNotFound),
);
assert_eq!(
metrics
.kv_cache_events_applied
.get_metric_with_label_values(&[
METRIC_EVENT_STORED,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap()
.get(),
1
);
metrics.increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
assert_eq!(
metrics metrics
.kv_cache_events_applied .increment_event_applied(METRIC_EVENT_REMOVED, Err(KvCacheEventError::BlockNotFound));
.get_metric_with_label_values(&[METRIC_EVENT_REMOVED, METRIC_STATUS_BLOCK_NOT_FOUND]) assert_eq!(
.unwrap() metrics
.get(), .kv_cache_events_applied
1 .get_metric_with_label_values(&[
); METRIC_EVENT_REMOVED,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap()
.get(),
1
);
}
} }
// ============================================================================ // ============================================================================
...@@ -1822,363 +1859,368 @@ fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer { ...@@ -1822,363 +1859,368 @@ fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
indexer indexer
} }
#[tokio::test] mod local_indexer_tests {
async fn test_local_indexer_slice_within_range() { use super::*;
let indexer = make_local_indexer_with_events(&[1, 2, 3, 4, 5]); use rstest_reuse::apply;
#[tokio::test]
async fn test_local_indexer_slice_within_range() {
let indexer = make_local_indexer_with_events(&[1, 2, 3, 4, 5]);
// Helper to extract events from response
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
match resp {
WorkerKvQueryResponse::Events(e) => e,
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type"),
}
};
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await;
let ids = get_ids(extract_events(result));
assert_eq!(ids, vec![3]); // single element when start == end
// Invalid range: end < start
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
}
// Helper to extract events from response #[tokio::test]
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> { async fn test_local_indexer_get_events_in_id_range_all_cases() {
match resp { // Create indexer with small buffer (5 events max)
WorkerKvQueryResponse::Events(e) => e, let indexer = LocalKvIndexer::new(
WorkerKvQueryResponse::TreeDump { events: e, .. } => e, CancellationToken::new(),
_ => panic!("Unexpected response type"), 4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
} }
};
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> { // Wait for events to be processed
events.iter().map(|e| e.event.event_id).collect() indexer.flush().await;
};
// Test get_events_in_id_range (buffer queries) let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> {
// Range is [start, end] inclusive match resp {
let result = indexer.get_events_in_id_range(Some(2), Some(4)).await; WorkerKvQueryResponse::Events(e) => e,
let ids = get_ids(extract_events(result)); WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4] _ => panic!("Unexpected response type: {:?}", resp),
}
};
let result = indexer.get_events_in_id_range(Some(2), Some(6)).await; let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
let ids = get_ids(extract_events(result)); events.iter().map(|e| e.event.event_id).collect()
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max };
// start_id=0 is before buffer (first is 1), so should trigger tree dump // Verify buffer state
let result = indexer.get_events_in_id_range(Some(0), Some(4)).await; let buffer_events = indexer.get_all_events_in_buffer();
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. })); assert_eq!(get_ids(buffer_events), vec![10, 11, 12, 13, 14]);
let result = indexer.get_events_in_id_range(Some(3), Some(3)).await; // Buffer path tests
let ids = get_ids(extract_events(result)); let result = indexer.get_events_in_id_range(Some(11), None).await;
assert_eq!(ids, vec![3]); // single element when start == end assert_eq!(get_ids(extract_events(result)), vec![11, 12, 13, 14]);
// Invalid range: end < start let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
let result = indexer.get_events_in_id_range(Some(5), Some(2)).await; assert_eq!(get_ids(extract_events(result)), vec![10, 11, 12, 13, 14]);
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
}
#[tokio::test] // Tree dump path tests
async fn test_local_indexer_get_events_in_id_range_all_cases() { let result = indexer.get_events_in_id_range(None, None).await;
// Create indexer with small buffer (5 events max) assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
let indexer = LocalKvIndexer::new( assert_eq!(extract_events(result).len(), 10);
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
// Helper to create a test event let result = indexer.get_events_in_id_range(Some(7), None).await;
let make_event = |id: u64| { assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
RouterEvent::new(
0, // Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
#[tokio::test]
async fn test_tree_dump_includes_last_event_id() {
// Create indexer with small buffer (5 events max)
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let make_event = |id: u64| {
RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Empty buffer case: create a fresh indexer with no events
let empty_indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
5,
);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 0,
"empty buffer should return last_event_id=0"
);
assert!(events.is_empty(), "empty indexer should have no events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64;
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token, 4, metrics, 100));
let test_event = RouterEvent::new(
worker_id,
KvCacheEvent { KvCacheEvent {
event_id: id, event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(id * 200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None, mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}, },
) );
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14 local_indexer
for id in 5..15 { .apply_event_with_buffer(test_event)
indexer
.apply_event_with_buffer(make_event(id))
.await .await
.unwrap(); .unwrap();
}
// Wait for events to be processed local_indexer.flush().await;
indexer.flush().await;
let extract_events = |resp: WorkerKvQueryResponse| -> Vec<RouterEvent> { let buffered_events = local_indexer.get_all_events_in_buffer();
match resp { assert_eq!(buffered_events.len(), 1);
WorkerKvQueryResponse::Events(e) => e, assert_eq!(buffered_events[0].worker_id, worker_id);
WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
_ => panic!("Unexpected response type: {:?}", resp),
}
};
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> { // Test serialization round-trip
events.iter().map(|e| e.event.event_id).collect() let response = WorkerKvQueryResponse::Events(buffered_events);
}; let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
// Verify buffer state let events = match deserialized {
let buffer_events = indexer.get_all_events_in_buffer(); WorkerKvQueryResponse::Events(e) => e,
assert_eq!(get_ids(buffer_events), vec![10, 11, 12, 13, 14]); _ => panic!("Expected Events variant"),
};
// Buffer path tests assert_eq!(events.len(), 1);
let result = indexer.get_events_in_id_range(Some(11), None).await; assert_eq!(events[0].worker_id, worker_id);
assert_eq!(get_ids(extract_events(result)), vec![11, 12, 13, 14]); }
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert_eq!(get_ids(extract_events(result)), vec![10, 11, 12, 13, 14]);
// Tree dump path tests
let result = indexer.get_events_in_id_range(None, None).await;
assert!(matches!(&result, WorkerKvQueryResponse::TreeDump { .. }));
assert_eq!(extract_events(result).len(), 10);
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert!(matches!(result, WorkerKvQueryResponse::TreeDump { .. }));
// Edge cases
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(matches!(result, WorkerKvQueryResponse::InvalidRange { .. }));
let result = indexer.get_events_in_id_range(Some(100), Some(200)).await;
assert!(matches!(result, WorkerKvQueryResponse::TooNew { .. }));
}
#[tokio::test] #[tokio::test]
async fn test_tree_dump_includes_last_event_id() { async fn test_local_indexer_does_not_buffer_failed_send() {
// Create indexer with small buffer (5 events max) let local_indexer = LocalKvIndexer::new(
let indexer = LocalKvIndexer::new( CancellationToken::new(),
CancellationToken::new(), 4,
4, Arc::new(KvIndexerMetrics::new_unregistered()),
Arc::new(KvIndexerMetrics::new_unregistered()), 5,
5, );
);
let make_event = |id: u64| { let test_event = RouterEvent::new(
RouterEvent::new( 7,
0,
KvCacheEvent { KvCacheEvent {
event_id: id, event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, parent_hash: None,
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(id * 200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None, mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
}, },
) );
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
indexer.flush().await;
// Request with start_id=None -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(None, None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let result = indexer.get_events_in_id_range(Some(7), None).await;
match result {
WorkerKvQueryResponse::TreeDump {
last_event_id,
events,
} => {
assert_eq!(
last_event_id, 14,
"last_event_id should be the buffer's newest event ID"
);
assert!(!events.is_empty(), "tree dump should contain events");
}
other => panic!("Expected TreeDump, got: {other:?}"),
}
// Empty buffer case: create a fresh indexer with no events let event_tx = local_indexer.event_sender();
let empty_indexer = LocalKvIndexer::new( local_indexer.shutdown();
CancellationToken::new(), event_tx.closed().await;
4,
Arc::new(KvIndexerMetrics::new_unregistered()), let result = local_indexer.apply_event_with_buffer(test_event).await;
5, assert!(matches!(result, Err(KvRouterError::IndexerOffline)));
); assert_eq!(local_indexer.buffer_len(), 0);
let result = empty_indexer.get_events_in_id_range(None, None).await;
match result { match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump { WorkerKvQueryResponse::TreeDump {
last_event_id, events,
events, last_event_id,
} => { } => {
assert_eq!( assert!(events.is_empty());
last_event_id, 0, assert_eq!(last_event_id, 0);
"empty buffer should return last_event_id=0" }
); other => panic!("Expected TreeDump, got: {other:?}"),
assert!(events.is_empty(), "empty indexer should have no events");
} }
other => panic!("Expected TreeDump, got: {other:?}"),
} }
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
let worker_id = 42u64;
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token, 4, metrics, 100));
let test_event = RouterEvent::new( #[tokio::test]
worker_id, #[apply(indexer_template)]
KvCacheEvent { async fn test_apply_events_idempotent(variant: &str) {
event_id: 1, let index = make_indexer(variant);
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
},
);
local_indexer
.apply_event_with_buffer(test_event)
.await
.unwrap();
local_indexer.flush().await;
let buffered_events = local_indexer.get_all_events_in_buffer();
assert_eq!(buffered_events.len(), 1);
assert_eq!(buffered_events[0].worker_id, worker_id);
// Test serialization round-trip
let response = WorkerKvQueryResponse::Events(buffered_events);
let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
let events = match deserialized {
WorkerKvQueryResponse::Events(e) => e,
_ => panic!("Expected Events variant"),
};
assert_eq!(events.len(), 1);
assert_eq!(events[0].worker_id, worker_id);
}
#[tokio::test] // Setup: build initial tree
async fn test_local_indexer_does_not_buffer_failed_send() { index.apply_event(make_store_event(0, &[1, 2, 3])).await;
let local_indexer = LocalKvIndexer::new( index.apply_event(make_store_event(1, &[4, 5, 6])).await;
CancellationToken::new(), index
4, .apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[7, 8]))
Arc::new(KvIndexerMetrics::new_unregistered()), .await;
5, flush_and_settle(index.as_ref()).await;
); let s0 = snapshot_tree(index.as_ref()).await;
let test_event = RouterEvent::new( // Mutation events: each add paired with its remove
7, let adds = [
KvCacheEvent { make_store_event(2, &[1, 2, 9]),
event_id: 1, make_store_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
data: KvCacheEventData::Stored(KvCacheStoreData { ];
parent_hash: None, let removes = [
blocks: vec![KvCacheStoredBlockData { make_remove_event(2, &[1, 2, 9]),
block_hash: ExternalSequenceBlockHash(100), make_remove_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
tokens_hash: LocalBlockHash(200), ];
mm_extra_info: None,
}], // Phase 1: interleaved add/remove
}), index.apply_event(adds[0].clone()).await;
dp_rank: 0, index.apply_event(removes[0].clone()).await;
}, index.apply_event(adds[1].clone()).await;
); index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s1 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s0, s1,
"Phase 1: interleaved add/remove should restore tree"
);
let event_tx = local_indexer.event_sender(); // Phase 2: same interleaved again (idempotence of the full cycle)
local_indexer.shutdown(); index.apply_event(adds[0].clone()).await;
event_tx.closed().await; index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
let result = local_indexer.apply_event_with_buffer(test_event).await; index.apply_event(removes[1].clone()).await;
assert!(matches!(result, Err(KvRouterError::IndexerOffline))); flush_and_settle(index.as_ref()).await;
assert_eq!(local_indexer.buffer_len(), 0); let s2 = snapshot_tree(index.as_ref()).await;
assert_eq!(s1, s2, "Phase 2: repeated cycle should be idempotent");
match local_indexer.get_events_in_id_range(None, None).await {
WorkerKvQueryResponse::TreeDump { // Phase 3: non-interleaved (all adds then all removes)
events, index.apply_event(adds[0].clone()).await;
last_event_id, index.apply_event(adds[1].clone()).await;
} => { index.apply_event(removes[0].clone()).await;
assert!(events.is_empty()); index.apply_event(removes[1].clone()).await;
assert_eq!(last_event_id, 0); flush_and_settle(index.as_ref()).await;
} let s3 = snapshot_tree(index.as_ref()).await;
other => panic!("Expected TreeDump, got: {other:?}"), assert_eq!(
s2, s3,
"Phase 3: non-interleaved ordering should restore tree"
);
} }
} }
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_events_idempotent(variant: &str) {
let index = make_indexer(variant);
// Setup: build initial tree
index.apply_event(make_store_event(0, &[1, 2, 3])).await;
index.apply_event(make_store_event(1, &[4, 5, 6])).await;
index
.apply_event(make_store_event_with_parent(0, &[1, 2, 3], &[7, 8]))
.await;
flush_and_settle(index.as_ref()).await;
let s0 = snapshot_tree(index.as_ref()).await;
// Mutation events: each add paired with its remove
let adds = [
make_store_event(2, &[1, 2, 9]),
make_store_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
let removes = [
make_remove_event(2, &[1, 2, 9]),
make_remove_event_with_parent(1, &[4, 5, 6], &[10, 11, 12]),
];
// Phase 1: interleaved add/remove
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s1 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s0, s1,
"Phase 1: interleaved add/remove should restore tree"
);
// Phase 2: same interleaved again (idempotence of the full cycle)
index.apply_event(adds[0].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s2 = snapshot_tree(index.as_ref()).await;
assert_eq!(s1, s2, "Phase 2: repeated cycle should be idempotent");
// Phase 3: non-interleaved (all adds then all removes)
index.apply_event(adds[0].clone()).await;
index.apply_event(adds[1].clone()).await;
index.apply_event(removes[0].clone()).await;
index.apply_event(removes[1].clone()).await;
flush_and_settle(index.as_ref()).await;
let s3 = snapshot_tree(index.as_ref()).await;
assert_eq!(
s2, s3,
"Phase 3: non-interleaved ordering should restore tree"
);
}
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for //! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems. //! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod event_sink;
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod scheduling; pub mod scheduling;
...@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId}; ...@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed; pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
compute_block_hash_for_seq, WorkerConfigLike, WorkerId, compute_block_hash_for_seq,
}; };
pub use queue::SchedulerQueue; pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree; pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy}; pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector}; pub use selector::{DefaultWorkerSelector, WorkerSelector};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use dynamo_tokens::{SequenceHash, Token}; use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -105,6 +107,12 @@ pub trait WorkerConfigLike { ...@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn total_kv_blocks(&self) -> Option<u64>; fn total_kv_blocks(&self) -> Option<u64>;
} }
/// Transport abstraction for publishing batched router-visible KV cache events.
pub trait RouterEventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
/// A worker identifier. /// A worker identifier.
pub type WorkerId = u64; pub type WorkerId = u64;
......
...@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError}; ...@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize {
1
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy { pub enum RouterQueuePolicy {
#[default] #[default]
Fcfs, Fcfs,
Lcfs,
Wspt, Wspt,
} }
...@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy { ...@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::Fcfs => f.write_str("fcfs"), Self::Fcfs => f.write_str("fcfs"),
Self::Lcfs => f.write_str("lcfs"),
Self::Wspt => f.write_str("wspt"), Self::Wspt => f.write_str("wspt"),
} }
} }
...@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy { ...@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s { match s {
"fcfs" => Ok(Self::Fcfs), "fcfs" => Ok(Self::Fcfs),
"lcfs" => Ok(Self::Lcfs),
"wspt" => Ok(Self::Wspt), "wspt" => Ok(Self::Wspt),
_ => Err(format!( _ => Err(format!(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'" "unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
)), )),
} }
} }
...@@ -58,6 +65,7 @@ pub struct RouterConfigOverride { ...@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(default)]
#[validate(schema(function = "validate_kv_router_config"))] #[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig { pub struct KvRouterConfig {
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
...@@ -130,6 +138,13 @@ pub struct KvRouterConfig { ...@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based /// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP). /// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool, pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue. /// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT. /// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT. /// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
...@@ -159,10 +174,11 @@ impl Default for KvRouterConfig { ...@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default() router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
router_queue_threshold: Some(2.0), router_queue_threshold: Some(4.0),
router_event_threads: 4, router_event_threads: 4,
router_enable_cache_control: false, router_enable_cache_control: false,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(), router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None, remote_indexer_component: None,
} }
...@@ -237,3 +253,39 @@ impl KvRouterConfig { ...@@ -237,3 +253,39 @@ impl KvRouterConfig {
self.use_kv_events && self.overlap_score_weight > 0.0 self.use_kv_events && self.overlap_score_weight > 0.0
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
{
request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>,
worker_type: &'static str,
}
impl<P, C, S, Sel> LocalScheduler<P, C, S, Sel>
where
P: SequencePublisher + 'static,
C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static,
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Sel,
policy: S,
cancellation_token: CancellationToken,
worker_type: &'static str,
monitor_worker_configs: bool,
) -> Self {
if monitor_worker_configs {
let slots_monitor = Arc::clone(&slots);
let mut monitor_rx = workers_with_configs.clone();
let mut last_workers = monitor_rx.borrow().clone();
let monitor_cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler workers monitoring task started");
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("LocalScheduler worker config watch dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers == last_workers {
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
});
}
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
workers_with_configs,
threshold_frac,
block_size,
selector,
policy,
));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL);
tracing::trace!("LocalScheduler background task started");
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::trace!("LocalScheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("LocalScheduler request channel closed");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = recheck_interval.tick() => {
queue_clone.update().await;
}
}
}
});
Self {
request_tx,
slots,
queue,
worker_type,
}
}
#[expect(clippy::too_many_arguments)]
pub async fn schedule(
&self,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?
}
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
}
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
}
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
}
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
let mut loads = Vec::with_capacity(workers.len());
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::*;
use crate::protocols::OverlapScores;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
#[allow(clippy::type_complexity)]
fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>,
monitor_worker_configs: bool,
) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
watch::Sender<HashMap<WorkerId, SimpleWorkerConfig>>,
CancellationToken,
) {
let dp_range = workers
.iter()
.map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size)))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
64,
dp_range,
false,
0,
"test",
));
let (cfg_tx, cfg_rx) = watch::channel(workers);
let cancel_token = CancellationToken::new();
let scheduler = Arc::new(LocalScheduler::new(
Arc::clone(&slots),
cfg_rx,
threshold_frac,
64,
DefaultWorkerSelector::new(None, "test"),
FcfsPolicy,
cancel_token.clone(),
"test",
monitor_worker_configs,
));
(scheduler, slots, cfg_tx, cancel_token)
}
#[tokio::test]
async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(scheduler.pending_count(), 1);
scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_free_updates_active_state() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
scheduler.free("req-1").await.unwrap();
assert!(scheduler.get_active_lora_counts().is_empty());
cancel_token.cancel();
}
#[tokio::test]
async fn test_get_potential_loads_matches_slots() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
workers.insert(
1,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128),
potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0),
})
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
for (actual, expected) in actual.iter().zip(expected.iter()) {
assert_eq!(actual.worker_id, expected.worker_id);
assert_eq!(actual.dp_rank, expected.dp_rank);
assert_eq!(
actual.potential_prefill_tokens,
expected.potential_prefill_tokens
);
assert_eq!(
actual.potential_decode_blocks,
expected.potential_decode_blocks
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default());
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
assert_eq!(loads[0].dp_rank, 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true);
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len(),
1
);
let mut updated_workers = HashMap::new();
updated_workers.insert(
0,
SimpleWorkerConfig {
data_parallel_size: 2,
..Default::default()
},
);
updated_workers.insert(1, SimpleWorkerConfig::default());
cfg_tx.send(updated_workers).unwrap();
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len()
== 3
{
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
cancel_token.cancel();
}
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod config; pub mod config;
mod local;
pub mod policy; pub mod policy;
pub mod queue; pub mod queue;
pub mod selector; pub mod selector;
mod types; mod types;
pub use local::LocalScheduler;
pub use types::*; pub use types::*;
...@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy { ...@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
} }
} }
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub struct LcfsPolicy;
impl SchedulingPolicy for LcfsPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
OrderedFloat(request.priority_jump.max(0.0) + arrival_offset.as_secs_f64())
}
}
/// Weighted Shortest Processing Time (Smith's rule): /// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the /// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL. /// actual prefill cost by subtracting the max KV cache overlap from ISL.
...@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy { ...@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time. /// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy { pub enum RouterSchedulingPolicy {
Fcfs(FcfsPolicy), Fcfs(FcfsPolicy),
Lcfs(LcfsPolicy),
Wspt(WsptPolicy), Wspt(WsptPolicy),
} }
...@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy { ...@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self { pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
match kind { match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy), RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }), RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
} }
} }
...@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy { ...@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key { fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
match self { match self {
Self::Fcfs(p) => p.enqueue_key(arrival_offset, request), Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Lcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Wspt(p) => p.enqueue_key(arrival_offset, request), Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
} }
} }
...@@ -178,6 +196,42 @@ mod tests { ...@@ -178,6 +196,42 @@ mod tests {
assert!(key_b > key_a); assert!(key_b > key_a);
} }
#[test]
fn lcfs_later_arrival_scheduled_first() {
let policy = LcfsPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let early = policy.enqueue_key(Duration::from_secs(1), &req);
let late = policy.enqueue_key(Duration::from_secs(10), &req);
assert!(late > early, "later arrival should have higher key");
}
#[test]
fn lcfs_priority_jump_promotes() {
let policy = LcfsPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 100.0, OverlapScores::default());
let t = Duration::from_secs(10);
let key_normal = policy.enqueue_key(t, &normal);
let key_boosted = policy.enqueue_key(t, &boosted);
assert!(
key_boosted > key_normal,
"priority_jump should produce a higher key"
);
}
#[test]
fn router_scheduling_policy_matches_fcfs_and_lcfs_ordering() {
let req = request_with(512, 0.0, OverlapScores::default());
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
// ---- WSPT policy tests ---- // ---- WSPT policy tests ----
#[test] #[test]
......
...@@ -11,7 +11,7 @@ use tokio::sync::Mutex; ...@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use tokio::sync::watch; use tokio::sync::watch;
use super::policy::{FcfsPolicy, SchedulingPolicy}; use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::selector::WorkerSelector; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse}; use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
...@@ -53,6 +53,7 @@ pub struct SchedulerQueue< ...@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P: SequencePublisher, P: SequencePublisher,
C: WorkerConfigLike, C: WorkerConfigLike,
S: SchedulingPolicy = FcfsPolicy, S: SchedulingPolicy = FcfsPolicy,
Sel: WorkerSelector<C> = DefaultWorkerSelector,
> { > {
pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>, pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>,
/// Number of requests currently parked in the pending queue. /// Number of requests currently parked in the pending queue.
...@@ -65,19 +66,23 @@ pub struct SchedulerQueue< ...@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets. /// Reference instant for computing arrival offsets.
start_time: Instant, start_time: Instant,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Sel,
policy: S, policy: S,
} }
impl<P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy> impl<
SchedulerQueue<P, C, S> P: SequencePublisher + 'static,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
> SchedulerQueue<P, C, S, Sel>
{ {
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>, workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Sel,
policy: S, policy: S,
) -> Self { ) -> Self {
if let Some(frac) = threshold_frac { if let Some(frac) = threshold_frac {
...@@ -341,7 +346,7 @@ mod tests { ...@@ -341,7 +346,7 @@ mod tests {
} }
let (cfg_tx, cfg_rx) = watch::channel(configs); let (cfg_tx, cfg_rx) = watch::channel(configs);
let selector = Box::new(DefaultWorkerSelector::new(None, "test")); let selector = DefaultWorkerSelector::new(None, "test");
let queue = Arc::new(SchedulerQueue::new( let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots), Arc::clone(&slots),
cfg_rx, cfg_rx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment