Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b7fe46b1
Unverified
Commit
b7fe46b1
authored
Mar 23, 2026
by
Yan Ru Pei
Committed by
GitHub
Mar 23, 2026
Browse files
feat(mocker): add multi-worker replay and router startup fixes (#7553)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
82794761
Changes
102
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3654 additions
and
1711 deletions
+3654
-1711
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+5
-1
lib/bindings/python/rust/llm.rs
lib/bindings/python/rust/llm.rs
+1
-0
lib/bindings/python/rust/llm/entrypoint.rs
lib/bindings/python/rust/llm/entrypoint.rs
+21
-85
lib/bindings/python/rust/llm/replay.rs
lib/bindings/python/rust/llm/replay.rs
+563
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+138
-6
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+25
-1
lib/bindings/python/src/dynamo/replay/__init__.py
lib/bindings/python/src/dynamo/replay/__init__.py
+6
-0
lib/bindings/python/src/dynamo/replay/__main__.py
lib/bindings/python/src/dynamo/replay/__main__.py
+7
-0
lib/bindings/python/src/dynamo/replay/api.py
lib/bindings/python/src/dynamo/replay/api.py
+59
-0
lib/bindings/python/src/dynamo/replay/main.py
lib/bindings/python/src/dynamo/replay/main.py
+94
-0
lib/bindings/python/tests/test_replay.py
lib/bindings/python/tests/test_replay.py
+421
-0
lib/kv-router/src/event_sink.rs
lib/kv-router/src/event_sink.rs
+0
-19
lib/kv-router/src/indexer/tests.rs
lib/kv-router/src/indexer/tests.rs
+1629
-1587
lib/kv-router/src/lib.rs
lib/kv-router/src/lib.rs
+3
-4
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+8
-0
lib/kv-router/src/scheduling/config.rs
lib/kv-router/src/scheduling/config.rs
+54
-2
lib/kv-router/src/scheduling/local.rs
lib/kv-router/src/scheduling/local.rs
+553
-0
lib/kv-router/src/scheduling/mod.rs
lib/kv-router/src/scheduling/mod.rs
+2
-0
lib/kv-router/src/scheduling/policy.rs
lib/kv-router/src/scheduling/policy.rs
+54
-0
lib/kv-router/src/scheduling/queue.rs
lib/kv-router/src/scheduling/queue.rs
+11
-6
No files found.
lib/bindings/python/rust/lib.rs
View file @
b7fe46b1
...
...
@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_function
(
wrap_pyfunction!
(
fetch_model
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
run_kv_indexer
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
make_engine
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
replay
::
run_mocker_trace_replay
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
run_mocker_trace_replay
,
llm
::
replay
::
run_mocker_
synthetic_
trace_replay
,
m
)
?
)
?
;
m
.add_function
(
wrap_pyfunction!
(
llm
::
entrypoint
::
run_input
,
m
)
?
)
?
;
...
...
@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
entrypoint
::
EngineType
>
()
?
;
m
.add_class
::
<
llm
::
entrypoint
::
RouterConfig
>
()
?
;
m
.add_class
::
<
llm
::
entrypoint
::
KvRouterConfig
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
ReasoningConfig
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
SglangArgs
>
()
?
;
m
.add_class
::
<
llm
::
replay
::
MockEngineArgs
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
WorkerMetricsPublisher
>
()
?
;
m
.add_class
::
<
llm
::
model_card
::
ModelDeploymentCard
>
()
?
;
// Internal: only in _internal, not public API
m
.add_class
::
<
llm
::
local_model
::
ModelRuntimeConfig
>
()
?
;
...
...
lib/bindings/python/rust/llm.rs
View file @
b7fe46b1
...
...
@@ -31,3 +31,4 @@ pub mod local_model;
pub
mod
lora
;
pub
mod
model_card
;
pub
mod
preprocessor
;
pub
mod
replay
;
lib/bindings/python/rust/llm/entrypoint.rs
View file @
b7fe46b1
...
...
@@ -9,7 +9,6 @@ use std::sync::Arc;
use
pyo3
::{
exceptions
::
PyException
,
prelude
::
*
};
use
pyo3_async_runtimes
::
TaskLocals
;
use
pythonize
::
pythonize
;
use
dynamo_kv_router
::
config
::
KvRouterConfig
as
RsKvRouterConfig
;
use
dynamo_llm
::
discovery
::
LoadThresholdConfig
as
RsLoadThresholdConfig
;
...
...
@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
super
::
aic_callback
::
create_aic_callback
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
;
use
super
::
replay
::
MockEngineArgs
as
PyMockEngineArgs
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
as
RsMockEngineArgs
;
use
dynamo_runtime
::
discovery
::
ModelCardInstanceId
as
RsModelCardInstanceId
;
use
dynamo_runtime
::
protocols
::
EndpointId
;
...
...
@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods]
impl
KvRouterConfig
{
#[new]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
2
.0
),
router_event_threads=
4
,
router_enable_cache_control=
false
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4
.0
),
router_event_threads=
4
,
router_enable_cache_control=
false
,
min_initial_workers=
1
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn
new
(
overlap_score_weight
:
f64
,
...
...
@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold
:
Option
<
f64
>
,
router_event_threads
:
u32
,
router_enable_cache_control
:
bool
,
min_initial_workers
:
usize
,
router_queue_policy
:
&
str
,
remote_indexer_component
:
Option
<
String
>
,
)
->
Self
{
...
...
@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads
,
router_enable_cache_control
,
skip_initial_worker_wait
:
false
,
min_initial_workers
,
router_queue_policy
:
router_queue_policy
.parse
()
.unwrap_or_else
(|
_
|
{
panic!
(
"invalid router_queue_policy: {router_queue_policy:?}"
)
}),
...
...
@@ -106,6 +108,13 @@ impl KvRouterConfig {
},
}
}
#[staticmethod]
fn
from_json
(
config_json
:
&
str
)
->
PyResult
<
Self
>
{
serde_json
::
from_str
::
<
RsKvRouterConfig
>
(
config_json
)
.map
(|
inner
|
KvRouterConfig
{
inner
})
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to parse KvRouterConfig JSON: {e}"
)))
}
}
#[pyclass]
...
...
@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path
:
Option
<
PathBuf
>
,
tls_key_path
:
Option
<
PathBuf
>
,
extra_engine_args
:
Option
<
PathBuf
>
,
mocker_engine_args
:
Option
<
PyMockEngineArgs
>
,
runtime_config
:
Option
<
ModelRuntimeConfig
>
,
namespace
:
Option
<
String
>
,
namespace_prefix
:
Option
<
String
>
,
...
...
@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl
EntrypointArgs
{
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
pub
fn
new
(
py
:
Python
<
'_
>
,
engine_type
:
EngineType
,
...
...
@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path
:
Option
<
PathBuf
>
,
tls_key_path
:
Option
<
PathBuf
>
,
extra_engine_args
:
Option
<
PathBuf
>
,
mocker_engine_args
:
Option
<
PyMockEngineArgs
>
,
runtime_config
:
Option
<
ModelRuntimeConfig
>
,
namespace
:
Option
<
String
>
,
namespace_prefix
:
Option
<
String
>
,
...
...
@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path
,
tls_key_path
,
extra_engine_args
,
mocker_engine_args
,
runtime_config
,
namespace
,
namespace_prefix
,
...
...
@@ -419,8 +431,10 @@ async fn select_engine(
}
}
EngineType
::
Mocker
=>
{
let
mut
mocker_args
=
if
let
Some
(
extra_args_path
)
=
args
.extra_engine_args
{
MockEngineArgs
::
from_json_file
(
&
extra_args_path
)
.map_err
(|
e
|
{
let
mut
mocker_args
=
if
let
Some
(
mocker_engine_args
)
=
args
.mocker_engine_args
{
mocker_engine_args
.inner
()
}
else
if
let
Some
(
extra_args_path
)
=
args
.extra_engine_args
{
RsMockEngineArgs
::
from_json_file
(
&
extra_args_path
)
.map_err
(|
e
|
{
anyhow
::
anyhow!
(
"Failed to load mocker args from {:?}: {}"
,
extra_args_path
,
...
...
@@ -431,7 +445,7 @@ async fn select_engine(
tracing
::
warn!
(
"No extra_engine_args specified for mocker engine. Using default mocker args."
);
MockEngineArgs
::
default
()
Rs
MockEngineArgs
::
default
()
};
// If aic_backend is set, create Python AIC callback and override perf_model
...
...
@@ -503,84 +517,6 @@ pub fn run_input<'p>(
})
}
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
num_workers=
1
,
replay_concurrency=None))]
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
trace_file
:
PathBuf
,
extra_engine_args
:
Option
<
PathBuf
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
)
->
PyResult
<
PyObject
>
{
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let
mut
args
=
if
let
Some
(
ref
extra_args_path
)
=
extra_engine_args
{
MockEngineArgs
::
from_json_file
(
extra_args_path
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to load mocker args from {:?}: {}"
,
extra_args_path
,
e
))
})
?
}
else
{
MockEngineArgs
::
default
()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if
let
Some
(
ref
backend_name
)
=
args
.aic_backend
.clone
()
{
let
backend
=
backend_name
.clone
();
let
system
=
args
.aic_system
.as_deref
()
.unwrap_or
(
"h200_sxm"
)
.to_string
();
let
model_name
=
args
.aic_model_path
.clone
()
.ok_or_else
(||
PyException
::
new_err
(
"--aic-perf-model requires --model-path"
))
?
;
let
backend_version
=
args
.aic_backend_version
.clone
();
let
tp_size
=
args
.aic_tp_size
.unwrap_or
(
1
);
let
callback
=
create_aic_callback
(
py
,
&
backend
,
&
system
,
&
model_name
,
tp_size
,
backend_version
.as_deref
(),
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to create AIC callback (--aic-perf-model was requested): {}"
,
e
))
})
?
;
tracing
::
info!
(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}"
,
backend
,
system
,
model_name
,
backend_version
);
args
.perf_model
=
Arc
::
new
(
PerfModel
::
from_aic_callback
(
callback
));
}
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
replay_concurrency
.map
(
usize
::
try_from
)
.transpose
()
.map_err
(|
_
|
anyhow
::
anyhow!
(
"replay_concurrency must be at least 1"
))
?
;
if
let
Some
(
max_in_flight
)
=
replay_concurrency
{
dynamo_mocker
::
simulation
::
simulate_concurrency_file
(
args
,
&
trace_file
,
max_in_flight
,
num_workers
,
)
}
else
{
dynamo_mocker
::
simulation
::
simulate_trace_file
(
args
,
&
trace_file
,
num_workers
)
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
pub
fn
to_pyerr
<
E
>
(
err
:
E
)
->
PyErr
where
E
:
Display
,
...
...
lib/bindings/python/rust/llm/replay.rs
0 → 100644
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
path
::
PathBuf
;
use
std
::
sync
::
Arc
;
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
dynamo_mocker
::
common
::
protocols
::{
DirectRequest
,
EngineType
as
RsMockerEngineType
,
MockEngineArgs
as
RsMockEngineArgs
,
PreemptionMode
as
RsPreemptionMode
,
ReasoningConfig
as
RsReasoningConfig
,
SglangArgs
as
RsSglangArgs
,
WorkerType
as
RsWorkerType
,
};
use
pyo3
::{
exceptions
::
PyException
,
prelude
::
*
};
use
pythonize
::
pythonize
;
use
uuid
::
Uuid
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
entrypoint
::{
KvRouterConfig
,
to_pyerr
};
fn
parse_mocker_engine_type
(
engine_type
:
&
str
)
->
PyResult
<
RsMockerEngineType
>
{
match
engine_type
{
"vllm"
=>
Ok
(
RsMockerEngineType
::
Vllm
),
"sglang"
=>
Ok
(
RsMockerEngineType
::
Sglang
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn
parse_worker_type
(
worker_type
:
&
str
)
->
PyResult
<
RsWorkerType
>
{
match
worker_type
{
"aggregated"
=>
Ok
(
RsWorkerType
::
Aggregated
),
"prefill"
=>
Ok
(
RsWorkerType
::
Prefill
),
"decode"
=>
Ok
(
RsWorkerType
::
Decode
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn
parse_preemption_mode
(
preemption_mode
:
&
str
)
->
PyResult
<
RsPreemptionMode
>
{
match
preemption_mode
{
"lifo"
=>
Ok
(
RsPreemptionMode
::
Lifo
),
"fifo"
=>
Ok
(
RsPreemptionMode
::
Fifo
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone,
Debug)]
pub
struct
ReasoningConfig
{
inner
:
RsReasoningConfig
,
}
impl
ReasoningConfig
{
pub
fn
inner
(
&
self
)
->
RsReasoningConfig
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
ReasoningConfig
{
#[new]
fn
new
(
start_thinking_token_id
:
u32
,
end_thinking_token_id
:
u32
,
thinking_ratio
:
f64
,
)
->
PyResult
<
Self
>
{
let
inner
=
RsReasoningConfig
{
start_thinking_token_id
,
end_thinking_token_id
,
thinking_ratio
,
};
Ok
(
Self
{
inner
})
}
}
#[pyclass]
#[derive(Clone,
Debug,
Default)]
pub
struct
SglangArgs
{
inner
:
RsSglangArgs
,
}
impl
SglangArgs
{
pub
fn
inner
(
&
self
)
->
RsSglangArgs
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
SglangArgs
{
#[new]
#[pyo3(signature
=
(schedule_policy=None,
page_size=None,
max_prefill_tokens=None,
chunked_prefill_size=None,
clip_max_new_tokens=None,
schedule_conservativeness=None))]
fn
new
(
schedule_policy
:
Option
<
String
>
,
page_size
:
Option
<
usize
>
,
max_prefill_tokens
:
Option
<
usize
>
,
chunked_prefill_size
:
Option
<
usize
>
,
clip_max_new_tokens
:
Option
<
usize
>
,
schedule_conservativeness
:
Option
<
f64
>
,
)
->
PyResult
<
Self
>
{
let
inner
=
RsSglangArgs
{
schedule_policy
,
page_size
,
max_prefill_tokens
,
chunked_prefill_size
,
clip_max_new_tokens
,
schedule_conservativeness
,
};
Ok
(
Self
{
inner
})
}
}
#[pyclass]
#[derive(Clone,
Debug,
Default)]
pub
struct
MockEngineArgs
{
inner
:
RsMockEngineArgs
,
}
impl
MockEngineArgs
{
pub
fn
inner
(
&
self
)
->
RsMockEngineArgs
{
self
.inner
.clone
()
}
}
#[pymethods]
impl
MockEngineArgs
{
#[new]
#[pyo3(signature
=
(engine_type=
"vllm"
,
num_gpu_blocks=
16384
,
block_size=
0
,
max_num_seqs=Some(
256
),
max_num_batched_tokens=Some(
8192
),
enable_prefix_caching=
true
,
enable_chunked_prefill=
true
,
speedup_ratio=
1.0
,
decode_speedup_ratio=
1.0
,
dp_size=
1
,
startup_time=None,
worker_type=
"aggregated"
,
aic_backend=None,
aic_system=None,
aic_backend_version=None,
aic_tp_size=None,
aic_model_path=None,
enable_local_indexer=
false
,
bootstrap_port=None,
kv_bytes_per_token=None,
kv_transfer_bandwidth=None,
reasoning=None,
zmq_kv_events_port=None,
zmq_replay_port=None,
preemption_mode=
"lifo"
,
router_queue_policy=None,
sglang=None))]
#[allow(clippy::too_many_arguments)]
fn
new
(
engine_type
:
&
str
,
num_gpu_blocks
:
usize
,
block_size
:
usize
,
max_num_seqs
:
Option
<
usize
>
,
max_num_batched_tokens
:
Option
<
usize
>
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
speedup_ratio
:
f64
,
decode_speedup_ratio
:
f64
,
dp_size
:
u32
,
startup_time
:
Option
<
f64
>
,
worker_type
:
&
str
,
aic_backend
:
Option
<
String
>
,
aic_system
:
Option
<
String
>
,
aic_backend_version
:
Option
<
String
>
,
aic_tp_size
:
Option
<
usize
>
,
aic_model_path
:
Option
<
String
>
,
enable_local_indexer
:
bool
,
bootstrap_port
:
Option
<
u16
>
,
kv_bytes_per_token
:
Option
<
usize
>
,
kv_transfer_bandwidth
:
Option
<
f64
>
,
reasoning
:
Option
<
ReasoningConfig
>
,
zmq_kv_events_port
:
Option
<
u16
>
,
zmq_replay_port
:
Option
<
u16
>
,
preemption_mode
:
&
str
,
router_queue_policy
:
Option
<&
str
>
,
sglang
:
Option
<
SglangArgs
>
,
)
->
PyResult
<
Self
>
{
let
engine_type
=
parse_mocker_engine_type
(
engine_type
)
?
;
let
worker_type
=
parse_worker_type
(
worker_type
)
?
;
let
preemption_mode
=
parse_preemption_mode
(
preemption_mode
)
?
;
let
router_queue_policy
=
router_queue_policy
.map
(|
value
|
{
value
.parse
()
.map_err
(|
e
:
String
|
{
PyException
::
new_err
(
format!
(
"invalid router_queue_policy {value:?}: {e}"
))
})
})
.transpose
()
?
;
let
inner
=
RsMockEngineArgs
::
builder
()
.engine_type
(
engine_type
)
.num_gpu_blocks
(
num_gpu_blocks
)
.block_size
(
block_size
)
.max_num_seqs
(
max_num_seqs
)
.max_num_batched_tokens
(
max_num_batched_tokens
)
.enable_prefix_caching
(
enable_prefix_caching
)
.enable_chunked_prefill
(
enable_chunked_prefill
)
.speedup_ratio
(
speedup_ratio
)
.decode_speedup_ratio
(
decode_speedup_ratio
)
.dp_size
(
dp_size
)
.startup_time
(
startup_time
)
.worker_type
(
worker_type
)
.aic_backend
(
aic_backend
)
.aic_system
(
aic_system
)
.aic_backend_version
(
aic_backend_version
)
.aic_tp_size
(
aic_tp_size
)
.aic_model_path
(
aic_model_path
)
.enable_local_indexer
(
enable_local_indexer
)
.bootstrap_port
(
bootstrap_port
)
.kv_bytes_per_token
(
kv_bytes_per_token
)
.kv_transfer_bandwidth
(
kv_transfer_bandwidth
)
.reasoning
(
reasoning
.map
(|
config
|
config
.inner
()))
.zmq_kv_events_port
(
zmq_kv_events_port
)
.zmq_replay_port
(
zmq_replay_port
)
.preemption_mode
(
preemption_mode
)
.router_queue_policy
(
router_queue_policy
)
.sglang
(
sglang
.map
(|
config
|
config
.inner
()))
.build
()
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to build MockEngineArgs: {e}"
)))
?
.normalized
()
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to normalize MockEngineArgs: {e}"
))
})
?
;
Ok
(
Self
{
inner
})
}
#[staticmethod]
fn
from_json
(
config_json
:
&
str
)
->
PyResult
<
Self
>
{
RsMockEngineArgs
::
from_json_str
(
config_json
)
.map
(|
inner
|
Self
{
inner
})
.map_err
(|
e
|
PyException
::
new_err
(
format!
(
"Failed to parse MockEngineArgs JSON: {e}"
)))
}
#[getter]
fn
block_size
(
&
self
)
->
usize
{
self
.inner.block_size
}
#[getter]
fn
num_gpu_blocks
(
&
self
)
->
usize
{
self
.inner.num_gpu_blocks
}
#[getter]
fn
max_num_seqs
(
&
self
)
->
Option
<
usize
>
{
self
.inner.max_num_seqs
}
#[getter]
fn
max_num_batched_tokens
(
&
self
)
->
Option
<
usize
>
{
self
.inner.max_num_batched_tokens
}
#[getter]
fn
enable_local_indexer
(
&
self
)
->
bool
{
self
.inner.enable_local_indexer
}
#[getter]
fn
dp_size
(
&
self
)
->
u32
{
self
.inner.dp_size
}
#[getter]
fn
bootstrap_port
(
&
self
)
->
Option
<
u16
>
{
self
.inner.bootstrap_port
}
fn
is_prefill
(
&
self
)
->
bool
{
self
.inner
.is_prefill
()
}
fn
is_decode
(
&
self
)
->
bool
{
self
.inner
.is_decode
()
}
#[pyo3(signature
=
(bootstrap_port=None,
zmq_kv_events_port=None,
zmq_replay_port=None,
kv_bytes_per_token=None))]
fn
with_overrides
(
&
self
,
bootstrap_port
:
Option
<
u16
>
,
zmq_kv_events_port
:
Option
<
u16
>
,
zmq_replay_port
:
Option
<
u16
>
,
kv_bytes_per_token
:
Option
<
usize
>
,
)
->
PyResult
<
Self
>
{
let
mut
inner
=
self
.inner
.clone
();
if
let
Some
(
port
)
=
bootstrap_port
{
inner
.bootstrap_port
=
Some
(
port
);
}
if
let
Some
(
port
)
=
zmq_kv_events_port
{
inner
.zmq_kv_events_port
=
Some
(
port
);
}
if
let
Some
(
port
)
=
zmq_replay_port
{
inner
.zmq_replay_port
=
Some
(
port
);
}
if
let
Some
(
bytes_per_token
)
=
kv_bytes_per_token
{
inner
.kv_bytes_per_token
=
Some
(
bytes_per_token
);
}
inner
.normalized
()
.map
(|
inner
|
Self
{
inner
})
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to normalize MockEngineArgs overrides: {e}"
))
})
}
}
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
router_config=None,
num_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
trace_file
:
PathBuf
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
replay_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
)
->
PyResult
<
PyObject
>
{
let
args
=
load_replay_mocker_args
(
py
,
extra_engine_args
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
match
(
replay_mode
.as_str
(),
replay_concurrency
)
{
(
"offline"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
"online"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_live_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"online"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_live_file_with_router_mode
(
args
,
router_config
.clone
(),
&
trace_file
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
other
,
_
)
=>
anyhow
::
bail!
(
"replay_mode must be either 'offline' or 'online', got '{}'"
,
other
),
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
#[pyfunction]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
router_config=None,
num_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_synthetic_trace_replay
(
py
:
Python
<
'_
>
,
input_tokens
:
usize
,
output_tokens
:
usize
,
request_count
:
usize
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
num_workers
:
usize
,
replay_concurrency
:
Option
<
isize
>
,
replay_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
arrival_interval_ms
:
f64
,
)
->
PyResult
<
PyObject
>
{
let
args
=
load_replay_mocker_args
(
py
,
extra_engine_args
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
let
requests
=
build_synthetic_requests
(
input_tokens
,
output_tokens
,
request_count
,
arrival_interval_ms
,
replay_concurrency
.is_none
(),
)
?
;
match
(
replay_mode
.as_str
(),
replay_concurrency
)
{
(
"offline"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
),
(
"online"
,
Some
(
max_in_flight
))
=>
{
dynamo_mocker
::
replay
::
simulate_concurrency_live_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
max_in_flight
,
num_workers
,
router_mode
,
)
}
(
"online"
,
None
)
=>
{
dynamo_mocker
::
replay
::
simulate_trace_live_requests_with_router_mode
(
args
,
router_config
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
)
}
(
other
,
_
)
=>
anyhow
::
bail!
(
"replay_mode must be either 'offline' or 'online', got '{}'"
,
other
),
}
});
let
report
=
report
.map_err
(
to_pyerr
)
?
;
pythonize
(
py
,
&
report
)
.map_err
(
to_pyerr
)
.map
(|
obj
|
obj
.unbind
())
}
fn
load_replay_mocker_args
(
py
:
Python
<
'_
>
,
extra_engine_args
:
Option
<
MockEngineArgs
>
,
)
->
PyResult
<
RsMockEngineArgs
>
{
let
mut
args
=
match
extra_engine_args
{
Some
(
extra_args
)
=>
extra_args
.inner
(),
None
=>
RsMockEngineArgs
::
default
(),
};
if
let
Some
(
ref
backend_name
)
=
args
.aic_backend
.clone
()
{
let
backend
=
backend_name
.clone
();
let
system
=
args
.aic_system
.as_deref
()
.unwrap_or
(
"h200_sxm"
)
.to_string
();
let
model_name
=
args
.aic_model_path
.clone
()
.ok_or_else
(||
PyException
::
new_err
(
"--aic-perf-model requires --model-path"
))
?
;
let
backend_version
=
args
.aic_backend_version
.clone
();
let
tp_size
=
args
.aic_tp_size
.unwrap_or
(
1
);
let
callback
=
create_aic_callback
(
py
,
&
backend
,
&
system
,
&
model_name
,
tp_size
,
backend_version
.as_deref
(),
)
.map_err
(|
e
|
{
PyException
::
new_err
(
format!
(
"Failed to create AIC callback (--aic-perf-model was requested): {}"
,
e
))
})
?
;
tracing
::
info!
(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}"
,
backend
,
system
,
model_name
,
backend_version
);
args
.perf_model
=
Arc
::
new
(
PerfModel
::
from_aic_callback
(
callback
));
}
Ok
(
args
)
}
fn
load_replay_router_config
(
router_config
:
Option
<
KvRouterConfig
>
,
)
->
Option
<
dynamo_kv_router
::
config
::
KvRouterConfig
>
{
router_config
.map
(|
config
|
config
.inner
())
}
fn
parse_replay_router_mode
(
router_mode
:
&
str
,
)
->
PyResult
<
dynamo_mocker
::
replay
::
ReplayRouterMode
>
{
match
router_mode
{
"round_robin"
=>
Ok
(
dynamo_mocker
::
replay
::
ReplayRouterMode
::
RoundRobin
),
"kv_router"
=>
Ok
(
dynamo_mocker
::
replay
::
ReplayRouterMode
::
KvRouter
),
other
=>
Err
(
PyException
::
new_err
(
format!
(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'"
,
other
))),
}
}
fn
parse_replay_concurrency
(
replay_concurrency
:
Option
<
isize
>
)
->
anyhow
::
Result
<
Option
<
usize
>>
{
match
replay_concurrency
{
Some
(
value
)
if
value
<
1
=>
anyhow
::
bail!
(
"replay_concurrency must be at least 1"
),
Some
(
value
)
=>
Ok
(
Some
(
value
as
usize
)),
None
=>
Ok
(
None
),
}
}
fn
build_synthetic_requests
(
input_tokens
:
usize
,
output_tokens
:
usize
,
request_count
:
usize
,
arrival_interval_ms
:
f64
,
include_arrival_timestamps
:
bool
,
)
->
anyhow
::
Result
<
Vec
<
DirectRequest
>>
{
if
input_tokens
==
0
{
anyhow
::
bail!
(
"input_tokens must be at least 1"
);
}
if
output_tokens
==
0
{
anyhow
::
bail!
(
"output_tokens must be at least 1"
);
}
if
request_count
==
0
{
anyhow
::
bail!
(
"request_count must be at least 1"
);
}
if
!
arrival_interval_ms
.is_finite
()
||
arrival_interval_ms
<
0.0
{
anyhow
::
bail!
(
"arrival_interval_ms must be a finite non-negative number, got {}"
,
arrival_interval_ms
);
}
let
mut
requests
=
Vec
::
with_capacity
(
request_count
);
for
request_idx
in
0
..
request_count
{
let
tokens
=
(
0
..
input_tokens
)
.map
(|
token_idx
|
synthetic_token_id
(
request_idx
,
token_idx
))
.collect
();
requests
.push
(
DirectRequest
{
tokens
,
max_output_tokens
:
output_tokens
,
uuid
:
Some
(
Uuid
::
from_u128
((
request_idx
as
u128
)
+
1
)),
dp_rank
:
0
,
arrival_timestamp_ms
:
include_arrival_timestamps
.then_some
(
request_idx
as
f64
*
arrival_interval_ms
),
});
}
Ok
(
requests
)
}
fn
synthetic_token_id
(
request_idx
:
usize
,
token_idx
:
usize
)
->
u32
{
let
mut
value
=
(((
request_idx
as
u64
)
<<
32
)
^
(
token_idx
as
u64
))
.wrapping_add
(
0x9E37_79B9_7F4A_7C15
);
value
^=
value
>>
30
;
value
=
value
.wrapping_mul
(
0xBF58_476D_1CE4_E5B9
);
value
^=
value
>>
27
;
value
=
value
.wrapping_mul
(
0x94D0_49BB_1331_11EB
);
value
^=
value
>>
31
;
let
token
=
value
as
u32
;
if
token
==
0
{
1
}
else
{
token
}
}
lib/bindings/python/src/dynamo/_core.pyi
View file @
b7fe46b1
...
...
@@ -3,7 +3,17 @@
import asyncio
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
...
...
@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] =
2
.0,
router_queue_threshold: Optional[float] =
4
.0,
router_event_threads: int = 4,
router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs",
) -> None:
"""
...
...
@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default:
2
.0).
router_queue_threshold: Queue threshold fraction for prefill token capacity (default:
4
.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler).
...
...
@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
"""
...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model(
model_input: ModelInput,
model_type: ModelType,
...
...
@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay(
trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]:
"""Replay a
mocker trace file and return the simulation report
."""
"""Replay a
synthetic mocker workload without requiring a trace file
."""
...
class Layer:
...
...
@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None,
...
...
@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
b7fe46b1
...
...
@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from
dynamo._core
import
LoRADownloader
as
LoRADownloader
from
dynamo._core
import
MediaDecoder
as
MediaDecoder
from
dynamo._core
import
MediaFetcher
as
MediaFetcher
from
dynamo._core
import
MockEngineArgs
as
MockEngineArgs
from
dynamo._core
import
ModelCardInstanceId
as
ModelCardInstanceId
from
dynamo._core
import
ModelInput
as
ModelInput
from
dynamo._core
import
ModelRuntimeConfig
as
ModelRuntimeConfig
...
...
@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from
dynamo._core
import
OverlapScores
as
OverlapScores
from
dynamo._core
import
PythonAsyncEngine
as
PythonAsyncEngine
from
dynamo._core
import
RadixTree
as
RadixTree
from
dynamo._core
import
ReasoningConfig
as
ReasoningConfig
from
dynamo._core
import
RouterConfig
as
RouterConfig
from
dynamo._core
import
RouterMode
as
RouterMode
from
dynamo._core
import
SglangArgs
as
SglangArgs
from
dynamo._core
import
WorkerMetricsPublisher
as
WorkerMetricsPublisher
from
dynamo._core
import
compute_block_hash_for_seq
as
compute_block_hash_for_seq
from
dynamo._core
import
fetch_model
as
fetch_model
...
...
@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from
dynamo._core
import
register_model
as
register_model
from
dynamo._core
import
run_input
from
dynamo._core
import
run_kv_indexer
as
run_kv_indexer
from
dynamo._core
import
run_mocker_trace_replay
from
dynamo._core
import
run_mocker_trace_replay
as
_run_mocker_trace_replay
from
dynamo._core
import
unregister_model
as
unregister_model
from
.exceptions
import
HttpError
...
...
@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm
=
fetch_model
register_llm
=
register_model
unregister_llm
=
unregister_model
def
run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
):
return
_run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
"offline"
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
)
lib/bindings/python/src/dynamo/replay/__init__.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.replay.api
import
run_synthetic_trace_replay
,
run_trace_replay
__all__
=
[
"run_synthetic_trace_replay"
,
"run_trace_replay"
]
lib/bindings/python/src/dynamo/replay/__main__.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.replay.main
import
main
if
__name__
==
"__main__"
:
raise
SystemExit
(
main
())
lib/bindings/python/src/dynamo/replay/api.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo._core
import
(
run_mocker_synthetic_trace_replay
as
_run_mocker_synthetic_trace_replay
,
)
from
dynamo._core
import
run_mocker_trace_replay
as
_run_mocker_trace_replay
def
run_trace_replay
(
trace_file
,
*
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
):
return
_run_mocker_trace_replay
(
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
)
def
run_synthetic_trace_replay
(
input_tokens
,
output_tokens
,
request_count
,
*
,
extra_engine_args
=
None
,
router_config
=
None
,
num_workers
=
1
,
replay_concurrency
=
None
,
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
arrival_interval_ms
=
1.0
,
):
return
_run_mocker_synthetic_trace_replay
(
input_tokens
,
output_tokens
,
request_count
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
num_workers
,
replay_concurrency
=
replay_concurrency
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
arrival_interval_ms
=
arrival_interval_ms
,
)
lib/bindings/python/src/dynamo/replay/main.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
argparse
import
json
import
os
import
sys
from
collections.abc
import
Sequence
os
.
environ
.
setdefault
(
"DYNAMO_SKIP_PYTHON_LOG_INIT"
,
"1"
)
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
def
main
(
argv
:
Sequence
[
str
]
|
None
=
None
)
->
int
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"python -m dynamo.replay"
)
parser
.
add_argument
(
"trace_file"
,
nargs
=
"?"
)
parser
.
add_argument
(
"--extra-engine-args"
)
parser
.
add_argument
(
"--router-config"
)
parser
.
add_argument
(
"--input-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--output-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--request-count"
,
type
=
int
)
parser
.
add_argument
(
"--arrival-interval-ms"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--replay-concurrency"
,
type
=
int
)
parser
.
add_argument
(
"--replay-mode"
,
choices
=
(
"offline"
,
"online"
),
default
=
"offline"
,
)
parser
.
add_argument
(
"--router-mode"
,
choices
=
(
"round_robin"
,
"kv_router"
),
default
=
"round_robin"
,
)
parser
.
add_argument
(
"--arrival-speedup-ratio"
,
type
=
float
,
default
=
1.0
)
args
=
parser
.
parse_args
(
list
(
sys
.
argv
[
1
:]
if
argv
is
None
else
argv
))
using_trace_file
=
args
.
trace_file
is
not
None
synthetic_args
=
(
args
.
input_tokens
,
args
.
output_tokens
,
args
.
request_count
)
using_synthetic
=
any
(
value
is
not
None
for
value
in
synthetic_args
)
if
using_trace_file
==
using_synthetic
:
parser
.
error
(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if
using_synthetic
and
not
all
(
value
is
not
None
for
value
in
synthetic_args
):
parser
.
error
(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args
=
(
MockEngineArgs
.
from_json
(
args
.
extra_engine_args
)
if
args
.
extra_engine_args
is
not
None
else
None
)
router_config
=
(
KvRouterConfig
.
from_json
(
args
.
router_config
)
if
args
.
router_config
is
not
None
else
None
)
if
using_trace_file
:
report
=
run_trace_replay
(
args
.
trace_file
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
args
.
num_workers
,
replay_concurrency
=
args
.
replay_concurrency
,
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
)
else
:
report
=
run_synthetic_trace_replay
(
args
.
input_tokens
,
args
.
output_tokens
,
args
.
request_count
,
extra_engine_args
=
extra_engine_args
,
router_config
=
router_config
,
num_workers
=
args
.
num_workers
,
replay_concurrency
=
args
.
replay_concurrency
,
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
arrival_interval_ms
=
args
.
arrival_interval_ms
,
)
json
.
dump
(
report
,
sys
.
stdout
,
indent
=
2
,
sort_keys
=
True
)
sys
.
stdout
.
write
(
"
\n
"
)
return
0
lib/bindings/python/tests/test_replay.py
0 → 100644
View file @
b7fe46b1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
json
import
pytest
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
pytestmark
=
[
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
parallel
,
pytest
.
mark
.
pre_merge
,
]
MOONCAKE_TRACE_FIRST20
=
"""{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def
_write_trace_and_args
(
tmp_path
):
trace_path
=
tmp_path
/
"trace.jsonl"
records
=
[
{
"timestamp"
:
1000.0
,
"input_length"
:
64
,
"output_length"
:
2
,
"hash_ids"
:
[
101
],
},
{
"timestamp"
:
1005.0
,
"input_length"
:
64
,
"output_length"
:
2
,
"hash_ids"
:
[
101
],
},
]
trace_path
.
write_text
(
"
\n
"
.
join
(
json
.
dumps
(
record
)
for
record
in
records
)
+
"
\n
"
,
encoding
=
"utf-8"
,
)
return
trace_path
def
_write_vllm_args
(
tmp_path
):
args_path
=
tmp_path
/
"args.json"
args_path
.
write_text
(
json
.
dumps
(
{
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
}
),
encoding
=
"utf-8"
,
)
return
args_path
def
_vllm_args
():
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
{
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
}
)
)
def
_write_sglang_args
(
tmp_path
):
args_path
=
tmp_path
/
"sglang_args.json"
args_path
.
write_text
(
json
.
dumps
(
{
"engine_type"
:
"sglang"
,
"num_gpu_blocks"
:
512
,
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
"sglang"
:
{
"page_size"
:
64
,
},
}
),
encoding
=
"utf-8"
,
)
return
args_path
def
_sglang_args
():
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
{
"engine_type"
:
"sglang"
,
"num_gpu_blocks"
:
512
,
"block_size"
:
64
,
"speedup_ratio"
:
1000.0
,
"sglang"
:
{
"page_size"
:
64
,
},
}
)
)
def
_write_router_config
(
tmp_path
):
config_path
=
tmp_path
/
"router_config.json"
config_path
.
write_text
(
json
.
dumps
(
{
"router_queue_threshold"
:
1.25
,
"router_event_threads"
:
1
,
"router_queue_policy"
:
"wspt"
,
"router_temperature"
:
0.0
,
"overlap_score_weight"
:
1.0
,
"use_kv_events"
:
True
,
"durable_kv_events"
:
False
,
"router_replica_sync"
:
False
,
"router_track_active_blocks"
:
True
,
"router_track_output_blocks"
:
False
,
"router_assume_kv_reuse"
:
True
,
"router_snapshot_threshold"
:
1000000
,
"router_reset_states"
:
False
,
"router_ttl_secs"
:
120.0
,
"router_max_tree_size"
:
1048576
,
"router_prune_target_ratio"
:
0.8
,
"router_enable_cache_control"
:
False
,
"skip_initial_worker_wait"
:
False
,
"min_initial_workers"
:
1
,
"remote_indexer_component"
:
None
,
}
),
encoding
=
"utf-8"
,
)
return
config_path
def
_router_config
():
return
KvRouterConfig
.
from_json
(
json
.
dumps
(
{
"router_queue_threshold"
:
1.25
,
"router_event_threads"
:
1
,
"router_queue_policy"
:
"wspt"
,
"router_temperature"
:
0.0
,
"overlap_score_weight"
:
1.0
,
"use_kv_events"
:
True
,
"durable_kv_events"
:
False
,
"router_replica_sync"
:
False
,
"router_track_active_blocks"
:
True
,
"router_track_output_blocks"
:
False
,
"router_assume_kv_reuse"
:
True
,
"router_snapshot_threshold"
:
1000000
,
"router_reset_states"
:
False
,
"router_ttl_secs"
:
120.0
,
"router_max_tree_size"
:
1048576
,
"router_prune_target_ratio"
:
0.8
,
"router_enable_cache_control"
:
False
,
"skip_initial_worker_wait"
:
False
,
"min_initial_workers"
:
1
,
"remote_indexer_component"
:
None
,
}
)
)
def
_partial_router_config
():
return
KvRouterConfig
(
router_queue_threshold
=
1.25
,
router_event_threads
=
1
,
router_queue_policy
=
"wspt"
,
)
def
_assert_basic_report_counts
(
report
,
*
,
num_requests
,
input_tokens
,
output_tokens
):
assert
report
[
"num_requests"
]
==
num_requests
assert
report
[
"completed_requests"
]
==
num_requests
assert
report
[
"total_input_tokens"
]
==
num_requests
*
input_tokens
assert
report
[
"total_output_tokens"
]
==
num_requests
*
output_tokens
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
def
test_run_trace_replay_smoke_matrix
(
tmp_path
,
engine_type
,
replay_mode
,
router_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
num_workers
=
1
if
router_mode
==
"round_robin"
else
2
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
num_workers
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_invariant_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
single
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
1
,
replay_mode
=
replay_mode
,
)
multi_round_robin
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"round_robin"
,
)
multi_kv_router
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
for
field
in
(
"num_requests"
,
"completed_requests"
,
"total_input_tokens"
,
"total_output_tokens"
,
):
assert
single
[
field
]
==
multi_round_robin
[
field
]
assert
single
[
field
]
==
multi_kv_router
[
field
]
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
def
test_run_synthetic_trace_replay_smoke_matrix
(
tmp_path
,
engine_type
,
replay_mode
,
router_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
num_workers
=
1
if
router_mode
==
"round_robin"
else
2
report
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
num_workers
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_interval_ms
=
5.0
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_synthetic_trace_replay_invariant_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
single
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
1
,
replay_mode
=
replay_mode
,
arrival_interval_ms
=
5.0
,
)
multi_round_robin
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"round_robin"
,
arrival_interval_ms
=
5.0
,
)
multi_kv_router
=
run_synthetic_trace_replay
(
64
,
2
,
2
,
extra_engine_args
=
args_path
,
num_workers
=
4
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
arrival_interval_ms
=
5.0
,
)
for
field
in
(
"num_requests"
,
"completed_requests"
,
"total_input_tokens"
,
"total_output_tokens"
,
):
assert
single
[
field
]
==
multi_round_robin
[
field
]
assert
single
[
field
]
==
multi_kv_router
[
field
]
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_synthetic_concurrency_replay_counts_match
(
tmp_path
,
engine_type
,
replay_mode
):
args_path
=
_vllm_args
()
if
engine_type
==
"vllm"
else
_sglang_args
()
report
=
run_synthetic_trace_replay
(
64
,
2
,
3
,
extra_engine_args
=
args_path
,
num_workers
=
2
,
replay_mode
=
replay_mode
,
replay_concurrency
=
2
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
3
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_router_config
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
router_config_path
=
_router_config
()
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
router_config
=
router_config_path
,
num_workers
=
2
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_partial_router_config_json
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
args_path
=
_vllm_args
()
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
args_path
,
router_config
=
_partial_router_config
(),
num_workers
=
2
,
replay_mode
=
replay_mode
,
router_mode
=
"kv_router"
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_accepts_partial_extra_engine_args_json
(
tmp_path
,
replay_mode
):
trace_path
=
_write_trace_and_args
(
tmp_path
)
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
MockEngineArgs
(
block_size
=
64
,
speedup_ratio
=
1000.0
),
num_workers
=
1
,
replay_mode
=
replay_mode
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
2
,
input_tokens
=
64
,
output_tokens
=
2
,
)
lib/kv-router/src/event_sink.rs
deleted
100644 → 0
View file @
82794761
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use
std
::
future
::
Future
;
use
crate
::
protocols
::
RouterEvent
;
/// Transport abstraction for publishing batched KV cache events.
pub
trait
EventSink
:
Send
+
Sync
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
anyhow
::
Result
<
()
>>
+
Send
;
}
lib/kv-router/src/indexer/tests.rs
View file @
b7fe46b1
...
...
@@ -245,9 +245,13 @@ async fn flush_and_settle(index: &dyn KvIndexerInterface) {
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
100
))
.await
;
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_store_and_find
(
variant
:
&
str
)
{
mod
interface_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_store_and_find
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a sequence for worker 0
...
...
@@ -266,11 +270,11 @@ async fn test_store_and_find(variant: &str) {
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_match
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_match
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store [1, 2, 3] for worker 0
...
...
@@ -288,11 +292,11 @@ async fn test_partial_match(variant: &str) {
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store sequence for worker 0
...
...
@@ -313,11 +317,11 @@ async fn test_remove(variant: &str) {
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_workers_shared_prefix
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_workers_shared_prefix
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Worker 0 has [1, 2], Worker 1 has [1, 3]
...
...
@@ -342,11 +346,11 @@ async fn test_multiple_workers_shared_prefix(variant: &str) {
assert_eq!
(
scores
.scores
.len
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
1
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_worker
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
...
...
@@ -370,11 +374,11 @@ async fn test_remove_worker(variant: &str) {
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_large_stores
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_large_stores
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Test sequences of increasing sizes
...
...
@@ -395,11 +399,11 @@ async fn test_large_stores(variant: &str) {
.collect
();
let
scores
=
index
.find_matches
(
last_seq
)
.await
.unwrap
();
assert
!
(
!
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_dump_and_restore
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_dump_and_restore
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store some data
...
...
@@ -424,11 +428,11 @@ async fn test_dump_and_restore(variant: &str) {
snapshot_tree
(
index
.as_ref
())
.await
,
snapshot_tree
(
restored
.as_ref
())
.await
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_all_blocks
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_all_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store some data for two workers
...
...
@@ -451,11 +455,11 @@ async fn test_clear_all_blocks(variant: &str) {
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_empty_query
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_empty_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
...
...
@@ -465,11 +469,11 @@ async fn test_empty_query(variant: &str) {
// Empty query should return empty scores
let
scores
=
index
.find_matches
(
vec!
[])
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_miss_query
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_miss_query
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
...
...
@@ -482,28 +486,28 @@ async fn test_miss_query(variant: &str) {
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.shutdown
();
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown_idempotent
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_shutdown_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
index
.apply_event
(
make_store_event
(
0
,
&
[
1
,
2
,
3
]))
.await
;
flush_and_settle
(
index
.as_ref
())
.await
;
index
.shutdown
();
index
.shutdown
();
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_find_matches_for_request
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_find_matches_for_request
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Empty index should return no matches
...
...
@@ -524,11 +528,11 @@ async fn test_find_matches_for_request(variant: &str) {
// The tokens [1,2,3,4] won't match our stored [1,2,3] local hashes
// because find_matches_for_request computes different hashes from raw tokens
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_process_routing_decision
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_process_routing_decision
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Create tokens with hashes
...
...
@@ -542,11 +546,11 @@ async fn test_process_routing_decision(variant: &str) {
.process_routing_decision_for_request
(
&
mut
tokens_with_hashes
,
worker
)
.await
;
assert
!
(
result
.is_ok
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_parent_hash_chains
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_parent_hash_chains
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store initial sequence [1, 2, 3]
...
...
@@ -569,11 +573,11 @@ async fn test_parent_hash_chains(variant: &str) {
let
prefix_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
prefix_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_dp_ranks
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Same worker_id but different dp_ranks should be tracked separately
...
...
@@ -597,11 +601,11 @@ async fn test_multiple_dp_ranks(variant: &str) {
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
1
))
.unwrap
(),
3
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
2
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_block_removal
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_partial_block_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store [1, 2, 3]
...
...
@@ -634,11 +638,11 @@ async fn test_partial_block_removal(variant: &str) {
let
partial_seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
2
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
partial_seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
2
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_mid_chain_block
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_mid_chain_block
(
variant
:
&
str
)
{
// TODO: positional indexer has no parent-child structure, so mid-chain removal
// doesn't invalidate later positions — jump search skips over the gap and over-counts.
if
variant
==
"flat"
{
...
...
@@ -688,11 +692,11 @@ async fn test_remove_mid_chain_block(variant: &str) {
// Query [1, 2, 3, 4, 5] — block 3 is back but 4 & 5 were orphaned, so score = 3
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_worker
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store data for worker 0
...
...
@@ -711,11 +715,11 @@ async fn test_remove_nonexistent_worker(variant: &str) {
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert
!
(
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_blocks
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_remove_nonexistent_blocks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store [1, 2, 3]
...
...
@@ -730,11 +734,11 @@ async fn test_remove_nonexistent_blocks(variant: &str) {
let
seq
:
Vec
<
LocalBlockHash
>
=
(
1
..=
3
)
.map
(
LocalBlockHash
)
.collect
();
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_then_reuse
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_then_reuse
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store initial data
...
...
@@ -759,11 +763,11 @@ async fn test_clear_then_reuse(variant: &str) {
let
scores
=
index
.find_matches
(
seq
)
.await
.unwrap
();
assert_eq!
(
scores
.scores
.len
(),
1
);
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_sequences_per_worker
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_multiple_sequences_per_worker
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store two disjoint sequences for the same worker
...
...
@@ -791,11 +795,11 @@ async fn test_multiple_sequences_per_worker(variant: &str) {
let
scores
=
index
.find_matches
(
mixed
)
.await
.unwrap
();
// Only block 1 matches because [1, 100] is not a valid prefix
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
1
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_clears_all_dp_ranks
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_clear_clears_all_dp_ranks
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store same sequence for different dp_ranks
...
...
@@ -824,15 +828,20 @@ async fn test_clear_clears_all_dp_ranks(variant: &str) {
scores
.scores
.is_empty
(),
"Cleared event should clear all dp_ranks for a worker"
);
}
}
// ============================================================================
// LoRA isolation tests
// ============================================================================
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_and_base_model_blocks_do_not_conflict
(
variant
:
&
str
)
{
mod
lora_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_and_base_model_blocks_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
...
...
@@ -840,7 +849,8 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
let
tokens
:
Vec
<
u32
>
=
(
0
..
kv_block_size
*
3
)
.collect
();
let
base_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
let
lora_hashes
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
// Hashes must differ despite identical tokens
assert_ne!
(
...
...
@@ -902,23 +912,23 @@ async fn test_lora_and_base_model_blocks_do_not_conflict(variant: &str) {
.unwrap
(),
3
);
}
}
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_base_same_tokens_no_seq_hash_mismatch
(
variant
:
&
str
)
{
/// Reproduces the "block_hash mismatch: sequence hashes should be uniform
/// across workers" warning seen when the same prompt is sent to both a base
/// model worker and a LoRA worker.
///
/// On main (without LoRA-aware hashing), both workers compute the same
/// LocalBlockHash for identical tokens. But vLLM's engine includes the
/// adapter in its rolling ExternalSequenceBlockHash, so the radix tree
/// sees conflicting sequence hashes at the same tree node.
///
/// With LoRA-aware hashing, compute_block_hash_for_seq produces distinct
/// LocalBlockHash values for different adapters, so the blocks land on
/// separate tree paths and no mismatch occurs.
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_lora_base_same_tokens_no_seq_hash_mismatch
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
...
...
@@ -926,7 +936,8 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
// With LoRA-aware hashing, base and adapter produce different LocalBlockHash
let
base_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
None
);
let
lora_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
let
lora_local
=
compute_block_hash_for_seq
(
&
tokens
,
kv_block_size
,
None
,
Some
(
"my-adapter"
));
assert_ne!
(
base_local
,
lora_local
,
...
...
@@ -988,11 +999,11 @@ async fn test_lora_base_same_tokens_no_seq_hash_mismatch(variant: &str) {
.unwrap
(),
3
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_different_lora_adapters_do_not_conflict
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_different_lora_adapters_do_not_conflict
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
let
kv_block_size
:
u32
=
32
;
...
...
@@ -1048,15 +1059,20 @@ async fn test_different_lora_adapters_do_not_conflict(variant: &str) {
assert_eq!
(
scores_b
.scores
.len
(),
1
);
assert
!
(
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
1
,
0
)));
assert
!
(
!
scores_b
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
}
// ============================================================================
// Long sequence tests - especially important for NestedMap/PositionalIndexer
// ============================================================================
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_single_store
(
variant
:
&
str
)
{
mod
long_sequence_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_single_store
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence (128 blocks) in a single event
...
...
@@ -1091,11 +1107,11 @@ async fn test_long_sequence_single_store(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
49
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_continuations
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Build a long sequence through multiple continuations
...
...
@@ -1133,12 +1149,14 @@ async fn test_long_sequence_multiple_continuations(variant: &str) {
// Query starts at block 45, but stored sequence starts at 1, so this won't match
// because the sequence hash at position 0 of our query (block 45) won't match
// the stored sequence hash at position 0 (block 1)
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
)));
}
assert
!
(
scores
.scores
.is_empty
()
||
!
scores
.scores
.contains_key
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
);
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_branching_continuations
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_branching_continuations
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Common prefix: blocks 1-30
...
...
@@ -1185,11 +1203,11 @@ async fn test_long_sequence_branching_continuations(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
1
,
0
))
.unwrap
(),
30
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_partial_removal
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_partial_removal
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence
...
...
@@ -1225,11 +1243,11 @@ async fn test_long_sequence_partial_removal(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
79
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_interleaved_workers
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_interleaved_workers
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Multiple workers storing overlapping long sequences concurrently
...
...
@@ -1270,11 +1288,11 @@ async fn test_long_sequence_interleaved_workers(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
25
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_exact_jump_size_boundaries
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_exact_jump_size_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Test sequences that align exactly with jump_size boundaries (32 for PositionalIndexer)
...
...
@@ -1315,11 +1333,11 @@ async fn test_long_sequence_exact_jump_size_boundaries(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
96
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_off_by_one_jump_boundaries
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_off_by_one_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Test sequences at jump_size +/- 1 boundaries to catch off-by-one errors
...
...
@@ -1363,11 +1381,11 @@ async fn test_long_sequence_off_by_one_jump_boundaries(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
3
,
0
))
.unwrap
(),
65
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_divergence_at_jump_boundaries
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_divergence_at_jump_boundaries
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence
...
...
@@ -1390,11 +1408,11 @@ async fn test_long_sequence_divergence_at_jump_boundaries(variant: &str) {
diverge_pos
);
}
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_deep_continuation_chain
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_deep_continuation_chain
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Build a very long sequence through many small continuations
...
...
@@ -1438,11 +1456,11 @@ async fn test_long_sequence_deep_continuation_chain(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
75
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_clear_and_rebuild
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_clear_and_rebuild
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Store a long sequence
...
...
@@ -1475,7 +1493,8 @@ async fn test_long_sequence_clear_and_rebuild(variant: &str) {
flush_and_settle
(
index
.as_ref
())
.await
;
// Verify new sequence works
let
new_query
:
Vec
<
LocalBlockHash
>
=
new_sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
new_query
:
Vec
<
LocalBlockHash
>
=
new_sequence
.iter
()
.map
(|
&
i
|
LocalBlockHash
(
i
))
.collect
();
let
scores
=
index
.find_matches
(
new_query
)
.await
.unwrap
();
assert_eq!
(
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
...
...
@@ -1485,11 +1504,11 @@ async fn test_long_sequence_clear_and_rebuild(variant: &str) {
// Verify old sequence no longer matches
let
scores
=
index
.find_matches
(
query
)
.await
.unwrap
();
assert
!
(
scores
.scores
.is_empty
());
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_workers_diverging
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_multiple_workers_diverging
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Multiple workers with long sequences that share a prefix then diverge
...
...
@@ -1546,11 +1565,11 @@ async fn test_long_sequence_multiple_workers_diverging(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
2
,
0
))
.unwrap
(),
40
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_staggered_lengths
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_long_sequence_staggered_lengths
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Workers with sequences of staggered lengths to test drain tracking
...
...
@@ -1593,11 +1612,11 @@ async fn test_long_sequence_staggered_lengths(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
4
,
0
))
.unwrap
(),
100
);
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_very_long_sequence
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_very_long_sequence
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Test with a very long sequence (1000 blocks)
...
...
@@ -1631,6 +1650,7 @@ async fn test_very_long_sequence(variant: &str) {
*
scores
.scores
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.unwrap
(),
499
);
}
}
// ============================================================================
...
...
@@ -1670,9 +1690,13 @@ fn make_tree_indexer_with_frequency(
}
}
#[tokio::test]
#[apply(tree_indexer_template)]
async
fn
test_frequency
(
variant
:
&
str
)
{
mod
tree_specific_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
#[apply(tree_indexer_template)]
async
fn
test_frequency
(
variant
:
&
str
)
{
const
ONE_MILLIS
:
Duration
=
Duration
::
from_millis
(
1
);
let
expiration
=
Duration
::
from_millis
(
50
);
...
...
@@ -1750,15 +1774,20 @@ async fn test_frequency(variant: &str) {
// The third access did not touch the last block
let
overlap
=
kv_indexer
.find_matches
(
block_hashes
.clone
())
.await
.unwrap
();
assert_eq!
(
overlap
.frequencies
,
vec!
[
3
,
3
,
3
,
2
]);
}
}
// ============================================================================
// KvIndexerMetrics tests
// ============================================================================
#[cfg(feature
=
"metrics"
)]
#[test]
fn
test_increment_event_applied
()
{
mod
metrics_tests
{
#[cfg(feature
=
"metrics"
)]
use
super
::
*
;
#[cfg(feature
=
"metrics"
)]
#[test]
fn
test_increment_event_applied
()
{
let
metrics
=
KvIndexerMetrics
::
new_unregistered
();
metrics
.increment_event_applied
(
METRIC_EVENT_STORED
,
Ok
(()));
...
...
@@ -1778,21 +1807,29 @@ fn test_increment_event_applied() {
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_PARENT_NOT_FOUND
])
.get_metric_with_label_values
(
&
[
METRIC_EVENT_STORED
,
METRIC_STATUS_PARENT_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
metrics
.increment_event_applied
(
METRIC_EVENT_REMOVED
,
Err
(
KvCacheEventError
::
BlockNotFound
));
metrics
.increment_event_applied
(
METRIC_EVENT_REMOVED
,
Err
(
KvCacheEventError
::
BlockNotFound
));
assert_eq!
(
metrics
.kv_cache_events_applied
.get_metric_with_label_values
(
&
[
METRIC_EVENT_REMOVED
,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.get_metric_with_label_values
(
&
[
METRIC_EVENT_REMOVED
,
METRIC_STATUS_BLOCK_NOT_FOUND
])
.unwrap
()
.get
(),
1
);
}
}
// ============================================================================
...
...
@@ -1822,8 +1859,12 @@ fn make_local_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
indexer
}
#[tokio::test]
async
fn
test_local_indexer_slice_within_range
()
{
mod
local_indexer_tests
{
use
super
::
*
;
use
rstest_reuse
::
apply
;
#[tokio::test]
async
fn
test_local_indexer_slice_within_range
()
{
let
indexer
=
make_local_indexer_with_events
(
&
[
1
,
2
,
3
,
4
,
5
]);
// Helper to extract events from response
...
...
@@ -1860,10 +1901,10 @@ async fn test_local_indexer_slice_within_range() {
// Invalid range: end < start
let
result
=
indexer
.get_events_in_id_range
(
Some
(
5
),
Some
(
2
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
InvalidRange
{
..
}));
}
}
#[tokio::test]
async
fn
test_local_indexer_get_events_in_id_range_all_cases
()
{
#[tokio::test]
async
fn
test_local_indexer_get_events_in_id_range_all_cases
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
...
...
@@ -1939,10 +1980,10 @@ async fn test_local_indexer_get_events_in_id_range_all_cases() {
let
result
=
indexer
.get_events_in_id_range
(
Some
(
100
),
Some
(
200
))
.await
;
assert
!
(
matches!
(
result
,
WorkerKvQueryResponse
::
TooNew
{
..
}));
}
}
#[tokio::test]
async
fn
test_tree_dump_includes_last_event_id
()
{
#[tokio::test]
async
fn
test_tree_dump_includes_last_event_id
()
{
// Create indexer with small buffer (5 events max)
let
indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
...
...
@@ -2031,10 +2072,10 @@ async fn test_tree_dump_includes_last_event_id() {
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
}
}
#[tokio::test]
async
fn
test_local_indexer_buffer_and_serialization
()
{
#[tokio::test]
async
fn
test_local_indexer_buffer_and_serialization
()
{
let
worker_id
=
42u64
;
let
token
=
CancellationToken
::
new
();
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
...
...
@@ -2078,10 +2119,10 @@ async fn test_local_indexer_buffer_and_serialization() {
};
assert_eq!
(
events
.len
(),
1
);
assert_eq!
(
events
[
0
]
.worker_id
,
worker_id
);
}
}
#[tokio::test]
async
fn
test_local_indexer_does_not_buffer_failed_send
()
{
#[tokio::test]
async
fn
test_local_indexer_does_not_buffer_failed_send
()
{
let
local_indexer
=
LocalKvIndexer
::
new
(
CancellationToken
::
new
(),
4
,
...
...
@@ -2123,11 +2164,11 @@ async fn test_local_indexer_does_not_buffer_failed_send() {
}
other
=>
panic!
(
"Expected TreeDump, got: {other:?}"
),
}
}
}
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_apply_events_idempotent
(
variant
:
&
str
)
{
#[tokio::test]
#[apply(indexer_template)]
async
fn
test_apply_events_idempotent
(
variant
:
&
str
)
{
let
index
=
make_indexer
(
variant
);
// Setup: build initial tree
...
...
@@ -2181,4 +2222,5 @@ async fn test_apply_events_idempotent(variant: &str) {
s2
,
s3
,
"Phase 3: non-interleaved ordering should restore tree"
);
}
}
lib/kv-router/src/lib.rs
View file @
b7fe46b1
...
...
@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub
mod
event_sink
;
pub
mod
indexer
;
pub
mod
protocols
;
pub
mod
scheduling
;
...
...
@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub
use
concurrent_radix_tree
::
ConcurrentRadixTree
;
pub
use
concurrent_radix_tree_compressed
::
ConcurrentRadixTreeCompressed
;
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterQueuePolicy
};
pub
use
event_sink
::
EventSink
;
pub
use
indexer
::{
MaybeError
,
SyncIndexer
,
ThreadPoolIndexer
};
pub
use
nested_map
::
PositionalIndexer
;
pub
use
protocols
::{
KvCacheEventError
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
WorkerConfigLike
,
WorkerId
,
compute_block_hash_for_seq
,
KvCacheEventError
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
RouterEventSink
,
WorkerConfigLike
,
WorkerId
,
compute_block_hash_for_seq
,
};
pub
use
queue
::
SchedulerQueue
;
pub
use
radix_tree
::
RadixTree
;
pub
use
scheduling
::
LocalScheduler
;
pub
use
scheduling
::
policy
::{
FcfsPolicy
,
RouterSchedulingPolicy
,
SchedulingPolicy
,
WsptPolicy
};
pub
use
scheduling
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
pub
use
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
lib/kv-router/src/protocols.rs
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
future
::
Future
;
use
dynamo_tokens
::{
SequenceHash
,
Token
};
use
rustc_hash
::
FxHashMap
;
use
serde
::{
Deserialize
,
Serialize
};
...
...
@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn
total_kv_blocks
(
&
self
)
->
Option
<
u64
>
;
}
/// Transport abstraction for publishing batched router-visible KV cache events.
pub
trait
RouterEventSink
:
Send
+
Sync
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
anyhow
::
Result
<
()
>>
+
Send
;
}
/// A worker identifier.
pub
type
WorkerId
=
u64
;
...
...
lib/kv-router/src/scheduling/config.rs
View file @
b7fe46b1
...
...
@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use
crate
::
protocols
::{
compute_block_hash_for_seq
,
compute_seq_hash_for_block
};
const
fn
default_min_initial_workers
()
->
usize
{
1
}
#[derive(Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
RouterQueuePolicy
{
#[default]
Fcfs
,
Lcfs
,
Wspt
,
}
...
...
@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
Self
::
Fcfs
=>
f
.write_str
(
"fcfs"
),
Self
::
Lcfs
=>
f
.write_str
(
"lcfs"
),
Self
::
Wspt
=>
f
.write_str
(
"wspt"
),
}
}
...
...
@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"fcfs"
=>
Ok
(
Self
::
Fcfs
),
"lcfs"
=>
Ok
(
Self
::
Lcfs
),
"wspt"
=>
Ok
(
Self
::
Wspt
),
_
=>
Err
(
format!
(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'"
"unknown queue policy: {s:?}, expected 'fcfs'
, 'lcfs',
or 'wspt'"
)),
}
}
...
...
@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Validate)]
#[serde(default)]
#[validate(schema(function
=
"validate_kv_router_config"
))]
pub
struct
KvRouterConfig
{
#[validate(range(min
=
0.0
))]
...
...
@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub
skip_initial_worker_wait
:
bool
,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default
=
"default_min_initial_workers"
)]
#[validate(range(min
=
1
))]
pub
min_initial_workers
:
usize
,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
...
...
@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs
:
120.0
,
router_max_tree_size
:
2u
size
.pow
(
20
),
// 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio
:
0.8
,
router_queue_threshold
:
Some
(
2
.0
),
router_queue_threshold
:
Some
(
4
.0
),
router_event_threads
:
4
,
router_enable_cache_control
:
false
,
skip_initial_worker_wait
:
false
,
min_initial_workers
:
default_min_initial_workers
(),
router_queue_policy
:
RouterQueuePolicy
::
default
(),
remote_indexer_component
:
None
,
}
...
...
@@ -237,3 +253,39 @@ impl KvRouterConfig {
self
.use_kv_events
&&
self
.overlap_score_weight
>
0.0
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
router_queue_policy_display_and_parse_support_lcfs
()
{
assert_eq!
(
RouterQueuePolicy
::
Lcfs
.to_string
(),
"lcfs"
);
assert_eq!
(
"lcfs"
.parse
::
<
RouterQueuePolicy
>
()
.unwrap
(),
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
router_queue_policy_serde_round_trip_supports_lcfs
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterQueuePolicy
::
Lcfs
)
.unwrap
();
assert_eq!
(
serialized
,
"
\"
lcfs
\"
"
);
let
deserialized
:
RouterQueuePolicy
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
,
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
kv_router_config_defaults_to_one_initial_worker
()
{
assert_eq!
(
KvRouterConfig
::
default
()
.min_initial_workers
,
1
);
}
#[test]
fn
kv_router_config_rejects_zero_initial_workers
()
{
let
cfg
=
KvRouterConfig
{
min_initial_workers
:
0
,
..
KvRouterConfig
::
default
()
};
assert
!
(
cfg
.validate
()
.is_err
());
}
}
lib/kv-router/src/scheduling/local.rs
0 → 100644
View file @
b7fe46b1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
mpsc
,
watch
};
use
tokio_util
::
sync
::
CancellationToken
;
use
super
::
policy
::{
RouterSchedulingPolicy
,
SchedulingPolicy
};
use
super
::
queue
::
SchedulerQueue
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
OverlapScores
,
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequenceError
,
SequencePublisher
,
SequenceRequest
,
};
use
dynamo_tokens
::
SequenceHash
;
const
RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
pub
struct
LocalScheduler
<
P
,
C
,
S
=
RouterSchedulingPolicy
,
Sel
=
DefaultWorkerSelector
>
where
P
:
SequencePublisher
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
,
Sel
:
WorkerSelector
<
C
>
,
{
request_tx
:
mpsc
::
Sender
<
SchedulingRequest
>
,
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
queue
:
Arc
<
SchedulerQueue
<
P
,
C
,
S
,
Sel
>>
,
worker_type
:
&
'static
str
,
}
impl
<
P
,
C
,
S
,
Sel
>
LocalScheduler
<
P
,
C
,
S
,
Sel
>
where
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
+
Clone
+
PartialEq
+
Send
+
Sync
+
'static
,
S
:
SchedulingPolicy
+
'static
,
Sel
:
WorkerSelector
<
C
>
+
Send
+
Sync
+
'static
,
{
#[allow(clippy::too_many_arguments)]
pub
fn
new
(
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
workers_with_configs
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
C
>>
,
threshold_frac
:
Option
<
f64
>
,
block_size
:
u32
,
selector
:
Sel
,
policy
:
S
,
cancellation_token
:
CancellationToken
,
worker_type
:
&
'static
str
,
monitor_worker_configs
:
bool
,
)
->
Self
{
if
monitor_worker_configs
{
let
slots_monitor
=
Arc
::
clone
(
&
slots
);
let
mut
monitor_rx
=
workers_with_configs
.clone
();
let
mut
last_workers
=
monitor_rx
.borrow
()
.clone
();
let
monitor_cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"LocalScheduler workers monitoring task started"
);
loop
{
tokio
::
select!
{
_
=
monitor_cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"LocalScheduler workers monitoring task shutting down"
);
break
;
}
result
=
monitor_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"LocalScheduler worker config watch dropped, shutting down"
);
break
;
}
}
}
let
current_workers
=
monitor_rx
.borrow_and_update
()
.clone
();
if
current_workers
==
last_workers
{
continue
;
}
let
dp_range
:
HashMap
<
WorkerId
,
(
u32
,
u32
)
>
=
current_workers
.iter
()
.map
(|(
&
id
,
cfg
)|
{
(
id
,
(
cfg
.data_parallel_start_rank
(),
cfg
.data_parallel_size
()),
)
})
.collect
();
slots_monitor
.update_workers
(
&
dp_range
);
last_workers
=
current_workers
;
}
});
}
let
queue
=
Arc
::
new
(
SchedulerQueue
::
new
(
Arc
::
clone
(
&
slots
),
workers_with_configs
,
threshold_frac
,
block_size
,
selector
,
policy
,
));
let
(
request_tx
,
request_rx
)
=
mpsc
::
channel
::
<
SchedulingRequest
>
(
1024
);
let
queue_clone
=
Arc
::
clone
(
&
queue
);
tokio
::
spawn
(
async
move
{
let
mut
request_rx
=
request_rx
;
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
RECHECK_INTERVAL
);
tracing
::
trace!
(
"LocalScheduler background task started"
);
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
trace!
(
"LocalScheduler background task shutting down"
);
break
;
}
request
=
request_rx
.recv
()
=>
{
let
Some
(
request
)
=
request
else
{
tracing
::
warn!
(
"LocalScheduler request channel closed"
);
break
;
};
tracing
::
trace!
(
"received request to be scheduled"
);
queue_clone
.enqueue
(
request
)
.await
;
}
_
=
recheck_interval
.tick
()
=>
{
queue_clone
.update
()
.await
;
}
}
}
});
Self
{
request_tx
,
slots
,
queue
,
worker_type
,
}
}
#[expect(clippy::too_many_arguments)]
pub
async
fn
schedule
(
&
self
,
maybe_request_id
:
Option
<
String
>
,
isl_tokens
:
usize
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
overlaps
:
OverlapScores
,
router_config_override
:
Option
<&
super
::
config
::
RouterConfigOverride
>
,
update_states
:
bool
,
lora_name
:
Option
<
String
>
,
priority_jump
:
f64
,
expected_output_tokens
:
Option
<
u32
>
,
allowed_worker_ids
:
Option
<
HashSet
<
WorkerId
>>
,
)
->
Result
<
SchedulingResponse
,
KvSchedulerError
>
{
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
request
=
SchedulingRequest
{
maybe_request_id
,
token_seq
,
isl_tokens
,
overlaps
,
decode_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
router_config_override
:
router_config_override
.cloned
(),
update_states
,
lora_name
,
priority_jump
,
expected_output_tokens
,
allowed_worker_ids
,
resp_tx
:
Some
(
resp_tx
),
};
self
.request_tx
.send
(
request
)
.await
.map_err
(|
_
|
KvSchedulerError
::
SubscriberShutdown
)
?
;
resp_rx
.await
.map_err
(|
_
|
KvSchedulerError
::
SubscriberShutdown
)
?
}
pub
fn
register_workers
(
&
self
,
worker_ids
:
&
HashSet
<
WorkerId
>
)
{
self
.queue
.register_workers
(
worker_ids
);
}
pub
async
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_request
(
req
)
.await
}
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
())
.await
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
pub
async
fn
free
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.free
(
&
request_id
.to_string
())
.await
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
pub
fn
pending_count
(
&
self
)
->
usize
{
self
.queue
.pending_count
()
}
pub
fn
worker_type
(
&
self
)
->
&
'static
str
{
self
.worker_type
}
pub
fn
add_output_block
(
&
self
,
request_id
:
&
str
,
decay_fraction
:
Option
<
f64
>
,
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_output_block
(
&
request_id
.to_string
(),
decay_fraction
)
}
pub
fn
get_potential_loads
(
&
self
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
isl_tokens
:
usize
,
overlaps
:
OverlapScores
,
)
->
Vec
<
PotentialLoad
>
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens
(
token_seq
.as_deref
(),
isl_tokens
,
overlaps
);
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
workers
.extend
(
decode_blocks
.keys
()
.copied
());
workers
.extend
(
prefill_tokens
.keys
()
.copied
());
let
mut
loads
=
Vec
::
with_capacity
(
workers
.len
());
for
worker
in
workers
{
loads
.push
(
PotentialLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
potential_prefill_tokens
:
prefill_tokens
.get
(
&
worker
)
.copied
()
.unwrap_or
(
isl_tokens
),
potential_decode_blocks
:
decode_blocks
.get
(
&
worker
)
.copied
()
.unwrap_or
(
0
),
});
}
loads
}
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
self
.slots
.get_active_lora_counts
()
}
}
#[cfg(test)]
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::
watch
;
use
super
::
*
;
use
crate
::
protocols
::
OverlapScores
;
use
crate
::
scheduling
::
policy
::
FcfsPolicy
;
use
crate
::
scheduling
::
selector
::
DefaultWorkerSelector
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
#[allow(clippy::type_complexity)]
fn
make_scheduler
(
workers
:
HashMap
<
WorkerId
,
SimpleWorkerConfig
>
,
threshold_frac
:
Option
<
f64
>
,
monitor_worker_configs
:
bool
,
)
->
(
Arc
<
LocalScheduler
<
NoopSequencePublisher
,
SimpleWorkerConfig
,
FcfsPolicy
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
watch
::
Sender
<
HashMap
<
WorkerId
,
SimpleWorkerConfig
>>
,
CancellationToken
,
)
{
let
dp_range
=
workers
.iter
()
.map
(|(
&
id
,
cfg
)|
(
id
,
(
cfg
.data_parallel_start_rank
,
cfg
.data_parallel_size
)))
.collect
();
let
slots
=
Arc
::
new
(
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
64
,
dp_range
,
false
,
0
,
"test"
,
));
let
(
cfg_tx
,
cfg_rx
)
=
watch
::
channel
(
workers
);
let
cancel_token
=
CancellationToken
::
new
();
let
scheduler
=
Arc
::
new
(
LocalScheduler
::
new
(
Arc
::
clone
(
&
slots
),
cfg_rx
,
threshold_frac
,
64
,
DefaultWorkerSelector
::
new
(
None
,
"test"
),
FcfsPolicy
,
cancel_token
.clone
(),
"test"
,
monitor_worker_configs
,
));
(
scheduler
,
slots
,
cfg_tx
,
cancel_token
)
}
#[tokio::test]
async
fn
test_schedule_books_request_into_active_sequences
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
response
=
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
Some
(
"adapter-a"
.to_string
()),
0.0
,
None
,
None
,
)
.await
.unwrap
();
assert_eq!
(
response
.best_worker.worker_id
,
0
);
assert_eq!
(
scheduler
.get_active_lora_counts
(),
HashMap
::
from
([(
String
::
from
(
"adapter-a"
),
1
)])
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_mark_prefill_completed_drains_pending_queue
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
);
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
.unwrap
();
let
queued
=
{
let
scheduler
=
Arc
::
clone
(
&
scheduler
);
tokio
::
spawn
(
async
move
{
scheduler
.schedule
(
Some
(
"req-2"
.to_string
()),
64
,
Some
(
vec!
[
5
,
6
,
7
,
8
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
})
};
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
25
))
.await
;
assert_eq!
(
scheduler
.pending_count
(),
1
);
scheduler
.mark_prefill_completed
(
"req-1"
)
.await
.unwrap
();
queued
.await
.unwrap
()
.unwrap
();
assert_eq!
(
scheduler
.pending_count
(),
0
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_free_updates_active_state
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
64
),
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
64
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
Some
(
"adapter-a"
.to_string
()),
0.0
,
None
,
None
,
)
.await
.unwrap
();
assert_eq!
(
scheduler
.get_active_lora_counts
(),
HashMap
::
from
([(
String
::
from
(
"adapter-a"
),
1
)])
);
scheduler
.free
(
"req-1"
)
.await
.unwrap
();
assert
!
(
scheduler
.get_active_lora_counts
()
.is_empty
());
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_get_potential_loads_matches_slots
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
workers
.insert
(
1
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
token_seq
=
vec!
[
11
,
22
,
33
,
44
];
let
overlaps
=
OverlapScores
::
default
();
let
(
decode_blocks
,
prefill_tokens
)
=
slots
.potential_blocks_and_tokens
(
Some
(
&
token_seq
),
128
,
overlaps
.clone
());
let
mut
expected
:
Vec
<
_
>
=
decode_blocks
.keys
()
.map
(|
worker
|
PotentialLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
potential_prefill_tokens
:
prefill_tokens
.get
(
worker
)
.copied
()
.unwrap_or
(
128
),
potential_decode_blocks
:
decode_blocks
.get
(
worker
)
.copied
()
.unwrap_or
(
0
),
})
.collect
();
expected
.sort_by_key
(|
load
|
(
load
.worker_id
,
load
.dp_rank
));
let
mut
actual
=
scheduler
.get_potential_loads
(
Some
(
token_seq
),
128
,
overlaps
);
actual
.sort_by_key
(|
load
|
(
load
.worker_id
,
load
.dp_rank
));
assert_eq!
(
actual
.len
(),
expected
.len
());
for
(
actual
,
expected
)
in
actual
.iter
()
.zip
(
expected
.iter
())
{
assert_eq!
(
actual
.worker_id
,
expected
.worker_id
);
assert_eq!
(
actual
.dp_rank
,
expected
.dp_rank
);
assert_eq!
(
actual
.potential_prefill_tokens
,
expected
.potential_prefill_tokens
);
assert_eq!
(
actual
.potential_decode_blocks
,
expected
.potential_decode_blocks
);
}
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_register_workers_uses_default_dp_fallback
()
{
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
HashMap
::
new
(),
None
,
false
);
scheduler
.register_workers
(
&
HashSet
::
from
([
42
]));
let
loads
=
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
());
assert_eq!
(
loads
.len
(),
1
);
assert_eq!
(
loads
[
0
]
.worker_id
,
42
);
assert_eq!
(
loads
[
0
]
.dp_rank
,
0
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_worker_watch_updates_slot_ranges
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
::
default
());
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
assert_eq!
(
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
())
.len
(),
1
);
let
mut
updated_workers
=
HashMap
::
new
();
updated_workers
.insert
(
0
,
SimpleWorkerConfig
{
data_parallel_size
:
2
,
..
Default
::
default
()
},
);
updated_workers
.insert
(
1
,
SimpleWorkerConfig
::
default
());
cfg_tx
.send
(
updated_workers
)
.unwrap
();
tokio
::
time
::
timeout
(
Duration
::
from_secs
(
1
),
async
{
loop
{
if
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
())
.len
()
==
3
{
break
;
}
tokio
::
task
::
yield_now
()
.await
;
}
})
.await
.unwrap
();
cancel_token
.cancel
();
}
}
lib/kv-router/src/scheduling/mod.rs
View file @
b7fe46b1
...
...
@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
pub
mod
config
;
mod
local
;
pub
mod
policy
;
pub
mod
queue
;
pub
mod
selector
;
mod
types
;
pub
use
local
::
LocalScheduler
;
pub
use
types
::
*
;
lib/kv-router/src/scheduling/policy.rs
View file @
b7fe46b1
...
...
@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
}
}
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub
struct
LcfsPolicy
;
impl
SchedulingPolicy
for
LcfsPolicy
{
type
Key
=
OrderedFloat
<
f64
>
;
fn
enqueue_key
(
&
self
,
arrival_offset
:
Duration
,
request
:
&
SchedulingRequest
)
->
Self
::
Key
{
OrderedFloat
(
request
.priority_jump
.max
(
0.0
)
+
arrival_offset
.as_secs_f64
())
}
}
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL.
...
...
@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time.
pub
enum
RouterSchedulingPolicy
{
Fcfs
(
FcfsPolicy
),
Lcfs
(
LcfsPolicy
),
Wspt
(
WsptPolicy
),
}
...
...
@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub
fn
new
(
kind
:
RouterQueuePolicy
,
block_size
:
usize
)
->
Self
{
match
kind
{
RouterQueuePolicy
::
Fcfs
=>
Self
::
Fcfs
(
FcfsPolicy
),
RouterQueuePolicy
::
Lcfs
=>
Self
::
Lcfs
(
LcfsPolicy
),
RouterQueuePolicy
::
Wspt
=>
Self
::
Wspt
(
WsptPolicy
{
block_size
}),
}
}
...
...
@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn
enqueue_key
(
&
self
,
arrival_offset
:
Duration
,
request
:
&
SchedulingRequest
)
->
Self
::
Key
{
match
self
{
Self
::
Fcfs
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
Self
::
Lcfs
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
Self
::
Wspt
(
p
)
=>
p
.enqueue_key
(
arrival_offset
,
request
),
}
}
...
...
@@ -178,6 +196,42 @@ mod tests {
assert
!
(
key_b
>
key_a
);
}
#[test]
fn
lcfs_later_arrival_scheduled_first
()
{
let
policy
=
LcfsPolicy
;
let
req
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
early
=
policy
.enqueue_key
(
Duration
::
from_secs
(
1
),
&
req
);
let
late
=
policy
.enqueue_key
(
Duration
::
from_secs
(
10
),
&
req
);
assert
!
(
late
>
early
,
"later arrival should have higher key"
);
}
#[test]
fn
lcfs_priority_jump_promotes
()
{
let
policy
=
LcfsPolicy
;
let
normal
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
boosted
=
request_with
(
512
,
100.0
,
OverlapScores
::
default
());
let
t
=
Duration
::
from_secs
(
10
);
let
key_normal
=
policy
.enqueue_key
(
t
,
&
normal
);
let
key_boosted
=
policy
.enqueue_key
(
t
,
&
boosted
);
assert
!
(
key_boosted
>
key_normal
,
"priority_jump should produce a higher key"
);
}
#[test]
fn
router_scheduling_policy_matches_fcfs_and_lcfs_ordering
()
{
let
req
=
request_with
(
512
,
0.0
,
OverlapScores
::
default
());
let
early
=
Duration
::
from_secs
(
1
);
let
late
=
Duration
::
from_secs
(
10
);
let
fcfs
=
RouterSchedulingPolicy
::
new
(
RouterQueuePolicy
::
Fcfs
,
16
);
assert
!
(
fcfs
.enqueue_key
(
early
,
&
req
)
>
fcfs
.enqueue_key
(
late
,
&
req
));
let
lcfs
=
RouterSchedulingPolicy
::
new
(
RouterQueuePolicy
::
Lcfs
,
16
);
assert
!
(
lcfs
.enqueue_key
(
late
,
&
req
)
>
lcfs
.enqueue_key
(
early
,
&
req
));
}
// ---- WSPT policy tests ----
#[test]
...
...
lib/kv-router/src/scheduling/queue.rs
View file @
b7fe46b1
...
...
@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use
tokio
::
sync
::
watch
;
use
super
::
policy
::{
FcfsPolicy
,
SchedulingPolicy
};
use
super
::
selector
::
WorkerSelector
;
use
super
::
selector
::
{
Default
WorkerSelector
,
WorkerSelector
}
;
use
super
::
types
::{
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequencePublisher
,
SequenceRequest
};
...
...
@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P
:
SequencePublisher
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
=
FcfsPolicy
,
Sel
:
WorkerSelector
<
C
>
=
DefaultWorkerSelector
,
>
{
pending
:
Mutex
<
BinaryHeap
<
QueueEntry
<
S
::
Key
>>>
,
/// Number of requests currently parked in the pending queue.
...
...
@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets.
start_time
:
Instant
,
block_size
:
u32
,
selector
:
Box
<
dyn
WorkerSelector
<
C
>
+
Send
+
Sync
>
,
selector
:
Sel
,
policy
:
S
,
}
impl
<
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
>
SchedulerQueue
<
P
,
C
,
S
>
impl
<
P
:
SequencePublisher
+
'static
,
C
:
WorkerConfigLike
,
S
:
SchedulingPolicy
,
Sel
:
WorkerSelector
<
C
>
,
>
SchedulerQueue
<
P
,
C
,
S
,
Sel
>
{
pub
fn
new
(
slots
:
Arc
<
ActiveSequencesMultiWorker
<
P
>>
,
workers_with_configs
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
C
>>
,
threshold_frac
:
Option
<
f64
>
,
block_size
:
u32
,
selector
:
Box
<
dyn
WorkerSelector
<
C
>
+
Send
+
Sync
>
,
selector
:
Sel
,
policy
:
S
,
)
->
Self
{
if
let
Some
(
frac
)
=
threshold_frac
{
...
...
@@ -341,7 +346,7 @@ mod tests {
}
let
(
cfg_tx
,
cfg_rx
)
=
watch
::
channel
(
configs
);
let
selector
=
Box
::
new
(
DefaultWorkerSelector
::
new
(
None
,
"test"
)
)
;
let
selector
=
DefaultWorkerSelector
::
new
(
None
,
"test"
);
let
queue
=
Arc
::
new
(
SchedulerQueue
::
new
(
Arc
::
clone
(
&
slots
),
cfg_rx
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment