Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
95a750f4
Unverified
Commit
95a750f4
authored
Apr 06, 2026
by
Yan Ru Pei
Committed by
GitHub
Apr 06, 2026
Browse files
chore(replay): refactor offline components into cleaner lanes (#7866)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
210bbf5d
Changes
91
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1085 additions
and
154 deletions
+1085
-154
lib/bindings/python/rust/llm/aic_callback.rs
lib/bindings/python/rust/llm/aic_callback.rs
+50
-7
lib/bindings/python/rust/llm/entrypoint.rs
lib/bindings/python/rust/llm/entrypoint.rs
+98
-4
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+24
-1
lib/bindings/python/rust/llm/replay.rs
lib/bindings/python/rust/llm/replay.rs
+96
-6
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+24
-0
lib/bindings/python/src/dynamo/_internal/aic.py
lib/bindings/python/src/dynamo/_internal/aic.py
+193
-0
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+3
-0
lib/bindings/python/src/dynamo/replay/api.py
lib/bindings/python/src/dynamo/replay/api.py
+6
-0
lib/bindings/python/src/dynamo/replay/main.py
lib/bindings/python/src/dynamo/replay/main.py
+48
-1
lib/bindings/python/tests/replay/test_replay_smoke.py
lib/bindings/python/tests/replay/test_replay_smoke.py
+26
-0
lib/kv-router/src/lib.rs
lib/kv-router/src/lib.rs
+2
-1
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+9
-0
lib/kv-router/src/scheduling/config.rs
lib/kv-router/src/scheduling/config.rs
+65
-33
lib/kv-router/src/scheduling/local.rs
lib/kv-router/src/scheduling/local.rs
+88
-17
lib/kv-router/src/scheduling/mod.rs
lib/kv-router/src/scheduling/mod.rs
+2
-0
lib/kv-router/src/scheduling/prefill_load.rs
lib/kv-router/src/scheduling/prefill_load.rs
+13
-0
lib/kv-router/src/scheduling/queue.rs
lib/kv-router/src/scheduling/queue.rs
+158
-47
lib/kv-router/src/sequences/block_tracker.rs
lib/kv-router/src/sequences/block_tracker.rs
+45
-0
lib/kv-router/src/sequences/mod.rs
lib/kv-router/src/sequences/mod.rs
+2
-0
lib/kv-router/src/sequences/multi_worker.rs
lib/kv-router/src/sequences/multi_worker.rs
+133
-37
No files found.
lib/bindings/python/rust/llm/aic_callback.rs
View file @
95a750f4
...
...
@@ -8,15 +8,17 @@
//! predictions without knowing about PyO3.
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
pyo3
::
prelude
::
*
;
use
dynamo_kv_router
::
PrefillLoadEstimator
;
use
dynamo_mocker
::
common
::
perf_model
::
AicCallback
;
/// Wraps a Python AIC InferenceSession for direct calls from Rust.
///
/// The Python object must expose:
/// - `predict_prefill(batch_size, isl, prefix
, osl
) -> float`
/// - `predict_prefill(batch_size,
effective_
isl, prefix) -> float`
/// - `predict_decode(batch_size, isl, osl) -> float`
pub
(
super
)
struct
PyAicCallback
{
pub
(
super
)
session
:
PyObject
,
...
...
@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback {
unsafe
impl
Send
for
PyAicCallback
{}
unsafe
impl
Sync
for
PyAicCallback
{}
impl
AicCallback
for
PyAicCallback
{
fn
predict_prefill
(
&
self
,
batch_size
:
usize
,
isl
:
usize
,
prefix
:
usize
,
osl
:
usize
)
->
f64
{
impl
PyAicCallback
{
fn
predict_prefill_ms
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
PyResult
<
f64
>
{
Python
::
with_gil
(|
py
|
{
self
.session
.call_method1
(
py
,
"predict_prefill"
,
(
batch_size
,
isl
,
prefix
,
osl
))
.and_then
(|
r
|
r
.extract
::
<
f64
>
(
py
))
.unwrap_or_else
(|
e
|
panic!
(
"AIC predict_prefill failed: {e}"
))
.call_method1
(
py
,
"predict_prefill"
,
(
batch_size
,
effective_isl
,
prefix
))
.and_then
(|
result
|
result
.extract
::
<
f64
>
(
py
))
})
}
}
impl
AicCallback
for
PyAicCallback
{
fn
predict_prefill
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
)
->
f64
{
self
.predict_prefill_ms
(
batch_size
,
effective_isl
,
prefix
)
.unwrap_or_else
(|
e
|
panic!
(
"AIC predict_prefill failed: {e}"
))
}
fn
predict_decode
(
&
self
,
batch_size
:
usize
,
isl
:
usize
,
osl
:
usize
)
->
f64
{
Python
::
with_gil
(|
py
|
{
...
...
@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback {
}
}
impl
PrefillLoadEstimator
for
PyAicCallback
{
fn
predict_prefill_duration
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
let
latency_ms
=
self
.predict_prefill_ms
(
batch_size
,
effective_isl
,
prefix
)
?
;
Ok
(
Duration
::
from_secs_f64
(
latency_ms
/
1000.0
))
}
}
/// Initialize an AIC callback by importing and calling the Python setup function.
///
/// Called once at mocker startup when `--aic-perf-model` is requested.
...
...
@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback(
moe_ep_size
:
Option
<
usize
>
,
attention_dp_size
:
Option
<
usize
>
,
)
->
PyResult
<
Arc
<
dyn
AicCallback
>>
{
let
module
=
py
.import
(
"dynamo.
mocker.aic_session
"
)
?
;
let
module
=
py
.import
(
"dynamo.
_internal.aic
"
)
?
;
let
session
=
module
.call_method1
(
"create_session"
,
(
...
...
@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback(
session
:
session
.into
(),
}))
}
pub
(
super
)
fn
create_aic_prefill_load_estimator
(
py
:
Python
<
'_
>
,
backend_name
:
&
str
,
system
:
&
str
,
model_path
:
&
str
,
tp_size
:
usize
,
backend_version
:
Option
<&
str
>
,
)
->
PyResult
<
Arc
<
dyn
PrefillLoadEstimator
>>
{
let
module
=
py
.import
(
"dynamo._internal.aic"
)
?
;
let
session
=
module
.call_method1
(
"create_session"
,
(
backend_name
,
system
,
model_path
,
tp_size
,
backend_version
),
)
?
;
Ok
(
Arc
::
new
(
PyAicCallback
{
session
:
session
.into
(),
}))
}
lib/bindings/python/rust/llm/entrypoint.rs
View file @
95a750f4
...
...
@@ -10,7 +10,9 @@ use std::sync::Arc;
use
pyo3
::{
exceptions
::
PyException
,
exceptions
::
PyValueError
,
prelude
::
*
};
use
pyo3_async_runtimes
::
TaskLocals
;
use
dynamo_kv_router
::
config
::
KvRouterConfig
as
RsKvRouterConfig
;
use
dynamo_kv_router
::
config
::{
KvRouterConfig
as
RsKvRouterConfig
,
RouterPrefillLoadModel
as
RsRouterPrefillLoadModel
,
};
use
dynamo_llm
::
discovery
::
LoadThresholdConfig
as
RsLoadThresholdConfig
;
use
dynamo_llm
::
entrypoint
::
ChatEngineFactoryCallback
;
use
dynamo_llm
::
entrypoint
::
EngineConfig
as
RsEngineConfig
;
...
...
@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use
dynamo_llm
::
types
::
openai
::
chat_completions
::
OpenAIChatCompletionsStreamingEngine
;
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
aic_callback
::
{
create_aic_callback
,
create_aic_prefill_load_estimator
}
;
use
super
::
replay
::
MockEngineArgs
as
PyMockEngineArgs
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
as
RsMockEngineArgs
;
use
dynamo_runtime
::
discovery
::
ModelCardInstanceId
as
RsModelCardInstanceId
;
...
...
@@ -55,10 +57,76 @@ impl KvRouterConfig {
}
}
#[pyclass]
#[derive(Clone,
Debug)]
pub
struct
AicPerfConfig
{
aic_backend
:
String
,
aic_system
:
String
,
aic_backend_version
:
Option
<
String
>
,
aic_tp_size
:
usize
,
aic_model_path
:
String
,
}
impl
AicPerfConfig
{
pub
(
crate
)
fn
backend_name
(
&
self
)
->
&
str
{
&
self
.aic_backend
}
pub
(
crate
)
fn
system
(
&
self
)
->
&
str
{
&
self
.aic_system
}
pub
(
crate
)
fn
backend_version
(
&
self
)
->
Option
<&
str
>
{
self
.aic_backend_version
.as_deref
()
}
pub
(
crate
)
fn
tp_size
(
&
self
)
->
usize
{
self
.aic_tp_size
}
pub
(
crate
)
fn
model_path
(
&
self
)
->
&
str
{
&
self
.aic_model_path
}
}
#[pymethods]
impl
AicPerfConfig
{
#[new]
#[pyo3(signature
=
(aic_backend,
aic_system,
aic_model_path,
aic_tp_size=
1
,
aic_backend_version=None))]
fn
new
(
aic_backend
:
String
,
aic_system
:
String
,
aic_model_path
:
String
,
aic_tp_size
:
usize
,
aic_backend_version
:
Option
<
String
>
,
)
->
PyResult
<
Self
>
{
if
aic_backend
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_backend must be non-empty"
));
}
if
aic_system
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_system must be non-empty"
));
}
if
aic_model_path
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_model_path must be non-empty"
));
}
if
aic_tp_size
==
0
{
return
Err
(
PyValueError
::
new_err
(
"aic_tp_size must be >= 1"
));
}
Ok
(
Self
{
aic_backend
,
aic_system
,
aic_backend_version
,
aic_tp_size
,
aic_model_path
,
})
}
}
#[pymethods]
impl
KvRouterConfig
{
#[new]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_track_prefill_tokens=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4.0
),
router_event_threads=
4
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_track_prefill_tokens=
true
,
router_prefill_load_model=
"none"
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4.0
),
router_event_threads=
4
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn
new
(
overlap_score_weight
:
f64
,
...
...
@@ -70,6 +138,7 @@ impl KvRouterConfig {
router_track_output_blocks
:
bool
,
router_assume_kv_reuse
:
bool
,
router_track_prefill_tokens
:
bool
,
router_prefill_load_model
:
&
str
,
router_snapshot_threshold
:
Option
<
u32
>
,
router_reset_states
:
bool
,
router_ttl_secs
:
f64
,
...
...
@@ -91,6 +160,11 @@ impl KvRouterConfig {
router_track_output_blocks
,
router_assume_kv_reuse
,
router_track_prefill_tokens
,
router_prefill_load_model
:
router_prefill_load_model
.parse
::
<
RsRouterPrefillLoadModel
>
()
.unwrap_or_else
(|
_
|
{
panic!
(
"invalid router_prefill_load_model: {router_prefill_load_model:?}"
)
}),
router_snapshot_threshold
,
router_reset_states
,
router_ttl_secs
,
...
...
@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs {
is_prefill
:
bool
,
migration_limit
:
u32
,
chat_engine_factory
:
Option
<
PyEngineFactory
>
,
aic_perf_config
:
Option
<
AicPerfConfig
>
,
}
#[pymethods]
impl
EntrypointArgs
{
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None
,
aic_perf_config=None
))]
pub
fn
new
(
py
:
Python
<
'_
>
,
engine_type
:
EngineType
,
...
...
@@ -279,6 +354,7 @@ impl EntrypointArgs {
is_prefill
:
bool
,
migration_limit
:
u32
,
chat_engine_factory
:
Option
<
PyObject
>
,
aic_perf_config
:
Option
<
AicPerfConfig
>
,
)
->
PyResult
<
Self
>
{
let
endpoint_id_obj
:
Option
<
EndpointId
>
=
endpoint_id
.as_deref
()
.map
(
EndpointId
::
from
);
if
(
tls_cert_path
.is_some
()
&&
tls_key_path
.is_none
())
...
...
@@ -327,6 +403,7 @@ impl EntrypointArgs {
is_prefill
,
migration_limit
,
chat_engine_factory
,
aic_perf_config
,
})
}
}
...
...
@@ -467,9 +544,26 @@ async fn select_engine(
EngineType
::
Dynamic
=>
{
// Convert Python chat engine factory to Rust callback
let
chat_engine_factory
=
args
.chat_engine_factory
.map
(
py_engine_factory_to_callback
);
let
prefill_load_estimator
=
args
.aic_perf_config
.as_ref
()
.map
(|
config
|
{
Python
::
with_gil
(|
py
|
{
create_aic_prefill_load_estimator
(
py
,
config
.backend_name
(),
config
.system
(),
config
.model_path
(),
config
.tp_size
(),
config
.backend_version
(),
)
})
})
.transpose
()
?
;
RsEngineConfig
::
Dynamic
{
model
:
Box
::
new
(
local_model
),
chat_engine_factory
,
prefill_load_estimator
,
}
}
EngineType
::
Mocker
=>
{
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
95a750f4
...
...
@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker;
use
llm_rs
::
protocols
::
common
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
serde_json
::
json
;
use
super
::
aic_callback
::
create_aic_prefill_load_estimator
;
use
super
::
entrypoint
::
AicPerfConfig
;
fn
depythonize_block_mm_infos
(
obj
:
&
Bound
<
'_
,
PyAny
>
)
->
PyResult
<
Vec
<
Option
<
BlockExtraInfo
>>>
{
depythonize
(
obj
)
.map_err
(
to_pyerr
)
}
...
...
@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint(
endpoint
:
&
Endpoint
,
block_size
:
usize
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
dynamo_kv_router
::
PrefillLoadEstimator
>>
,
)
->
Result
<
Arc
<
llm_rs
::
kv_router
::
KvRouter
>
,
PyErr
>
{
// Create ModelManager and use it to create KvRouter (ensures registration)
let
model_manager
=
Arc
::
new
(
llm_rs
::
discovery
::
ModelManager
::
new
());
...
...
@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint(
&
endpoint
.inner
,
block_size
as
u32
,
kv_router_config
,
prefill_load_estimator
,
worker_type
,
model_name
,
enable_eagle
,
...
...
@@ -888,12 +893,29 @@ impl KvRouter {
/// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
/// (contains "prefill") or by `router_track_active_blocks` being disabled.
#[new]
#[pyo3(signature
=
(endpoint,
block_size,
kv_router_config))]
#[pyo3(signature
=
(endpoint,
block_size,
kv_router_config
,
aic_perf_config=None
))]
fn
new
(
endpoint
:
&
Endpoint
,
block_size
:
usize
,
kv_router_config
:
&
super
::
entrypoint
::
KvRouterConfig
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
)
->
PyResult
<
Self
>
{
let
prefill_load_estimator
=
aic_perf_config
.map
(|
config
|
{
Python
::
with_gil
(|
py
|
{
create_aic_prefill_load_estimator
(
py
,
config
.backend_name
(),
config
.system
(),
config
.model_path
(),
config
.tp_size
(),
config
.backend_version
(),
)
})
})
.transpose
()
.map_err
(
to_pyerr
)
?
;
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
move
{
let
client
=
endpoint
.inner
.client
()
.await
.map_err
(
to_pyerr
)
?
;
...
...
@@ -916,6 +938,7 @@ impl KvRouter {
endpoint
,
block_size
,
Some
(
kv_router_config
.inner
()),
prefill_load_estimator
,
)
.await
?
;
...
...
lib/bindings/python/rust/llm/replay.rs
View file @
95a750f4
...
...
@@ -19,8 +19,8 @@ use pythonize::pythonize;
use
serde_json
::
json
;
use
uuid
::
Uuid
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
entrypoint
::{
KvRouterConfig
,
to_pyerr
};
use
super
::
aic_callback
::
{
create_aic_callback
,
create_aic_prefill_load_estimator
}
;
use
super
::
entrypoint
::{
AicPerfConfig
,
KvRouterConfig
,
to_pyerr
};
fn
parse_mocker_engine_type
(
engine_type
:
&
str
)
->
PyResult
<
RsMockerEngineType
>
{
match
engine_type
{
...
...
@@ -526,7 +526,7 @@ impl MockEngineArgs {
}
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
))]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
trace_block_size=
512
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
...
...
@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay(
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
num_workers
:
usize
,
num_prefill_workers
:
usize
,
num_decode_workers
:
usize
,
...
...
@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay(
replay_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
trace_block_size
:
usize
,
)
->
PyResult
<
PyObject
>
{
let
args_selection
=
load_replay_args_selection
(
py
,
...
...
@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay(
num_prefill_workers
,
num_decode_workers
,
)
?
;
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
prefill_load_estimator
=
load_replay_prefill_load_estimator
(
py
,
router_mode
,
router_config
.as_ref
(),
aic_perf_config
,
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
...
...
@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_file_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
max_in_flight
,
num_workers
,
router_mode
,
...
...
@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_file_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
...
...
@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_file_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
max_in_flight
,
num_workers
,
router_mode
,
...
...
@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_file_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
num_workers
,
arrival_speedup_ratio
,
router_mode
,
...
...
@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_file_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
max_in_flight
,
router_mode
,
)
...
...
@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_file_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
trace_block_size
,
arrival_speedup_ratio
,
router_mode
,
)
...
...
@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay(
}
#[pyfunction]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
,
turns_per_session=
1
,
shared_prefix_ratio=
0.0
,
num_prefix_groups=
0
,
inter_turn_delay_ms=
0.0
))]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
,
turns_per_session=
1
,
shared_prefix_ratio=
0.0
,
num_prefix_groups=
0
,
inter_turn_delay_ms=
0.0
))]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_synthetic_trace_replay
(
py
:
Python
<
'_
>
,
...
...
@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay(
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
num_workers
:
usize
,
num_prefill_workers
:
usize
,
num_decode_workers
:
usize
,
...
...
@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay(
num_prefill_workers
,
num_decode_workers
,
)
?
;
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
prefill_load_estimator
=
load_replay_prefill_load_estimator
(
py
,
router_mode
,
router_config
.as_ref
(),
aic_perf_config
,
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
block_size
=
match
&
args_selection
{
ReplayArgsSelection
::
Aggregated
(
args
)
=>
args
.block_size
.max
(
1
),
ReplayArgsSelection
::
Disagg
(
config
)
=>
config
.prefill_args.block_size
.max
(
1
),
...
...
@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_workload_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
max_in_flight
,
num_workers
,
...
...
@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_workload_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
num_workers
,
router_mode
,
...
...
@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_workload_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
max_in_flight
,
num_workers
,
...
...
@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_workload_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
num_workers
,
router_mode
,
...
...
@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
Some
(
max_in_flight
))
=>
dynamo_mocker
::
replay
::
simulate_concurrency_workload_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
max_in_flight
,
router_mode
,
...
...
@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_workload_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
router_mode
,
),
...
...
@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
max_in_flight
,
num_workers
,
...
...
@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_requests_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
...
...
@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_requests_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
max_in_flight
,
num_workers
,
...
...
@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_requests_with_router_mode
(
*
args
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
num_workers
,
arrival_speedup_ratio
,
...
...
@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
max_in_flight
,
router_mode
,
...
...
@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_requests_disagg_with_router_mode
(
*
config
,
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
arrival_speedup_ratio
,
router_mode
,
...
...
@@ -970,6 +1009,57 @@ fn load_replay_router_config(
router_config
.map
(|
config
|
config
.inner
())
}
fn
load_replay_prefill_load_estimator
(
py
:
Python
<
'_
>
,
router_mode
:
dynamo_mocker
::
replay
::
ReplayRouterMode
,
router_config
:
Option
<&
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
)
->
PyResult
<
Option
<
dynamo_mocker
::
replay
::
ReplayPrefillLoadEstimator
>>
{
if
router_mode
!=
dynamo_mocker
::
replay
::
ReplayRouterMode
::
KvRouter
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_mode='kv_router'"
,
));
}
return
Ok
(
None
);
}
let
Some
(
router_config
)
=
router_config
else
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_config with router_prefill_load_model='aic'"
,
));
}
return
Ok
(
None
);
};
let
router_config
=
router_config
.inner
();
if
!
router_config
.router_prefill_load_model
.is_enabled
()
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_prefill_load_model='aic'"
,
));
}
return
Ok
(
None
);
}
let
Some
(
aic_perf_config
)
=
aic_perf_config
else
{
return
Err
(
PyException
::
new_err
(
"router_prefill_load_model='aic' requires aic_perf_config"
,
));
};
create_aic_prefill_load_estimator
(
py
,
aic_perf_config
.backend_name
(),
aic_perf_config
.system
(),
aic_perf_config
.model_path
(),
aic_perf_config
.tp_size
(),
aic_perf_config
.backend_version
(),
)
.map
(
Some
)
}
fn
parse_replay_router_mode
(
router_mode
:
&
str
,
)
->
PyResult
<
dynamo_mocker
::
replay
::
ReplayRouterMode
>
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
95a750f4
...
...
@@ -1159,6 +1159,17 @@ class RouterConfig:
"""
...
class AicPerfConfig:
def __init__(
self,
aic_backend: str,
aic_system: str,
aic_model_path: str,
aic_tp_size: int = 1,
aic_backend_version: Optional[str] = None,
) -> None:
...
class KvRouterConfig:
"""Values for KV router"""
...
...
@@ -1172,6 +1183,8 @@ class KvRouterConfig:
router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True,
router_track_prefill_tokens: bool = True,
router_prefill_load_model: str = "none",
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
router_ttl_secs: float = 120.0,
...
...
@@ -1199,6 +1212,10 @@ class KvRouterConfig:
sequence length (agent_hints.osl in nvext).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes.
router_track_prefill_tokens: Include prompt-side prefill tokens in active load accounting (default: True).
router_prefill_load_model: Prompt-side prefill load model (default: "none").
"none" keeps static prompt load accounting.
"aic" decays the oldest active prefill request using AIC-predicted duration.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
...
...
@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
...
...
@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay(
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
...
...
@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
...
...
@@ -1779,6 +1799,7 @@ class KvRouter:
endpoint: Endpoint,
block_size: int,
kv_router_config: KvRouterConfig,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
"""
Create a new KvRouter instance.
...
...
@@ -1787,6 +1808,7 @@ class KvRouter:
endpoint: The endpoint to connect to for routing requests
block_size: The KV cache block size
kv_router_config: Configuration for the KV router
aic_perf_config: Optional AIC perf-model config for effective prefill load tracking
"""
...
...
...
@@ -1998,6 +2020,7 @@ class EntrypointArgs:
is_prefill: bool = False,
migration_limit: int = 0,
chat_engine_factory: Optional[Callable] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
"""
Create EntrypointArgs.
...
...
@@ -2024,6 +2047,7 @@ class EntrypointArgs:
is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled)
chat_engine_factory: Optional Python chat completions engine factory callback
aic_perf_config: Optional AIC perf-model configuration for default KV routing
"""
...
...
...
lib/bindings/python/src/dynamo/_internal/aic.py
0 → 100644
View file @
95a750f4
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared AIC session helpers used by internal Dynamo integrations."""
from
__future__
import
annotations
import
logging
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_BACKEND_VERSIONS
=
{
"vllm"
:
"0.12.0"
,
"sglang"
:
"0.5.6.post2"
,
}
DEFAULT_STATIC_STRIDE
=
32
def
resolve_backend_version
(
backend_name
:
str
,
backend_version
:
str
|
None
)
->
str
:
"""Return the pinned backend version used for AIC perf lookups."""
if
backend_version
is
not
None
:
return
backend_version
return
DEFAULT_BACKEND_VERSIONS
.
get
(
backend_name
,
DEFAULT_BACKEND_VERSIONS
[
"vllm"
])
def
_load_aiconfigurator
():
try
:
from
aiconfigurator.sdk
import
config
from
aiconfigurator.sdk.backends.factory
import
get_backend
from
aiconfigurator.sdk.inference_session
import
InferenceSession
from
aiconfigurator.sdk.models
import
get_model
from
aiconfigurator.sdk.perf_database
import
(
get_database
,
get_supported_databases
,
)
except
(
ImportError
)
as
exc
:
# pragma: no cover - exercised in integration environments
raise
RuntimeError
(
"aiconfigurator is required for AIC perf modeling but is not installed"
)
from
exc
return
{
"config"
:
config
,
"get_backend"
:
get_backend
,
"InferenceSession"
:
InferenceSession
,
"get_model"
:
get_model
,
"get_database"
:
get_database
,
"get_supported_databases"
:
get_supported_databases
,
}
class
AicSession
:
"""Wrap an AIC InferenceSession with direct prefill/decode predictors."""
def
__init__
(
self
,
backend_name
:
str
,
system
:
str
,
model_path
:
str
,
tp_size
:
int
,
backend_version
:
str
|
None
=
None
,
moe_tp_size
:
int
|
None
=
None
,
moe_ep_size
:
int
|
None
=
None
,
attention_dp_size
:
int
|
None
=
None
,
):
aic
=
_load_aiconfigurator
()
version
=
resolve_backend_version
(
backend_name
,
backend_version
)
database
=
aic
[
"get_database"
](
system
=
system
,
backend
=
backend_name
,
version
=
version
)
if
database
is
None
:
supported
=
(
aic
[
"get_supported_databases"
]().
get
(
system
,
{}).
get
(
backend_name
,
[])
)
supported_versions
=
", "
.
join
(
supported
)
if
supported
else
"<none>"
raise
RuntimeError
(
"AIC perf database not found for "
f
"system=
{
system
!
r
}
, backend=
{
backend_name
!
r
}
, version=
{
version
!
r
}
. "
f
"Supported versions for this system/backend:
{
supported_versions
}
"
)
model_config
=
aic
[
"config"
].
ModelConfig
(
tp_size
=
tp_size
,
moe_tp_size
=
moe_tp_size
,
moe_ep_size
=
moe_ep_size
,
attention_dp_size
=
attention_dp_size
or
1
,
)
model
=
aic
[
"get_model"
](
model_path
=
model_path
,
model_config
=
model_config
,
backend_name
=
backend_name
,
)
backend
=
aic
[
"get_backend"
](
backend_name
)
self
.
_session
=
aic
[
"InferenceSession"
](
model
=
model
,
database
=
database
,
backend
=
backend
)
self
.
_database
=
database
self
.
_model
=
model
self
.
_model_name
=
getattr
(
model
,
"model_name"
,
None
)
or
model_path
logger
.
info
(
"AIC session initialized: backend=%s, system=%s, model=%s, tp=%d"
,
backend_name
,
system
,
model_path
,
tp_size
,
)
def
_predict_context_latency
(
self
,
batch_size
:
int
,
effective_isl
:
int
,
prefix
:
int
)
->
float
:
if
effective_isl
<=
0
:
raise
ValueError
(
f
"effective_isl must be positive, got effective_isl=
{
effective_isl
}
"
)
total_latency
=
0.0
for
op
in
self
.
_model
.
context_ops
:
op_name
=
getattr
(
op
,
"_name"
,
""
)
x
=
batch_size
if
"logits_gemm"
in
op_name
else
batch_size
*
effective_isl
result
=
op
.
query
(
self
.
_database
,
x
=
x
,
batch_size
=
batch_size
,
beam_width
=
1
,
s
=
effective_isl
,
prefix
=
prefix
,
model_name
=
self
.
_model_name
,
seq_imbalance_correction_scale
=
1.0
,
)
total_latency
+=
float
(
result
)
return
total_latency
def
_predict_generation_latency
(
self
,
batch_size
:
int
,
isl
:
int
,
osl
:
int
)
->
float
:
if
osl
<=
1
:
return
0.0
effective_batch_size
=
batch_size
*
(
self
.
_model
.
_nextn
+
1
)
total_latency
=
0.0
for
step
in
range
(
0
,
osl
-
1
,
DEFAULT_STATIC_STRIDE
):
step_latency
=
0.0
for
op
in
self
.
_model
.
generation_ops
:
result
=
op
.
query
(
self
.
_database
,
x
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
beam_width
=
1
,
s
=
isl
+
step
+
1
,
model_name
=
self
.
_model_name
,
gen_seq_imbalance_correction_scale
=
1.0
,
)
step_latency
+=
float
(
result
)
repeat_count
=
min
(
DEFAULT_STATIC_STRIDE
,
osl
-
1
-
step
)
total_latency
+=
step_latency
*
repeat_count
return
total_latency
def
predict_prefill
(
self
,
batch_size
:
int
,
effective_isl
:
int
,
prefix
:
int
)
->
float
:
"""Predict prefill latency in ms from uncached tokens and cached prefix."""
return
self
.
_predict_context_latency
(
batch_size
,
effective_isl
,
prefix
)
def
predict_decode
(
self
,
batch_size
:
int
,
isl
:
int
,
osl
:
int
)
->
float
:
"""Predict decode (generation) latency in ms."""
return
self
.
_predict_generation_latency
(
batch_size
,
isl
,
osl
)
def
create_session
(
backend_name
:
str
,
system
:
str
,
model_path
:
str
,
tp_size
:
int
,
backend_version
:
str
|
None
=
None
,
moe_tp_size
:
int
|
None
=
None
,
moe_ep_size
:
int
|
None
=
None
,
attention_dp_size
:
int
|
None
=
None
,
)
->
AicSession
:
"""Factory function called from Rust via PyO3."""
return
AicSession
(
backend_name
,
system
,
model_path
,
tp_size
,
backend_version
,
moe_tp_size
,
moe_ep_size
,
attention_dp_size
,
)
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
95a750f4
...
...
@@ -5,6 +5,7 @@
import
logging
from
dynamo._core
import
AicPerfConfig
as
AicPerfConfig
from
dynamo._core
import
EngineType
from
dynamo._core
import
EntrypointArgs
as
EntrypointArgs
from
dynamo._core
import
FpmEventRelay
as
FpmEventRelay
...
...
@@ -57,6 +58,7 @@ def run_mocker_trace_replay(
replay_concurrency
=
None
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
trace_block_size
=
512
,
):
return
_run_mocker_trace_replay
(
trace_file
,
...
...
@@ -67,4 +69,5 @@ def run_mocker_trace_replay(
replay_mode
=
"offline"
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
trace_block_size
=
trace_block_size
,
)
lib/bindings/python/src/dynamo/replay/api.py
View file @
95a750f4
...
...
@@ -14,6 +14,7 @@ def run_trace_replay(
prefill_engine_args
=
None
,
decode_engine_args
=
None
,
router_config
=
None
,
aic_perf_config
=
None
,
num_workers
=
1
,
num_prefill_workers
=
1
,
num_decode_workers
=
1
,
...
...
@@ -21,6 +22,7 @@ def run_trace_replay(
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
trace_block_size
=
512
,
):
return
_run_mocker_trace_replay
(
trace_file
,
...
...
@@ -28,6 +30,7 @@ def run_trace_replay(
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
num_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_decode_workers
=
num_decode_workers
,
...
...
@@ -35,6 +38,7 @@ def run_trace_replay(
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
trace_block_size
=
trace_block_size
,
)
...
...
@@ -47,6 +51,7 @@ def run_synthetic_trace_replay(
prefill_engine_args
=
None
,
decode_engine_args
=
None
,
router_config
=
None
,
aic_perf_config
=
None
,
num_workers
=
1
,
num_prefill_workers
=
1
,
num_decode_workers
=
1
,
...
...
@@ -68,6 +73,7 @@ def run_synthetic_trace_replay(
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
num_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_decode_workers
=
num_decode_workers
,
...
...
lib/bindings/python/src/dynamo/replay/main.py
View file @
95a750f4
...
...
@@ -15,7 +15,7 @@ from typing import Protocol
os
.
environ
.
setdefault
(
"DYNAMO_SKIP_PYTHON_LOG_INIT"
,
"1"
)
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.llm
import
AicPerfConfig
,
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
from
dynamo.replay.reporting
import
format_report_table
,
write_report_json
...
...
@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None):
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
raw
))
def
_load_aic_perf_config
(
args
:
argparse
.
Namespace
):
values
=
{
"aic_backend"
:
args
.
aic_backend
,
"aic_system"
:
args
.
aic_system
,
"aic_model_path"
:
args
.
aic_model_path
,
"aic_backend_version"
:
args
.
aic_backend_version
,
"aic_tp_size"
:
args
.
aic_tp_size
,
}
if
not
any
(
value
is
not
None
for
value
in
values
.
values
()):
return
None
missing
=
[
name
for
name
in
(
"aic_backend"
,
"aic_system"
,
"aic_model_path"
)
if
values
[
name
]
is
None
]
if
missing
:
missing_flags
=
", "
.
join
(
f
"--
{
name
.
replace
(
'_'
,
'-'
)
}
"
for
name
in
missing
)
raise
ValueError
(
f
"AIC replay modeling requires
{
missing_flags
}
"
)
return
AicPerfConfig
(
aic_backend
=
values
[
"aic_backend"
],
aic_system
=
values
[
"aic_system"
],
aic_model_path
=
values
[
"aic_model_path"
],
aic_tp_size
=
values
[
"aic_tp_size"
]
or
1
,
aic_backend_version
=
values
[
"aic_backend_version"
],
)
def
main
(
argv
:
Sequence
[
str
]
|
None
=
None
)
->
int
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"python -m dynamo.replay"
)
parser
.
add_argument
(
"trace_file"
,
nargs
=
"?"
)
...
...
@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int:
parser
.
add_argument
(
"--prefill-engine-args"
)
parser
.
add_argument
(
"--decode-engine-args"
)
parser
.
add_argument
(
"--router-config"
)
parser
.
add_argument
(
"--aic-backend"
)
parser
.
add_argument
(
"--aic-system"
)
parser
.
add_argument
(
"--aic-backend-version"
)
parser
.
add_argument
(
"--aic-tp-size"
,
type
=
int
)
parser
.
add_argument
(
"--aic-model-path"
)
parser
.
add_argument
(
"--input-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--output-tokens"
,
type
=
int
)
parser
.
add_argument
(
...
...
@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int:
default
=
"round_robin"
,
)
parser
.
add_argument
(
"--arrival-speedup-ratio"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--trace-block-size"
,
type
=
int
,
default
=
512
,
help
=
"tokens represented by each hash_id in the trace file; only used for file replay"
,
)
parser
.
add_argument
(
"--report-json"
,
help
=
"path to save the full replay report JSON; defaults to a timestamped file in the current directory"
,
...
...
@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int:
if
args
.
router_config
is
not
None
else
None
)
try
:
aic_perf_config
=
_load_aic_perf_config
(
args
)
except
ValueError
as
exc
:
parser
.
error
(
str
(
exc
))
if
using_trace_file
:
report
=
run_trace_replay
(
...
...
@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
args
.
num_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
...
...
@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int:
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
trace_block_size
=
args
.
trace_block_size
,
)
else
:
report
=
run_synthetic_trace_replay
(
...
...
@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
args
.
num_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
...
...
lib/bindings/python/tests/replay/test_replay_smoke.py
View file @
95a750f4
...
...
@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode):
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_supports_distinct_trace_and_engine_block_sizes
(
tmp_path
,
replay_mode
):
trace_path
=
tmp_path
/
"trace_block_size_split.jsonl"
trace_path
.
write_text
(
'{"timestamp":1000.0,"input_length":128,"output_length":2,"hash_ids":[101]}
\n
'
,
encoding
=
"utf-8"
,
)
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
_vllm_args
(),
num_workers
=
1
,
replay_mode
=
replay_mode
,
trace_block_size
=
512
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
1
,
input_tokens
=
128
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
...
...
lib/kv-router/src/lib.rs
View file @
95a750f4
...
...
@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
pub
use
self
::
sequence
::{
ActiveSequences
,
RequestId
};
pub
use
concurrent_radix_tree
::
ConcurrentRadixTree
;
pub
use
concurrent_radix_tree_compressed
::
ConcurrentRadixTreeCompressed
;
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterQueuePolicy
};
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterPrefillLoadModel
,
RouterQueuePolicy
};
pub
use
indexer
::{
MaybeError
,
SyncIndexer
,
ThreadPoolIndexer
};
pub
use
nested_map
::
PositionalIndexer
;
pub
use
protocols
::{
...
...
@@ -50,6 +50,7 @@ pub use protocols::{
pub
use
queue
::
SchedulerQueue
;
pub
use
radix_tree
::
RadixTree
;
pub
use
scheduling
::
LocalScheduler
;
pub
use
scheduling
::
PrefillLoadEstimator
;
pub
use
scheduling
::
policy
::{
FcfsPolicy
,
RouterSchedulingPolicy
,
SchedulingPolicy
,
WsptPolicy
};
pub
use
scheduling
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
pub
use
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
lib/kv-router/src/protocols.rs
View file @
95a750f4
...
...
@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use
std
::
future
::
Future
;
use
std
::
time
::
Duration
;
use
dynamo_tokens
::{
SequenceHash
,
Token
};
use
rustc_hash
::
FxHashMap
;
...
...
@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent {
pub
lora_name
:
Option
<
String
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
struct
PrefillLoadHint
{
pub
initial_effective_prefill_tokens
:
usize
,
pub
expected_prefill_duration
:
Option
<
Duration
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
ActiveSequenceEventData
{
AddRequest
{
...
...
@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData {
#[serde(default
=
"default_track_prefill_tokens"
)]
track_prefill_tokens
:
bool
,
expected_output_tokens
:
Option
<
u32
>
,
#[serde(default)]
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
},
Free
,
MarkPrefillCompleted
,
...
...
lib/kv-router/src/scheduling/config.rs
View file @
95a750f4
...
...
@@ -4,6 +4,7 @@
use
std
::
env
::{
self
,
VarError
};
use
std
::
fmt
;
use
std
::
str
::
FromStr
;
use
std
::
time
::
Duration
;
use
derive_builder
::
Builder
;
use
rand
::
Rng
;
...
...
@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy {
}
}
#[derive(Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
RouterPrefillLoadModel
{
#[default]
None
,
Aic
,
}
impl
fmt
::
Display
for
RouterPrefillLoadModel
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
Self
::
None
=>
f
.write_str
(
"none"
),
Self
::
Aic
=>
f
.write_str
(
"aic"
),
}
}
}
impl
FromStr
for
RouterPrefillLoadModel
{
type
Err
=
String
;
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"none"
=>
Ok
(
Self
::
None
),
"aic"
=>
Ok
(
Self
::
Aic
),
_
=>
Err
(
format!
(
"unknown prefill load model: {s:?}, expected 'none' or 'aic'"
)),
}
}
}
impl
RouterPrefillLoadModel
{
pub
fn
is_enabled
(
self
)
->
bool
{
!
matches!
(
self
,
Self
::
None
)
}
}
impl
FromStr
for
RouterQueuePolicy
{
type
Err
=
String
;
...
...
@@ -124,6 +162,9 @@ pub struct KvRouterConfig {
#[serde(default
=
"default_track_prefill_tokens"
)]
pub
router_track_prefill_tokens
:
bool
,
/// Optional model for estimating effective prompt-side prefill load over time.
pub
router_prefill_load_model
:
RouterPrefillLoadModel
,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min
=
1
))]
pub
router_snapshot_threshold
:
Option
<
u32
>
,
...
...
@@ -183,6 +224,7 @@ impl Default for KvRouterConfig {
router_track_output_blocks
:
false
,
router_assume_kv_reuse
:
true
,
router_track_prefill_tokens
:
default_track_prefill_tokens
(),
router_prefill_load_model
:
RouterPrefillLoadModel
::
default
(),
router_snapshot_threshold
:
Some
(
1000000
),
router_reset_states
:
false
,
router_ttl_secs
:
120.0
,
...
...
@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr
"router_track_output_blocks requires router_track_active_blocks=true"
,
));
}
if
config
.router_prefill_load_model
.is_enabled
()
&&
!
config
.router_track_prefill_tokens
{
return
Err
(
ValidationError
::
new
(
"router_prefill_load_model requires router_track_prefill_tokens=true"
,
));
}
if
config
.router_prefill_load_model
.is_enabled
()
&&
!
matches!
(
config
.router_queue_policy
,
RouterQueuePolicy
::
Fcfs
)
{
return
Err
(
ValidationError
::
new
(
"router_prefill_load_model currently requires router_queue_policy='fcfs'"
,
));
}
Ok
(())
}
impl
KvRouterConfig
{
pub
fn
router_queue_recheck_interval
(
&
self
)
->
Duration
{
const
DEFAULT_RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
const
PREFILL_LOAD_RECHECK_INTERVAL
:
Duration
=
Duration
::
from_millis
(
100
);
if
self
.router_prefill_load_model
.is_enabled
()
&&
self
.router_queue_threshold
.is_some
()
{
return
PREFILL_LOAD_RECHECK_INTERVAL
;
}
DEFAULT_RECHECK_INTERVAL
}
pub
fn
assume_kv_reuse
(
&
self
,
config_override
:
Option
<&
RouterConfigOverride
>
)
->
bool
{
config_override
.and_then
(|
cfg
|
cfg
.assume_kv_reuse
)
...
...
@@ -288,28 +353,6 @@ mod tests {
use
super
::
*
;
use
crate
::
protocols
::{
BlockExtraInfo
,
BlockMmObjectInfo
};
#[test]
fn
router_queue_policy_display_and_parse_support_lcfs
()
{
assert_eq!
(
RouterQueuePolicy
::
Lcfs
.to_string
(),
"lcfs"
);
assert_eq!
(
"lcfs"
.parse
::
<
RouterQueuePolicy
>
()
.unwrap
(),
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
router_queue_policy_serde_round_trip_supports_lcfs
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterQueuePolicy
::
Lcfs
)
.unwrap
();
assert_eq!
(
serialized
,
"
\"
lcfs
\"
"
);
let
deserialized
:
RouterQueuePolicy
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
,
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
kv_router_config_defaults_to_tracking_prefill_tokens
()
{
assert
!
(
KvRouterConfig
::
default
()
.router_track_prefill_tokens
);
}
#[test]
fn
compute_seq_hashes_for_tracking_uses_mm_hashes
()
{
let
cfg
=
KvRouterConfig
::
default
();
...
...
@@ -343,17 +386,6 @@ mod tests {
assert_ne!
(
without_mm
,
with_mm
);
}
#[test]
fn
router_config_override_serde_round_trip_preserves_track_prefill_tokens
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterConfigOverride
{
track_prefill_tokens
:
Some
(
false
),
..
Default
::
default
()
})
.unwrap
();
let
deserialized
:
RouterConfigOverride
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.track_prefill_tokens
,
Some
(
false
));
}
#[test]
fn
compute_seq_hashes_for_tracking_uses_precomputed_block_hashes
()
{
let
config
=
KvRouterConfig
::
default
();
...
...
lib/kv-router/src/scheduling/local.rs
View file @
95a750f4
...
...
@@ -6,9 +6,11 @@ use std::sync::Arc;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
mpsc
,
watch
};
use
tokio
::
time
::
Instant
;
use
tokio_util
::
sync
::
CancellationToken
;
use
super
::
policy
::{
RouterSchedulingPolicy
,
SchedulingPolicy
};
use
super
::
prefill_load
::
PrefillLoadEstimator
;
use
super
::
queue
::
SchedulerQueue
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
...
...
@@ -18,8 +20,6 @@ use crate::sequences::{
};
use
dynamo_tokens
::
SequenceHash
;
const
RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
pub
struct
LocalScheduler
<
P
,
C
,
S
=
RouterSchedulingPolicy
,
Sel
=
DefaultWorkerSelector
>
where
P
:
SequencePublisher
,
...
...
@@ -49,6 +49,8 @@ where
block_size
:
u32
,
selector
:
Sel
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
recheck_interval
:
Duration
,
track_prefill_tokens_default
:
bool
,
cancellation_token
:
CancellationToken
,
worker_type
:
&
'static
str
,
...
...
@@ -103,13 +105,14 @@ where
block_size
,
selector
,
policy
,
prefill_load_estimator
,
));
let
(
request_tx
,
request_rx
)
=
mpsc
::
channel
::
<
SchedulingRequest
>
(
1024
);
let
queue_clone
=
Arc
::
clone
(
&
queue
);
tokio
::
spawn
(
async
move
{
let
mut
request_rx
=
request_rx
;
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
RECHECK_INTERVAL
);
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
recheck_interval
);
tracing
::
trace!
(
"LocalScheduler background task started"
);
loop
{
...
...
@@ -192,17 +195,18 @@ where
}
pub
async
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_request
(
req
)
self
.slots
.add_request
(
req
,
Instant
::
now
()
)
}
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
())
?
;
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
(),
Instant
::
now
())
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
pub
async
fn
free
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.free
(
&
request_id
.to_string
())
?
;
self
.slots
.free
(
&
request_id
.to_string
()
,
Instant
::
now
()
)
?
;
self
.queue
.update
()
.await
;
Ok
(())
}
...
...
@@ -231,6 +235,7 @@ where
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
)
->
Vec
<
PotentialLoad
>
{
let
decay_now
=
Instant
::
now
();
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens_with_prefill_tracking
(
...
...
@@ -238,6 +243,7 @@ where
isl_tokens
,
overlaps
,
track_prefill_tokens
,
decay_now
,
);
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
...
...
@@ -275,15 +281,32 @@ mod tests {
use
super
::
*
;
use
crate
::
protocols
::
OverlapScores
;
use
crate
::
scheduling
::
PrefillLoadEstimator
;
use
crate
::
scheduling
::
policy
::
FcfsPolicy
;
use
crate
::
scheduling
::
selector
::
DefaultWorkerSelector
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
struct
FixedPrefillLoadEstimator
{
duration
:
Duration
,
}
impl
PrefillLoadEstimator
for
FixedPrefillLoadEstimator
{
fn
predict_prefill_duration
(
&
self
,
_
batch_size
:
usize
,
_
effective_isl
:
usize
,
_
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
Ok
(
self
.duration
)
}
}
#[allow(clippy::type_complexity)]
fn
make_scheduler
(
workers
:
HashMap
<
WorkerId
,
SimpleWorkerConfig
>
,
threshold_frac
:
Option
<
f64
>
,
monitor_worker_configs
:
bool
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
(
Arc
<
LocalScheduler
<
NoopSequencePublisher
,
SimpleWorkerConfig
,
FcfsPolicy
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
...
...
@@ -311,6 +334,8 @@ mod tests {
64
,
DefaultWorkerSelector
::
new
(
None
,
"test"
),
FcfsPolicy
,
prefill_load_estimator
,
Duration
::
from_secs
(
60
),
true
,
cancel_token
.clone
(),
"test"
,
...
...
@@ -329,7 +354,7 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
let
response
=
scheduler
.schedule
(
...
...
@@ -366,7 +391,7 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
.schedule
(
...
...
@@ -389,7 +414,7 @@ mod tests {
assert_eq!
(
slots
.active_tokens
()
.active_tokens
(
Instant
::
now
()
)
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.copied
(),
Some
(
0
)
...
...
@@ -408,7 +433,8 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
,
None
);
scheduler
.schedule
(
...
...
@@ -466,7 +492,7 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
.schedule
(
...
...
@@ -511,12 +537,16 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
let
token_seq
=
vec!
[
11
,
22
,
33
,
44
];
let
overlaps
=
OverlapScores
::
default
();
let
(
decode_blocks
,
prefill_tokens
)
=
slots
.potential_blocks_and_tokens
(
Some
(
&
token_seq
),
128
,
overlaps
.clone
());
let
(
decode_blocks
,
prefill_tokens
)
=
slots
.potential_blocks_and_tokens
(
Some
(
&
token_seq
),
128
,
overlaps
.clone
(),
Instant
::
now
(),
);
let
mut
expected
:
Vec
<
_
>
=
decode_blocks
.keys
()
.map
(|
worker
|
PotentialLoad
{
...
...
@@ -548,10 +578,51 @@ mod tests {
cancel_token
.cancel
();
}
#[tokio::test(start_paused
=
true
)]
async
fn
test_get_potential_loads_uses_decayed_prefill_tokens
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
let
estimator
:
Arc
<
dyn
PrefillLoadEstimator
>
=
Arc
::
new
(
FixedPrefillLoadEstimator
{
duration
:
Duration
::
from_secs
(
10
),
});
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
Some
(
estimator
));
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
100
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
.unwrap
();
tokio
::
time
::
advance
(
Duration
::
from_secs
(
6
))
.await
;
let
loads
=
scheduler
.get_potential_loads
(
None
,
0
,
OverlapScores
::
default
(),
true
);
assert_eq!
(
loads
.len
(),
1
);
assert_eq!
(
loads
[
0
]
.potential_prefill_tokens
,
40
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_register_workers_uses_default_dp_fallback
()
{
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
HashMap
::
new
(),
None
,
false
);
make_scheduler
(
HashMap
::
new
(),
None
,
false
,
None
);
scheduler
.register_workers
(
&
HashSet
::
from
([
42
]));
let
loads
=
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
(),
true
);
...
...
@@ -567,7 +638,7 @@ mod tests {
async
fn
test_worker_watch_updates_slot_ranges
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
::
default
());
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
assert_eq!
(
scheduler
...
...
@@ -615,7 +686,7 @@ mod tests {
..
Default
::
default
()
},
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
.schedule
(
...
...
lib/kv-router/src/scheduling/mod.rs
View file @
95a750f4
...
...
@@ -4,9 +4,11 @@
pub
mod
config
;
mod
local
;
pub
mod
policy
;
pub
mod
prefill_load
;
pub
mod
queue
;
pub
mod
selector
;
mod
types
;
pub
use
local
::
LocalScheduler
;
pub
use
prefill_load
::
PrefillLoadEstimator
;
pub
use
types
::
*
;
lib/kv-router/src/scheduling/prefill_load.rs
0 → 100644
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
time
::
Duration
;
pub
trait
PrefillLoadEstimator
:
Send
+
Sync
{
fn
predict_prefill_duration
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
;
}
lib/kv-router/src/scheduling/queue.rs
View file @
95a750f4
...
...
@@ -5,15 +5,16 @@ use std::cmp::Ordering;
use
std
::
collections
::{
BinaryHeap
,
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
as
AtomicOrdering
};
use
std
::
time
::
Instant
;
use
tokio
::
sync
::
Mutex
;
use
tokio
::
sync
::
watch
;
use
tokio
::
time
::
Instant
;
use
super
::
policy
::{
FcfsPolicy
,
SchedulingPolicy
};
use
super
::
prefill_load
::
PrefillLoadEstimator
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
protocols
::{
PrefillLoadHint
,
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequencePublisher
,
SequenceRequest
};
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
...
...
@@ -68,6 +69,7 @@ pub struct SchedulerQueue<
block_size
:
u32
,
selector
:
Sel
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
}
impl
<
...
...
@@ -84,6 +86,7 @@ impl<
block_size
:
u32
,
selector
:
Sel
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
Self
{
if
let
Some
(
frac
)
=
threshold_frac
{
tracing
::
info!
(
"Router queue enabled with threshold fraction {frac}"
);
...
...
@@ -98,6 +101,7 @@ impl<
block_size
,
selector
,
policy
,
prefill_load_estimator
,
}
}
...
...
@@ -133,23 +137,24 @@ impl<
/// capacity check is skipped.
pub
async
fn
enqueue
(
&
self
,
request
:
SchedulingRequest
)
{
let
Some
(
threshold
)
=
self
.threshold_frac
else
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
Instant
::
now
()
)
.await
;
return
;
};
if
request
.allowed_worker_ids
.is_some
()
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
Instant
::
now
()
)
.await
;
return
;
}
if
self
.all_workers_busy
(
threshold
,
request
.allowed_worker_ids
.as_ref
())
{
let
decay_now
=
Instant
::
now
();
if
self
.all_workers_busy
(
threshold
,
request
.allowed_worker_ids
.as_ref
(),
decay_now
)
{
tracing
::
debug!
(
"all workers busy, queueing request"
);
let
arrival_offset
=
self
.start_time
.elapsed
();
let
key
=
self
.policy
.enqueue_key
(
arrival_offset
,
&
request
);
self
.pending
.lock
()
.await
.push
(
QueueEntry
{
key
,
request
});
self
.pending_count
.fetch_add
(
1
,
AtomicOrdering
::
Relaxed
);
}
else
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
decay_now
)
.await
;
}
}
...
...
@@ -176,7 +181,8 @@ impl<
}
loop
{
if
self
.all_workers_busy
(
threshold
,
None
)
{
let
decay_now
=
Instant
::
now
();
if
self
.all_workers_busy
(
threshold
,
None
,
decay_now
)
{
break
;
}
let
Some
(
entry
)
=
self
.pending
.lock
()
.await
.pop
()
else
{
...
...
@@ -184,13 +190,13 @@ impl<
};
self
.pending_count
.fetch_sub
(
1
,
AtomicOrdering
::
Relaxed
);
tracing
::
debug!
(
"scheduling request from pending queue"
);
self
.schedule
(
entry
.request
)
.await
;
self
.schedule
(
entry
.request
,
decay_now
)
.await
;
}
}
/// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request.
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
)
{
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
,
decay_now
:
Instant
)
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens_with_prefill_tracking
(
...
...
@@ -198,6 +204,7 @@ impl<
request
.isl_tokens
,
request
.overlaps
.clone
(),
request
.track_prefill_tokens
,
decay_now
,
);
request
.decode_blocks
=
decode_blocks
;
request
.prefill_tokens
=
prefill_tokens
;
...
...
@@ -231,20 +238,66 @@ impl<
return
;
};
if
let
Err
(
e
)
=
self
.slots
.add_request
(
SequenceRequest
{
let
prefill_load_hint
=
self
.prefill_load_hint_for
(
request
.isl_tokens
,
selection
.overlap_blocks
,
request
.track_prefill_tokens
,
);
if
let
Err
(
e
)
=
self
.slots
.add_request
(
SequenceRequest
{
request_id
:
request_id
.clone
(),
token_sequence
:
request
.token_seq
,
isl
:
request
.isl_tokens
,
overlap
:
selection
.overlap_blocks
,
track_prefill_tokens
:
request
.track_prefill_tokens
,
expected_output_tokens
:
request
.expected_output_tokens
,
prefill_load_hint
,
worker
:
selection
.worker
,
lora_name
:
request
.lora_name
.clone
(),
})
{
},
decay_now
,
)
{
tracing
::
warn!
(
"Failed to add request {request_id}: {e}"
);
}
}
fn
prefill_load_hint_for
(
&
self
,
isl_tokens
:
usize
,
overlap_blocks
:
u32
,
track_prefill_tokens
:
bool
,
)
->
Option
<
PrefillLoadHint
>
{
if
!
track_prefill_tokens
{
return
None
;
}
let
prefix
=
(
overlap_blocks
as
usize
)
*
(
self
.block_size
as
usize
);
let
effective_isl
=
isl_tokens
.saturating_sub
(
prefix
);
if
effective_isl
==
0
{
return
None
;
}
let
Some
(
estimator
)
=
&
self
.prefill_load_estimator
else
{
return
None
;
};
match
estimator
.predict_prefill_duration
(
1
,
effective_isl
,
prefix
)
{
Ok
(
expected_prefill_duration
)
=>
Some
(
PrefillLoadHint
{
initial_effective_prefill_tokens
:
effective_isl
,
expected_prefill_duration
:
Some
(
expected_prefill_duration
),
}),
Err
(
error
)
=>
{
tracing
::
warn!
(
effective_isl
,
prefix
,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
}
}
/// Number of requests currently parked in the pending queue (lock-free).
pub
fn
pending_count
(
&
self
)
->
usize
{
self
.pending_count
.load
(
AtomicOrdering
::
Relaxed
)
...
...
@@ -255,8 +308,13 @@ impl<
/// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error.
fn
all_workers_busy
(
&
self
,
threshold
:
f64
,
allowed
:
Option
<&
HashSet
<
WorkerId
>>
)
->
bool
{
let
active_tokens
=
self
.slots
.active_tokens
();
fn
all_workers_busy
(
&
self
,
threshold
:
f64
,
allowed
:
Option
<&
HashSet
<
WorkerId
>>
,
decay_now
:
Instant
,
)
->
bool
{
let
active_tokens
=
self
.slots
.active_tokens
(
decay_now
);
let
configs
=
self
.workers_with_configs
.borrow
();
let
mut
checked_any
=
false
;
...
...
@@ -289,6 +347,7 @@ impl<
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::
watch
;
...
...
@@ -298,6 +357,25 @@ mod tests {
use
crate
::
sequences
::
ActiveSequencesMultiWorker
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
fn
decay_now
()
->
Instant
{
Instant
::
now
()
}
struct
FixedPrefillLoadEstimator
{
duration
:
Duration
,
}
impl
PrefillLoadEstimator
for
FixedPrefillLoadEstimator
{
fn
predict_prefill_duration
(
&
self
,
_
batch_size
:
usize
,
_
effective_isl
:
usize
,
_
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
Ok
(
self
.duration
)
}
}
fn
make_queue
(
num_workers
:
usize
,
block_size
:
u32
,
...
...
@@ -308,7 +386,7 @@ mod tests {
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
)
{
let
(
queue
,
slots
,
_
tx
)
=
make_queue_with_sender
(
num_workers
,
block_size
,
isl
,
threshold_frac
);
make_queue_with_sender
(
num_workers
,
block_size
,
isl
,
threshold_frac
,
None
);
(
queue
,
slots
)
}
...
...
@@ -318,6 +396,7 @@ mod tests {
block_size
:
u32
,
isl
:
usize
,
threshold_frac
:
Option
<
f64
>
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
(
Arc
<
SchedulerQueue
<
NoopSequencePublisher
,
SimpleWorkerConfig
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
...
...
@@ -354,6 +433,7 @@ mod tests {
block_size
,
selector
,
FcfsPolicy
,
prefill_load_estimator
,
));
(
queue
,
slots
,
cfg_tx
)
...
...
@@ -409,8 +489,8 @@ mod tests {
let
resp
=
resp
.expect
(
"scheduling failed"
);
assert
!
(
resp
.best_worker.worker_id
<
num_workers
as
u64
);
slots
.mark_prefill_completed
(
&
req_id
)
.unwrap
();
slots
.free
(
&
req_id
)
.unwrap
();
slots
.mark_prefill_completed
(
&
req_id
,
decay_now
()
)
.unwrap
();
slots
.free
(
&
req_id
,
decay_now
()
)
.unwrap
();
queue
.update
()
.await
;
}));
}
...
...
@@ -419,7 +499,7 @@ mod tests {
h
.await
.expect
(
"task panicked"
);
}
let
active
=
slots
.active_tokens
();
let
active
=
slots
.active_tokens
(
decay_now
()
);
for
(
worker
,
tokens
)
in
&
active
{
assert_eq!
(
*
tokens
,
0
,
...
...
@@ -453,8 +533,8 @@ mod tests {
for
_
in
0
..
num_requests
{
queue
.update
()
.await
;
for
rid
in
&
req_ids
{
let
_
=
slots
.mark_prefill_completed
(
rid
);
let
_
=
slots
.free
(
rid
);
let
_
=
slots
.mark_prefill_completed
(
rid
,
decay_now
()
);
let
_
=
slots
.free
(
rid
,
decay_now
()
);
}
}
queue
.update
()
.await
;
...
...
@@ -495,8 +575,10 @@ mod tests {
assert_eq!
(
queue
.pending_count
(),
2
);
// Free the first request and update — should drain one from pending
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
())
.unwrap
();
slots
.free
(
&
"req-1"
.to_string
())
.unwrap
();
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
(),
decay_now
())
.unwrap
();
slots
.free
(
&
"req-1"
.to_string
(),
decay_now
())
.unwrap
();
queue
.update
()
.await
;
// After update, one pending request should have been scheduled
...
...
@@ -507,16 +589,43 @@ mod tests {
);
// Free req-2 and update to drain remaining
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
());
let
_
=
slots
.free
(
&
"req-2"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-2"
.to_string
()
,
decay_now
()
);
queue
.update
()
.await
;
let
_
=
slots
.mark_prefill_completed
(
&
"req-3"
.to_string
());
let
_
=
slots
.free
(
&
"req-3"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-3"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-3"
.to_string
()
,
decay_now
()
);
queue
.update
()
.await
;
assert_eq!
(
queue
.pending_count
(),
0
,
"all requests should be drained"
);
}
#[tokio::test(start_paused
=
true
)]
async
fn
test_queue_update_uses_decayed_oldest_prefill_load
()
{
let
estimator
:
Arc
<
dyn
PrefillLoadEstimator
>
=
Arc
::
new
(
FixedPrefillLoadEstimator
{
duration
:
Duration
::
from_secs
(
10
),
});
let
(
queue
,
_
slots
,
_
cfg_tx
)
=
make_queue_with_sender
(
1
,
16
,
100
,
Some
(
0.5
),
Some
(
estimator
));
let
(
req1
,
rx1
)
=
make_request
(
"req-1"
,
100
);
queue
.enqueue
(
req1
)
.await
;
let
_
=
rx1
.await
.unwrap
()
.unwrap
();
let
(
req2
,
mut
rx2
)
=
make_request
(
"req-2"
,
100
);
queue
.enqueue
(
req2
)
.await
;
assert_eq!
(
queue
.pending_count
(),
1
);
tokio
::
time
::
advance
(
Duration
::
from_secs
(
6
))
.await
;
queue
.update
()
.await
;
let
scheduled
=
rx2
.try_recv
()
.expect
(
"queued request should have been scheduled"
);
let
response
=
scheduled
.expect
(
"scheduling returned error"
);
assert_eq!
(
response
.best_worker.worker_id
,
0
);
assert_eq!
(
queue
.pending_count
(),
0
);
}
#[tokio::test]
async
fn
test_no_workers_returns_error
()
{
let
(
queue
,
_
slots
)
=
make_queue
(
0
,
16
,
512
,
None
);
...
...
@@ -542,7 +651,7 @@ mod tests {
let
isl
=
512
;
// Start with zero workers (mimics skip_initial_worker_wait=true)
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Routing with no workers must fail
let
(
req_fail
,
rx_fail
)
=
make_request
(
"before-register"
,
isl
);
...
...
@@ -590,9 +699,11 @@ mod tests {
// Clean up
slots
.mark_prefill_completed
(
&
"after-register"
.to_string
())
.mark_prefill_completed
(
&
"after-register"
.to_string
(),
decay_now
())
.unwrap
();
slots
.free
(
&
"after-register"
.to_string
(),
decay_now
())
.unwrap
();
slots
.free
(
&
"after-register"
.to_string
())
.unwrap
();
}
/// Register_workers is additive: calling with a new set does NOT remove old workers.
...
...
@@ -601,7 +712,7 @@ mod tests {
let
block_size
=
16
;
let
isl
=
256
;
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Register worker 10 in slots and config
let
mut
dp1
=
std
::
collections
::
HashMap
::
new
();
...
...
@@ -643,8 +754,8 @@ mod tests {
.expect
(
"oneshot dropped"
)
.expect
(
"scheduling failed"
);
seen
.insert
(
resp
.best_worker.worker_id
);
slots
.mark_prefill_completed
(
&
req_id
)
.unwrap
();
slots
.free
(
&
req_id
)
.unwrap
();
slots
.mark_prefill_completed
(
&
req_id
,
decay_now
()
)
.unwrap
();
slots
.free
(
&
req_id
,
decay_now
()
)
.unwrap
();
}
assert
!
(
...
...
@@ -659,7 +770,7 @@ mod tests {
let
block_size
=
16
;
let
isl
=
256
;
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Register three workers
let
mut
dp
=
std
::
collections
::
HashMap
::
new
();
...
...
@@ -712,9 +823,9 @@ mod tests {
resp
.best_worker.worker_id
);
slots
.mark_prefill_completed
(
&
"filter-0"
.to_string
())
.mark_prefill_completed
(
&
"filter-0"
.to_string
()
,
decay_now
()
)
.unwrap
();
slots
.free
(
&
"filter-0"
.to_string
())
.unwrap
();
slots
.free
(
&
"filter-0"
.to_string
()
,
decay_now
()
)
.unwrap
();
}
#[tokio::test(flavor
=
"multi_thread"
)]
...
...
@@ -727,7 +838,7 @@ mod tests {
let
_
resp1
=
rx1
.await
.unwrap
()
.unwrap
();
assert_eq!
(
slots
.active_tokens
()
.active_tokens
(
decay_now
()
)
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.copied
(),
Some
(
0
)
...
...
@@ -738,9 +849,9 @@ mod tests {
let
_
resp2
=
rx2
.await
.unwrap
()
.unwrap
();
assert_eq!
(
queue
.pending_count
(),
0
);
let
_
=
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
());
let
_
=
slots
.free
(
&
"req-1"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
());
let
_
=
slots
.free
(
&
"req-2"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-1"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-2"
.to_string
()
,
decay_now
()
);
}
}
lib/kv-router/src/sequences/block_tracker.rs
0 → 100644
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dynamo_tokens
::
SequenceHash
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
Weak
};
#[derive(Debug,
Default)]
pub
(
super
)
struct
BlockTracker
{
pub
(
super
)
unique_blocks
:
HashMap
<
SequenceHash
,
Weak
<
()
>>
,
pub
(
super
)
fractional_blocks
:
HashMap
<
SequenceHash
,
f64
>
,
}
impl
BlockTracker
{
pub
(
super
)
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
Arc
<
()
>
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
let
Some
(
rc
)
=
weak
.upgrade
()
{
return
rc
;
}
let
rc
=
Arc
::
new
(());
self
.unique_blocks
.insert
(
*
block
,
Arc
::
downgrade
(
&
rc
));
rc
}
pub
(
super
)
fn
try_remove_block
(
&
mut
self
,
block
:
&
SequenceHash
)
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
weak
.strong_count
()
==
0
{
self
.unique_blocks
.remove
(
block
);
self
.fractional_blocks
.remove
(
block
);
}
}
pub
(
super
)
fn
active_blocks
(
&
self
)
->
usize
{
let
mut
count
=
self
.unique_blocks
.len
()
as
f64
;
for
(
hash
,
frac
)
in
&
self
.fractional_blocks
{
if
self
.unique_blocks
.contains_key
(
hash
)
{
count
=
count
-
1.0
+
frac
;
}
}
count
.round
()
as
usize
}
}
lib/kv-router/src/sequences/mod.rs
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod
block_tracker
;
pub
mod
multi_worker
;
mod
prefill_tracker
;
pub
mod
single
;
pub
use
multi_worker
::
*
;
...
...
lib/kv-router/src/sequences/multi_worker.rs
View file @
95a750f4
...
...
@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken;
use
super
::
single
::{
ActiveSequences
,
RequestId
};
use
crate
::
protocols
::{
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
WorkerWithDpRank
,
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
PrefillLoadHint
,
WorkerWithDpRank
,
};
// How often we force expire stale requests across all workers. See the comment
...
...
@@ -93,6 +94,7 @@ pub struct SequenceRequest {
pub
overlap
:
u32
,
pub
track_prefill_tokens
:
bool
,
pub
expected_output_tokens
:
Option
<
u32
>
,
pub
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
pub
worker
:
WorkerWithDpRank
,
pub
lora_name
:
Option
<
String
>
,
}
...
...
@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return
;
}
// TODO: Publish explicit prompt-load decay timestamps with these events so peer routers
// can mirror the same oldest-prefill anchor instead of approximating from receive time.
let
publisher
=
Arc
::
clone
(
&
self
.publisher
);
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
publisher
.publish_event
(
&
event
)
.await
{
...
...
@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
continue
;
}
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time.
let
decay_now
=
Instant
::
now
();
match
&
event
.data
{
ActiveSequenceEventData
::
AddRequest
{
token_sequence
,
...
...
@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
prefill_load_hint
,
}
=>
{
self
.request_to_worker
.insert
(
event
.request_id
.clone
(),
event
.worker
);
...
...
@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
*
overlap
,
*
expected_output_tokens
,
*
track_prefill_tokens
,
*
prefill_load_hint
,
decay_now
,
);
}
else
{
tracing
::
warn!
(
...
...
@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
{
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
table
.slots
[
idx
]
.1
.write
()
.free
(
&
event
.request_id
);
table
.slots
[
idx
]
.1
.write
()
.free
(
&
event
.request_id
,
decay_now
);
}
}
self
.request_to_lora
.remove
(
&
event
.request_id
);
...
...
@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
table
.slots
[
idx
]
.1
.write
()
.mark_prefill_completed
(
&
event
.request_id
);
.mark_prefill_completed
(
&
event
.request_id
,
decay_now
);
}
}
}
...
...
@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
fn
add_request_local
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
fn
add_request_local
(
&
self
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
let
SequenceRequest
{
request_id
,
token_sequence
,
...
...
@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
prefill_load_hint
,
worker
,
lora_name
,
}
=
req
;
...
...
@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
expected_output_tokens
,
track_prefill_tokens
,
prefill_load_hint
,
decay_now
,
)
};
...
...
@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
expired_id
);
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
Ok
(())
}
pub
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
pub
fn
add_request
(
&
self
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
self
.spawn_publish_event
(
ActiveSequenceEvent
{
request_id
:
req
.request_id
.clone
(),
worker
:
req
.worker
,
...
...
@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
:
req
.overlap
,
track_prefill_tokens
:
req
.track_prefill_tokens
,
expected_output_tokens
:
req
.expected_output_tokens
,
prefill_load_hint
:
req
.prefill_load_hint
,
},
router_id
:
self
.router_id
,
lora_name
:
req
.lora_name
.clone
(),
});
self
.add_request_local
(
req
)
self
.add_request_local
(
req
,
decay_now
)
}
/// Send a mutation to the worker assigned to a request, optionally publishing
...
...
@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn
mutate_request_worker_local
(
&
self
,
request_id
:
&
RequestId
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
),
decay_now
:
Instant
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
...
...
@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
mutate_fn
(
&
mut
seq
,
request_id
);
mutate_fn
(
&
mut
seq
,
request_id
,
decay_now
);
}
if
remove_mapping
{
...
...
@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
request_id
);
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
Ok
(())
}
...
...
@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn
mutate_request_worker
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
event_data
:
ActiveSequenceEventData
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
),
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
...
...
@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
lora_name
,
});
self
.mutate_request_worker_local
(
request_id
,
mutate_fn
,
remove_mapping
)
self
.mutate_request_worker_local
(
request_id
,
decay_now
,
mutate_fn
,
remove_mapping
)
}
/// Free all blocks associated with a request.
...
...
@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
)
->
Result
<
(),
SequenceError
>
{
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
->
Result
<
(),
SequenceError
>
{
if
!
self
.request_to_worker
.contains_key
(
request_id
)
{
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
return
Ok
(());
...
...
@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.mutate_request_worker
(
request_id
,
decay_now
,
ActiveSequenceEventData
::
Free
,
|
seqs
,
rid
|
{
seqs
.free
(
rid
);
|
seqs
,
rid
,
decay_now
|
{
seqs
.free
(
rid
,
decay_now
);
},
true
,
)
...
...
@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
pub
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
RequestId
)
->
Result
<
(),
SequenceError
>
{
pub
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
self
.mutate_request_worker
(
request_id
,
decay_now
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
|
seqs
,
rid
|
{
seqs
.mark_prefill_completed
(
rid
);
|
seqs
,
rid
,
decay_now
|
{
seqs
.mark_prefill_completed
(
rid
,
decay_now
);
},
false
,
)
...
...
@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
});
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
Instant
::
now
()
);
Ok
(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
)
{
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
,
decay_now
:
Instant
)
{
let
(
active_blocks
,
active_tokens
)
=
{
let
table
=
self
.workers
.read
();
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
...
...
@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return
;
};
let
seq
=
table
.slots
[
idx
]
.1
.read
();
(
seq
.active_blocks
(),
seq
.active_tokens
())
(
seq
.active_blocks
(),
seq
.active_tokens
(
decay_now
))
};
self
.publisher
...
...
@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlaps
:
OverlapScores
,
decay_now
:
Instant
,
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlaps
,
true
)
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlaps
,
true
,
decay_now
,
)
}
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
...
...
@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl
:
usize
,
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
...
...
@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl
,
overlap
,
track_prefill_tokens
,
decay_now
,
);
potential_blocks
.insert
(
*
worker
,
blocks
);
potential_tokens
.insert
(
*
worker
,
tokens
);
...
...
@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
/// Query all workers for their current number of active tokens.
pub
fn
active_tokens
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.active_tokens
());
results
.insert
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
));
}
results
}
...
...
@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Return true if any worker satisfies the provided predicate on active token count.
pub
fn
any_worker_matches_active_tokens
(
&
self
,
decay_now
:
Instant
,
mut
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
)
->
bool
{
let
table
=
self
.workers
.read
();
for
(
worker
,
lock
)
in
&
table
.slots
{
if
predicate
(
*
worker
,
lock
.read
()
.active_tokens
())
{
if
predicate
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
))
{
return
true
;
}
}
...
...
@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
expired_id
);
removed_request_count
+=
1
;
}
self
.publish_active_load_for_worker
(
*
worker
);
self
.publish_active_load_for_worker
(
*
worker
,
now
);
}
}
let
duration
=
now
.elapsed
();
...
...
@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
#[cfg(test)]
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
time
::
Duration
;
use
super
::
*
;
use
crate
::
protocols
::{
OverlapScores
,
PrefillLoadHint
};
use
crate
::
test_utils
::
NoopSequencePublisher
;
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
...
...
@@ -854,20 +896,74 @@ mod tests {
async
fn
add_request_can_skip_prefill_token_tracking
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
decay_now
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
})
},
decay_now
,
)
.unwrap
();
assert_eq!
(
sequences
.active_tokens
(
decay_now
)
.get
(
&
worker
)
.copied
(),
Some
(
0
)
);
}
#[test]
fn
explicit_decay_time_drives_multi_worker_load_queries_consistently
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
start
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
100
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
Some
(
PrefillLoadHint
{
initial_effective_prefill_tokens
:
100
,
expected_prefill_duration
:
Some
(
Duration
::
from_secs
(
10
)),
}),
worker
,
lora_name
:
None
,
},
start
,
)
.unwrap
();
assert_eq!
(
sequences
.active_tokens
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
let
decay_now
=
start
+
Duration
::
from_secs
(
5
);
let
active_tokens
=
sequences
.active_tokens
(
decay_now
);
assert_eq!
(
active_tokens
.get
(
&
worker
)
.copied
(),
Some
(
50
));
let
(
_
,
potential_tokens
)
=
sequences
.potential_blocks_and_tokens_with_prefill_tracking
(
None
,
0
,
OverlapScores
::
default
(),
false
,
decay_now
,
);
assert_eq!
(
potential_tokens
.get
(
&
worker
)
.copied
(),
Some
(
50
));
assert
!
(
sequences
.any_worker_matches_active_tokens
(
decay_now
,
|
candidate
,
tokens
|
{
candidate
==
worker
&&
tokens
==
50
})
);
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment