Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b7fe46b1
Unverified
Commit
b7fe46b1
authored
Mar 23, 2026
by
Yan Ru Pei
Committed by
GitHub
Mar 23, 2026
Browse files
feat(mocker): add multi-worker replay and router startup fixes (#7553)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
82794761
Changes
102
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3654 additions
and
1711 deletions
+3654
-1711
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+5
-1
lib/bindings/python/rust/llm.rs
lib/bindings/python/rust/llm.rs
+1
-0
lib/bindings/python/rust/llm/entrypoint.rs
lib/bindings/python/rust/llm/entrypoint.rs
+21
-85
lib/bindings/python/rust/llm/replay.rs
lib/bindings/python/rust/llm/replay.rs
+563
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+138
-6
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+25
-1
lib/bindings/python/src/dynamo/replay/__init__.py
lib/bindings/python/src/dynamo/replay/__init__.py
+6
-0
lib/bindings/python/src/dynamo/replay/__main__.py
lib/bindings/python/src/dynamo/replay/__main__.py
+7
-0
lib/bindings/python/src/dynamo/replay/api.py
lib/bindings/python/src/dynamo/replay/api.py
+59
-0
lib/bindings/python/src/dynamo/replay/main.py
lib/bindings/python/src/dynamo/replay/main.py
+94
-0
lib/bindings/python/tests/test_replay.py
lib/bindings/python/tests/test_replay.py
+421
-0
lib/kv-router/src/event_sink.rs
lib/kv-router/src/event_sink.rs
+0
-19
lib/kv-router/src/indexer/tests.rs
lib/kv-router/src/indexer/tests.rs
+1629
-1587
lib/kv-router/src/lib.rs
lib/kv-router/src/lib.rs
+3
-4
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+8
-0
lib/kv-router/src/scheduling/config.rs
lib/kv-router/src/scheduling/config.rs
+54
-2
lib/kv-router/src/scheduling/local.rs
lib/kv-router/src/scheduling/local.rs
+553
-0
lib/kv-router/src/scheduling/mod.rs
lib/kv-router/src/scheduling/mod.rs
+2
-0
lib/kv-router/src/scheduling/policy.rs
lib/kv-router/src/scheduling/policy.rs
+54
-0
lib/kv-router/src/scheduling/queue.rs
lib/kv-router/src/scheduling/queue.rs
+11
-6
No files found.
lib/bindings/python/rust/lib.rs
View file @
b7fe46b1
...
...
@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_function
(
wrap_pyfunction!
(
fetch_model
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
run_kv_indexer
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
make_engine
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
replay
::
run_mocker_trace_replay
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
run_mocker_trace_replay
,
llm
::
replay
::
run_mocker_
synthetic_
trace_replay
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
run_input
,
m
)
?
)
?
;
...
...
@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
entrypoint
::
EngineType
>
()
?
;
m
.add_class
::
<
llm
::
entrypoint
::
RouterConfig
>
()
?
;
m
.add_class
::
<
llm
::
entrypoint
::
KvRouterConfig
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
ReasoningConfig
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
SglangArgs
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
MockEngineArgs
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
WorkerMetricsPublisher
>
()
?
;
m
.add_class
::
<
llm
::
model_card
::
ModelDeploymentCard
>
()
?
;
// Internal: only in _internal, not public API
m
.add_class
::
<
llm
::
local_model
::
ModelRuntimeConfig
>
()
?
;
...
...
lib/bindings/python/rust/llm.rs
View file @
b7fe46b1
...
...
@@ -31,3 +31,4 @@ pub mod local_model;
pub
mod
lora
;
pub
mod
model_card
;
pub
mod
preprocessor
;
pub
mod
replay
;
lib/bindings/python/rust/llm/entrypoint.rs
View file @
b7fe46b1
...
...
@@ -9,7 +9,6 @@ use std::sync::Arc;
use
pyo3
::{
exceptions
::
PyException
,
prelude
::
*
};
use
pyo3_async_runtimes
::
TaskLocals
;
use
pythonize
::
pythonize
;
use
dynamo_kv_router
::
config
::
KvRouterConfig
as
RsKvRouterConfig
;
use
dynamo_llm
::
discovery
::
LoadThresholdConfig
as
RsLoadThresholdConfig
;
...
...
@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
super
::
aic_callback
::
create_aic_callback
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
;
use
super
::
replay
::
MockEngineArgs
as
PyMockEngineArgs
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
as
RsMockEngineArgs
;
use
dynamo_runtime
::
discovery
::
ModelCardInstanceId
as
RsModelCardInstanceId
;
use
dynamo_runtime
::
protocols
::
EndpointId
;
...
...
@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods]
impl
KvRouterConfig
{
#[new]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
2
.0
),
router_event_threads=
4
,
router_enable_cache_control=
false
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4
.0
),
router_event_threads=
4
,
router_enable_cache_control=
false
,
min_initial_workers=
1
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn
new
(
overlap_score_weight
:
f64
,
...
...
@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold
:
Option
<
f64
>
,
router_event_threads
:
u32
,
router_enable_cache_control
:
bool
,
min_initial_workers
:
usize
,
router_queue_policy
:
&
str
,
remote_indexer_component
:
Option
<
String
>
,
)
->
Self
{
...
...
@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads
,
router_enable_cache_control
,
skip_initial_worker_wait
:
false
,
min_initial_workers
,
router_queue_policy
:
router_queue_policy
.parse
()
.unwrap_or_else
(|
_
|
{
panic!
(
"invalid router_queue_policy: {router_queue_policy:?}"
)
}),
...
...
@@ -106,6 +108,13 @@ impl KvRouterConfig {
},
}
}
#[staticmethod]
fn
from_json
(
config_json
:
&
str
)
->
PyResult
<
Self
>
{
serde_json
::
from_str
::
<
RsKvRouterConfig
>
(
config_json
)
.map
(|
inner
|
KvRouterConfig
{
inner
})
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to parse KvRouterConfig JSON: {e}"
)))
}
}
#[pyclass]
...
...
@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path
:
Option
<
PathBuf
>
,
tls_key_path
:
Option
<
PathBuf
>
,
extra_engine_args
:
Option
<
PathBuf
>
,
mocker_engine_args
:
Option
<
PyMockEngineArgs
>
,
runtime_config
:
Option
<
ModelRuntimeConfig
>
,
namespace
:
Option
<
String
>
,
namespace_prefix
:
Option
<
String
>
,
...
...
@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl
EntrypointArgs
{
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
pub
fn
new
(
py
:
Python
<
'_
>
,
engine_type
:
EngineType
,
...
...
@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path
:
Option
<
PathBuf
>
,
tls_key_path
:
Option
<
PathBuf
>
,
extra_engine_args
:
Option
<
PathBuf
>
,
mocker_engine_args
:
Option
<
PyMockEngineArgs
>
,
runtime_config
:
Option
<
ModelRuntimeConfig
>
,
namespace
:
Option
<
String
>
,
namespace_prefix
:
Option
<
String
>
,
...
...
@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path
,
tls_key_path
,
extra_engine_args
,
mocker_engine_args
,
runtime_config
,
namespace
,
namespace_prefix
,
...
...
@@ -419,8 +431,10 @@ async fn select_engine(
}
}
EngineType
::
Mocker
=>
{
let
mut
mocker_args
=
if
let
Some
(
extra_args_path
)
=
args
.extra_engine_args
{
MockEngineArgs
::
from_json_file
(
&
extra_args_path
)
.map_err
(|
e
|
{
let
mut
mocker_args
=
if
let
Some
(
mocker_engine_args
)
=
args
.mocker_engine_args
{
mocker_engine_args
.inner
()
}
else
if
let
Some
(
extra_args_path
)
=
args
.extra_engine_args
{
RsMockEngineArgs
::
from_json_file
(
&
extra_args_path
)
.map_err
(|
e
|
{
anyhow
::
anyhow!
(
"Failed to load mocker args from {:?}: {}"
,
extra_args_path
,
...
...
@@ -431,7 +445,7 @@ async fn select_engine(
tracing
::
warn!
(
"No extra_engine_args specified for mocker engine. Using default mocker args."
);
MockEngineArgs
::
default
()
Rs
MockEngineArgs
::
default
()
};
// If aic_backend is set, create Python AIC callback and override perf_model
...
...
@@ -503,84 +517,6 @@ pub fn run_input<'p>(
})
}
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
num_workers=
1
,
replay_concurrency=None))]
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
trace_file
:
PathBuf
,
extra_engine_args
:
Option
<
PathBuf
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
)
->
PyResult
<
PyObject
>
{
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let
mut
args
=
if
let
Some
(
ref
extra_args_path
)
=
extra_engine_args
{
MockEngineArgs
::
from_json_file
(
extra_args_path
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to load mocker args from {:?}: {}"
,
extra_args_path
,
e
))
})
?
}
else
{
MockEngineArgs
::
default
()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if
let
Some
(
ref
backend_name
)
=
args
.aic_backend
.clone
()
{
let
backend
=
backend_name
.clone
();
let
system
=
args
.aic_system
.as_deref
()
.unwrap_or
(
"h200_sxm"
)
.to_string
();
let
model_name
=
args
.aic_model_path
.clone
()
.ok_or_else
(||
PyException
::
new_err
(
"--aic-perf-model requires --model-path"
))
?
;
let
backend_version
=
args
.aic_backend_version
.clone
();
let
tp_size
=
args
.aic_tp_size
.unwrap_or
(
1
);
let
callback
=
create_aic_callback
(
py
,
&
backend
,
&
system
,
&
model_name
,
tp_size
,
backend_version
.as_deref
(),
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to create AIC callback (--aic-perf-model was requested): {}"
,
e
))
})
?
;
tracing
::
info!
(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}"
,
backend
,
system
,
model_name
,
backend_version
);
args
.perf_model
=
Arc
::
new
(
PerfModel
::
from_aic_callback
(
callback
));
}
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
replay_concurrency
.map
(
usize
::
try_from
)
.transpose
()
.map_err
(|
_
|
anyhow
::
anyhow!
(
"replay_concurrency must be at least 1"
))
?
;
if
let
Some
(
max_in_flight
)
=
replay_concurrency
{
dynamo_mocker
::
simulation
::
simulate_concurrency_file
(
args
,
&
trace_file
,
max_in_flight
,
num_workers
,
)
}
else
{
dynamo_mocker
::
simulation
::
simulate_trace_file
(
args
,
&
trace_file
,
num_workers
)
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
pub
fn
to_pyerr
<
E
>
(
err
:
E
)
->
PyErr
where
E
:
Display
,
...
...
lib/bindings/python/rust/llm/replay.rs
0 → 100644
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
path
::
PathBuf
;
use
std
::
sync
::
Arc
;
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
dynamo_mocker
::
common
::
protocols
::{
DirectRequest
,
EngineType
as
RsMockerEngineType
,
MockEngineArgs
as
RsMockEngineArgs
,
PreemptionMode
as
RsPreemptionMode
,
ReasoningConfig
as
RsReasoningConfig
,
SglangArgs
as
RsSglangArgs
,
WorkerType
as
RsWorkerType
,
};
use
pyo3
::{
exceptions
::
PyException
,
prelude
::
*
};
use
pythonize
::
pythonize
;
use
uuid
::
Uuid
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
entrypoint
::{
KvRouterConfig
,
to_pyerr
};
fn
parse_mocker_engine_type
(
engine_type
:
&
str
)
->
PyResult
<
RsMockerEngineType
>
{
match
engine_type
{
"vllm"
=>
Ok
(
RsMockerEngineType
::
Vllm
),
"sglang"
=>
Ok
(
RsMockerEngineType
::
Sglang
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn
parse_worker_type
(
worker_type
:
&
str
)
->
PyResult
<
RsWorkerType
>
{
match
worker_type
{
"aggregated"
=>
Ok
(
RsWorkerType
::
Aggregated
),
"prefill"
=>
Ok
(
RsWorkerType
::
Prefill
),
"decode"
=>
Ok
(
RsWorkerType
::
Decode
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn
parse_preemption_mode
(
preemption_mode
:
&
str
)
->
PyResult
<
RsPreemptionMode
>
{
match
preemption_mode
{
"lifo"
=>
Ok
(
RsPreemptionMode
::
Lifo
),
"fifo"
=>
Ok
(
RsPreemptionMode
::
Fifo
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone,
Debug)]
pub
struct
ReasoningConfig
{
inner
:
RsReasoningConfig
,
}
impl
ReasoningConfig
{
pub
fn
inner
(
&
self
)
->
RsReasoningConfig
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
ReasoningConfig
{
#[new]
fn
new
(
start_thinking_token_id
:
u32
,
end_thinking_token_id
:
u32
,
thinking_ratio
:
f64
,
)
->
PyResult
<
Self
>
{
let
inner
=
RsReasoningConfig
{
start_thinking_token_id
,
end_thinking_token_id
,
thinking_ratio
,
};
Ok
(
Self
{
inner
})
}
}
#[pyclass]
#[derive(Clone,
Debug,
Default)]
pub
struct
SglangArgs
{
inner
:
RsSglangArgs
,
}
impl
SglangArgs
{
pub
fn
inner
(
&
self
)
->
RsSglangArgs
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
SglangArgs
{
#[new]
#[pyo3(signature
=
(schedule_policy=None,
page_size=None,
max_prefill_tokens=None,
chunked_prefill_size=None,
clip_max_new_tokens=None,
schedule_conservativeness=None))]
fn
new
(
schedule_policy
:
Option
<
String
>
,
page_size
:
Option
<
usize
>
,
max_prefill_tokens
:
Option
<
usize
>
,
chunked_prefill_size
:
Option
<
usize
>
,
clip_max_new_tokens
:
Option
<
usize
>
,
schedule_conservativeness
:
Option
<
f64
>
,
)
->
PyResult
<
Self
>
{
let
inner
=
RsSglangArgs
{
schedule_policy
,
page_size
,
max_prefill_tokens
,
chunked_prefill_size
,
clip_max_new_tokens
,
schedule_conservativeness
,
};
Ok
(
Self
{
inner
})
}
}
#[pyclass]
#[derive(Clone,
Debug,
Default)]
pub
struct
MockEngineArgs
{
inner
:
RsMockEngineArgs
,
}
impl
MockEngineArgs
{
pub
fn
inner
(
&
self
)
->
RsMockEngineArgs
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
MockEngineArgs
{
#[new]
#[pyo3(signature
=
(engine_type=
"vllm"
,
num_gpu_blocks=
16384
,
block_size=
0
,
max_num_seqs=Some(
256
),
max_num_batched_tokens=Some(
8192
),
enable_prefix_caching=
true
,
enable_chunked_prefill=
true
,
speedup_ratio=
1.0
,
decode_speedup_ratio=
1.0
,
dp_size=
1
,
startup_time=None,
worker_type=
"aggregated"
,
aic_backend=None,
aic_system=None,
aic_backend_version=None,
aic_tp_size=None,
aic_model_path=None,
enable_local_indexer=
false
,
bootstrap_port=None,
kv_bytes_per_token=None,
kv_transfer_bandwidth=None,
reasoning=None,
zmq_kv_events_port=None,
zmq_replay_port=None,
preemption_mode=
"lifo"
,
router_queue_policy=None,
sglang=None))]
#[allow(clippy::too_many_arguments)]
fn
new
(
engine_type
:
&
str
,
num_gpu_blocks
:
usize
,
block_size
:
usize
,
max_num_seqs
:
Option
<
usize
>
,
max_num_batched_tokens
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
speedup_ratio
:
f64
,
decode_speedup_ratio
:
f64
,
dp_size
:
u32
,
startup_time
:
Option
<
f64
>
,
worker_type
:
&
str
,
aic_backend
:
Option
<
String
>
,
aic_system
:
Option
<
String
>
,
aic_backend_version
:
Option
<
String
>
,
aic_tp_size
:
Option
<
usize
>
,
aic_model_path
:
Option
<
String
>
,
enable_local_indexer
:
bool
,
bootstrap_port
:
Option
<
u16
>
,
kv_bytes_per_token
:
Option
<
usize
>
,
kv_transfer_bandwidth
:
Option
<
f64
>
,
reasoning
:
Option
<
ReasoningConfig
>
,
zmq_kv_events_port
:
Option
<
u16
>
,
zmq_replay_port
:
Option
<
u16
>
,
preemption_mode
:
&
str
,
router_queue_policy
:
Option
<&
str
>
,
sglang
:
Option
<
SglangArgs
>
,
)
->
PyResult
<
Self
>
{
let
engine_type
=
parse_mocker_engine_type
(
engine_type
)
?
;
let
worker_type
=
parse_worker_type
(
worker_type
)
?
;
let
preemption_mode
=
parse_preemption_mode
(
preemption_mode
)
?
;
let
router_queue_policy
=
router_queue_policy
.map
(|
value
|
{
value
.parse
()
.map_err
(|
e
:
String
|
{
PyException
::
new_err
(
format!
(
"invalid router_queue_policy {value:?}: {e}"
))
})
})
.transpose
()
?
;
let
inner
=
RsMockEngineArgs
::
builder
()
.engine_type
(
engine_type
)
.num_gpu_blocks
(
num_gpu_blocks
)
.block_size
(
block_size
)
.max_num_seqs
(
max_num_seqs
)
.max_num_batched_tokens
(
max_num_batched_tokens
)
.enable_prefix_caching
(
enable_prefix_caching
)
.enable_chunked_prefill
(
enable_chunked_prefill
)
.speedup_ratio
(
speedup_ratio
)
.decode_speedup_ratio
(
decode_speedup_ratio
)
.dp_size
(
dp_size
)
.startup_time
(
startup_time
)
.worker_type
(
worker_type
)
.aic_backend
(
aic_backend
)
.aic_system
(
aic_system
)
.aic_backend_version
(
aic_backend_version
)
.aic_tp_size
(
aic_tp_size
)
.aic_model_path
(
aic_model_path
)
.enable_local_indexer
(
enable_local_indexer
)
.bootstrap_port
(
bootstrap_port
)
.kv_bytes_per_token
(
kv_bytes_per_token
)
.kv_transfer_bandwidth
(
kv_transfer_bandwidth
)
.reasoning
(
reasoning
.map
(|
config
|
config
.inner
()))
.zmq_kv_events_port
(
zmq_kv_events_port
)
.zmq_replay_port
(
zmq_replay_port
)
.preemption_mode
(
preemption_mode
)
.router_queue_policy
(
router_queue_policy
)
.sglang
(
sglang
.map
(|
config
|
config
.inner
()))
.build
()
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to build MockEngineArgs: {e}"
)))
?
.normalized
()
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to normalize MockEngineArgs: {e}"
))
})
?
;
Ok
(
Self
{
inner
})
}
#[staticmethod]
fn
from_json
(
config_json
:
&
str
)
->
PyResult
<
Self
>
{
RsMockEngineArgs
::
from_json_str
(
config_json
)
.map
(|
inner
|
Self
{
inner
})
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to parse MockEngineArgs JSON: {e}"
)))
}
#[getter]
fn
block_size
(
&
self
)
->
usize
{
self
.inner.block_size
}
#[getter]
fn
num_gpu_blocks
(
&
self
)
->
usize
{
self
.inner.num_gpu_blocks
}
#[getter]
fn
max_num_seqs
(
&
self
)
->
Option
<
usize
>
{
self
.inner.max_num_seqs
}
#[getter]
fn
max_num_batched_tokens
(
&
self
)
->
Option
<
usize
>
{
self
.inner.max_num_batched_tokens
}
#[getter]
fn
enable_local_indexer
(
&
self
)
->
bool
{
self
.inner.enable_local_indexer
}
#[getter]
fn
dp_size
(
&
self
)
->
u32
{
self
.inner.dp_size
}
#[getter]
fn
bootstrap_port
(
&
self
)
->
Option
<
u16
>
{
self
.inner.bootstrap_port
}
fn
is_prefill
(
&
self
)
->
bool
{
self
.inner
.is_prefill
()
}
fn
is_decode
(
&
self
)
->
bool
{
self
.inner
.is_decode
()
}
#[pyo3(signature
=
(bootstrap_port=None,
zmq_kv_events_port=None,
zmq_replay_port=None,
kv_bytes_per_token=None))]
fn
with_overrides
(
&
self
,
bootstrap_port
:
Option
<
u16
>
,
zmq_kv_events_port
:
Option
<
u16
>
,
zmq_replay_port
:
Option
<
u16
>
,
kv_bytes_per_token
:
Option
<
usize
>
,
)
->
PyResult
<
Self
>
{
let
mut
inner
=
self
.inner
.clone
();
if
let
Some
(
port
)
=
bootstrap_port
{
inner
.bootstrap_port
=
Some
(
port
);
}
if
let
Some
(
port
)
=
zmq_kv_events_port
{
inner
.zmq_kv_events_port
=
Some
(
port
);
}
if
let
Some
(
port
)
=
zmq_replay_port
{
inner
.zmq_replay_port
=
Some
(
port
);
}
if
let
Some
(
bytes_per_token
)
=
kv_bytes_per_token
{
inner
.kv_bytes_per_token
=
Some
(
bytes_per_token
);
}
inner
.normalized
()
.map
(|
inner
|
Self
{
inner
})
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to normalize MockEngineArgs overrides: {e}"
))
})
}
}
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
router_config=None,
num_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
trace_file
:
PathBuf
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
replay_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
)
->
PyResult
<
PyObject
>
{
let
args
=
load_replay_mocker_args
(
py
,
extra_engine_args
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
match
(
replay_mode
.as_str
(),
replay_concurrency
)
{
(
"offline"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
"online"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_live_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"online"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_live_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
other
,
_
)
=>
anyhow
::
bail!
(
"replay_mode must be either 'offline' or 'online', got '{}'"
,
other
),
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
#[pyfunction]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
router_config=None,
num_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_synthetic_trace_replay
(
py
:
Python
<
'_
>
,
input_tokens
:
usize
,
output_tokens
:
usize
,
request_count
:
usize
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
replay_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
arrival_interval_ms
:
f64
,
)
->
PyResult
<
PyObject
>
{
let
args
=
load_replay_mocker_args
(
py
,
extra_engine_args
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
let
requests
=
build_synthetic_requests
(
input_tokens
,
output_tokens
,
request_count
,
arrival_interval_ms
,
replay_concurrency
.is_none
(),
)
?
;
match
(
replay_mode
.as_str
(),
replay_concurrency
)
{
(
"offline"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
"online"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_live_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"online"
,
None
)
=>
{
dynamo_mocker
::
replay
::
simulate_trace_live_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
)
}
(
other
,
_
)
=>
anyhow
::
bail!
(
"replay_mode must be either 'offline' or 'online', got '{}'"
,
other
),
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
fn
load_replay_mocker_args
(
py
:
Python
<
'_
>
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
)
->
PyResult
<
RsMockEngineArgs
>
{
let
mut
args
=
match
extra_engine_args
{
Some
(
extra_args
)
=>
extra_args
.inner
(),
None
=>
RsMockEngineArgs
::
default
(),
};
if
let
Some
(
ref
backend_name
)
=
args
.aic_backend
.clone
()
{
let
backend
=
backend_name
.clone
();
let
system
=
args
.aic_system
.as_deref
()
.unwrap_or
(
"h200_sxm"
)
.to_string
();
let
model_name
=
args
.aic_model_path
.clone
()
.ok_or_else
(||
PyException
::
new_err
(
"--aic-perf-model requires --model-path"
))
?
;
let
backend_version
=
args
.aic_backend_version
.clone
();
let
tp_size
=
args
.aic_tp_size
.unwrap_or
(
1
);
let
callback
=
create_aic_callback
(
py
,
&
backend
,
&
system
,
&
model_name
,
tp_size
,
backend_version
.as_deref
(),
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to create AIC callback (--aic-perf-model was requested): {}"
,
e
))
})
?
;
tracing
::
info!
(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}"
,
backend
,
system
,
model_name
,
backend_version
);
args
.perf_model
=
Arc
::
new
(
PerfModel
::
from_aic_callback
(
callback
));
}
Ok
(
args
)
}
fn
load_replay_router_config
(
router_config
:
Option
<
KvRouterConfig
>
,
)
->
Option
<
dynamo_kv_router
::
config
::
KvRouterConfig
>
{
router_config
.map
(|
config
|
config
.inner
())
}
fn
parse_replay_router_mode
(
router_mode
:
&
str
,
)
->
PyResult
<
dynamo_mocker
::
replay
::
ReplayRouterMode
>
{
match
router_mode
{
"round_robin"
=>
Ok
(
dynamo_mocker
::
replay
::
ReplayRouterMode
::
RoundRobin
),
"kv_router"
=>
Ok
(
dynamo_mocker
::
replay
::
ReplayRouterMode
::
KvRouter
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'"
,
other
))),
}
}
fn
parse_replay_concurrency
(
replay_concurrency
:
Option
<
isize
>
)
->
anyhow
::
Result
<
Option
<
usize
>>
{
match
replay_concurrency
{
Some
(
value
)
if
value
<
1
=>
anyhow
::
bail!
(
"replay_concurrency must be at least 1"
),
Some
(
value
)
=>
Ok
(
Some
(
value
as
usize
)),
None
=>
Ok
(
None
),
}
}
fn
build_synthetic_requests
(
input_tokens
:
usize
,
output_tokens
:
usize
,
request_count
:
usize
,
arrival_interval_ms
:
f64
,
include_arrival_timestamps
:
bool
,
)
->
anyhow
::
Result
<
Vec
<
DirectRequest
>>
{
if
input_tokens
==
0
{
anyhow
::
bail!
(
"input_tokens must be at least 1"
);
}
if
output_tokens
==
0
{
anyhow
::
bail!
(
"output_tokens must be at least 1"
);
}
if
request_count
==
0
{
anyhow
::
bail!
(
"request_count must be at least 1"
);
}
if
!
arrival_interval_ms
.is_finite
()
||
arrival_interval_ms
<
0.0
{
anyhow
::
bail!
(
"arrival_interval_ms must be a finite non-negative number, got {}"
,
arrival_interval_ms
);
}
let
mut
requests
=
Vec
::
with_capacity
(
request_count
);
for
request_idx
in
0
..
request_count
{
let
tokens
=
(
0
..
input_tokens
)
.map
(|
token_idx
|
synthetic_token_id
(
request_idx
,
token_idx
))
.collect
();
requests
.push
(
DirectRequest
{
tokens
,
max_output_tokens
:
output_tokens
,
uuid
:
Some
(
Uuid
::
from_u128
((
request_idx
as
u128
)
+
1
)),
dp_rank
:
0
,
arrival_timestamp_ms
:
include_arrival_timestamps
.then_some
(
request_idx
as
f64
*
arrival_interval_ms
),
});
}
Ok
(
requests
)
}
fn
synthetic_token_id
(
request_idx
:
usize
,
token_idx
:
usize
)
->
u32
{
let
mut
value
=
(((
request_idx
as
u64
)
<<
32
)
^
(
token_idx
as
u64
))
.wrapping_add
(
0x9E37_79B9_7F4A_7C15
);
value
^=
value
>>
30
;
value
=
value
.wrapping_mul
(
0xBF58_476D_1CE4_E5B9
);
value
^=
value
>>
27
;
value
=
value
.wrapping_mul
(
0x94D0_49BB_1331_11EB
);
value
^=
value
>>
31
;
let
token
=
value
as
u32
;
if
token
==
0
{
1
}
else
{
token
}
}
lib/bindings/python/src/dynamo/_core.pyi
View file @
b7fe46b1
...
...
@@ -3,7 +3,17 @@
import asyncio
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
...
...
@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] =
2
.0,
router_queue_threshold: Optional[float] =
4
.0,
router_event_threads: int = 4,
router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs",
) -> None:
"""
...
...
@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default:
2
.0).
router_queue_threshold: Queue threshold fraction for prefill token capacity (default:
4
.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler).
...
...
@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
"""
...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model(
model_input: ModelInput,
model_type: ModelType,
...
...
@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay(
trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]:
"""Replay a
mocker trace file and return the simulation report
."""
"""Replay a
synthetic mocker workload without requiring a trace file
."""
...
class Layer:
...
...
@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None,
...
...
@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
b7fe46b1
...
...
@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from
dynamo._core
import
LoRADownloader
as
LoRADownloader
from
dynamo._core
import
MediaDecoder
as
MediaDecoder
from
dynamo._core
import
MediaFetcher
as
MediaFetcher
from
dynamo._core
import
MockEngineArgs
as
MockEngineArgs
from
dynamo._core
import
ModelCardInstanceId
as
ModelCardInstanceId
from
dynamo._core
import
ModelInput
as
ModelInput
from
dynamo._core
import
ModelRuntimeConfig
as
ModelRuntimeConfig
...
...
@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from
dynamo._core
import
OverlapScores
as
OverlapScores
from
dynamo._core
import
PythonAsyncEngine
as
PythonAsyncEngine
from
dynamo._core
import
RadixTree
as
RadixTree
from
dynamo._core
import
ReasoningConfig
as
ReasoningConfig
from
dynamo._core
import
RouterConfig
as
RouterConfig
from
dynamo._core
import
RouterMode
as
RouterMode
from
dynamo._core
import
SglangArgs
as
SglangArgs
from
dynamo._core
import
WorkerMetricsPublisher
as
WorkerMetricsPublisher
from
dynamo._core
import
compute_block_hash_for_seq
as
compute_block_hash_for_seq
from
dynamo._core
import
fetch_model
as
fetch_model
...
...
@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from
dynamo._core
import
register_model
as
register_model
from
dynamo._core
import
run_input
from
dynamo._core
import
run_kv_indexer
as
run_kv_indexer
from
dynamo._core
import
run_mocker_trace_replay
from
dynamo._core
import
run_mocker_trace_replay
as
_run_mocker_trace_replay
from
dynamo._core
import
unregister_model
as
unregister_model
from
.exceptions
import
HttpError
...
...
@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm
=
fetch_model
register_llm
=
register_model
unregister_llm
=
unregister_model
def
run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
):
return
_run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
"offline"
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
)
lib/bindings/python/src/dynamo/replay/__init__.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.replay.api
import
run_synthetic_trace_replay
,
run_trace_replay
__all__
=
[
"run_synthetic_trace_replay"
,
"run_trace_replay"
]
lib/bindings/python/src/dynamo/replay/__main__.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.replay.main
import
main
if
__name__
==
"__main__"
:
raise
SystemExit
(
main
())
lib/bindings/python/src/dynamo/replay/api.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo._core
import
(
run_mocker_synthetic_trace_replay
as
_run_mocker_synthetic_trace_replay
,
)
from
dynamo._core
import
run_mocker_trace_replay
as
_run_mocker_trace_replay
def
run_trace_replay
(
trace_file
,
*
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
):
return
_run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
)
def
run_synthetic_trace_replay
(
input_tokens
,
output_tokens
,
request_count
,
*
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
arrival_interval_ms
=
1.0
,
):
return
_run_mocker_synthetic_trace_replay
(
input_tokens
,
output_tokens
,
request_count
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
arrival_interval_ms
=
arrival_interval_ms
,
)
lib/bindings/python/src/dynamo/replay/main.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
argparse
import
json
import
os
import
sys
from
collections.abc
import
Sequence
os
.
environ
.
setdefault
(
"DYNAMO_SKIP_PYTHON_LOG_INIT"
,
"1"
)
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
def
main
(
argv
:
Sequence
[
str
]
|
None
=
None
)
->
int
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"python -m dynamo.replay"
)
parser
.
add_argument
(
"trace_file"
,
nargs
=
"?"
)
parser
.
add_argument
(
"--extra-engine-args"
)
parser
.
add_argument
(
"--router-config"
)
parser
.
add_argument
(
"--input-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--output-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--request-count"
,
type
=
int
)
parser
.
add_argument
(
"--arrival-interval-ms"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--replay-concurrency"
,
type
=
int
)
parser
.
add_argument
(
"--replay-mode"
,
choices
=
(
"offline"
,
"online"
),
default
=
"offline"
,
)
parser
.
add_argument
(
"--router-mode"
,
choices
=
(
"round_robin"
,
"kv_router"
),
default
=
"round_robin"
,
)
parser
.
add_argument
(
"--arrival-speedup-ratio"
,
type
=
float
,
default
=
1.0
)
args
=
parser
.
parse_args
(
list
(
sys
.
argv
[
1
:]
if
argv
is
None
else
argv
))
using_trace_file
=
args
.
trace_file
is
not
None
synthetic_args
=
(
args
.
input_tokens
,
args
.
output_tokens
,
args
.
request_count
)
using_synthetic
=
any
(
value
is
not
None
for
value
in
synthetic_args
)
if
using_trace_file
==
using_synthetic
:
parser
.
error
(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if
using_synthetic
and
not
all
(
value
is
not
None
for
value
in
synthetic_args
):
parser
.
error
(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args
=
(
MockEngineArgs
.
from_json
(
args
.
extra_engine_args
)
if
args
.
extra_engine_args
is
not
None
else
None
)
router_config
=
(
KvRouterConfig
.
from_json
(
args
.
router_config
)
if
args
.
router_config
is
not
None
else
None
)
if
using_trace_file
:
report
=
run_trace_replay
(
args
.
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
args
.
num_workers
,
replay_concurrency
=
args
.
replay_concurrency
,
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
)
else
:
report
=
run_synthetic_trace_replay
(
args
.
input_tokens
,
args
.
output_tokens
,
args
.
request_count
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
args
.
num_workers
,
replay_concurrency
=
args
.
replay_concurrency
,
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
arrival_interval_ms
=
args
.
arrival_interval_ms
,
)
json
.
dump
(
report
,
sys
.
stdout
,
indent
=
2
,
sort_keys
=
True
)
sys
.
stdout
.
write
(
"
\n
"
)
return
0
lib/bindings/python/tests/test_replay.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
json
import
pytest
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
pytestmark
=
[
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
parallel
,
pytest
.
mark
.
pre_merge
,
]
MOONCAKE_TRACE_FIRST20
=
"""{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def
_write_trace_and_args
(
tmp_path
):
trace_path
=
tmp_path
/
"trace.jsonl"
records
=
[
{
"timestamp"
:
1000.0
,
"input_length"
:
64
,
"output_length"
:
2
,
"hash_ids"
:
[
101
],
},
{
"timestamp"
:
1005.0
,
"input_length"
:
64
,
"output_length"
:
2
,
"hash_ids"
:
[
101
],
},
]
trace_path
.
write_text
(
"
\n
"
.
join
(
json
.
dumps
(
record
)
for
record
in
records
)
+
"
\n
"
,
encoding
=
"utf-8"
,
)
return
trace_path
def
_write_vllm_args
(
tmp_path
):
args_path
=
tmp_path
/
"args.json"
args_path
.
write_text
(
json
.
dumps
(
{
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
}
),
encoding
=
"utf-8"
,
)
return
args_path
def
_vllm_args
():
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
{
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
}
)
)
def
_write_sglang_args
(
tmp_path
):
args_path
=
tmp_path
/
"sglang_args.json"
args_path
.
write_text
(
json
.
dumps
(
{
"engine_type"
:
"sglang"
,
"num_gpu_blocks"
:
512
,
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
"sglang"
:
{
"page_size"
:
64
,
},
}
),
encoding
=
"utf-8"
,
)
return
args_path
def
_sglang_args
():
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
{
"engine_type"
:
"sglang"
,
"num_gpu_blocks"
:
512
,
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
"sglang"
:
{
"page_size"
:
64
,
},
}
)
)
def
_write_router_config
(
tmp_path
):
config_path
=
tmp_path
/
"router_config.json"
config_path
.
write_text
(
json
.
dumps
(
{
"router_queue_threshold"
:
1.25
,
"router_event_threads"
:
1
,
"router_queue_policy"
:
"wspt"
,
"router_temperature"
:
0.0
,
"overlap_score_weight"
:
1.0
,
"use_kv_events"
:
True
,
"durable_kv_events"
:
False
,
"router_replica_sync"
:
False
,
"router_track_active_blocks"
:
True
,
"router_track_output_blocks"
:
False
,
"router_assume_kv_reuse"
:
True
,
"router_snapshot_threshold"
:
1000000
,
"router_reset_states"
:
False
,
"router_ttl_secs"
:
120.0
,
"router_max_tree_size"
:
1048576
,
"router_prune_target_ratio"
:
0.8
,
"router_enable_cache_control"
:
False
,
"skip_initial_worker_wait"
:
False
,
"min_initial_workers"
:
1
,
"remote_indexer_component"
:
None
,
}
),
encoding
=
"utf-8"
,
)
return
config_path
def
_router_config
():
return
KvRouterConfig
.
from_json
(
json
.
dumps
(
{
"router_queue_threshold"
:
1.25
,
"router_event_threads"
:
1
,
"router_queue_policy"
:
"wspt"
,
"router_temperature"
:
0.0
,
"overlap_score_weight"
:
1.0
,
"use_kv_events"
:
True
,
"durable_kv_events"
:
False
,
"router_replica_sync"
:
False
,
"router_track_active_blocks"
:
True
,
"router_track_output_blocks"
:
False
,
"router_assume_kv_reuse"
:
True
,
"router_snapshot_threshold"
:
1000000
,
"router_reset_states"
:
False
,
"router_ttl_secs"
:
120.0
,
"router_max_tree_size"
:
1048576
,
"router_prune_target_ratio"
:
0.8
,
"router_enable_cache_control"
:
False
,
"skip_initial_worker_wait"
:
False
,
"min_initial_workers"
:
1
,
"remote_indexer_component"
:
None
,
}
)
)
def
_partial_router_config
():
return
KvRouterConfig
(
router_queue_threshold
=
1.25
,
router_event_threads
=
1
,
router_queue_policy
=
"wspt"
,
)
def
_assert_basic_report_counts
(
report
,
*
,
num_requests
,
input_tokens
,
output_tokens
):
assert
report
[
"num_requests"
]
==
num_requests
assert
report
[
"completed_requests"
]
==
num_requests
assert
report
[
"total_input_tokens"
]
==
num_requests
*
input_tokens
assert
report
[
"total_output_tokens"
]
==
num_requests
*
output_tokens
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
def
test_run_trace_replay_smoke_matrix
(
tmp_path
,
engine_type
,
replay_mode
,
router_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
num_workers
=
1
if
router_mode
==
"round_robin"
else
2
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
num_workers
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_invariant_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
single
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
1
,
replay_mode
=
replay_mode
,
)
multi_round_robin
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"round_robin"
,
)
multi_kv_router
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
for
field
in
(
"num_requests"
,
"completed_requests"
,
"total_input_tokens"
,
"total_output_tokens"
,
):
assert
single
[
field
]
==
multi_round_robin
[
field
]
assert
single
[
field
]
==
multi_kv_router
[
field
]
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
def
test_run_synthetic_trace_replay_smoke_matrix
(
tmp_path
,
engine_type
,
replay_mode
,
router_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
num_workers
=
1
if
router_mode
==
"round_robin"
else
2
report
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
num_workers
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_interval_ms
=
5.0
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_synthetic_trace_replay_invariant_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
single
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
1
,
replay_mode
=
replay_mode
,
arrival_interval_ms
=
5.0
,
)
multi_round_robin
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"round_robin"
,
arrival_interval_ms
=
5.0
,
)
multi_kv_router
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
arrival_interval_ms
=
5.0
,
)
for
field
in
(
"num_requests"
,
"completed_requests"
,
"total_input_tokens"
,
"total_output_tokens"
,
):
assert
single
[
field
]
==
multi_round_robin
[
field
]
assert
single
[
field
]
==
multi_kv_router
[
field
]
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_synthetic_concurrency_replay_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
report
=
run_synthetic_trace_replay
(
64
,
2
,
3
,
extra_engine_args
=
args_path
,
num_workers
=
2
,
replay_mode
=
replay_mode
,
replay_concurrency
=
2
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
3
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_router_config
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
router_config_path
=
_router_config
()
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
router_config
=
router_config_path
,
num_workers
=
2
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_partial_router_config_json
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
router_config
=
_partial_router_config
(),
num_workers
=
2
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_partial_extra_engine_args_json
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
MockEngineArgs
(
block_size
=
64
,
speedup_ratio
=
1000.0
),
num_workers
=
1
,
replay_mode
=
replay_mode
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
lib/kv-router/src/event_sink.rs
deleted
100644 → 0
View file @
82794761
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use
std
::
future
::
Future
;
use
crate
::
protocols
::
RouterEvent
;
/// Transport abstraction for publishing batched KV cache events.
pub
trait
EventSink
:
Send
+
Sync
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
anyhow
::
Result
<
()
>>
+
Send
;
}
lib/kv-router/src/indexer/tests.rs
View file @
b7fe46b1
...
...
@@ -245,1392 +245,1412 @@ async fn flush_and_settle(index: &dyn KvIndexerInterface) {
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_store_and_find
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
mod
interface_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_store_and_find
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Find matches using local hashes
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_match
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Find matches using local hashes
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
// Store [1, 2, 3] for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_match
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store [1, 2, 3] for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Find matches for [1, 2, 999] - should match first 2 then stop
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
999
),
])
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Find matches for [1, 2, 999] - should match first 2 then stop
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
999
),
])
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
// Store sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Remove all blocks
index
.apply_event
(
make_remove_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Store sequence for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Remove all blocks
index
.apply_event
(
make_remove_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Find should return nothing
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_workers_shared_prefix
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Find should return nothing
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_workers_shared_prefix
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
// Since sequence hashes are cumulative, [1] has same hash for both,
// but [1, 2] and [1, 3] have different hashes.
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query [1] - both workers should match
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
)])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query [1] - both workers should match
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
)])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.await
;
// Query [1, 2] - worker 0 matches both, worker 1 matches only first block
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
)])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
}
// Allow time for async event processing
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.remove_worker
(
0
)
.await
;
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.await
;
// Allow time for async
remove_worker
processing
flush_and_settle
(
index
.as_ref
())
.await
;
// Allow time for async
event
processing
flush_and_settle
(
index
.as_ref
())
.await
;
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
index
.remove_worker
(
0
)
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_large_stores
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Allow time for async remove_worker processing
flush_and_settle
(
index
.as_ref
())
.await
;
// Test sequences of increasing sizes
for
i
in
0
..
10u64
{
let
len
=
1
<<
i
;
// 1, 2, 4, 8, ..., 512
let
worker_id
=
i
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
len
)
.map
(|
x
|
x
+
(
i
*
10000
))
.collect
();
index
.apply_event
(
make_store_event
(
worker_id
,
&
sequence
))
.await
;
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify we can find matches for the last stored sequence
let
last_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
512u64
)
.map
(|
x
|
LocalBlockHash
(
x
+
(
9
*
10000
)))
.collect
();
let
scores
=
index
.find_matches
(
last_seq
)
.await
.unwrap
();
assert
!
(
!
scores
.scores
.is_empty
());
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_dump_and_restore
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store some data
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
4
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_large_stores
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Allow background worker threads to process events.
flush_and_settle
(
index
.as_ref
())
.await
;
// Test sequences of increasing sizes
for
i
in
0
..
10u64
{
let
len
=
1
<<
i
;
// 1, 2, 4, 8, ..., 512
let
worker_id
=
i
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
len
)
.map
(|
x
|
x
+
(
i
*
10000
))
.collect
();
index
.apply_event
(
make_store_event
(
worker_id
,
&
sequence
))
.await
;
}
// Dump the tree as events and replay into a new index
let
events
=
index
.dump_events
()
.await
.unwrap
();
assert
!
(
!
events
.is_empty
());
flush_and_settle
(
index
.as_ref
())
.await
;
let
restored
=
make_indexer
(
variant
);
for
event
in
events
{
restored
.apply_event
(
event
)
.await
;
// Verify we can find matches for the last stored sequence
let
last_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
512u64
)
.map
(|
x
|
LocalBlockHash
(
x
+
(
9
*
10000
)))
.collect
();
let
scores
=
index
.find_matches
(
last_seq
)
.await
.unwrap
();
assert
!
(
!
scores
.scores
.is_empty
());
}
flush_and_settle
(
restored
.as_ref
())
.await
;
assert_eq!
(
snapshot_tree
(
index
.as_ref
())
.await
,
snapshot_tree
(
restored
.as_ref
())
.await
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_dump_and_restore
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_all_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store some data
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
4
]))
.await
;
// Store some data for two workers
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.await
;
// Allow background worker threads to process events.
flush_and_settle
(
index
.as_ref
())
.await
;
// Clear worker 0's blocks using the Cleared event
index
.apply_event
(
make_clear_event
(
0
))
.await
;
// Dump the tree as events and replay into a new index
let
events
=
index
.dump_events
()
.await
.unwrap
();
assert
!
(
!
events
.is_empty
());
flush_and_settle
(
index
.as_ref
())
.await
;
let
restored
=
make_indexer
(
variant
);
for
event
in
events
{
restored
.apply_event
(
event
)
.await
;
}
// Worker 0's blocks should be gone, worker 1's remain
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
flush_and_settle
(
restored
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_empty_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
assert_eq!
(
snapshot_tree
(
index
.as_ref
())
.await
,
snapshot_tree
(
restored
.as_ref
())
.await
);
}
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_all_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store some data for two workers
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
1
,
2
,
3
]))
.await
;
// Empty query should return empty scores
let
scores
=
index
.find_matches
(
vec!
[])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
// Clear worker 0's blocks using the Cleared event
index
.apply_event
(
make_clear_event
(
0
))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_miss_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Worker 0's blocks should be gone, worker 1's remain
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
])
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_empty_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query for non-existent blocks
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
999
),
LocalBlockHash
(
998
)])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.shutdown
();
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
index
.shutdown
();
index
.shutdown
();
}
// Empty query should return empty scores
let
scores
=
index
.find_matches
(
vec!
[])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_find_matches_for_request
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Empty index should return no matches
let
tokens
=
vec!
[
1
,
2
,
3
,
4
];
let
scores
=
index
.find_matches_for_request
(
&
tokens
,
None
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
// Store some data and verify we can find it via tokens
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Allow time for async processing
flush_and_settle
(
index
.as_ref
())
.await
;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let
scores
=
index
.find_matches_for_request
(
&
tokens
,
None
)
.await
.unwrap
();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.is_empty
());
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_miss_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_process_routing_decision
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Create tokens with hashes
let
tokens
=
vec!
[
1u32
,
2
,
3
,
4
,
5
,
6
,
7
,
8
];
let
mut
tokens_with_hashes
=
TokensWithHashes
::
new
(
tokens
,
32
);
flush_and_settle
(
index
.as_ref
())
.await
;
let
worker
=
WorkerWithDpRank
::
new
(
0
,
0
);
// Query for non-existent blocks
let
scores
=
index
.find_matches
(
vec!
[
LocalBlockHash
(
999
),
LocalBlockHash
(
998
)])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
// Process routing decision - should not error
let
result
=
index
.process_routing_decision_for_request
(
&
mut
tokens_with_hashes
,
worker
)
.await
;
assert
!
(
result
.is_ok
()
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
)
;
index
.shutdown
(
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_parent_hash_chains
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
index
.shutdown
();
index
.shutdown
();
}
// Store initial sequence [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_find_matches_for_request
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Empty index should return no matches
let
tokens
=
vec!
[
1
,
2
,
3
,
4
];
let
scores
=
index
.find_matches_for_request
(
&
tokens
,
None
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
// Store some data and verify we can find it via tokens
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Allow time for async processing
flush_and_settle
(
index
.as_ref
())
.await
;
// Note: find_matches_for_request computes block hashes from tokens,
// so we need tokens that hash to the same LocalBlockHash values.
// For this test, we just verify the method works without error.
let
scores
=
index
.find_matches_for_request
(
&
tokens
,
None
)
.await
.unwrap
();
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.is_empty
());
}
// Store continuation [4, 5] with parent pointing to block 3
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
,
3
],
&
[
4
,
5
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_process_routing_decision
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
)
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Create tokens with hashes
let
tokens
=
vec!
[
1u32
,
2
,
3
,
4
,
5
,
6
,
7
,
8
];
let
mut
tokens_with_hashes
=
TokensWithHashes
::
new
(
tokens
,
32
);
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let
full_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
5
);
let
worker
=
WorkerWithDpRank
::
new
(
0
,
0
);
// Query for just [1, 2, 3] should match 3 blocks
let
prefix_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
// Process routing decision - should not error
let
result
=
index
.process_routing_decision_for_request
(
&
mut
tokens_with_hashes
,
worker
)
.await
;
assert
!
(
result
.is_ok
());
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
0
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
1
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
2
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query should return all 3 dp_ranks as separate entries
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
1
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
2
))
.unwrap
(),
3
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_parent_hash_chains
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_block_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store initial sequence [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Store [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Store continuation [4, 5] with parent pointing to block 3
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
,
3
],
&
[
4
,
5
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify all 3 blocks match
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
// Query for full sequence [1, 2, 3, 4, 5] should match all 5 blocks
let
full_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
5
);
// Remove only the last block (block 3)
// To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let
full_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
full_hashes
);
let
block_3_seq_hash
=
ExternalSequenceBlockHash
(
seq_hashes
[
2
]);
// Last block's hash
// Query for just [1, 2, 3] should match 3 blocks
let
prefix_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
let
remove_event
=
remove_event
(
0
,
0
,
0
,
vec!
[
block_3_seq_hash
]);
index
.apply_event
(
remove_event
)
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Same worker_id but different dp_ranks should be tracked separately
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
0
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
1
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
2
))
.await
;
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Query [1, 2] - should still match 2 blocks
let
partial_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
2
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
// Query should return all 3 dp_ranks as separate entries
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_mid_chain_block
(
variant
:
&
str
)
{
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if
variant
==
"flat"
{
return
;
assert_eq!
(
scores
.scores
.len
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
1
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
2
))
.unwrap
(),
3
);
}
let
index
=
make_indexer
(
variant
);
// Store [1, 2, 3, 4, 5]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
,
4
,
5
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_block_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Verify all 5 blocks match
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
5
);
flush_and_settle
(
index
.as_ref
())
.await
;
//
Remove only block 3 (index 2) — the middle of the chain
let
full_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
s
eq_hashes
=
compute_seq_hash_for_block
(
&
full_hashes
);
let
block_3_seq_hash
=
ExternalSequenceBlockHash
(
seq_hashes
[
2
]
);
//
Verify all 3 blocks match
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
s
cores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
(
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
let
remove_event
=
remove_event
(
0
,
0
,
0
,
vec!
[
block_3_seq_hash
]);
index
.apply_event
(
remove_event
)
.await
;
// Remove only the last block (block 3)
// To do this correctly, we need to compute the seq_hash for block 3 specifically,
// which requires the full sequence context [1,2,3].
let
full_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
full_hashes
);
let
block_3_seq_hash
=
ExternalSequenceBlockHash
(
seq_hashes
[
2
]);
// Last block's hash
flush_and_settle
(
index
.as_ref
())
.await
;
let
remove_event
=
remove_event
(
0
,
0
,
0
,
vec!
[
block_3_seq_hash
]);
index
.apply_event
(
remove_event
)
.await
;
// Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5)
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Query [1, 2] — prefix before the gap is still intact
let
prefix_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
2
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
// Query [1, 2, 3] - should only match 2 blocks now (block 3 is removed)
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
// Re-store block 3 as a continuation of [1, 2]
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
],
&
[
3
]))
.await
;
// Query [1, 2] - should still match 2 blocks
let
partial_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
2
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_mid_chain_block
(
variant
:
&
str
)
{
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if
variant
==
"flat"
{
return
;
}
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
let
index
=
make_indexer
(
variant
);
#[tokio::test
]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
)
;
// Store [1, 2, 3, 4, 5
]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
,
4
,
5
]))
.await
;
// Store data for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify all 5 blocks match
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
5
);
// Remove non-existent worker 999 - should not error or affect worker 0
index
.remove_worker
(
999
)
.await
;
// Remove only block 3 (index 2) — the middle of the chain
let
full_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
5
)
.map
(
LocalBlockHash
)
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
full_hashes
);
let
block_3_seq_hash
=
ExternalSequenceBlockHash
(
seq_hashes
[
2
]);
// Allow time for async processing
flush_and_settle
(
index
.as_ref
()
)
.await
;
let
remove_event
=
remove_event
(
0
,
0
,
0
,
vec!
[
block_3_seq_hash
]);
index
.apply_event
(
remove_event
)
.await
;
// Worker 0's data should still be there
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query [1, 2, 3, 4, 5] — only first 2 positions reachable (block 3 removed, orphaning 4 & 5)
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
// Store [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Query [1, 2] — prefix before the gap is still intact
let
prefix_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
2
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
// Try to remove blocks [999, 998] that don't exist - should not error
index
.apply_event
(
make_remove_event
(
0
,
&
[
999
,
998
]))
.await
;
// Re-store block 3 as a continuation of [1, 2]
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
],
&
[
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Original data should still be there
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_
clear_then_reuse
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_
remove_nonexistent_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store
initial data
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Store
data for worker 0
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Clear the worker
index
.apply_event
(
make_clear_event
(
0
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Remove non-existent worker 999 - should not error or affect worker 0
index
.remove_worker
(
999
)
.await
;
// Verify data is gone
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
// Allow time for async processing
flush_and_settle
(
index
.as_ref
())
.await
;
// Store new data for the same worker
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Worker 0's data should still be there
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Verify new data is accessible
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
// Store [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_sequences_per_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event
(
make_store_event
(
0
,
&
[
100
,
101
,
102
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query first sequence
let
seq1
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq1
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
// Query second sequence
let
seq2
:
Vec
<
LocalBlockHash
>
=
(
100
..=
102
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq2
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
// Query a mix that doesn't exist as a sequence - should only match first block
let
mixed
:
Vec
<
LocalBlockHash
>
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
100
)];
let
scores
=
index
.find_matches
(
mixed
)
.await
.unwrap
();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
}
// Try to remove blocks [999, 998] that don't exist - should not error
index
.apply_event
(
make_remove_event
(
0
,
&
[
999
,
998
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_clears_all_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store same sequence for different dp_ranks
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
0
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
1
))
.await
;
// Original data should still be there
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_then_reuse
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Verify both dp_ranks are present
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
// Store initial data
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Clear
event clears ALL blocks for the worker_id, regardless of dp_rank
index
.apply_event
(
make_clear_event
_with_dp_rank
(
0
,
0
))
.await
;
// Clear
the worker
index
.apply_event
(
make_clear_event
(
0
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Both dp_ranks should be cleared
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
(),
"Cleared event should clear all dp_ranks for a worker"
);
}
// Verify data is gone
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
// ============================================================================
// LoRA isolation tests
// ============================================================================
// Store new data for the same worker
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_and_base_model_blocks_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Same token sequence for both base model and LoRA adapter
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
3
)
.collect
();
// Verify new data is accessible
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
let
base_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_sequences_per_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Hashes must differ despite identical tokens
assert_ne!
(
base_hashes
,
lora_hashes
,
"Base and LoRA hashes must differ for the same tokens"
);
// Store two disjoint sequences for the same worker
// Sequence 1: [1, 2, 3]
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
// Sequence 2: [100, 101, 102] (completely different, no parent)
index
.apply_event
(
make_store_event
(
0
,
&
[
100
,
101
,
102
]))
.await
;
let
base_seq
=
compute_seq_hash_for_block
(
&
base_hashes
);
let
lora_seq
=
compute_seq_hash_for_block
(
&
lora_hashes
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store base-model blocks on worker 0
let
base_event
=
router_event
(
0
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
base_hashes
,
&
base_seq
),
}),
);
index
.apply_event
(
base_event
)
.await
;
// Query first sequence
let
seq1
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq1
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
// Store LoRA blocks on worker 1
let
lora_event
=
router_event
(
1
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
lora_hashes
,
&
lora_seq
),
}),
);
index
.apply_event
(
lora_event
)
.await
;
// Query second sequence
let
seq2
:
Vec
<
LocalBlockHash
>
=
(
100
..=
102
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq2
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Query a mix that doesn't exist as a sequence - should only match first block
let
mixed
:
Vec
<
LocalBlockHash
>
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
100
)];
let
scores
=
index
.find_matches
(
mixed
)
.await
.unwrap
();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
}
// Query with base-model hashes → only worker 0
let
base_scores
=
index
.find_matches
(
base_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
base_scores
.scores
.len
(),
1
,
"Only base-model worker should match"
);
assert_eq!
(
*
base_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_clears_all_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query with LoRA hashes → only worker 1
let
lora_scores
=
index
.find_matches
(
lora_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
lora_scores
.scores
.len
(),
1
,
"Only LoRA worker should match"
);
assert_eq!
(
*
lora_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
3
);
}
// Store same sequence for different dp_ranks
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
0
))
.await
;
index
.apply_event
(
make_store_event_with_dp_rank
(
0
,
&
[
1
,
2
,
3
],
1
))
.await
;
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_base_same_tokens_no_seq_hash_mismatch
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
3
)
.collect
();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let
base_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
assert_ne!
(
base_local
,
lora_local
,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
flush_and_settle
(
index
.as_ref
())
.await
;
//
Simulate what vLLM does: same tokens, different rolling seq hashes
// because the engine accounts for the adapter internally.
let
base_seq
=
compute_seq_hash_for_block
(
&
base_local
);
let
lora_seq
=
compute_seq_hash_for_block
(
&
lora_local
);
//
Verify both dp_ranks are present
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
.clone
())
.await
.unwrap
(
);
assert_eq!
(
scores
.scores
.len
(),
2
);
// Worker 0: base model
index
.apply_event
(
router_event
(
0
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
base_local
,
&
base_seq
),
}),
))
.await
;
// Clear event clears ALL blocks for the worker_id, regardless of dp_rank
index
.apply_event
(
make_clear_event_with_dp_rank
(
0
,
0
))
.await
;
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event
(
router_event
(
1
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
lora_local
,
&
lora_seq
),
}),
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Base query finds only worker 0
let
base_scores
=
index
.find_matches
(
base_local
.clone
())
.await
.unwrap
();
assert_eq!
(
base_scores
.scores
.len
(),
1
);
assert_eq!
(
*
base_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
flush_and_settle
(
index
.as_ref
())
.await
;
// LoRA query finds only worker 1
let
lora_scores
=
index
.find_matches
(
lora_local
.clone
())
.await
.unwrap
();
assert_eq!
(
lora_scores
.scores
.len
(),
1
);
assert_eq!
(
*
lora_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
3
);
// Both dp_ranks should be cleared
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
(),
"Cleared event should clear all dp_ranks for a worker"
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_different_lora_adapters_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
// ============================================================================
// LoRA isolation tests
// ============================================================================
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
2
)
.collect
();
mod
lora_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
let
hashes_a
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"adapter-a"
));
let
hashes_b
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"adapter-b"
));
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_and_base_model_blocks_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
assert_ne!
(
hashes_a
,
hashes_b
,
"Different adapters must produce different hashes"
);
// Same token sequence for both base model and LoRA adapter
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
3
)
.collect
();
let
seq_a
=
compute_seq_hash_for_block
(
&
hashes_a
);
let
seq_b
=
compute_seq_hash_for_block
(
&
hashes_b
);
let
base_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
// Store adapter-a blocks on worker 0
index
.apply_event
(
router_event
(
// Hashes must differ despite identical tokens
assert_ne!
(
base_hashes
,
lora_hashes
,
"Base and LoRA hashes must differ for the same tokens"
);
let
base_seq
=
compute_seq_hash_for_block
(
&
base_hashes
);
let
lora_seq
=
compute_seq_hash_for_block
(
&
lora_hashes
);
// Store base-model blocks on worker 0
let
base_event
=
router_event
(
0
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
hashes
_a
,
&
seq
_a
),
blocks
:
stored_blocks_with_sequence_hashes
(
&
base_
hashes
,
&
base_
seq
),
}),
)
)
.await
;
)
;
index
.apply_event
(
base_event
)
.await
;
// Store adapter-b blocks on worker 1
index
.apply_event
(
router_event
(
// Store LoRA blocks on worker 1
let
lora_event
=
router_event
(
1
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
hashes
_b
,
&
seq
_b
),
blocks
:
stored_blocks_with_sequence_hashes
(
&
lora_
hashes
,
&
lora_
seq
),
}),
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query adapter-a → only worker 0
let
scores_a
=
index
.find_matches
(
hashes_a
.clone
())
.await
.unwrap
();
assert_eq!
(
scores_a
.scores
.len
(),
1
);
assert
!
(
scores_a
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
assert
!
(
!
scores_a
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
// Query adapter-b → only worker 1
let
scores_b
=
index
.find_matches
(
hashes_b
.clone
())
.await
.unwrap
();
assert_eq!
(
scores_b
.scores
.len
(),
1
);
assert
!
(
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
assert
!
(
!
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
);
index
.apply_event
(
lora_event
)
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_single_store
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence (128 blocks) in a single event
let
seq_len
=
128
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
seq_len
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query full sequence - should match all blocks
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
seq_len
as
u32
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Query prefix (first 64 blocks)
let
prefix_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
64
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
64
);
// Query with base-model hashes → only worker 0
let
base_scores
=
index
.find_matches
(
base_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
base_scores
.scores
.len
(),
1
,
"Only base-model worker should match"
);
assert_eq!
(
*
base_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
// Query with divergence at position 50
let
mut
divergent_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
divergent_query
[
49
]
=
LocalBlockHash
(
99999
);
// Position 49 (0-indexed) diverges
let
scores
=
index
.find_matches
(
divergent_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
49
);
}
// Query with LoRA hashes → only worker 1
let
lora_scores
=
index
.find_matches
(
lora_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
lora_scores
.scores
.len
(),
1
,
"Only LoRA worker should match"
);
assert_eq!
(
*
lora_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
3
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Build a long sequence through multiple continuations
// First store: blocks 1-50
let
first_chunk
:
Vec
<
u64
>
=
(
1
..=
50
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
first_chunk
))
.await
;
// Second store: blocks 51-100 (continuation of first)
let
second_chunk
:
Vec
<
u64
>
=
(
51
..=
100
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
first_chunk
,
&
second_chunk
))
.await
;
// Third store: blocks 101-150 (continuation of second)
let
prefix_1_2
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
let
third_chunk
:
Vec
<
u64
>
=
(
101
..=
150
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
prefix_1_2
,
&
third_chunk
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query full sequence - should match all 150 blocks
let
full_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
150
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
150
);
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_base_same_tokens_no_seq_hash_mismatch
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
3
)
.collect
();
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let
base_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
assert_ne!
(
base_local
,
lora_local
,
"LoRA-aware hashing must produce different LocalBlockHash values"
);
// Query crossing continuation boundaries
let
cross_boundary_query
:
Vec
<
LocalBlockHash
>
=
(
45
..=
105
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
cross_boundary_query
)
.await
.unwrap
();
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
// Simulate what vLLM does: same tokens, different rolling seq hashes
// because the engine accounts for the adapter internally.
let
base_seq
=
compute_seq_hash_for_block
(
&
base_local
);
let
lora_seq
=
compute_seq_hash_for_block
(
&
lora_local
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_branching_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Common prefix: blocks 1-30
let
common_prefix
:
Vec
<
u64
>
=
(
1
..=
30
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
common_prefix
))
.await
;
// Branch A: blocks 31-60 on worker 0
let
branch_a
:
Vec
<
u64
>
=
(
31
..=
60
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
common_prefix
,
&
branch_a
))
.await
;
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index
.apply_event
(
make_store_event
(
1
,
&
common_prefix
))
.await
;
let
branch_b
:
Vec
<
u64
>
=
(
131
..=
160
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
1
,
&
common_prefix
,
&
branch_b
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query common prefix - both workers should match
let
prefix_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
30
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
30
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
30
);
// Worker 0: base model
index
.apply_event
(
router_event
(
0
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
base_local
,
&
base_seq
),
}),
))
.await
;
// Query branch A path - only worker 0 should match fully
let
branch_a_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
60
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
branch_a_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
30
);
}
// Worker 1: LoRA adapter — different LocalBlockHash, so this goes to
// a separate tree path instead of colliding with worker 0's node.
index
.apply_event
(
router_event
(
1
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
lora_local
,
&
lora_seq
),
}),
))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_partial_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Base query finds only worker 0
let
base_scores
=
index
.find_matches
(
base_local
.clone
())
.await
.unwrap
();
assert_eq!
(
base_scores
.scores
.len
(),
1
);
assert_eq!
(
*
base_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
flush_and_settle
(
index
.as_ref
())
.await
;
// LoRA query finds only worker 1
let
lora_scores
=
index
.find_matches
(
lora_local
.clone
())
.await
.unwrap
();
assert_eq!
(
lora_scores
.scores
.len
(),
1
);
assert_eq!
(
*
lora_scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
3
);
}
// Verify full match
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_different_lora_adapters_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
// Remove blocks 80-100 (the tail)
let
tail_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
tail_hashes
);
let
remove_hashes
:
Vec
<
ExternalSequenceBlockHash
>
=
seq_hashes
[
79
..
100
]
.iter
()
.map
(|
&
h
|
ExternalSequenceBlockHash
(
h
))
.collect
();
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
2
)
.collect
();
let
remove_event
=
remove_event
(
0
,
0
,
0
,
remove_hashes
);
index
.apply_event
(
remove_event
)
.await
;
let
hashes_a
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"adapter-a"
)
);
let
hashes_b
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"adapter-b"
))
;
flush_and_settle
(
index
.as_ref
())
.await
;
assert_ne!
(
hashes_a
,
hashes_b
,
"Different adapters must produce different hashes"
);
// Query should now only match first 79 blocks
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
79
);
}
let
seq_a
=
compute_seq_hash_for_block
(
&
hashes_a
);
let
seq_b
=
compute_seq_hash_for_block
(
&
hashes_b
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_interleaved_workers
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let
seq_100
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
let
seq_75
:
Vec
<
u64
>
=
(
1
..=
75
)
.collect
();
let
seq_50
:
Vec
<
u64
>
=
(
1
..=
50
)
.collect
();
let
seq_25
:
Vec
<
u64
>
=
(
1
..=
25
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
seq_100
))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
seq_75
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
seq_50
))
.await
;
index
.apply_event
(
make_store_event
(
3
,
&
seq_25
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let
query_60
:
Vec
<
LocalBlockHash
>
=
(
1
..=
60
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
query_60
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
4
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
50
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
25
);
}
// Store adapter-a blocks on worker 0
index
.apply_event
(
router_event
(
0
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
hashes_a
,
&
seq_a
),
}),
))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_exact_jump_size_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store adapter-b blocks on worker 1
index
.apply_event
(
router_event
(
1
,
0
,
0
,
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
stored_blocks_with_sequence_hashes
(
&
hashes_b
,
&
seq_b
),
}),
))
.await
;
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
// This tests edge cases in the jump search algorithm
flush_and_settle
(
index
.as_ref
())
.await
;
// Store sequence of exactly 32 blocks
let
seq_32
:
Vec
<
u64
>
=
(
1
..=
32
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
seq_32
))
.await
;
// Query adapter-a → only worker 0
let
scores_a
=
index
.find_matches
(
hashes_a
.clone
())
.await
.unwrap
();
assert_eq!
(
scores_a
.scores
.len
(),
1
);
assert
!
(
scores_a
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
assert
!
(
!
scores_a
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
// Store sequence of exactly 64 blocks (2x jump_size)
let
seq_64
:
Vec
<
u64
>
=
(
1001
..=
1064
)
.collect
();
index
.apply_event
(
make_store_event
(
1
,
&
seq_64
))
.await
;
// Query adapter-b → only worker 1
let
scores_b
=
index
.find_matches
(
hashes_b
.clone
())
.await
.unwrap
();
assert_eq!
(
scores_b
.scores
.len
(),
1
);
assert
!
(
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
assert
!
(
!
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
}
// Store sequence of exactly 96 blocks (3x jump_size)
let
seq_96
:
Vec
<
u64
>
=
(
2001
..=
2096
)
.collect
();
index
.apply_event
(
make_store_event
(
2
,
&
seq_96
))
.await
;
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
flush_and_settle
(
index
.as_ref
())
.await
;
mod
long_sequence_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
// Verify all sequences match correctly
let
query_32
:
Vec
<
LocalBlockHash
>
=
seq_32
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_32
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
32
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_single_store
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
query_64
:
Vec
<
LocalBlockHash
>
=
seq_64
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_64
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
64
);
// Store a long sequence (128 blocks) in a single event
let
seq_len
=
128
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
seq_len
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
let
query_96
:
Vec
<
LocalBlockHash
>
=
seq_96
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_96
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
96
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_off_by_one_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let
seq_31
:
Vec
<
u64
>
=
(
1
..=
31
)
.collect
();
let
seq_33
:
Vec
<
u64
>
=
(
101
..=
133
)
.collect
();
let
seq_63
:
Vec
<
u64
>
=
(
201
..=
263
)
.collect
();
let
seq_65
:
Vec
<
u64
>
=
(
301
..=
365
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
seq_31
))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
seq_33
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
seq_63
))
.await
;
index
.apply_event
(
make_store_event
(
3
,
&
seq_65
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify all sequences match correctly
let
query_31
:
Vec
<
LocalBlockHash
>
=
seq_31
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_31
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
31
);
// Query full sequence - should match all blocks
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
seq_len
as
u32
);
let
query_33
:
Vec
<
LocalBlockHash
>
=
seq_33
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_33
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
33
);
// Query prefix (first 64 blocks)
let
prefix_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
64
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
64
);
let
query_63
:
Vec
<
LocalBlockHash
>
=
seq_63
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_63
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
63
);
// Query with divergence at position 50
let
mut
divergent_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
divergent_query
[
49
]
=
LocalBlockHash
(
99999
);
// Position 49 (0-indexed) diverges
let
scores
=
index
.find_matches
(
divergent_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
49
);
}
let
query_65
:
Vec
<
LocalBlockHash
>
=
seq_65
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_65
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
65
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_divergence_at_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
)
;
// Build a long sequence through multiple continuations
// First store: blocks 1-50
let
first_chunk
:
Vec
<
u64
>
=
(
1
..=
50
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
first_chunk
))
.await
;
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
128
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Second store: blocks 51-100 (continuation of first)
let
second_chunk
:
Vec
<
u64
>
=
(
51
..=
100
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
first_chunk
,
&
second_chunk
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Third store: blocks 101-150 (continuation of second)
let
prefix_1_2
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
let
third_chunk
:
Vec
<
u64
>
=
(
101
..=
150
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
prefix_1_2
,
&
third_chunk
))
.await
;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for
diverge_pos
in
[
31u
size
,
32
,
33
,
63
,
64
,
65
,
95
,
96
,
97
]
{
let
mut
query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
128
)
.map
(
LocalBlockHash
)
.collect
();
query
[
diverge_pos
]
=
LocalBlockHash
(
99999
);
flush_and_settle
(
index
.as_ref
())
.await
;
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
// Query full sequence - should match all 150 blocks
let
full_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
150
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
diverge_pos
as
u32
,
"Divergence at position {} should match {} blocks"
,
diverge_pos
,
diverge_pos
150
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_deep_continuation_chain
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query crossing continuation boundaries
let
cross_boundary_query
:
Vec
<
LocalBlockHash
>
=
(
45
..=
105
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
cross_boundary_query
)
.await
.unwrap
();
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
);
}
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let
chunk_size
=
10
;
let
num_chunks
=
20
;
// Total 200 blocks
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_branching_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
mut
full_prefix
:
Vec
<
u64
>
=
Vec
::
new
();
// Common prefix: blocks 1-30
let
common_prefix
:
Vec
<
u64
>
=
(
1
..=
30
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
common_prefix
))
.await
;
for
chunk_idx
in
0
..
num_chunks
{
let
chunk_start
=
chunk_idx
*
chunk_size
+
1
;
let
chunk
:
Vec
<
u64
>
=
(
chunk_start
..
chunk_start
+
chunk_size
)
.
map
(|
x
|
x
as
u64
)
.
collect
()
;
// Branch A: blocks 31-60 on worker 0
let
branch_a
:
Vec
<
u64
>
=
(
31
..=
60
)
.collect
()
;
index
.
apply_event
(
make_store_event_with_parent
(
0
,
&
common_prefix
,
&
branch_a
)
)
.
await
;
if
chunk_idx
==
0
{
index
.apply_event
(
make_store_event
(
0
,
&
chunk
))
.await
;
}
else
{
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
full_prefix
,
&
chunk
))
.await
;
}
// Branch B: blocks 131-160 (different content) on worker 1
// First store the common prefix for worker 1
index
.apply_event
(
make_store_event
(
1
,
&
common_prefix
))
.await
;
let
branch_b
:
Vec
<
u64
>
=
(
131
..=
160
)
.collect
();
index
.apply_event
(
make_store_event_with_parent
(
1
,
&
common_prefix
,
&
branch_b
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
full_prefix
.extend
(
&
chunk
);
// Query common prefix - both workers should match
let
prefix_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
30
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_query
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
30
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
30
);
// Query branch A path - only worker 0 should match fully
let
branch_a_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
60
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
branch_a_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
30
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_partial_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query full sequence
let
full_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
200
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
200
);
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Query partial prefix crossing multiple chunk boundaries
let
partial_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
75
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
75
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_clear_and_rebuild
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Verify full match
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Remove blocks 80-100 (the tail)
let
tail_hashes
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
let
seq_hashes
=
compute_seq_hash_for_block
(
&
tail_hashes
);
let
remove_hashes
:
Vec
<
ExternalSequenceBlockHash
>
=
seq_hashes
[
79
..
100
]
.iter
()
.map
(|
&
h
|
ExternalSequenceBlockHash
(
h
))
.collect
();
flush_and_settle
(
index
.as_ref
())
.await
;
let
remove_event
=
remove_event
(
0
,
0
,
0
,
remove_hashes
);
index
.apply_event
(
remove_event
)
.await
;
// Verify it's stored
let
query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Clear the worker
index
.apply_event
(
make_clear_event
(
0
))
.await
;
// Query should now only match first 79 blocks
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
79
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_interleaved_workers
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Multiple workers storing overlapping long sequences concurrently
// Worker 0: blocks 1-100
// Worker 1: blocks 1-75
// Worker 2: blocks 1-50
// Worker 3: blocks 1-25
let
seq_100
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
let
seq_75
:
Vec
<
u64
>
=
(
1
..=
75
)
.collect
();
let
seq_50
:
Vec
<
u64
>
=
(
1
..=
50
)
.collect
();
let
seq_25
:
Vec
<
u64
>
=
(
1
..=
25
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
seq_100
))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
seq_75
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
seq_50
))
.await
;
index
.apply_event
(
make_store_event
(
3
,
&
seq_25
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query for 60 blocks - workers 0,1 match 60, worker 2 matches 50, worker 3 matches 25
let
query_60
:
Vec
<
LocalBlockHash
>
=
(
1
..=
60
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
query_60
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
4
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
60
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
50
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
25
);
}
// Verify it's cleared
let
scores
=
index
.find_matches
(
query
.clone
())
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_exact_jump_size_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Rebuild with a different sequence
let
new_sequence
:
Vec
<
u64
>
=
(
1001
..=
1100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
new_sequence
))
.await
;
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
// This tests edge cases in the jump search algorithm
flush_and_settle
(
index
.as_ref
())
.await
;
// Store sequence of exactly 32 blocks
let
seq_32
:
Vec
<
u64
>
=
(
1
..=
32
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
seq_32
))
.await
;
// Verify new sequence works
let
new_query
:
Vec
<
LocalBlockHash
>
=
new_sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
new_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
// Store sequence of exactly 64 blocks (2x jump_size)
let
seq_64
:
Vec
<
u64
>
=
(
1001
..=
1064
)
.collect
();
index
.apply_event
(
make_store_event
(
1
,
&
seq_64
))
.await
;
// Verify old sequence no longer matches
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
// Store sequence of exactly 96 blocks (3x jump_size)
let
seq_96
:
Vec
<
u64
>
=
(
2001
..=
2096
)
.collect
();
index
.apply_event
(
make_store_event
(
2
,
&
seq_96
))
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_workers_diverging
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// Verify all sequences match correctly
let
query_32
:
Vec
<
LocalBlockHash
>
=
seq_32
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_32
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
32
);
// All workers share prefix 1-40
let
shared_prefix
:
Vec
<
u64
>
=
(
1
..=
40
)
.collect
();
let
query_64
:
Vec
<
LocalBlockHash
>
=
seq_64
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_64
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
64
);
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let
worker_0_full
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
let
query_96
:
Vec
<
LocalBlockHash
>
=
seq_96
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_96
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
96
);
}
// Worker 1: prefix + 141-180 (diverges at block 41)
let
worker_1_suffix
:
Vec
<
u64
>
=
(
141
..=
180
)
.collect
();
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_off_by_one_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Worker 2: prefix + 241-300 (diverges at block 41)
let
worker_2_suffix
:
Vec
<
u64
>
=
(
241
..=
300
)
.collect
();
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
let
seq_31
:
Vec
<
u64
>
=
(
1
..=
31
)
.collect
();
let
seq_33
:
Vec
<
u64
>
=
(
101
..=
133
)
.collect
();
let
seq_63
:
Vec
<
u64
>
=
(
201
..=
263
)
.collect
();
let
seq_65
:
Vec
<
u64
>
=
(
301
..=
365
)
.collect
();
// Store for all workers
index
.apply_event
(
make_store_event
(
0
,
&
worker_0_full
))
.await
;
index
.apply_event
(
make_store_event
(
0
,
&
seq_31
))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
seq_33
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
seq_63
))
.await
;
index
.apply_event
(
make_store_event
(
3
,
&
seq_65
))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
shared_prefix
))
.await
;
index
.apply_event
(
make_store_event_with_parent
(
1
,
&
shared_prefix
,
&
worker_1_suffix
,
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
shared_prefix
))
.await
;
index
.apply_event
(
make_store_event_with_parent
(
2
,
&
shared_prefix
,
&
worker_2_suffix
,
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let
query
:
Vec
<
LocalBlockHash
>
=
worker_0_full
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
40
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
40
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify all sequences match correctly
let
query_31
:
Vec
<
LocalBlockHash
>
=
seq_31
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_31
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
31
);
let
query_33
:
Vec
<
LocalBlockHash
>
=
seq_33
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_33
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
33
);
let
query_63
:
Vec
<
LocalBlockHash
>
=
seq_63
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_63
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
63
);
let
query_65
:
Vec
<
LocalBlockHash
>
=
seq_65
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query_65
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
65
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_divergence_at_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
128
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Test divergence exactly at jump boundaries (position 31, 32, 33, 63, 64, 65)
for
diverge_pos
in
[
31u
size
,
32
,
33
,
63
,
64
,
65
,
95
,
96
,
97
]
{
let
mut
query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
128
)
.map
(
LocalBlockHash
)
.collect
();
query
[
diverge_pos
]
=
LocalBlockHash
(
99999
);
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
diverge_pos
as
u32
,
"Divergence at position {} should match {} blocks"
,
diverge_pos
,
diverge_pos
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_deep_continuation_chain
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Build a very long sequence through many small continuations
// This tests the parent_hash chain handling
let
chunk_size
=
10
;
let
num_chunks
=
20
;
// Total 200 blocks
let
mut
full_prefix
:
Vec
<
u64
>
=
Vec
::
new
();
for
chunk_idx
in
0
..
num_chunks
{
let
chunk_start
=
chunk_idx
*
chunk_size
+
1
;
let
chunk
:
Vec
<
u64
>
=
(
chunk_start
..
chunk_start
+
chunk_size
)
.map
(|
x
|
x
as
u64
)
.collect
();
if
chunk_idx
==
0
{
index
.apply_event
(
make_store_event
(
0
,
&
chunk
))
.await
;
}
else
{
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
full_prefix
,
&
chunk
))
.await
;
}
full_prefix
.extend
(
&
chunk
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
// Query full sequence
let
full_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
200
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
200
);
// Query partial prefix crossing multiple chunk boundaries
let
partial_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
75
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
75
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_clear_and_rebuild
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence
let
sequence
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify it's stored
let
query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query
.clone
())
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
// Clear the worker
index
.apply_event
(
make_clear_event
(
0
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify it's cleared
let
scores
=
index
.find_matches
(
query
.clone
())
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
// Rebuild with a different sequence
let
new_sequence
:
Vec
<
u64
>
=
(
1001
..=
1100
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
new_sequence
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify new sequence works
let
new_query
:
Vec
<
LocalBlockHash
>
=
new_sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
new_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
// Verify old sequence no longer matches
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_workers_diverging
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Multiple workers with long sequences that share a prefix then diverge
// This tests precise drain point tracking across workers
// All workers share prefix 1-40
let
shared_prefix
:
Vec
<
u64
>
=
(
1
..=
40
)
.collect
();
// Worker 0: prefix + 41-100 (stores full sequence 1-100)
let
worker_0_full
:
Vec
<
u64
>
=
(
1
..=
100
)
.collect
();
// Worker 1: prefix + 141-180 (diverges at block 41)
let
worker_1_suffix
:
Vec
<
u64
>
=
(
141
..=
180
)
.collect
();
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_staggered_lengths
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Worker 2: prefix + 241-300 (diverges at block 41)
let
worker_2_suffix
:
Vec
<
u64
>
=
(
241
..=
300
)
.collect
();
// Workers with sequences of staggered lengths to test drain tracking
// Worker 0: 10 blocks
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
// Store for all workers
index
.apply_event
(
make_store_event
(
0
,
&
worker_0_full
))
.await
;
for
(
worker_id
,
len
)
in
[(
0
,
10
),
(
1
,
20
),
(
2
,
35
),
(
3
,
64
),
(
4
,
100
)]
{
let
sequence
:
Vec
<
u64
>
=
(
1
..=
len
)
.collect
();
index
.apply_event
(
make_store_event
(
1
,
&
shared_prefix
))
.await
;
index
.apply_event
(
make_store_event
(
worker_id
,
&
sequence
))
.apply_event
(
make_store_event_with_parent
(
1
,
&
shared_prefix
,
&
worker_1_suffix
,
))
.await
;
index
.apply_event
(
make_store_event
(
2
,
&
shared_prefix
))
.await
;
index
.apply_event
(
make_store_event_with_parent
(
2
,
&
shared_prefix
,
&
worker_2_suffix
,
))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
// Query 1-100 - worker 0 matches 100, workers 1&2 match 40
let
query
:
Vec
<
LocalBlockHash
>
=
worker_0_full
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
100
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
40
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
40
);
}
flush_and_settle
(
index
.as_ref
())
.await
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_staggered_lengths
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Query for 100 blocks - each worker should match their stored length
let
query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
// Workers with sequences of staggered lengths to test drain tracking
// Worker 0: 10 blocks
// Worker 1: 20 blocks
// Worker 2: 35 blocks (just past first jump)
// Worker 3: 64 blocks (exactly 2 jumps)
// Worker 4: 100 blocks
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
10
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
20
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
35
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
64
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
4
,
0
))
.unwrap
(),
100
);
}
for
(
worker_id
,
len
)
in
[(
0
,
10
),
(
1
,
20
),
(
2
,
35
),
(
3
,
64
),
(
4
,
100
)]
{
let
sequence
:
Vec
<
u64
>
=
(
1
..=
len
)
.collect
();
index
.apply_event
(
make_store_event
(
worker_id
,
&
sequence
))
.await
;
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_very_long_sequence
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Test with a very long sequence (1000 blocks)
let
seq_len
=
1000u64
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
seq_len
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Query for 100 blocks - each worker should match their stored length
let
query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
100
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
flush_and_settle
(
index
.as_ref
())
.await
;
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
10
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
20
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
35
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
64
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
4
,
0
))
.unwrap
(),
100
);
}
// Full match
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
seq_len
as
u32
);
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_very_long_sequence
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Partial match (first 500)
let
partial_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
500
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
500
);
// Test with a very long sequence (1000 blocks)
let
seq_len
=
1000u64
;
let
sequence
:
Vec
<
u64
>
=
(
1
..=
seq_len
)
.collect
();
index
.apply_event
(
make_store_event
(
0
,
&
sequence
))
.await
;
// Divergence in the middle
let
mut
mid_diverge
:
Vec
<
LocalBlockHash
>
=
(
1
..=
1000
)
.map
(
LocalBlockHash
)
.collect
();
mid_diverge
[
499
]
=
LocalBlockHash
(
99999
);
let
scores
=
index
.find_matches
(
mid_diverge
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
499
);
flush_and_settle
(
index
.as_ref
())
.await
;
// Full match
let
full_query
:
Vec
<
LocalBlockHash
>
=
sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
full_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
seq_len
as
u32
);
// Partial match (first 500)
let
partial_query
:
Vec
<
LocalBlockHash
>
=
(
1
..=
500
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
500
);
// Divergence in the middle
let
mut
mid_diverge
:
Vec
<
LocalBlockHash
>
=
(
1
..=
1000
)
.map
(
LocalBlockHash
)
.collect
();
mid_diverge
[
499
]
=
LocalBlockHash
(
99999
);
let
scores
=
index
.find_matches
(
mid_diverge
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
499
);
}
}
// ============================================================================
...
...
@@ -1670,129 +1690,146 @@ fn make_tree_indexer_with_frequency(
}
}
#[tokio::test]
#[apply(tree_indexer_template)]
async
fn
test_frequency
(
variant
:
&
str
)
{
const
ONE_MILLIS
:
Duration
=
Duration
::
from_millis
(
1
);
let
expiration
=
Duration
::
from_millis
(
50
);
let
kv_indexer
=
make_tree_indexer_with_frequency
(
variant
,
expiration
);
// The blocks
let
block_hashes
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
LocalBlockHash
(
4
),
];
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Should be no cached blocks yet"
);
mod
tree_specific_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
// Blocks go in cache
let
event
=
make_store_event
(
0
,
&
[
1
,
2
,
3
,
4
]);
kv_indexer
.apply_event
(
event
)
.await
;
// First access - poll briefly since store event is applied async
let
mut
overlap
=
OverlapScores
::
default
();
let
timeout
=
Duration
::
from_millis
(
10
);
let
start
=
Instant
::
now
();
while
overlap
.scores
.is_empty
()
&&
Instant
::
now
()
.duration_since
(
start
)
<
timeout
{
time
::
sleep
(
ONE_MILLIS
)
.await
;
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
}
assert_eq!
(
overlap
.scores
.len
(),
1
,
"One worker has these blocks cached"
);
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Blocks have not previously been accessed"
);
#[tokio::test]
#[apply(tree_indexer_template)]
async
fn
test_frequency
(
variant
:
&
str
)
{
const
ONE_MILLIS
:
Duration
=
Duration
::
from_millis
(
1
);
// Second access
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.scores
.len
(),
1
,
"Still one worker matches"
);
assert_eq!
(
overlap
.frequencies
,
vec!
[
1
,
1
,
1
,
1
],
"We should see the first access now"
);
let
expiration
=
Duration
::
from_millis
(
50
);
let
kv_indexer
=
make_tree_indexer_with_frequency
(
variant
,
expiration
);
// Let those two accesses expire
time
::
sleep
(
expiration
+
Duration
::
from_millis
(
10
))
.await
;
// The blocks
let
block_hashes
=
vec!
[
LocalBlockHash
(
1
),
LocalBlockHash
(
2
),
LocalBlockHash
(
3
),
LocalBlockHash
(
4
),
];
// New first access
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Blocks were accessed too long ago"
);
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Should be no cached blocks yet"
);
// Blocks go in cache
let
event
=
make_store_event
(
0
,
&
[
1
,
2
,
3
,
4
]);
kv_indexer
.apply_event
(
event
)
.await
;
// First access - poll briefly since store event is applied async
let
mut
overlap
=
OverlapScores
::
default
();
let
timeout
=
Duration
::
from_millis
(
10
);
let
start
=
Instant
::
now
();
while
overlap
.scores
.is_empty
()
&&
Instant
::
now
()
.duration_since
(
start
)
<
timeout
{
time
::
sleep
(
ONE_MILLIS
)
.await
;
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
}
assert_eq!
(
overlap
.scores
.len
(),
1
,
"One worker has these blocks cached"
);
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Blocks have not previously been accessed"
);
// New second access
let
_
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
// Second access
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.scores
.len
(),
1
,
"Still one worker matches"
);
assert_eq!
(
overlap
.frequencies
,
vec!
[
1
,
1
,
1
,
1
],
"We should see the first access now"
);
// Access only the first three blocks
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
[
0
..
3
]
.to_vec
())
.await
.unwrap
();
// We see the previous two new accesses
assert_eq!
(
overlap
.frequencies
,
vec!
[
2
,
2
,
2
]);
// Let those two accesses expire
time
::
sleep
(
expiration
+
Duration
::
from_millis
(
10
))
.await
;
// The third access did not touch the last block
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
,
vec!
[
3
,
3
,
3
,
2
]);
// New first access
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
.len
(),
0
,
"Blocks were accessed too long ago"
);
// New second access
let
_
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
// Access only the first three blocks
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
[
0
..
3
]
.to_vec
())
.await
.unwrap
();
// We see the previous two new accesses
assert_eq!
(
overlap
.frequencies
,
vec!
[
2
,
2
,
2
]);
// The third access did not touch the last block
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
,
vec!
[
3
,
3
,
3
,
2
]);
}
}
// ============================================================================
// KvIndexerMetrics tests
// ============================================================================
#[cfg(feature
=
"metrics"
)]
#[test]
fn
test_increment_event_applied
()
{
let
metrics
=
KvIndexerMetrics
::
new_unregistered
();
mod
metrics_tests
{
#[cfg(feature
=
"metrics"
)]
use
super
::
*
;
metrics
.increment_event_applied
(
METRIC_EVENT_STORED
,
Ok
(()));
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_OK
])
.unwrap
()
.get
(),
1
);
#[cfg(feature
=
"metrics"
)]
#[test]
fn
test_increment_event_applied
()
{
let
metrics
=
KvIndexerMetrics
::
new_unregistered
();
metrics
.increment_event_applied
(
METRIC_EVENT_STORED
,
Err
(
KvCacheEventError
::
ParentBlockNotFound
),
);
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
metrics
.increment_event_applied
(
METRIC_EVENT_STORED
,
Ok
(()));
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_OK
])
.unwrap
()
.get
(),
1
);
metrics
.increment_event_applied
(
METRIC_EVENT_STORED
,
Err
(
KvCacheEventError
::
ParentBlockNotFound
),
);
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
metrics
.increment_event_applied
(
METRIC_EVENT_REMOVED
,
Err
(
KvCacheEventError
::
BlockNotFound
));
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_REMOVED
,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
.increment_event_applied
(
METRIC_EVENT_REMOVED
,
Err
(
KvCacheEventError
::
BlockNotFound
));
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_REMOVED
,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
}
}
// ============================================================================
...
...
@@ -1822,363 +1859,368 @@ fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
indexer
}
#[tokio::test]
async
fn
test_local_indexer_slice_within_range
()
{
let
indexer
=
make_local_indexer_with_events
(
&
[
1
,
2
,
3
,
4
,
5
]);
mod
local_indexer_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
async
fn
test_local_indexer_slice_within_range
()
{
let
indexer
=
make_local_indexer_with_events
(
&
[
1
,
2
,
3
,
4
,
5
]);
// Helper to extract events from response
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
{
events
:
e
,
..
}
=>
e
,
_
=>
panic!
(
"Unexpected response type"
),
}
};
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
4
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
]);
// inclusive range [2, 4]
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
6
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
,
5
]);
// clamp end to buffer max
// start_id=0 is before buffer (first is 1), so should trigger tree dump
let
result
=
indexer
.get_events_in_id_range
(
Some
(
0
),
Some
(
4
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
}));
let
result
=
indexer
.get_events_in_id_range
(
Some
(
3
),
Some
(
3
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
3
]);
// single element when start == end
// Invalid range: end < start
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
2
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
}
// Helper to extract events from response
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
{
events
:
e
,
..
}
=>
e
,
_
=>
panic!
(
"Unexpected response type"
),
#[tokio::test]
async
fn
test_local_indexer_get_events_in_id_range_all_cases
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
// Helper to create a test event
let
make_event
=
|
id
:
u64
|
{
RouterEvent
::
new
(
0
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for
id
in
5
..
15
{
indexer
.apply_event_with_buffer
(
make_event
(
id
))
.await
.unwrap
();
}
};
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Wait for events to be processed
indexer
.flush
()
.await
;
// Test get_events_in_id_range (buffer queries)
// Range is [start, end] inclusive
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
4
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
]);
// inclusive range [2, 4]
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
{
events
:
e
,
..
}
=>
e
,
_
=>
panic!
(
"Unexpected response type: {:?}"
,
resp
),
}
};
let
result
=
indexer
.get_events_in_id_range
(
Some
(
2
),
Some
(
6
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
));
assert_eq!
(
ids
,
vec!
[
2
,
3
,
4
,
5
]);
// clamp end to buffer max
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
//
start_id=0 is before buffer (first is 1), so should trigger tree dump
let
result
=
indexer
.get_events_in_
id_range
(
Some
(
0
),
Some
(
4
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
})
);
//
Verify buffer state
let
buffer_events
=
indexer
.get_
all_
events_in_
buffer
()
;
assert
_eq!
(
get_ids
(
buffer_events
),
vec!
[
10
,
11
,
12
,
13
,
14
]
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
3
),
Some
(
3
))
.await
;
let
ids
=
get_ids
(
extract_events
(
result
))
;
assert_eq!
(
ids
,
vec!
[
3
]);
// single element when start == end
// Buffer path tests
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
None
)
.await
;
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
11
,
12
,
13
,
14
]);
// Invalid range: end < start
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
2
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
}
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
Some
(
14
))
.await
;
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
10
,
11
,
12
,
13
,
14
]);
#[tokio::test]
async
fn
test_local_indexer_get_events_in_id_range_all_cases
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
// Tree dump path tests
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
assert
!
(
matches!
(
&
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
}));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
);
// Helper to create a test event
let
make_event
=
|
id
:
u64
|
{
RouterEvent
::
new
(
0
,
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
}));
// Edge cases
let
result
=
indexer
.get_events_in_id_range
(
Some
(
15
),
Some
(
10
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
let
result
=
indexer
.get_events_in_id_range
(
Some
(
100
),
Some
(
200
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TooNew
{
..
}));
}
#[tokio::test]
async
fn
test_tree_dump_includes_last_event_id
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
let
make_event
=
|
id
:
u64
|
{
RouterEvent
::
new
(
0
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for
id
in
5
..
15
{
indexer
.apply_event_with_buffer
(
make_event
(
id
))
.await
.unwrap
();
}
indexer
.flush
()
.await
;
// Request with start_id=None -> tree dump should include last_event_id=14
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
14
,
"last_event_id should be the buffer's newest event ID"
);
assert
!
(
!
events
.is_empty
(),
"tree dump should contain events"
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
14
,
"last_event_id should be the buffer's newest event ID"
);
assert
!
(
!
events
.is_empty
(),
"tree dump should contain events"
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
// Empty buffer case: create a fresh indexer with no events
let
empty_indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
let
result
=
empty_indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
0
,
"empty buffer should return last_event_id=0"
);
assert
!
(
events
.is_empty
(),
"empty indexer should have no events"
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
}
#[tokio::test]
async
fn
test_local_indexer_buffer_and_serialization
()
{
let
worker_id
=
42u64
;
let
token
=
CancellationToken
::
new
();
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
let
local_indexer
=
Arc
::
new
(
LocalKvIndexer
::
new
(
token
,
4
,
metrics
,
100
));
let
test_event
=
RouterEvent
::
new
(
worker_id
,
KvCacheEvent
{
event_id
:
id
,
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
block_hash
:
ExternalSequenceBlockHash
(
100
),
tokens_hash
:
LocalBlockHash
(
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
)
};
);
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for
id
in
5
..
15
{
indexer
.apply_event_with_buffer
(
make_event
(
id
))
local_indexer
.apply_event_with_buffer
(
test_event
)
.await
.unwrap
();
}
// Wait for events to be processed
indexer
.flush
()
.await
;
local_indexer
.flush
()
.await
;
let
extract_events
=
|
resp
:
WorkerKvQueryResponse
|
->
Vec
<
RouterEvent
>
{
match
resp
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
WorkerKvQueryResponse
::
TreeDump
{
events
:
e
,
..
}
=>
e
,
_
=>
panic!
(
"Unexpected response type: {:?}"
,
resp
),
}
};
let
buffered_events
=
local_indexer
.get_all_events_in_buffer
();
assert_eq!
(
buffered_events
.len
(),
1
);
assert_eq!
(
buffered_events
[
0
]
.worker_id
,
worker_id
);
let
get_ids
=
|
events
:
Vec
<
RouterEvent
>
|
->
Vec
<
u64
>
{
events
.iter
()
.map
(|
e
|
e
.event.event_id
)
.collect
()
};
// Test serialization round-trip
let
response
=
WorkerKvQueryResponse
::
Events
(
buffered_events
);
let
serialized
=
serde_json
::
to_vec
(
&
response
)
.unwrap
();
let
deserialized
:
WorkerKvQueryResponse
=
serde_json
::
from_slice
(
&
serialized
)
.unwrap
();
// Verify buffer state
let
buffer_events
=
indexer
.get_all_events_in_buffer
();
assert_eq!
(
get_ids
(
buffer_events
),
vec!
[
10
,
11
,
12
,
13
,
14
]);
// Buffer path tests
let
result
=
indexer
.get_events_in_id_range
(
Some
(
11
),
None
)
.await
;
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
11
,
12
,
13
,
14
]);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
10
),
Some
(
14
))
.await
;
assert_eq!
(
get_ids
(
extract_events
(
result
)),
vec!
[
10
,
11
,
12
,
13
,
14
]);
// Tree dump path tests
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
assert
!
(
matches!
(
&
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
}));
assert_eq!
(
extract_events
(
result
)
.len
(),
10
);
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TreeDump
{
..
}));
// Edge cases
let
result
=
indexer
.get_events_in_id_range
(
Some
(
15
),
Some
(
10
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
let
result
=
indexer
.get_events_in_id_range
(
Some
(
100
),
Some
(
200
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TooNew
{
..
}));
}
let
events
=
match
deserialized
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
_
=>
panic!
(
"Expected Events variant"
),
};
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
]
.worker_id
,
worker_id
);
}
#[tokio::test]
async
fn
test_tree_dump_includes_last_event_id
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
#[tokio::test]
async
fn
test_local_indexer_does_not_buffer_failed_send
()
{
let
local_indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
let
make_event
=
|
id
:
u64
|
{
RouterEvent
::
new
(
0
,
let
test_event
=
RouterEvent
::
new
(
7
,
KvCacheEvent
{
event_id
:
id
,
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
id
*
100
),
tokens_hash
:
LocalBlockHash
(
id
*
200
),
block_hash
:
ExternalSequenceBlockHash
(
100
),
tokens_hash
:
LocalBlockHash
(
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
)
};
// Add 10 events (IDs 5-14), buffer keeps last 5: events 10-14
for
id
in
5
..
15
{
indexer
.apply_event_with_buffer
(
make_event
(
id
))
.await
.unwrap
();
}
indexer
.flush
()
.await
;
// Request with start_id=None -> tree dump should include last_event_id=14
let
result
=
indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
14
,
"last_event_id should be the buffer's newest event ID"
);
assert
!
(
!
events
.is_empty
(),
"tree dump should contain events"
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
// Request with start_id older than buffer -> tree dump should include last_event_id=14
let
result
=
indexer
.get_events_in_id_range
(
Some
(
7
),
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
14
,
"last_event_id should be the buffer's newest event ID"
);
assert
!
(
!
events
.is_empty
(),
"tree dump should contain events"
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
);
// Empty buffer case: create a fresh indexer with no events
let
empty_indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
let
result
=
empty_indexer
.get_events_in_id_range
(
None
,
None
)
.await
;
match
result
{
WorkerKvQueryResponse
::
TreeDump
{
last_event_id
,
events
,
}
=>
{
assert_eq!
(
last_event_id
,
0
,
"empty buffer should return last_event_id=0"
);
assert
!
(
events
.is_empty
(),
"empty indexer should have no events"
);
let
event_tx
=
local_indexer
.event_sender
();
local_indexer
.shutdown
();
event_tx
.closed
()
.await
;
let
result
=
local_indexer
.apply_event_with_buffer
(
test_event
)
.await
;
assert
!
(
matches!
(
result
,
Err
(
KvRouterError
::
IndexerOffline
)));
assert_eq!
(
local_indexer
.buffer_len
(),
0
);
match
local_indexer
.get_events_in_id_range
(
None
,
None
)
.await
{
WorkerKvQueryResponse
::
TreeDump
{
events
,
last_event_id
,
}
=>
{
assert
!
(
events
.is_empty
());
assert_eq!
(
last_event_id
,
0
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
}
#[tokio::test]
async
fn
test_local_indexer_buffer_and_serialization
()
{
let
worker_id
=
42u64
;
let
token
=
CancellationToken
::
new
();
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
let
local_indexer
=
Arc
::
new
(
LocalKvIndexer
::
new
(
token
,
4
,
metrics
,
100
));
let
test_event
=
RouterEvent
::
new
(
worker_id
,
KvCacheEvent
{
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
100
),
tokens_hash
:
LocalBlockHash
(
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
);
local_indexer
.apply_event_with_buffer
(
test_event
)
.await
.unwrap
();
local_indexer
.flush
()
.await
;
let
buffered_events
=
local_indexer
.get_all_events_in_buffer
();
assert_eq!
(
buffered_events
.len
(),
1
);
assert_eq!
(
buffered_events
[
0
]
.worker_id
,
worker_id
);
// Test serialization round-trip
let
response
=
WorkerKvQueryResponse
::
Events
(
buffered_events
);
let
serialized
=
serde_json
::
to_vec
(
&
response
)
.unwrap
();
let
deserialized
:
WorkerKvQueryResponse
=
serde_json
::
from_slice
(
&
serialized
)
.unwrap
();
let
events
=
match
deserialized
{
WorkerKvQueryResponse
::
Events
(
e
)
=>
e
,
_
=>
panic!
(
"Expected Events variant"
),
};
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
]
.worker_id
,
worker_id
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_apply_events_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
#[tokio::test]
async
fn
test_local_indexer_does_not_buffer_failed_send
()
{
let
local_indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
()),
5
,
);
let
test_event
=
RouterEvent
::
new
(
7
,
KvCacheEvent
{
event_id
:
1
,
data
:
KvCacheEventData
::
Stored
(
KvCacheStoreData
{
parent_hash
:
None
,
blocks
:
vec!
[
KvCacheStoredBlockData
{
block_hash
:
ExternalSequenceBlockHash
(
100
),
tokens_hash
:
LocalBlockHash
(
200
),
mm_extra_info
:
None
,
}],
}),
dp_rank
:
0
,
},
);
// Setup: build initial tree
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
4
,
5
,
6
]))
.await
;
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
,
3
],
&
[
7
,
8
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s0
=
snapshot_tree
(
index
.as_ref
())
.await
;
// Mutation events: each add paired with its remove
let
adds
=
[
make_store_event
(
2
,
&
[
1
,
2
,
9
]),
make_store_event_with_parent
(
1
,
&
[
4
,
5
,
6
],
&
[
10
,
11
,
12
]),
];
let
removes
=
[
make_remove_event
(
2
,
&
[
1
,
2
,
9
]),
make_remove_event_with_parent
(
1
,
&
[
4
,
5
,
6
],
&
[
10
,
11
,
12
]),
];
// Phase 1: interleaved add/remove
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s1
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s0
,
s1
,
"Phase 1: interleaved add/remove should restore tree"
);
let
event_tx
=
local_indexer
.event_sender
();
local_indexer
.shutdown
();
event_tx
.closed
()
.await
;
let
result
=
local_indexer
.apply_event_with_buffer
(
test_event
)
.await
;
assert
!
(
matches!
(
result
,
Err
(
KvRouterError
::
IndexerOffline
)));
assert_eq!
(
local_indexer
.buffer_len
(),
0
);
match
local_indexer
.get_events_in_id_range
(
None
,
None
)
.await
{
WorkerKvQueryResponse
::
TreeDump
{
events
,
last_event_id
,
}
=>
{
assert
!
(
events
.is_empty
());
assert_eq!
(
last_event_id
,
0
);
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
// Phase 2: same interleaved again (idempotence of the full cycle)
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s2
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s1
,
s2
,
"Phase 2: repeated cycle should be idempotent"
);
// Phase 3: non-interleaved (all adds then all removes)
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s3
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s2
,
s3
,
"Phase 3: non-interleaved ordering should restore tree"
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_apply_events_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Setup: build initial tree
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
index
.apply_event
(
make_store_event
(
1
,
&
[
4
,
5
,
6
]))
.await
;
index
.apply_event
(
make_store_event_with_parent
(
0
,
&
[
1
,
2
,
3
],
&
[
7
,
8
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s0
=
snapshot_tree
(
index
.as_ref
())
.await
;
// Mutation events: each add paired with its remove
let
adds
=
[
make_store_event
(
2
,
&
[
1
,
2
,
9
]),
make_store_event_with_parent
(
1
,
&
[
4
,
5
,
6
],
&
[
10
,
11
,
12
]),
];
let
removes
=
[
make_remove_event
(
2
,
&
[
1
,
2
,
9
]),
make_remove_event_with_parent
(
1
,
&
[
4
,
5
,
6
],
&
[
10
,
11
,
12
]),
];
// Phase 1: interleaved add/remove
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s1
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s0
,
s1
,
"Phase 1: interleaved add/remove should restore tree"
);
// Phase 2: same interleaved again (idempotence of the full cycle)
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s2
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s1
,
s2
,
"Phase 2: repeated cycle should be idempotent"
);
// Phase 3: non-interleaved (all adds then all removes)
index
.apply_event
(
adds
[
0
]
.clone
())
.await
;
index
.apply_event
(
adds
[
1
]
.clone
())
.await
;
index
.apply_event
(
removes
[
0
]
.clone
())
.await
;
index
.apply_event
(
removes
[
1
]
.clone
())
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
let
s3
=
snapshot_tree
(
index
.as_ref
())
.await
;
assert_eq!
(
s2
,
s3
,
"Phase 3: non-interleaved ordering should restore tree"
);
}
lib/kv-router/src/lib.rs
View file @
b7fe46b1
...
...
@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub
mod
event_sink
;
pub
mod
indexer
;
pub
mod
protocols
;
pub
mod
scheduling
;
...
...
@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub
use
concurrent_radix_tree
::
ConcurrentRadixTree
;
pub
use
concurrent_radix_tree_compressed
::
ConcurrentRadixTreeCompressed
;
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterQueuePolicy
};
pub
use
event_sink
::
EventSink
;
pub
use
indexer
::{
MaybeError
,
SyncIndexer
,
ThreadPoolIndexer
};
pub
use
nested_map
::
PositionalIndexer
;
pub
use
protocols
::{
KvCacheEventError
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
WorkerConfigLike
,
WorkerId
,
compute_block_hash_for_seq
,
KvCacheEventError
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
RouterEventSink
,
WorkerConfigLike
,
WorkerId
,
compute_block_hash_for_seq
,
};
pub
use
queue
::
SchedulerQueue
;
pub
use
radix_tree
::
RadixTree
;
pub
use
scheduling
::
LocalScheduler
;
pub
use
scheduling
::
policy
::{
FcfsPolicy
,
RouterSchedulingPolicy
,
SchedulingPolicy
,
WsptPolicy
};
pub
use
scheduling
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
pub
use
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
lib/kv-router/src/protocols.rs
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
future
::
Future
;
use
dynamo_tokens
::{
SequenceHash
,
Token
};
use
rustc_hash
::
FxHashMap
;
use
serde
::{
Deserialize
,
Serialize
};
...
...
@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn
total_kv_blocks
(
&
self
)
->
Option
<
u64
>
;
}
/// Transport abstraction for publishing batched router-visible KV cache events.
pub
trait
RouterEventSink
:
Send
+
Sync
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
anyhow
::
Result
<
()
>>
+
Send
;
}
/// A worker identifier.
pub
type
WorkerId
=
u64
;
...
...
lib/kv-router/src/scheduling/config.rs
View file @
b7fe46b1
...
...
@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use
crate
::
protocols
::{
compute_block_hash_for_seq
,
compute_seq_hash_for_block
};
const
fn
default_min_initial_workers
()
->
usize
{
1
}
#[derive(Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
RouterQueuePolicy
{
#[default]
Fcfs
,
Lcfs
,
Wspt
,
}
...
...
@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
Self
::
Fcfs
=>
f
.write_str
(
"fcfs"
),
Self
::
Lcfs
=>
f
.write_str
(
"lcfs"
),
Self
::
Wspt
=>
f
.write_str
(
"wspt"
),
}
}
...
...
@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"fcfs"
=>
Ok
(
Self
::
Fcfs
),
"lcfs"
=>
Ok
(
Self
::
Lcfs
),
"wspt"
=>
Ok
(
Self
::
Wspt
),
_
=>
Err
(
format!
(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'"
"unknown queue policy: {s:?}, expected 'fcfs'
, 'lcfs',
or 'wspt'"
)),
}
}
...
...
@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Validate)]
#[serde(default)]
#[validate(schema(function
=
"validate_kv_router_config"
))]
pub
struct
KvRouterConfig
{
#[validate(range(min
=
0.0
))]
...
...
@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub
skip_initial_worker_wait
:
bool
,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default
=
"default_min_initial_workers"
)]
#[validate(range(min
=
1
))]
pub
min_initial_workers
:
usize
,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
...
...
@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs
:
120.0
,
router_max_tree_size
:
2u
size
.pow
(
20
),
// 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio
:
0.8
,
router_queue_threshold
:
Some
(
2
.0
),
router_queue_threshold
:
Some
(
4
.0
),
router_event_threads
:
4
,
router_enable_cache_control
:
false
,
skip_initial_worker_wait
:
false
,
min_initial_workers
:
default_min_initial_workers
(),
router_queue_policy
:
RouterQueuePolicy
::
default
(),
remote_indexer_component
:
None
,
}
...
...
@@ -237,3 +253,39 @@ impl KvRouterConfig {
self
.use_kv_events
&&
self
.overlap_score_weight
>
0.0
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
router_queue_policy_display_and_parse_support_lcfs
()
{
assert_eq!
(
RouterQueuePolicy
::
Lcfs
.to_string
(),
"lcfs"
);
assert_eq!
(
"lcfs"
.parse
::
<
RouterQueuePolicy
>
()
.unwrap
(),
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
router_queue_policy_serde_round_trip_supports_lcfs
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterQueuePolicy
::
Lcfs
)
.unwrap
();
assert_eq!
(
serialized
,
"
\"
lcfs
\"
"
);
let
deserialized
:
RouterQueuePolicy
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
,
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
kv_router_config_defaults_to_one_initial_worker
()
{
assert_eq!
(
KvRouterConfig
::
default
()
.min_initial_workers
,
1
);
}
#[test]
fn
kv_router_config_rejects_zero_initial_workers
()
{
let
cfg
=
KvRouterConfig
{
min_initial_workers
:
0
,
..
KvRouterConfig
::
default
()
};
assert
!
(
cfg
.validate
()
.is_err
());
}
}
lib/kv-router/src/scheduling/local.rs
0 → 100644
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
mpsc
,
watch
};
use
tokio_util
::
sync
::
CancellationToken
;
use
super
::
policy
::{
RouterSchedulingPolicy
,
SchedulingPolicy
};
use
super
::
queue
::
SchedulerQueue
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
OverlapScores
,
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequenceError
,
SequencePublisher
,
SequenceRequest
,
};
use
dynamo_tokens
::
SequenceHash
;
const
RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
pub
struct
LocalScheduler
<
P
,
C
,
S
=
RouterSchedulingPolicy
,
Sel
=
DefaultWorkerSelector
>
where
P
:
SequencePublisher
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
,
Sel
:
WorkerSelector
<
C
>
,
{
request_tx
:
mpsc
::
Sender
<
SchedulingRequest
>
,
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
queue
:
Arc
<
SchedulerQueue
<
P
,
C
,
S
,
Sel
>>
,
worker_type
:
&
'static
str
,
}
impl
<
P
,
C
,
S
,
Sel
>
LocalScheduler
<
P
,
C
,
S
,
Sel
>
where
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
+
Clone
+
PartialEq
+
Send
+
Sync
+
'static
,
S
:
SchedulingPolicy
+
'static
,
Sel
:
WorkerSelector
<
C
>
+
Send
+
Sync
+
'static
,
{
#[allow(clippy::too_many_arguments)]
pub
fn
new
(
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
workers_with_configs
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
C
>>
,
threshold_frac
:
Option
<
f64
>
,
block_size
:
u32
,
selector
:
Sel
,
policy
:
S
,
cancellation_token
:
CancellationToken
,
worker_type
:
&
'static
str
,
monitor_worker_configs
:
bool
,
)
->
Self
{
if
monitor_worker_configs
{
let
slots_monitor
=
Arc
::
clone
(
&
slots
);
let
mut
monitor_rx
=
workers_with_configs
.clone
();
let
mut
last_workers
=
monitor_rx
.borrow
()
.clone
();
let
monitor_cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"LocalScheduler workers monitoring task started"
);
loop
{
tokio
::
select!
{
_
=
monitor_cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"LocalScheduler workers monitoring task shutting down"
);
break
;
}
result
=
monitor_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"LocalScheduler worker config watch dropped, shutting down"
);
break
;
}
}
}
let
current_workers
=
monitor_rx
.borrow_and_update
()
.clone
();
if
current_workers
==
last_workers
{
continue
;
}
let
dp_range
:
HashMap
<
WorkerId
,
(
u32
,
u32
)
>
=
current_workers
.iter
()
.map
(|(
&
id
,
cfg
)|
{
(
id
,
(
cfg
.data_parallel_start_rank
(),
cfg
.data_parallel_size
()),
)
})
.collect
();
slots_monitor
.update_workers
(
&
dp_range
);
last_workers
=
current_workers
;
}
});
}
let
queue
=
Arc
::
new
(
SchedulerQueue
::
new
(
Arc
::
clone
(
&
slots
),
workers_with_configs
,
threshold_frac
,
block_size
,
selector
,
policy
,
));
let
(
request_tx
,
request_rx
)
=
mpsc
::
channel
::
<
SchedulingRequest
>
(
1024
);
let
queue_clone
=
Arc
::
clone
(
&
queue
);
tokio
::
spawn
(
async
move
{
let
mut
request_rx
=
request_rx
;
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
RECHECK_INTERVAL
);
tracing
::
trace!
(
"LocalScheduler background task started"
);
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
trace!
(
"LocalScheduler background task shutting down"
);
break
;
}
request
=
request_rx
.recv
()
=>
{
let
Some
(
request
)
=
request
else
{
tracing
::
warn!
(
"LocalScheduler request channel closed"
);
break
;
};
tracing
::
trace!
(
"received request to be scheduled"
);
queue_clone
.enqueue
(
request
)
.await
;
}
_
=
recheck_interval
.tick
()
=>
{
queue_clone
.update
()
.await
;
}
}
}
});
Self
{
request_tx
,
slots
,
queue
,
worker_type
,
}
}
#[expect(clippy::too_many_arguments)]
pub
async
fn
schedule
(
&
self
,
maybe_request_id
:
Option
<
String
>
,
isl_tokens
:
usize
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
overlaps
:
OverlapScores
,
router_config_override
:
Option
<&
super
::
config
::
RouterConfigOverride
>
,
update_states
:
bool
,
lora_name
:
Option
<
String
>
,
priority_jump
:
f64
,
expected_output_tokens
:
Option
<
u32
>
,
allowed_worker_ids
:
Option
<
HashSet
<
WorkerId
>>
,
)
->
Result
<
SchedulingResponse
,
KvSchedulerError
>
{
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
request
=
SchedulingRequest
{
maybe_request_id
,
token_seq
,
isl_tokens
,
overlaps
,
decode_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
router_config_override
:
router_config_override
.cloned
(),
update_states
,
lora_name
,
priority_jump
,
expected_output_tokens
,
allowed_worker_ids
,
resp_tx
:
Some
(
resp_tx
),
};
self
.request_tx
.send
(
request
)
.await
.map_err
(|
_
|
KvSchedulerError
::
SubscriberShutdown
)
?
;
resp_rx
.await
.map_err
(|
_
|
KvSchedulerError
::
SubscriberShutdown
)
?
}
pub
fn
register_workers
(
&
self
,
worker_ids
:
&
HashSet
<
WorkerId
>
)
{
self
.queue
.register_workers
(
worker_ids
);
}
pub
async
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_request
(
req
)
.await
}
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
())
.await
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
pub
async
fn
free
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.free
(
&
request_id
.to_string
())
.await
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
pub
fn
pending_count
(
&
self
)
->
usize
{
self
.queue
.pending_count
()
}
pub
fn
worker_type
(
&
self
)
->
&
'static
str
{
self
.worker_type
}
pub
fn
add_output_block
(
&
self
,
request_id
:
&
str
,
decay_fraction
:
Option
<
f64
>
,
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_output_block
(
&
request_id
.to_string
(),
decay_fraction
)
}
pub
fn
get_potential_loads
(
&
self
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
isl_tokens
:
usize
,
overlaps
:
OverlapScores
,
)
->
Vec
<
PotentialLoad
>
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens
(
token_seq
.as_deref
(),
isl_tokens
,
overlaps
);
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
workers
.extend
(
decode_blocks
.keys
()
.copied
());
workers
.extend
(
prefill_tokens
.keys
()
.copied
());
let
mut
loads
=
Vec
::
with_capacity
(
workers
.len
());
for
worker
in
workers
{
loads
.push
(
PotentialLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
potential_prefill_tokens
:
prefill_tokens
.get
(
&
worker
)
.copied
()
.unwrap_or
(
isl_tokens
),
potential_decode_blocks
:
decode_blocks
.get
(
&
worker
)
.copied
()
.unwrap_or
(
0
),
});
}
loads
}
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
self
.slots
.get_active_lora_counts
()
}
}
#[cfg(test)]
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::
watch
;
use
super
::
*
;
use
crate
::
protocols
::
OverlapScores
;
use
crate
::
scheduling
::
policy
::
FcfsPolicy
;
use
crate
::
scheduling
::
selector
::
DefaultWorkerSelector
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
#[allow(clippy::type_complexity)]
fn
make_scheduler
(
workers
:
HashMap
<
WorkerId
,
SimpleWorkerConfig
>
,
threshold_frac
:
Option
<
f64
>
,
monitor_worker_configs
:
bool
,
)
->
(
Arc
<
LocalScheduler
<
NoopSequencePublisher
,
SimpleWorkerConfig
,
FcfsPolicy
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
watch
::
Sender
<
HashMap
<
WorkerId
,
SimpleWorkerConfig
>>
,
CancellationToken
,
)
{
let
dp_range
=
workers
.iter
()
.map
(|(
&
id
,
cfg
)|
(
id
,
(
cfg
.data_parallel_start_rank
,
cfg
.data_parallel_size
)))
.collect
();
let
slots
=
Arc
::
new
(
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
64
,
dp_range
,
false
,
0
,
"test"
,
));
let
(
cfg_tx
,
cfg_rx
)
=
watch
::
channel
(
workers
);
let
cancel_token
=
CancellationToken
::
new
();
let
scheduler
=
Arc
::
new
(
LocalScheduler
::
new
(
Arc
::
clone
(
&
slots
),
cfg_rx
,
threshold_frac
,
64
,
DefaultWorkerSelector
::
new
(
None
,
"test"
),
FcfsPolicy
,
cancel_token
.clone
(),
"test"
,
monitor_worker_configs
,
));
(
scheduler
,
slots
,
cfg_tx
,
cancel_token
)
}
#[tokio::test]
async
fn
test_schedule_books_request_into_active_sequences
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
response
=
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
Some
(
"adapter-a"
.to_string
()),
0.0
,
None
,
None
,
)
.await
.unwrap
();
assert_eq!
(
response
.best_worker.worker_id
,
0
);
assert_eq!
(
scheduler
.get_active_lora_counts
(),
HashMap
::
from
([(
String
::
from
(
"adapter-a"
),
1
)])
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_mark_prefill_completed_drains_pending_queue
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
);
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
.unwrap
();
let
queued
=
{
let
scheduler
=
Arc
::
clone
(
&
scheduler
);
tokio
::
spawn
(
async
move
{
scheduler
.schedule
(
Some
(
"req-2"
.to_string
()),
64
,
Some
(
vec!
[
5
,
6
,
7
,
8
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
})
};
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
25
))
.await
;
assert_eq!
(
scheduler
.pending_count
(),
1
);
scheduler
.mark_prefill_completed
(
"req-1"
)
.await
.unwrap
();
queued
.await
.unwrap
()
.unwrap
();
assert_eq!
(
scheduler
.pending_count
(),
0
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_free_updates_active_state
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
Some
(
"adapter-a"
.to_string
()),
0.0
,
None
,
None
,
)
.await
.unwrap
();
assert_eq!
(
scheduler
.get_active_lora_counts
(),
HashMap
::
from
([(
String
::
from
(
"adapter-a"
),
1
)])
);
scheduler
.free
(
"req-1"
)
.await
.unwrap
();
assert
!
(
scheduler
.get_active_lora_counts
()
.is_empty
());
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_get_potential_loads_matches_slots
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
workers
.insert
(
1
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
token_seq
=
vec!
[
11
,
22
,
33
,
44
];
let
overlaps
=
OverlapScores
::
default
();
let
(
decode_blocks
,
prefill_tokens
)
=
slots
.potential_blocks_and_tokens
(
Some
(
&
token_seq
),
128
,
overlaps
.clone
());
let
mut
expected
:
Vec
<
_
>
=
decode_blocks
.keys
()
.map
(|
worker
|
PotentialLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
potential_prefill_tokens
:
prefill_tokens
.get
(
worker
)
.copied
()
.unwrap_or
(
128
),
potential_decode_blocks
:
decode_blocks
.get
(
worker
)
.copied
()
.unwrap_or
(
0
),
})
.collect
();
expected
.sort_by_key
(|
load
|
(
load
.worker_id
,
load
.dp_rank
));
let
mut
actual
=
scheduler
.get_potential_loads
(
Some
(
token_seq
),
128
,
overlaps
);
actual
.sort_by_key
(|
load
|
(
load
.worker_id
,
load
.dp_rank
));
assert_eq!
(
actual
.len
(),
expected
.len
());
for
(
actual
,
expected
)
in
actual
.iter
()
.zip
(
expected
.iter
())
{
assert_eq!
(
actual
.worker_id
,
expected
.worker_id
);
assert_eq!
(
actual
.dp_rank
,
expected
.dp_rank
);
assert_eq!
(
actual
.potential_prefill_tokens
,
expected
.potential_prefill_tokens
);
assert_eq!
(
actual
.potential_decode_blocks
,
expected
.potential_decode_blocks
);
}
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_register_workers_uses_default_dp_fallback
()
{
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
HashMap
::
new
(),
None
,
false
);
scheduler
.register_workers
(
&
HashSet
::
from
([
42
]));
let
loads
=
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
());
assert_eq!
(
loads
.len
(),
1
);
assert_eq!
(
loads
[
0
]
.worker_id
,
42
);
assert_eq!
(
loads
[
0
]
.dp_rank
,
0
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_worker_watch_updates_slot_ranges
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
::
default
());
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
assert_eq!
(
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
())
.len
(),
1
);
let
mut
updated_workers
=
HashMap
::
new
();
updated_workers
.insert
(
0
,
SimpleWorkerConfig
{
data_parallel_size
:
2
,
..
Default
::
default
()
},
);
updated_workers
.insert
(
1
,
SimpleWorkerConfig
::
default
());
cfg_tx
.send
(
updated_workers
)
.unwrap
();
tokio
::
time
::
timeout
(
Duration
::
from_secs
(
1
),
async
{
loop
{
if
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
())
.len
()
==
3
{
break
;
}
tokio
::
task
::
yield_now
()
.await
;
}
})
.await
.unwrap
();
cancel_token
.cancel
();
}
}
lib/kv-router/src/scheduling/mod.rs
View file @
b7fe46b1
...
...
@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
pub
mod
config
;
mod
local
;
pub
mod
policy
;
pub
mod
queue
;
pub
mod
selector
;
mod
types
;
pub
use
local
::
LocalScheduler
;
pub
use
types
::
*
;
lib/kv-router/src/scheduling/policy.rs
View file @
b7fe46b1
...
...
@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
}
}
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub
struct
LcfsPolicy
;
impl
SchedulingPolicy
for
LcfsPolicy
{
type
Key
=
OrderedFloat
<
f64
>
;
fn
enqueue_key
(
&
self
,
arrival_offset
:
Duration
,
request
:
&
SchedulingRequest
)
->
Self
::
Key
{
OrderedFloat
(
request
.priority_jump
.max
(
0.0
)
+
arrival_offset
.as_secs_f64
())
}
}
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL.
...
...
@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time.
pub
enum
RouterSchedulingPolicy
{
Fcfs
(
FcfsPolicy
),
Lcfs
(
LcfsPolicy
),
Wspt
(
WsptPolicy
),
}
...
...
@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub
fn
new
(
kind
:
RouterQueuePolicy
,
block_size
:
usize
)
->
Self
{
match
kind
{
RouterQueuePolicy
::
Fcfs
=>
Self
::
Fcfs
(
FcfsPolicy
),
RouterQueuePolicy
::
Lcfs
=>
Self
::
Lcfs
(
LcfsPolicy
),
RouterQueuePolicy
::
Wspt
=>
Self
::
Wspt
(
WsptPolicy
{
block_size
}),
}
}
...
...
@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn
enqueue_key
(
&
self
,
arrival_offset
:
Duration
,
request
:
&
SchedulingRequest
)
->
Self
::
Key
{
match
self
{
Self
::
Fcfs
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
Self
::
Lcfs
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
Self
::
Wspt
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
}
}
...
...
@@ -178,6 +196,42 @@ mod tests {
assert
!
(
key_b
>
key_a
);
}
#[test]
fn
lcfs_later_arrival_scheduled_first
()
{
let
policy
=
LcfsPolicy
;
let
req
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
early
=
policy
.enqueue_key
(
Duration
::
from_secs
(
1
),
&
req
);
let
late
=
policy
.enqueue_key
(
Duration
::
from_secs
(
10
),
&
req
);
assert
!
(
late
>
early
,
"later arrival should have higher key"
);
}
#[test]
fn
lcfs_priority_jump_promotes
()
{
let
policy
=
LcfsPolicy
;
let
normal
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
boosted
=
request_with
(
512
,
100.0
,
OverlapScores
::
default
());
let
t
=
Duration
::
from_secs
(
10
);
let
key_normal
=
policy
.enqueue_key
(
t
,
&
normal
);
let
key_boosted
=
policy
.enqueue_key
(
t
,
&
boosted
);
assert
!
(
key_boosted
>
key_normal
,
"priority_jump should produce a higher key"
);
}
#[test]
fn
router_scheduling_policy_matches_fcfs_and_lcfs_ordering
()
{
let
req
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
early
=
Duration
::
from_secs
(
1
);
let
late
=
Duration
::
from_secs
(
10
);
let
fcfs
=
RouterSchedulingPolicy
::
new
(
RouterQueuePolicy
::
Fcfs
,
16
);
assert
!
(
fcfs
.enqueue_key
(
early
,
&
req
)
>
fcfs
.enqueue_key
(
late
,
&
req
));
let
lcfs
=
RouterSchedulingPolicy
::
new
(
RouterQueuePolicy
::
Lcfs
,
16
);
assert
!
(
lcfs
.enqueue_key
(
late
,
&
req
)
>
lcfs
.enqueue_key
(
early
,
&
req
));
}
// ---- WSPT policy tests ----
#[test]
...
...
lib/kv-router/src/scheduling/queue.rs
View file @
b7fe46b1
...
...
@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use
tokio
::
sync
::
watch
;
use
super
::
policy
::{
FcfsPolicy
,
SchedulingPolicy
};
use
super
::
selector
::
WorkerSelector
;
use
super
::
selector
::
{
Default
WorkerSelector
,
WorkerSelector
}
;
use
super
::
types
::{
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequencePublisher
,
SequenceRequest
};
...
...
@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P
:
SequencePublisher
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
=
FcfsPolicy
,
Sel
:
WorkerSelector
<
C
>
=
DefaultWorkerSelector
,
>
{
pending
:
Mutex
<
BinaryHeap
<
QueueEntry
<
S
::
Key
>>>
,
/// Number of requests currently parked in the pending queue.
...
...
@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets.
start_time
:
Instant
,
block_size
:
u32
,
selector
:
Box
<
dyn
WorkerSelector
<
C
>
+
Send
+
Sync
>
,
selector
:
Sel
,
policy
:
S
,
}
impl
<
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
>
SchedulerQueue
<
P
,
C
,
S
>
impl
<
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
,
Sel
:
WorkerSelector
<
C
>
,
>
SchedulerQueue
<
P
,
C
,
S
,
Sel
>
{
pub
fn
new
(
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
workers_with_configs
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
C
>>
,
threshold_frac
:
Option
<
f64
>
,
block_size
:
u32
,
selector
:
Box
<
dyn
WorkerSelector
<
C
>
+
Send
+
Sync
>
,
selector
:
Sel
,
policy
:
S
,
)
->
Self
{
if
let
Some
(
frac
)
=
threshold_frac
{
...
...
@@ -341,7 +346,7 @@ mod tests {
}
let
(
cfg_tx
,
cfg_rx
)
=
watch
::
channel
(
configs
);
let
selector
=
Box
::
new
(
DefaultWorkerSelector
::
new
(
None
,
"test"
)
)
;
let
selector
=
DefaultWorkerSelector
::
new
(
None
,
"test"
);
let
queue
=
Arc
::
new
(
SchedulerQueue
::
new
(
Arc
::
clone
(
&
slots
),
cfg_rx
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment