Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
95a750f4
Unverified
Commit
95a750f4
authored
Apr 06, 2026
by
Yan Ru Pei
Committed by
GitHub
Apr 06, 2026
Browse files
chore(replay): refactor offline components into cleaner lanes (#7866)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
210bbf5d
Changes
91
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1085 additions
and
154 deletions
+1085
-154
lib/bindings/python/rust/llm/aic_callback.rs
lib/bindings/python/rust/llm/aic_callback.rs
+50
-7
lib/bindings/python/rust/llm/entrypoint.rs
lib/bindings/python/rust/llm/entrypoint.rs
+98
-4
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+24
-1
lib/bindings/python/rust/llm/replay.rs
lib/bindings/python/rust/llm/replay.rs
+96
-6
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+24
-0
lib/bindings/python/src/dynamo/_internal/aic.py
lib/bindings/python/src/dynamo/_internal/aic.py
+193
-0
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+3
-0
lib/bindings/python/src/dynamo/replay/api.py
lib/bindings/python/src/dynamo/replay/api.py
+6
-0
lib/bindings/python/src/dynamo/replay/main.py
lib/bindings/python/src/dynamo/replay/main.py
+48
-1
lib/bindings/python/tests/replay/test_replay_smoke.py
lib/bindings/python/tests/replay/test_replay_smoke.py
+26
-0
lib/kv-router/src/lib.rs
lib/kv-router/src/lib.rs
+2
-1
lib/kv-router/src/protocols.rs
lib/kv-router/src/protocols.rs
+9
-0
lib/kv-router/src/scheduling/config.rs
lib/kv-router/src/scheduling/config.rs
+65
-33
lib/kv-router/src/scheduling/local.rs
lib/kv-router/src/scheduling/local.rs
+88
-17
lib/kv-router/src/scheduling/mod.rs
lib/kv-router/src/scheduling/mod.rs
+2
-0
lib/kv-router/src/scheduling/prefill_load.rs
lib/kv-router/src/scheduling/prefill_load.rs
+13
-0
lib/kv-router/src/scheduling/queue.rs
lib/kv-router/src/scheduling/queue.rs
+158
-47
lib/kv-router/src/sequences/block_tracker.rs
lib/kv-router/src/sequences/block_tracker.rs
+45
-0
lib/kv-router/src/sequences/mod.rs
lib/kv-router/src/sequences/mod.rs
+2
-0
lib/kv-router/src/sequences/multi_worker.rs
lib/kv-router/src/sequences/multi_worker.rs
+133
-37
No files found.
lib/bindings/python/rust/llm/aic_callback.rs
View file @
95a750f4
...
@@ -8,15 +8,17 @@
...
@@ -8,15 +8,17 @@
//! predictions without knowing about PyO3.
//! predictions without knowing about PyO3.
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
pyo3
::
prelude
::
*
;
use
pyo3
::
prelude
::
*
;
use
dynamo_kv_router
::
PrefillLoadEstimator
;
use
dynamo_mocker
::
common
::
perf_model
::
AicCallback
;
use
dynamo_mocker
::
common
::
perf_model
::
AicCallback
;
/// Wraps a Python AIC InferenceSession for direct calls from Rust.
/// Wraps a Python AIC InferenceSession for direct calls from Rust.
///
///
/// The Python object must expose:
/// The Python object must expose:
/// - `predict_prefill(batch_size, isl, prefix
, osl
) -> float`
/// - `predict_prefill(batch_size,
effective_
isl, prefix) -> float`
/// - `predict_decode(batch_size, isl, osl) -> float`
/// - `predict_decode(batch_size, isl, osl) -> float`
pub
(
super
)
struct
PyAicCallback
{
pub
(
super
)
struct
PyAicCallback
{
pub
(
super
)
session
:
PyObject
,
pub
(
super
)
session
:
PyObject
,
...
@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback {
...
@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback {
unsafe
impl
Send
for
PyAicCallback
{}
unsafe
impl
Send
for
PyAicCallback
{}
unsafe
impl
Sync
for
PyAicCallback
{}
unsafe
impl
Sync
for
PyAicCallback
{}
impl
AicCallback
for
PyAicCallback
{
impl
PyAicCallback
{
fn
predict_prefill
(
&
self
,
batch_size
:
usize
,
isl
:
usize
,
prefix
:
usize
,
osl
:
usize
)
->
f64
{
fn
predict_prefill_ms
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
PyResult
<
f64
>
{
Python
::
with_gil
(|
py
|
{
Python
::
with_gil
(|
py
|
{
self
.session
self
.session
.call_method1
(
py
,
"predict_prefill"
,
(
batch_size
,
isl
,
prefix
,
osl
))
.call_method1
(
py
,
"predict_prefill"
,
(
batch_size
,
effective_isl
,
prefix
))
.and_then
(|
r
|
r
.extract
::
<
f64
>
(
py
))
.and_then
(|
result
|
result
.extract
::
<
f64
>
(
py
))
.unwrap_or_else
(|
e
|
panic!
(
"AIC predict_prefill failed: {e}"
))
})
})
}
}
}
impl
AicCallback
for
PyAicCallback
{
fn
predict_prefill
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
)
->
f64
{
self
.predict_prefill_ms
(
batch_size
,
effective_isl
,
prefix
)
.unwrap_or_else
(|
e
|
panic!
(
"AIC predict_prefill failed: {e}"
))
}
fn
predict_decode
(
&
self
,
batch_size
:
usize
,
isl
:
usize
,
osl
:
usize
)
->
f64
{
fn
predict_decode
(
&
self
,
batch_size
:
usize
,
isl
:
usize
,
osl
:
usize
)
->
f64
{
Python
::
with_gil
(|
py
|
{
Python
::
with_gil
(|
py
|
{
...
@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback {
...
@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback {
}
}
}
}
impl
PrefillLoadEstimator
for
PyAicCallback
{
fn
predict_prefill_duration
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
let
latency_ms
=
self
.predict_prefill_ms
(
batch_size
,
effective_isl
,
prefix
)
?
;
Ok
(
Duration
::
from_secs_f64
(
latency_ms
/
1000.0
))
}
}
/// Initialize an AIC callback by importing and calling the Python setup function.
/// Initialize an AIC callback by importing and calling the Python setup function.
///
///
/// Called once at mocker startup when `--aic-perf-model` is requested.
/// Called once at mocker startup when `--aic-perf-model` is requested.
...
@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback(
...
@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback(
moe_ep_size
:
Option
<
usize
>
,
moe_ep_size
:
Option
<
usize
>
,
attention_dp_size
:
Option
<
usize
>
,
attention_dp_size
:
Option
<
usize
>
,
)
->
PyResult
<
Arc
<
dyn
AicCallback
>>
{
)
->
PyResult
<
Arc
<
dyn
AicCallback
>>
{
let
module
=
py
.import
(
"dynamo.
mocker.aic_session
"
)
?
;
let
module
=
py
.import
(
"dynamo.
_internal.aic
"
)
?
;
let
session
=
module
.call_method1
(
let
session
=
module
.call_method1
(
"create_session"
,
"create_session"
,
(
(
...
@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback(
...
@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback(
session
:
session
.into
(),
session
:
session
.into
(),
}))
}))
}
}
pub
(
super
)
fn
create_aic_prefill_load_estimator
(
py
:
Python
<
'_
>
,
backend_name
:
&
str
,
system
:
&
str
,
model_path
:
&
str
,
tp_size
:
usize
,
backend_version
:
Option
<&
str
>
,
)
->
PyResult
<
Arc
<
dyn
PrefillLoadEstimator
>>
{
let
module
=
py
.import
(
"dynamo._internal.aic"
)
?
;
let
session
=
module
.call_method1
(
"create_session"
,
(
backend_name
,
system
,
model_path
,
tp_size
,
backend_version
),
)
?
;
Ok
(
Arc
::
new
(
PyAicCallback
{
session
:
session
.into
(),
}))
}
lib/bindings/python/rust/llm/entrypoint.rs
View file @
95a750f4
...
@@ -10,7 +10,9 @@ use std::sync::Arc;
...
@@ -10,7 +10,9 @@ use std::sync::Arc;
use
pyo3
::{
exceptions
::
PyException
,
exceptions
::
PyValueError
,
prelude
::
*
};
use
pyo3
::{
exceptions
::
PyException
,
exceptions
::
PyValueError
,
prelude
::
*
};
use
pyo3_async_runtimes
::
TaskLocals
;
use
pyo3_async_runtimes
::
TaskLocals
;
use
dynamo_kv_router
::
config
::
KvRouterConfig
as
RsKvRouterConfig
;
use
dynamo_kv_router
::
config
::{
KvRouterConfig
as
RsKvRouterConfig
,
RouterPrefillLoadModel
as
RsRouterPrefillLoadModel
,
};
use
dynamo_llm
::
discovery
::
LoadThresholdConfig
as
RsLoadThresholdConfig
;
use
dynamo_llm
::
discovery
::
LoadThresholdConfig
as
RsLoadThresholdConfig
;
use
dynamo_llm
::
entrypoint
::
ChatEngineFactoryCallback
;
use
dynamo_llm
::
entrypoint
::
ChatEngineFactoryCallback
;
use
dynamo_llm
::
entrypoint
::
EngineConfig
as
RsEngineConfig
;
use
dynamo_llm
::
entrypoint
::
EngineConfig
as
RsEngineConfig
;
...
@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
...
@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use
dynamo_llm
::
types
::
openai
::
chat_completions
::
OpenAIChatCompletionsStreamingEngine
;
use
dynamo_llm
::
types
::
openai
::
chat_completions
::
OpenAIChatCompletionsStreamingEngine
;
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
dynamo_mocker
::
common
::
perf_model
::
PerfModel
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
aic_callback
::
{
create_aic_callback
,
create_aic_prefill_load_estimator
}
;
use
super
::
replay
::
MockEngineArgs
as
PyMockEngineArgs
;
use
super
::
replay
::
MockEngineArgs
as
PyMockEngineArgs
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
as
RsMockEngineArgs
;
use
dynamo_mocker
::
common
::
protocols
::
MockEngineArgs
as
RsMockEngineArgs
;
use
dynamo_runtime
::
discovery
::
ModelCardInstanceId
as
RsModelCardInstanceId
;
use
dynamo_runtime
::
discovery
::
ModelCardInstanceId
as
RsModelCardInstanceId
;
...
@@ -55,10 +57,76 @@ impl KvRouterConfig {
...
@@ -55,10 +57,76 @@ impl KvRouterConfig {
}
}
}
}
#[pyclass]
#[derive(Clone,
Debug)]
pub
struct
AicPerfConfig
{
aic_backend
:
String
,
aic_system
:
String
,
aic_backend_version
:
Option
<
String
>
,
aic_tp_size
:
usize
,
aic_model_path
:
String
,
}
impl
AicPerfConfig
{
pub
(
crate
)
fn
backend_name
(
&
self
)
->
&
str
{
&
self
.aic_backend
}
pub
(
crate
)
fn
system
(
&
self
)
->
&
str
{
&
self
.aic_system
}
pub
(
crate
)
fn
backend_version
(
&
self
)
->
Option
<&
str
>
{
self
.aic_backend_version
.as_deref
()
}
pub
(
crate
)
fn
tp_size
(
&
self
)
->
usize
{
self
.aic_tp_size
}
pub
(
crate
)
fn
model_path
(
&
self
)
->
&
str
{
&
self
.aic_model_path
}
}
#[pymethods]
impl
AicPerfConfig
{
#[new]
#[pyo3(signature
=
(aic_backend,
aic_system,
aic_model_path,
aic_tp_size=
1
,
aic_backend_version=None))]
fn
new
(
aic_backend
:
String
,
aic_system
:
String
,
aic_model_path
:
String
,
aic_tp_size
:
usize
,
aic_backend_version
:
Option
<
String
>
,
)
->
PyResult
<
Self
>
{
if
aic_backend
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_backend must be non-empty"
));
}
if
aic_system
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_system must be non-empty"
));
}
if
aic_model_path
.is_empty
()
{
return
Err
(
PyValueError
::
new_err
(
"aic_model_path must be non-empty"
));
}
if
aic_tp_size
==
0
{
return
Err
(
PyValueError
::
new_err
(
"aic_tp_size must be >= 1"
));
}
Ok
(
Self
{
aic_backend
,
aic_system
,
aic_backend_version
,
aic_tp_size
,
aic_model_path
,
})
}
}
#[pymethods]
#[pymethods]
impl
KvRouterConfig
{
impl
KvRouterConfig
{
#[new]
#[new]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_track_prefill_tokens=
true
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4.0
),
router_event_threads=
4
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[pyo3(signature
=
(overlap_score_weight=
1.0
,
router_temperature=
0.0
,
use_kv_events=
true
,
durable_kv_events=
false
,
router_replica_sync=
false
,
router_track_active_blocks=
true
,
router_track_output_blocks=
false
,
router_assume_kv_reuse=
true
,
router_track_prefill_tokens=
true
,
router_prefill_load_model=
"none"
,
router_snapshot_threshold=
1000000
,
router_reset_states=
false
,
router_ttl_secs=
120.0
,
router_max_tree_size=
1048576
,
router_prune_target_ratio=
0.8
,
router_queue_threshold=Some(
4.0
),
router_event_threads=
4
,
router_queue_policy=
"fcfs"
,
remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
fn
new
(
fn
new
(
overlap_score_weight
:
f64
,
overlap_score_weight
:
f64
,
...
@@ -70,6 +138,7 @@ impl KvRouterConfig {
...
@@ -70,6 +138,7 @@ impl KvRouterConfig {
router_track_output_blocks
:
bool
,
router_track_output_blocks
:
bool
,
router_assume_kv_reuse
:
bool
,
router_assume_kv_reuse
:
bool
,
router_track_prefill_tokens
:
bool
,
router_track_prefill_tokens
:
bool
,
router_prefill_load_model
:
&
str
,
router_snapshot_threshold
:
Option
<
u32
>
,
router_snapshot_threshold
:
Option
<
u32
>
,
router_reset_states
:
bool
,
router_reset_states
:
bool
,
router_ttl_secs
:
f64
,
router_ttl_secs
:
f64
,
...
@@ -91,6 +160,11 @@ impl KvRouterConfig {
...
@@ -91,6 +160,11 @@ impl KvRouterConfig {
router_track_output_blocks
,
router_track_output_blocks
,
router_assume_kv_reuse
,
router_assume_kv_reuse
,
router_track_prefill_tokens
,
router_track_prefill_tokens
,
router_prefill_load_model
:
router_prefill_load_model
.parse
::
<
RsRouterPrefillLoadModel
>
()
.unwrap_or_else
(|
_
|
{
panic!
(
"invalid router_prefill_load_model: {router_prefill_load_model:?}"
)
}),
router_snapshot_threshold
,
router_snapshot_threshold
,
router_reset_states
,
router_reset_states
,
router_ttl_secs
,
router_ttl_secs
,
...
@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs {
...
@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs {
is_prefill
:
bool
,
is_prefill
:
bool
,
migration_limit
:
u32
,
migration_limit
:
u32
,
chat_engine_factory
:
Option
<
PyEngineFactory
>
,
chat_engine_factory
:
Option
<
PyEngineFactory
>
,
aic_perf_config
:
Option
<
AicPerfConfig
>
,
}
}
#[pymethods]
#[pymethods]
impl
EntrypointArgs
{
impl
EntrypointArgs
{
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
#[new]
#[new]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None))]
#[pyo3(signature
=
(engine_type,
model_path=None,
model_name=None,
endpoint_id=None,
context_length=None,
template_file=None,
router_config=None,
kv_cache_block_size=None,
http_host=None,
http_port=None,
http_metrics_port=None,
tls_cert_path=None,
tls_key_path=None,
extra_engine_args=None,
mocker_engine_args=None,
runtime_config=None,
namespace=None,
namespace_prefix=None,
is_prefill=
false
,
migration_limit=
0
,
chat_engine_factory=None
,
aic_perf_config=None
))]
pub
fn
new
(
pub
fn
new
(
py
:
Python
<
'_
>
,
py
:
Python
<
'_
>
,
engine_type
:
EngineType
,
engine_type
:
EngineType
,
...
@@ -279,6 +354,7 @@ impl EntrypointArgs {
...
@@ -279,6 +354,7 @@ impl EntrypointArgs {
is_prefill
:
bool
,
is_prefill
:
bool
,
migration_limit
:
u32
,
migration_limit
:
u32
,
chat_engine_factory
:
Option
<
PyObject
>
,
chat_engine_factory
:
Option
<
PyObject
>
,
aic_perf_config
:
Option
<
AicPerfConfig
>
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
let
endpoint_id_obj
:
Option
<
EndpointId
>
=
endpoint_id
.as_deref
()
.map
(
EndpointId
::
from
);
let
endpoint_id_obj
:
Option
<
EndpointId
>
=
endpoint_id
.as_deref
()
.map
(
EndpointId
::
from
);
if
(
tls_cert_path
.is_some
()
&&
tls_key_path
.is_none
())
if
(
tls_cert_path
.is_some
()
&&
tls_key_path
.is_none
())
...
@@ -327,6 +403,7 @@ impl EntrypointArgs {
...
@@ -327,6 +403,7 @@ impl EntrypointArgs {
is_prefill
,
is_prefill
,
migration_limit
,
migration_limit
,
chat_engine_factory
,
chat_engine_factory
,
aic_perf_config
,
})
})
}
}
}
}
...
@@ -467,9 +544,26 @@ async fn select_engine(
...
@@ -467,9 +544,26 @@ async fn select_engine(
EngineType
::
Dynamic
=>
{
EngineType
::
Dynamic
=>
{
// Convert Python chat engine factory to Rust callback
// Convert Python chat engine factory to Rust callback
let
chat_engine_factory
=
args
.chat_engine_factory
.map
(
py_engine_factory_to_callback
);
let
chat_engine_factory
=
args
.chat_engine_factory
.map
(
py_engine_factory_to_callback
);
let
prefill_load_estimator
=
args
.aic_perf_config
.as_ref
()
.map
(|
config
|
{
Python
::
with_gil
(|
py
|
{
create_aic_prefill_load_estimator
(
py
,
config
.backend_name
(),
config
.system
(),
config
.model_path
(),
config
.tp_size
(),
config
.backend_version
(),
)
})
})
.transpose
()
?
;
RsEngineConfig
::
Dynamic
{
RsEngineConfig
::
Dynamic
{
model
:
Box
::
new
(
local_model
),
model
:
Box
::
new
(
local_model
),
chat_engine_factory
,
chat_engine_factory
,
prefill_load_estimator
,
}
}
}
}
EngineType
::
Mocker
=>
{
EngineType
::
Mocker
=>
{
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
95a750f4
...
@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker;
...
@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker;
use
llm_rs
::
protocols
::
common
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
llm_rs
::
protocols
::
common
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
serde_json
::
json
;
use
serde_json
::
json
;
use
super
::
aic_callback
::
create_aic_prefill_load_estimator
;
use
super
::
entrypoint
::
AicPerfConfig
;
fn
depythonize_block_mm_infos
(
obj
:
&
Bound
<
'_
,
PyAny
>
)
->
PyResult
<
Vec
<
Option
<
BlockExtraInfo
>>>
{
fn
depythonize_block_mm_infos
(
obj
:
&
Bound
<
'_
,
PyAny
>
)
->
PyResult
<
Vec
<
Option
<
BlockExtraInfo
>>>
{
depythonize
(
obj
)
.map_err
(
to_pyerr
)
depythonize
(
obj
)
.map_err
(
to_pyerr
)
}
}
...
@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint(
...
@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint(
endpoint
:
&
Endpoint
,
endpoint
:
&
Endpoint
,
block_size
:
usize
,
block_size
:
usize
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
dynamo_kv_router
::
PrefillLoadEstimator
>>
,
)
->
Result
<
Arc
<
llm_rs
::
kv_router
::
KvRouter
>
,
PyErr
>
{
)
->
Result
<
Arc
<
llm_rs
::
kv_router
::
KvRouter
>
,
PyErr
>
{
// Create ModelManager and use it to create KvRouter (ensures registration)
// Create ModelManager and use it to create KvRouter (ensures registration)
let
model_manager
=
Arc
::
new
(
llm_rs
::
discovery
::
ModelManager
::
new
());
let
model_manager
=
Arc
::
new
(
llm_rs
::
discovery
::
ModelManager
::
new
());
...
@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint(
...
@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint(
&
endpoint
.inner
,
&
endpoint
.inner
,
block_size
as
u32
,
block_size
as
u32
,
kv_router_config
,
kv_router_config
,
prefill_load_estimator
,
worker_type
,
worker_type
,
model_name
,
model_name
,
enable_eagle
,
enable_eagle
,
...
@@ -888,12 +893,29 @@ impl KvRouter {
...
@@ -888,12 +893,29 @@ impl KvRouter {
/// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
/// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
/// (contains "prefill") or by `router_track_active_blocks` being disabled.
/// (contains "prefill") or by `router_track_active_blocks` being disabled.
#[new]
#[new]
#[pyo3(signature
=
(endpoint,
block_size,
kv_router_config))]
#[pyo3(signature
=
(endpoint,
block_size,
kv_router_config
,
aic_perf_config=None
))]
fn
new
(
fn
new
(
endpoint
:
&
Endpoint
,
endpoint
:
&
Endpoint
,
block_size
:
usize
,
block_size
:
usize
,
kv_router_config
:
&
super
::
entrypoint
::
KvRouterConfig
,
kv_router_config
:
&
super
::
entrypoint
::
KvRouterConfig
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
let
prefill_load_estimator
=
aic_perf_config
.map
(|
config
|
{
Python
::
with_gil
(|
py
|
{
create_aic_prefill_load_estimator
(
py
,
config
.backend_name
(),
config
.system
(),
config
.model_path
(),
config
.tp_size
(),
config
.backend_version
(),
)
})
})
.transpose
()
.map_err
(
to_pyerr
)
?
;
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
move
{
runtime
.block_on
(
async
move
{
let
client
=
endpoint
.inner
.client
()
.await
.map_err
(
to_pyerr
)
?
;
let
client
=
endpoint
.inner
.client
()
.await
.map_err
(
to_pyerr
)
?
;
...
@@ -916,6 +938,7 @@ impl KvRouter {
...
@@ -916,6 +938,7 @@ impl KvRouter {
endpoint
,
endpoint
,
block_size
,
block_size
,
Some
(
kv_router_config
.inner
()),
Some
(
kv_router_config
.inner
()),
prefill_load_estimator
,
)
)
.await
?
;
.await
?
;
...
...
lib/bindings/python/rust/llm/replay.rs
View file @
95a750f4
...
@@ -19,8 +19,8 @@ use pythonize::pythonize;
...
@@ -19,8 +19,8 @@ use pythonize::pythonize;
use
serde_json
::
json
;
use
serde_json
::
json
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
use
super
::
aic_callback
::
create_aic_callback
;
use
super
::
aic_callback
::
{
create_aic_callback
,
create_aic_prefill_load_estimator
}
;
use
super
::
entrypoint
::{
KvRouterConfig
,
to_pyerr
};
use
super
::
entrypoint
::{
AicPerfConfig
,
KvRouterConfig
,
to_pyerr
};
fn
parse_mocker_engine_type
(
engine_type
:
&
str
)
->
PyResult
<
RsMockerEngineType
>
{
fn
parse_mocker_engine_type
(
engine_type
:
&
str
)
->
PyResult
<
RsMockerEngineType
>
{
match
engine_type
{
match
engine_type
{
...
@@ -526,7 +526,7 @@ impl MockEngineArgs {
...
@@ -526,7 +526,7 @@ impl MockEngineArgs {
}
}
#[pyfunction]
#[pyfunction]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
))]
#[pyo3(signature
=
(trace_file,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
trace_block_size=
512
))]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_trace_replay
(
pub
fn
run_mocker_trace_replay
(
py
:
Python
<
'_
>
,
py
:
Python
<
'_
>
,
...
@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay(
...
@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay(
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
router_config
:
Option
<
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
num_workers
:
usize
,
num_workers
:
usize
,
num_prefill_workers
:
usize
,
num_prefill_workers
:
usize
,
num_decode_workers
:
usize
,
num_decode_workers
:
usize
,
...
@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay(
...
@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay(
replay_mode
:
&
str
,
replay_mode
:
&
str
,
router_mode
:
&
str
,
router_mode
:
&
str
,
arrival_speedup_ratio
:
f64
,
arrival_speedup_ratio
:
f64
,
trace_block_size
:
usize
,
)
->
PyResult
<
PyObject
>
{
)
->
PyResult
<
PyObject
>
{
let
args_selection
=
load_replay_args_selection
(
let
args_selection
=
load_replay_args_selection
(
py
,
py
,
...
@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay(
...
@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay(
num_prefill_workers
,
num_prefill_workers
,
num_decode_workers
,
num_decode_workers
,
)
?
;
)
?
;
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
prefill_load_estimator
=
load_replay_prefill_load_estimator
(
py
,
router_mode
,
router_config
.as_ref
(),
aic_perf_config
,
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
report
=
py
.allow_threads
(
move
||
{
let
report
=
py
.allow_threads
(
move
||
{
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
let
replay_concurrency
=
parse_replay_concurrency
(
replay_concurrency
)
?
;
...
@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay(
...
@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_file_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_file_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
router_mode
,
router_mode
,
...
@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay(
...
@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_file_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_file_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
num_workers
,
num_workers
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
router_mode
,
router_mode
,
...
@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay(
...
@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_file_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_live_file_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
router_mode
,
router_mode
,
...
@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay(
...
@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_file_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_live_file_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
num_workers
,
num_workers
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
router_mode
,
router_mode
,
...
@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay(
...
@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_file_disagg_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_file_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
max_in_flight
,
max_in_flight
,
router_mode
,
router_mode
,
)
)
...
@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay(
...
@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_file_disagg_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_file_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
&
trace_file
,
&
trace_file
,
trace_block_size
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
router_mode
,
router_mode
,
)
)
...
@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay(
...
@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay(
}
}
#[pyfunction]
#[pyfunction]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
,
turns_per_session=
1
,
shared_prefix_ratio=
0.0
,
num_prefix_groups=
0
,
inter_turn_delay_ms=
0.0
))]
#[pyo3(signature
=
(input_tokens,
output_tokens,
request_count,
extra_engine_args=None,
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=
1
,
num_prefill_workers=
1
,
num_decode_workers=
1
,
replay_concurrency=None,
replay_mode=
"offline"
,
router_mode=
"round_robin"
,
arrival_speedup_ratio=
1.0
,
arrival_interval_ms=
1.0
,
turns_per_session=
1
,
shared_prefix_ratio=
0.0
,
num_prefix_groups=
0
,
inter_turn_delay_ms=
0.0
))]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
fn
run_mocker_synthetic_trace_replay
(
pub
fn
run_mocker_synthetic_trace_replay
(
py
:
Python
<
'_
>
,
py
:
Python
<
'_
>
,
...
@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay(
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
prefill_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
decode_engine_args
:
Option
<
MockEngineArgs
>
,
router_config
:
Option
<
KvRouterConfig
>
,
router_config
:
Option
<
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
num_workers
:
usize
,
num_workers
:
usize
,
num_prefill_workers
:
usize
,
num_prefill_workers
:
usize
,
num_decode_workers
:
usize
,
num_decode_workers
:
usize
,
...
@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay(
num_prefill_workers
,
num_prefill_workers
,
num_decode_workers
,
num_decode_workers
,
)
?
;
)
?
;
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
prefill_load_estimator
=
load_replay_prefill_load_estimator
(
py
,
router_mode
,
router_config
.as_ref
(),
aic_perf_config
,
)
?
;
let
router_config
=
load_replay_router_config
(
router_config
);
let
router_config
=
load_replay_router_config
(
router_config
);
let
replay_mode
=
replay_mode
.to_owned
();
let
replay_mode
=
replay_mode
.to_owned
();
let
router_mode
=
parse_replay_router_mode
(
router_mode
)
?
;
let
block_size
=
match
&
args_selection
{
let
block_size
=
match
&
args_selection
{
ReplayArgsSelection
::
Aggregated
(
args
)
=>
args
.block_size
.max
(
1
),
ReplayArgsSelection
::
Aggregated
(
args
)
=>
args
.block_size
.max
(
1
),
ReplayArgsSelection
::
Disagg
(
config
)
=>
config
.prefill_args.block_size
.max
(
1
),
ReplayArgsSelection
::
Disagg
(
config
)
=>
config
.prefill_args.block_size
.max
(
1
),
...
@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_workload_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_workload_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
...
@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_workload_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_workload_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
num_workers
,
num_workers
,
router_mode
,
router_mode
,
...
@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_workload_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_live_workload_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
...
@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_workload_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_live_workload_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
num_workers
,
num_workers
,
router_mode
,
router_mode
,
...
@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
Some
(
max_in_flight
))
=>
dynamo_mocker
::
replay
::
simulate_concurrency_workload_disagg_with_router_mode
(
(
"offline"
,
Some
(
max_in_flight
))
=>
dynamo_mocker
::
replay
::
simulate_concurrency_workload_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
max_in_flight
,
max_in_flight
,
router_mode
,
router_mode
,
...
@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_workload_disagg_with_router_mode
(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_workload_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
trace
,
trace
,
router_mode
,
router_mode
,
),
),
...
@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
...
@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_requests_with_router_mode
(
(
"offline"
,
None
)
=>
dynamo_mocker
::
replay
::
simulate_trace_requests_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
num_workers
,
num_workers
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
...
@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_live_requests_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_live_requests_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
max_in_flight
,
max_in_flight
,
num_workers
,
num_workers
,
...
@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_live_requests_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_live_requests_with_router_mode
(
*
args
,
*
args
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
num_workers
,
num_workers
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
...
@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_disagg_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_concurrency_requests_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
max_in_flight
,
max_in_flight
,
router_mode
,
router_mode
,
...
@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay(
...
@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker
::
replay
::
simulate_trace_requests_disagg_with_router_mode
(
dynamo_mocker
::
replay
::
simulate_trace_requests_disagg_with_router_mode
(
*
config
,
*
config
,
router_config
.clone
(),
router_config
.clone
(),
prefill_load_estimator
.clone
(),
requests
,
requests
,
arrival_speedup_ratio
,
arrival_speedup_ratio
,
router_mode
,
router_mode
,
...
@@ -970,6 +1009,57 @@ fn load_replay_router_config(
...
@@ -970,6 +1009,57 @@ fn load_replay_router_config(
router_config
.map
(|
config
|
config
.inner
())
router_config
.map
(|
config
|
config
.inner
())
}
}
fn
load_replay_prefill_load_estimator
(
py
:
Python
<
'_
>
,
router_mode
:
dynamo_mocker
::
replay
::
ReplayRouterMode
,
router_config
:
Option
<&
KvRouterConfig
>
,
aic_perf_config
:
Option
<&
AicPerfConfig
>
,
)
->
PyResult
<
Option
<
dynamo_mocker
::
replay
::
ReplayPrefillLoadEstimator
>>
{
if
router_mode
!=
dynamo_mocker
::
replay
::
ReplayRouterMode
::
KvRouter
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_mode='kv_router'"
,
));
}
return
Ok
(
None
);
}
let
Some
(
router_config
)
=
router_config
else
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_config with router_prefill_load_model='aic'"
,
));
}
return
Ok
(
None
);
};
let
router_config
=
router_config
.inner
();
if
!
router_config
.router_prefill_load_model
.is_enabled
()
{
if
aic_perf_config
.is_some
()
{
return
Err
(
PyException
::
new_err
(
"aic_perf_config requires router_prefill_load_model='aic'"
,
));
}
return
Ok
(
None
);
}
let
Some
(
aic_perf_config
)
=
aic_perf_config
else
{
return
Err
(
PyException
::
new_err
(
"router_prefill_load_model='aic' requires aic_perf_config"
,
));
};
create_aic_prefill_load_estimator
(
py
,
aic_perf_config
.backend_name
(),
aic_perf_config
.system
(),
aic_perf_config
.model_path
(),
aic_perf_config
.tp_size
(),
aic_perf_config
.backend_version
(),
)
.map
(
Some
)
}
fn
parse_replay_router_mode
(
fn
parse_replay_router_mode
(
router_mode
:
&
str
,
router_mode
:
&
str
,
)
->
PyResult
<
dynamo_mocker
::
replay
::
ReplayRouterMode
>
{
)
->
PyResult
<
dynamo_mocker
::
replay
::
ReplayRouterMode
>
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
95a750f4
...
@@ -1159,6 +1159,17 @@ class RouterConfig:
...
@@ -1159,6 +1159,17 @@ class RouterConfig:
"""
"""
...
...
class AicPerfConfig:
def __init__(
self,
aic_backend: str,
aic_system: str,
aic_model_path: str,
aic_tp_size: int = 1,
aic_backend_version: Optional[str] = None,
) -> None:
...
class KvRouterConfig:
class KvRouterConfig:
"""Values for KV router"""
"""Values for KV router"""
...
@@ -1172,6 +1183,8 @@ class KvRouterConfig:
...
@@ -1172,6 +1183,8 @@ class KvRouterConfig:
router_track_active_blocks: bool = True,
router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True,
router_assume_kv_reuse: bool = True,
router_track_prefill_tokens: bool = True,
router_prefill_load_model: str = "none",
router_snapshot_threshold: Optional[int] = 1000000,
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
router_reset_states: bool = False,
router_ttl_secs: float = 120.0,
router_ttl_secs: float = 120.0,
...
@@ -1199,6 +1212,10 @@ class KvRouterConfig:
...
@@ -1199,6 +1212,10 @@ class KvRouterConfig:
sequence length (agent_hints.osl in nvext).
sequence length (agent_hints.osl in nvext).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes.
When True, computes actual block hashes. When False, generates random hashes.
router_track_prefill_tokens: Include prompt-side prefill tokens in active load accounting (default: True).
router_prefill_load_model: Prompt-side prefill load model (default: "none").
"none" keeps static prompt load accounting.
"aic" decays the oldest active prefill request using AIC-predicted duration.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False)
router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
...
@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay(
...
@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
num_decode_workers: int = 1,
...
@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay(
...
@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay(
replay_mode: Literal["offline", "online"] = "offline",
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> Dict[str, Any]:
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
...
...
@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay(
...
@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
num_decode_workers: int = 1,
...
@@ -1779,6 +1799,7 @@ class KvRouter:
...
@@ -1779,6 +1799,7 @@ class KvRouter:
endpoint: Endpoint,
endpoint: Endpoint,
block_size: int,
block_size: int,
kv_router_config: KvRouterConfig,
kv_router_config: KvRouterConfig,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
) -> None:
"""
"""
Create a new KvRouter instance.
Create a new KvRouter instance.
...
@@ -1787,6 +1808,7 @@ class KvRouter:
...
@@ -1787,6 +1808,7 @@ class KvRouter:
endpoint: The endpoint to connect to for routing requests
endpoint: The endpoint to connect to for routing requests
block_size: The KV cache block size
block_size: The KV cache block size
kv_router_config: Configuration for the KV router
kv_router_config: Configuration for the KV router
aic_perf_config: Optional AIC perf-model config for effective prefill load tracking
"""
"""
...
...
...
@@ -1998,6 +2020,7 @@ class EntrypointArgs:
...
@@ -1998,6 +2020,7 @@ class EntrypointArgs:
is_prefill: bool = False,
is_prefill: bool = False,
migration_limit: int = 0,
migration_limit: int = 0,
chat_engine_factory: Optional[Callable] = None,
chat_engine_factory: Optional[Callable] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
) -> None:
"""
"""
Create EntrypointArgs.
Create EntrypointArgs.
...
@@ -2024,6 +2047,7 @@ class EntrypointArgs:
...
@@ -2024,6 +2047,7 @@ class EntrypointArgs:
is_prefill: Whether this is a prefill worker
is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled)
migration_limit: Maximum number of request migrations (0=disabled)
chat_engine_factory: Optional Python chat completions engine factory callback
chat_engine_factory: Optional Python chat completions engine factory callback
aic_perf_config: Optional AIC perf-model configuration for default KV routing
"""
"""
...
...
...
...
lib/bindings/python/src/dynamo/_internal/aic.py
0 → 100644
View file @
95a750f4
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared AIC session helpers used by internal Dynamo integrations."""
from
__future__
import
annotations
import
logging
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_BACKEND_VERSIONS
=
{
"vllm"
:
"0.12.0"
,
"sglang"
:
"0.5.6.post2"
,
}
DEFAULT_STATIC_STRIDE
=
32
def
resolve_backend_version
(
backend_name
:
str
,
backend_version
:
str
|
None
)
->
str
:
"""Return the pinned backend version used for AIC perf lookups."""
if
backend_version
is
not
None
:
return
backend_version
return
DEFAULT_BACKEND_VERSIONS
.
get
(
backend_name
,
DEFAULT_BACKEND_VERSIONS
[
"vllm"
])
def
_load_aiconfigurator
():
try
:
from
aiconfigurator.sdk
import
config
from
aiconfigurator.sdk.backends.factory
import
get_backend
from
aiconfigurator.sdk.inference_session
import
InferenceSession
from
aiconfigurator.sdk.models
import
get_model
from
aiconfigurator.sdk.perf_database
import
(
get_database
,
get_supported_databases
,
)
except
(
ImportError
)
as
exc
:
# pragma: no cover - exercised in integration environments
raise
RuntimeError
(
"aiconfigurator is required for AIC perf modeling but is not installed"
)
from
exc
return
{
"config"
:
config
,
"get_backend"
:
get_backend
,
"InferenceSession"
:
InferenceSession
,
"get_model"
:
get_model
,
"get_database"
:
get_database
,
"get_supported_databases"
:
get_supported_databases
,
}
class
AicSession
:
"""Wrap an AIC InferenceSession with direct prefill/decode predictors."""
def
__init__
(
self
,
backend_name
:
str
,
system
:
str
,
model_path
:
str
,
tp_size
:
int
,
backend_version
:
str
|
None
=
None
,
moe_tp_size
:
int
|
None
=
None
,
moe_ep_size
:
int
|
None
=
None
,
attention_dp_size
:
int
|
None
=
None
,
):
aic
=
_load_aiconfigurator
()
version
=
resolve_backend_version
(
backend_name
,
backend_version
)
database
=
aic
[
"get_database"
](
system
=
system
,
backend
=
backend_name
,
version
=
version
)
if
database
is
None
:
supported
=
(
aic
[
"get_supported_databases"
]().
get
(
system
,
{}).
get
(
backend_name
,
[])
)
supported_versions
=
", "
.
join
(
supported
)
if
supported
else
"<none>"
raise
RuntimeError
(
"AIC perf database not found for "
f
"system=
{
system
!
r
}
, backend=
{
backend_name
!
r
}
, version=
{
version
!
r
}
. "
f
"Supported versions for this system/backend:
{
supported_versions
}
"
)
model_config
=
aic
[
"config"
].
ModelConfig
(
tp_size
=
tp_size
,
moe_tp_size
=
moe_tp_size
,
moe_ep_size
=
moe_ep_size
,
attention_dp_size
=
attention_dp_size
or
1
,
)
model
=
aic
[
"get_model"
](
model_path
=
model_path
,
model_config
=
model_config
,
backend_name
=
backend_name
,
)
backend
=
aic
[
"get_backend"
](
backend_name
)
self
.
_session
=
aic
[
"InferenceSession"
](
model
=
model
,
database
=
database
,
backend
=
backend
)
self
.
_database
=
database
self
.
_model
=
model
self
.
_model_name
=
getattr
(
model
,
"model_name"
,
None
)
or
model_path
logger
.
info
(
"AIC session initialized: backend=%s, system=%s, model=%s, tp=%d"
,
backend_name
,
system
,
model_path
,
tp_size
,
)
def
_predict_context_latency
(
self
,
batch_size
:
int
,
effective_isl
:
int
,
prefix
:
int
)
->
float
:
if
effective_isl
<=
0
:
raise
ValueError
(
f
"effective_isl must be positive, got effective_isl=
{
effective_isl
}
"
)
total_latency
=
0.0
for
op
in
self
.
_model
.
context_ops
:
op_name
=
getattr
(
op
,
"_name"
,
""
)
x
=
batch_size
if
"logits_gemm"
in
op_name
else
batch_size
*
effective_isl
result
=
op
.
query
(
self
.
_database
,
x
=
x
,
batch_size
=
batch_size
,
beam_width
=
1
,
s
=
effective_isl
,
prefix
=
prefix
,
model_name
=
self
.
_model_name
,
seq_imbalance_correction_scale
=
1.0
,
)
total_latency
+=
float
(
result
)
return
total_latency
def
_predict_generation_latency
(
self
,
batch_size
:
int
,
isl
:
int
,
osl
:
int
)
->
float
:
if
osl
<=
1
:
return
0.0
effective_batch_size
=
batch_size
*
(
self
.
_model
.
_nextn
+
1
)
total_latency
=
0.0
for
step
in
range
(
0
,
osl
-
1
,
DEFAULT_STATIC_STRIDE
):
step_latency
=
0.0
for
op
in
self
.
_model
.
generation_ops
:
result
=
op
.
query
(
self
.
_database
,
x
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
beam_width
=
1
,
s
=
isl
+
step
+
1
,
model_name
=
self
.
_model_name
,
gen_seq_imbalance_correction_scale
=
1.0
,
)
step_latency
+=
float
(
result
)
repeat_count
=
min
(
DEFAULT_STATIC_STRIDE
,
osl
-
1
-
step
)
total_latency
+=
step_latency
*
repeat_count
return
total_latency
def
predict_prefill
(
self
,
batch_size
:
int
,
effective_isl
:
int
,
prefix
:
int
)
->
float
:
"""Predict prefill latency in ms from uncached tokens and cached prefix."""
return
self
.
_predict_context_latency
(
batch_size
,
effective_isl
,
prefix
)
def
predict_decode
(
self
,
batch_size
:
int
,
isl
:
int
,
osl
:
int
)
->
float
:
"""Predict decode (generation) latency in ms."""
return
self
.
_predict_generation_latency
(
batch_size
,
isl
,
osl
)
def
create_session
(
backend_name
:
str
,
system
:
str
,
model_path
:
str
,
tp_size
:
int
,
backend_version
:
str
|
None
=
None
,
moe_tp_size
:
int
|
None
=
None
,
moe_ep_size
:
int
|
None
=
None
,
attention_dp_size
:
int
|
None
=
None
,
)
->
AicSession
:
"""Factory function called from Rust via PyO3."""
return
AicSession
(
backend_name
,
system
,
model_path
,
tp_size
,
backend_version
,
moe_tp_size
,
moe_ep_size
,
attention_dp_size
,
)
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
95a750f4
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
logging
import
logging
from
dynamo._core
import
AicPerfConfig
as
AicPerfConfig
from
dynamo._core
import
EngineType
from
dynamo._core
import
EngineType
from
dynamo._core
import
EntrypointArgs
as
EntrypointArgs
from
dynamo._core
import
EntrypointArgs
as
EntrypointArgs
from
dynamo._core
import
FpmEventRelay
as
FpmEventRelay
from
dynamo._core
import
FpmEventRelay
as
FpmEventRelay
...
@@ -57,6 +58,7 @@ def run_mocker_trace_replay(
...
@@ -57,6 +58,7 @@ def run_mocker_trace_replay(
replay_concurrency
=
None
,
replay_concurrency
=
None
,
router_mode
=
"round_robin"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
arrival_speedup_ratio
=
1.0
,
trace_block_size
=
512
,
):
):
return
_run_mocker_trace_replay
(
return
_run_mocker_trace_replay
(
trace_file
,
trace_file
,
...
@@ -67,4 +69,5 @@ def run_mocker_trace_replay(
...
@@ -67,4 +69,5 @@ def run_mocker_trace_replay(
replay_mode
=
"offline"
,
replay_mode
=
"offline"
,
router_mode
=
router_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
trace_block_size
=
trace_block_size
,
)
)
lib/bindings/python/src/dynamo/replay/api.py
View file @
95a750f4
...
@@ -14,6 +14,7 @@ def run_trace_replay(
...
@@ -14,6 +14,7 @@ def run_trace_replay(
prefill_engine_args
=
None
,
prefill_engine_args
=
None
,
decode_engine_args
=
None
,
decode_engine_args
=
None
,
router_config
=
None
,
router_config
=
None
,
aic_perf_config
=
None
,
num_workers
=
1
,
num_workers
=
1
,
num_prefill_workers
=
1
,
num_prefill_workers
=
1
,
num_decode_workers
=
1
,
num_decode_workers
=
1
,
...
@@ -21,6 +22,7 @@ def run_trace_replay(
...
@@ -21,6 +22,7 @@ def run_trace_replay(
replay_mode
=
"offline"
,
replay_mode
=
"offline"
,
router_mode
=
"round_robin"
,
router_mode
=
"round_robin"
,
arrival_speedup_ratio
=
1.0
,
arrival_speedup_ratio
=
1.0
,
trace_block_size
=
512
,
):
):
return
_run_mocker_trace_replay
(
return
_run_mocker_trace_replay
(
trace_file
,
trace_file
,
...
@@ -28,6 +30,7 @@ def run_trace_replay(
...
@@ -28,6 +30,7 @@ def run_trace_replay(
prefill_engine_args
=
prefill_engine_args
,
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_decode_workers
=
num_decode_workers
,
num_decode_workers
=
num_decode_workers
,
...
@@ -35,6 +38,7 @@ def run_trace_replay(
...
@@ -35,6 +38,7 @@ def run_trace_replay(
replay_mode
=
replay_mode
,
replay_mode
=
replay_mode
,
router_mode
=
router_mode
,
router_mode
=
router_mode
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
arrival_speedup_ratio
=
arrival_speedup_ratio
,
trace_block_size
=
trace_block_size
,
)
)
...
@@ -47,6 +51,7 @@ def run_synthetic_trace_replay(
...
@@ -47,6 +51,7 @@ def run_synthetic_trace_replay(
prefill_engine_args
=
None
,
prefill_engine_args
=
None
,
decode_engine_args
=
None
,
decode_engine_args
=
None
,
router_config
=
None
,
router_config
=
None
,
aic_perf_config
=
None
,
num_workers
=
1
,
num_workers
=
1
,
num_prefill_workers
=
1
,
num_prefill_workers
=
1
,
num_decode_workers
=
1
,
num_decode_workers
=
1
,
...
@@ -68,6 +73,7 @@ def run_synthetic_trace_replay(
...
@@ -68,6 +73,7 @@ def run_synthetic_trace_replay(
prefill_engine_args
=
prefill_engine_args
,
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_prefill_workers
=
num_prefill_workers
,
num_decode_workers
=
num_decode_workers
,
num_decode_workers
=
num_decode_workers
,
...
...
lib/bindings/python/src/dynamo/replay/main.py
View file @
95a750f4
...
@@ -15,7 +15,7 @@ from typing import Protocol
...
@@ -15,7 +15,7 @@ from typing import Protocol
os
.
environ
.
setdefault
(
"DYNAMO_SKIP_PYTHON_LOG_INIT"
,
"1"
)
os
.
environ
.
setdefault
(
"DYNAMO_SKIP_PYTHON_LOG_INIT"
,
"1"
)
from
dynamo.llm
import
KvRouterConfig
,
MockEngineArgs
from
dynamo.llm
import
AicPerfConfig
,
KvRouterConfig
,
MockEngineArgs
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
from
dynamo.replay
import
run_synthetic_trace_replay
,
run_trace_replay
from
dynamo.replay.reporting
import
format_report_table
,
write_report_json
from
dynamo.replay.reporting
import
format_report_table
,
write_report_json
...
@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None):
...
@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None):
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
raw
))
return
MockEngineArgs
.
from_json
(
json
.
dumps
(
raw
))
def
_load_aic_perf_config
(
args
:
argparse
.
Namespace
):
values
=
{
"aic_backend"
:
args
.
aic_backend
,
"aic_system"
:
args
.
aic_system
,
"aic_model_path"
:
args
.
aic_model_path
,
"aic_backend_version"
:
args
.
aic_backend_version
,
"aic_tp_size"
:
args
.
aic_tp_size
,
}
if
not
any
(
value
is
not
None
for
value
in
values
.
values
()):
return
None
missing
=
[
name
for
name
in
(
"aic_backend"
,
"aic_system"
,
"aic_model_path"
)
if
values
[
name
]
is
None
]
if
missing
:
missing_flags
=
", "
.
join
(
f
"--
{
name
.
replace
(
'_'
,
'-'
)
}
"
for
name
in
missing
)
raise
ValueError
(
f
"AIC replay modeling requires
{
missing_flags
}
"
)
return
AicPerfConfig
(
aic_backend
=
values
[
"aic_backend"
],
aic_system
=
values
[
"aic_system"
],
aic_model_path
=
values
[
"aic_model_path"
],
aic_tp_size
=
values
[
"aic_tp_size"
]
or
1
,
aic_backend_version
=
values
[
"aic_backend_version"
],
)
def
main
(
argv
:
Sequence
[
str
]
|
None
=
None
)
->
int
:
def
main
(
argv
:
Sequence
[
str
]
|
None
=
None
)
->
int
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"python -m dynamo.replay"
)
parser
=
argparse
.
ArgumentParser
(
prog
=
"python -m dynamo.replay"
)
parser
.
add_argument
(
"trace_file"
,
nargs
=
"?"
)
parser
.
add_argument
(
"trace_file"
,
nargs
=
"?"
)
...
@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int:
parser
.
add_argument
(
"--prefill-engine-args"
)
parser
.
add_argument
(
"--prefill-engine-args"
)
parser
.
add_argument
(
"--decode-engine-args"
)
parser
.
add_argument
(
"--decode-engine-args"
)
parser
.
add_argument
(
"--router-config"
)
parser
.
add_argument
(
"--router-config"
)
parser
.
add_argument
(
"--aic-backend"
)
parser
.
add_argument
(
"--aic-system"
)
parser
.
add_argument
(
"--aic-backend-version"
)
parser
.
add_argument
(
"--aic-tp-size"
,
type
=
int
)
parser
.
add_argument
(
"--aic-model-path"
)
parser
.
add_argument
(
"--input-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--input-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--output-tokens"
,
type
=
int
)
parser
.
add_argument
(
"--output-tokens"
,
type
=
int
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int:
default
=
"round_robin"
,
default
=
"round_robin"
,
)
)
parser
.
add_argument
(
"--arrival-speedup-ratio"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--arrival-speedup-ratio"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--trace-block-size"
,
type
=
int
,
default
=
512
,
help
=
"tokens represented by each hash_id in the trace file; only used for file replay"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--report-json"
,
"--report-json"
,
help
=
"path to save the full replay report JSON; defaults to a timestamped file in the current directory"
,
help
=
"path to save the full replay report JSON; defaults to a timestamped file in the current directory"
,
...
@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int:
if
args
.
router_config
is
not
None
if
args
.
router_config
is
not
None
else
None
else
None
)
)
try
:
aic_perf_config
=
_load_aic_perf_config
(
args
)
except
ValueError
as
exc
:
parser
.
error
(
str
(
exc
))
if
using_trace_file
:
if
using_trace_file
:
report
=
run_trace_replay
(
report
=
run_trace_replay
(
...
@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args
=
prefill_engine_args
,
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
...
@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int:
replay_mode
=
args
.
replay_mode
,
replay_mode
=
args
.
replay_mode
,
router_mode
=
args
.
router_mode
,
router_mode
=
args
.
router_mode
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
arrival_speedup_ratio
=
args
.
arrival_speedup_ratio
,
trace_block_size
=
args
.
trace_block_size
,
)
)
else
:
else
:
report
=
run_synthetic_trace_replay
(
report
=
run_synthetic_trace_replay
(
...
@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int:
...
@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args
=
prefill_engine_args
,
prefill_engine_args
=
prefill_engine_args
,
decode_engine_args
=
decode_engine_args
,
decode_engine_args
=
decode_engine_args
,
router_config
=
router_config
,
router_config
=
router_config
,
aic_perf_config
=
aic_perf_config
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_prefill_workers
=
args
.
num_prefill_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
num_decode_workers
=
args
.
num_decode_workers
,
...
...
lib/bindings/python/tests/replay/test_replay_smoke.py
View file @
95a750f4
...
@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode):
...
@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode):
)
)
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
def
test_run_trace_replay_supports_distinct_trace_and_engine_block_sizes
(
tmp_path
,
replay_mode
):
trace_path
=
tmp_path
/
"trace_block_size_split.jsonl"
trace_path
.
write_text
(
'{"timestamp":1000.0,"input_length":128,"output_length":2,"hash_ids":[101]}
\n
'
,
encoding
=
"utf-8"
,
)
report
=
run_trace_replay
(
trace_path
,
extra_engine_args
=
_vllm_args
(),
num_workers
=
1
,
replay_mode
=
replay_mode
,
trace_block_size
=
512
,
)
_assert_basic_report_counts
(
report
,
num_requests
=
1
,
input_tokens
=
128
,
output_tokens
=
2
,
)
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"engine_type"
,
[
"vllm"
,
"sglang"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"replay_mode"
,
[
"offline"
,
"online"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
@
pytest
.
mark
.
parametrize
(
"router_mode"
,
[
"round_robin"
,
"kv_router"
])
...
...
lib/kv-router/src/lib.rs
View file @
95a750f4
...
@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
...
@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
pub
use
self
::
sequence
::{
ActiveSequences
,
RequestId
};
pub
use
self
::
sequence
::{
ActiveSequences
,
RequestId
};
pub
use
concurrent_radix_tree
::
ConcurrentRadixTree
;
pub
use
concurrent_radix_tree
::
ConcurrentRadixTree
;
pub
use
concurrent_radix_tree_compressed
::
ConcurrentRadixTreeCompressed
;
pub
use
concurrent_radix_tree_compressed
::
ConcurrentRadixTreeCompressed
;
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterQueuePolicy
};
pub
use
config
::{
KvRouterConfig
,
RouterConfigOverride
,
RouterPrefillLoadModel
,
RouterQueuePolicy
};
pub
use
indexer
::{
MaybeError
,
SyncIndexer
,
ThreadPoolIndexer
};
pub
use
indexer
::{
MaybeError
,
SyncIndexer
,
ThreadPoolIndexer
};
pub
use
nested_map
::
PositionalIndexer
;
pub
use
nested_map
::
PositionalIndexer
;
pub
use
protocols
::{
pub
use
protocols
::{
...
@@ -50,6 +50,7 @@ pub use protocols::{
...
@@ -50,6 +50,7 @@ pub use protocols::{
pub
use
queue
::
SchedulerQueue
;
pub
use
queue
::
SchedulerQueue
;
pub
use
radix_tree
::
RadixTree
;
pub
use
radix_tree
::
RadixTree
;
pub
use
scheduling
::
LocalScheduler
;
pub
use
scheduling
::
LocalScheduler
;
pub
use
scheduling
::
PrefillLoadEstimator
;
pub
use
scheduling
::
policy
::{
FcfsPolicy
,
RouterSchedulingPolicy
,
SchedulingPolicy
,
WsptPolicy
};
pub
use
scheduling
::
policy
::{
FcfsPolicy
,
RouterSchedulingPolicy
,
SchedulingPolicy
,
WsptPolicy
};
pub
use
scheduling
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
pub
use
scheduling
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
pub
use
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
pub
use
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
lib/kv-router/src/protocols.rs
View file @
95a750f4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
std
::
time
::
Duration
;
use
dynamo_tokens
::{
SequenceHash
,
Token
};
use
dynamo_tokens
::{
SequenceHash
,
Token
};
use
rustc_hash
::
FxHashMap
;
use
rustc_hash
::
FxHashMap
;
...
@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent {
...
@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent {
pub
lora_name
:
Option
<
String
>
,
pub
lora_name
:
Option
<
String
>
,
}
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
struct
PrefillLoadHint
{
pub
initial_effective_prefill_tokens
:
usize
,
pub
expected_prefill_duration
:
Option
<
Duration
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
ActiveSequenceEventData
{
pub
enum
ActiveSequenceEventData
{
AddRequest
{
AddRequest
{
...
@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData {
...
@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData {
#[serde(default
=
"default_track_prefill_tokens"
)]
#[serde(default
=
"default_track_prefill_tokens"
)]
track_prefill_tokens
:
bool
,
track_prefill_tokens
:
bool
,
expected_output_tokens
:
Option
<
u32
>
,
expected_output_tokens
:
Option
<
u32
>
,
#[serde(default)]
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
},
},
Free
,
Free
,
MarkPrefillCompleted
,
MarkPrefillCompleted
,
...
...
lib/kv-router/src/scheduling/config.rs
View file @
95a750f4
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
use
std
::
env
::{
self
,
VarError
};
use
std
::
env
::{
self
,
VarError
};
use
std
::
fmt
;
use
std
::
fmt
;
use
std
::
str
::
FromStr
;
use
std
::
str
::
FromStr
;
use
std
::
time
::
Duration
;
use
derive_builder
::
Builder
;
use
derive_builder
::
Builder
;
use
rand
::
Rng
;
use
rand
::
Rng
;
...
@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy {
...
@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy {
}
}
}
}
#[derive(Debug,
Clone,
Copy,
Default,
PartialEq,
Eq,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
RouterPrefillLoadModel
{
#[default]
None
,
Aic
,
}
impl
fmt
::
Display
for
RouterPrefillLoadModel
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
<
'_
>
)
->
fmt
::
Result
{
match
self
{
Self
::
None
=>
f
.write_str
(
"none"
),
Self
::
Aic
=>
f
.write_str
(
"aic"
),
}
}
}
impl
FromStr
for
RouterPrefillLoadModel
{
type
Err
=
String
;
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"none"
=>
Ok
(
Self
::
None
),
"aic"
=>
Ok
(
Self
::
Aic
),
_
=>
Err
(
format!
(
"unknown prefill load model: {s:?}, expected 'none' or 'aic'"
)),
}
}
}
impl
RouterPrefillLoadModel
{
pub
fn
is_enabled
(
self
)
->
bool
{
!
matches!
(
self
,
Self
::
None
)
}
}
impl
FromStr
for
RouterQueuePolicy
{
impl
FromStr
for
RouterQueuePolicy
{
type
Err
=
String
;
type
Err
=
String
;
...
@@ -124,6 +162,9 @@ pub struct KvRouterConfig {
...
@@ -124,6 +162,9 @@ pub struct KvRouterConfig {
#[serde(default
=
"default_track_prefill_tokens"
)]
#[serde(default
=
"default_track_prefill_tokens"
)]
pub
router_track_prefill_tokens
:
bool
,
pub
router_track_prefill_tokens
:
bool
,
/// Optional model for estimating effective prompt-side prefill load over time.
pub
router_prefill_load_model
:
RouterPrefillLoadModel
,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min
=
1
))]
#[validate(range(min
=
1
))]
pub
router_snapshot_threshold
:
Option
<
u32
>
,
pub
router_snapshot_threshold
:
Option
<
u32
>
,
...
@@ -183,6 +224,7 @@ impl Default for KvRouterConfig {
...
@@ -183,6 +224,7 @@ impl Default for KvRouterConfig {
router_track_output_blocks
:
false
,
router_track_output_blocks
:
false
,
router_assume_kv_reuse
:
true
,
router_assume_kv_reuse
:
true
,
router_track_prefill_tokens
:
default_track_prefill_tokens
(),
router_track_prefill_tokens
:
default_track_prefill_tokens
(),
router_prefill_load_model
:
RouterPrefillLoadModel
::
default
(),
router_snapshot_threshold
:
Some
(
1000000
),
router_snapshot_threshold
:
Some
(
1000000
),
router_reset_states
:
false
,
router_reset_states
:
false
,
router_ttl_secs
:
120.0
,
router_ttl_secs
:
120.0
,
...
@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr
...
@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr
"router_track_output_blocks requires router_track_active_blocks=true"
,
"router_track_output_blocks requires router_track_active_blocks=true"
,
));
));
}
}
if
config
.router_prefill_load_model
.is_enabled
()
&&
!
config
.router_track_prefill_tokens
{
return
Err
(
ValidationError
::
new
(
"router_prefill_load_model requires router_track_prefill_tokens=true"
,
));
}
if
config
.router_prefill_load_model
.is_enabled
()
&&
!
matches!
(
config
.router_queue_policy
,
RouterQueuePolicy
::
Fcfs
)
{
return
Err
(
ValidationError
::
new
(
"router_prefill_load_model currently requires router_queue_policy='fcfs'"
,
));
}
Ok
(())
Ok
(())
}
}
impl
KvRouterConfig
{
impl
KvRouterConfig
{
pub
fn
router_queue_recheck_interval
(
&
self
)
->
Duration
{
const
DEFAULT_RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
const
PREFILL_LOAD_RECHECK_INTERVAL
:
Duration
=
Duration
::
from_millis
(
100
);
if
self
.router_prefill_load_model
.is_enabled
()
&&
self
.router_queue_threshold
.is_some
()
{
return
PREFILL_LOAD_RECHECK_INTERVAL
;
}
DEFAULT_RECHECK_INTERVAL
}
pub
fn
assume_kv_reuse
(
&
self
,
config_override
:
Option
<&
RouterConfigOverride
>
)
->
bool
{
pub
fn
assume_kv_reuse
(
&
self
,
config_override
:
Option
<&
RouterConfigOverride
>
)
->
bool
{
config_override
config_override
.and_then
(|
cfg
|
cfg
.assume_kv_reuse
)
.and_then
(|
cfg
|
cfg
.assume_kv_reuse
)
...
@@ -288,28 +353,6 @@ mod tests {
...
@@ -288,28 +353,6 @@ mod tests {
use
super
::
*
;
use
super
::
*
;
use
crate
::
protocols
::{
BlockExtraInfo
,
BlockMmObjectInfo
};
use
crate
::
protocols
::{
BlockExtraInfo
,
BlockMmObjectInfo
};
#[test]
fn
router_queue_policy_display_and_parse_support_lcfs
()
{
assert_eq!
(
RouterQueuePolicy
::
Lcfs
.to_string
(),
"lcfs"
);
assert_eq!
(
"lcfs"
.parse
::
<
RouterQueuePolicy
>
()
.unwrap
(),
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
router_queue_policy_serde_round_trip_supports_lcfs
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterQueuePolicy
::
Lcfs
)
.unwrap
();
assert_eq!
(
serialized
,
"
\"
lcfs
\"
"
);
let
deserialized
:
RouterQueuePolicy
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
,
RouterQueuePolicy
::
Lcfs
);
}
#[test]
fn
kv_router_config_defaults_to_tracking_prefill_tokens
()
{
assert
!
(
KvRouterConfig
::
default
()
.router_track_prefill_tokens
);
}
#[test]
#[test]
fn
compute_seq_hashes_for_tracking_uses_mm_hashes
()
{
fn
compute_seq_hashes_for_tracking_uses_mm_hashes
()
{
let
cfg
=
KvRouterConfig
::
default
();
let
cfg
=
KvRouterConfig
::
default
();
...
@@ -343,17 +386,6 @@ mod tests {
...
@@ -343,17 +386,6 @@ mod tests {
assert_ne!
(
without_mm
,
with_mm
);
assert_ne!
(
without_mm
,
with_mm
);
}
}
#[test]
fn
router_config_override_serde_round_trip_preserves_track_prefill_tokens
()
{
let
serialized
=
serde_json
::
to_string
(
&
RouterConfigOverride
{
track_prefill_tokens
:
Some
(
false
),
..
Default
::
default
()
})
.unwrap
();
let
deserialized
:
RouterConfigOverride
=
serde_json
::
from_str
(
&
serialized
)
.unwrap
();
assert_eq!
(
deserialized
.track_prefill_tokens
,
Some
(
false
));
}
#[test]
#[test]
fn
compute_seq_hashes_for_tracking_uses_precomputed_block_hashes
()
{
fn
compute_seq_hashes_for_tracking_uses_precomputed_block_hashes
()
{
let
config
=
KvRouterConfig
::
default
();
let
config
=
KvRouterConfig
::
default
();
...
...
lib/kv-router/src/scheduling/local.rs
View file @
95a750f4
...
@@ -6,9 +6,11 @@ use std::sync::Arc;
...
@@ -6,9 +6,11 @@ use std::sync::Arc;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::{
mpsc
,
watch
};
use
tokio
::
sync
::{
mpsc
,
watch
};
use
tokio
::
time
::
Instant
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
super
::
policy
::{
RouterSchedulingPolicy
,
SchedulingPolicy
};
use
super
::
policy
::{
RouterSchedulingPolicy
,
SchedulingPolicy
};
use
super
::
prefill_load
::
PrefillLoadEstimator
;
use
super
::
queue
::
SchedulerQueue
;
use
super
::
queue
::
SchedulerQueue
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
use
super
::
types
::{
KvSchedulerError
,
PotentialLoad
,
SchedulingRequest
,
SchedulingResponse
};
...
@@ -18,8 +20,6 @@ use crate::sequences::{
...
@@ -18,8 +20,6 @@ use crate::sequences::{
};
};
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
const
RECHECK_INTERVAL
:
Duration
=
Duration
::
from_secs
(
60
);
pub
struct
LocalScheduler
<
P
,
C
,
S
=
RouterSchedulingPolicy
,
Sel
=
DefaultWorkerSelector
>
pub
struct
LocalScheduler
<
P
,
C
,
S
=
RouterSchedulingPolicy
,
Sel
=
DefaultWorkerSelector
>
where
where
P
:
SequencePublisher
,
P
:
SequencePublisher
,
...
@@ -49,6 +49,8 @@ where
...
@@ -49,6 +49,8 @@ where
block_size
:
u32
,
block_size
:
u32
,
selector
:
Sel
,
selector
:
Sel
,
policy
:
S
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
recheck_interval
:
Duration
,
track_prefill_tokens_default
:
bool
,
track_prefill_tokens_default
:
bool
,
cancellation_token
:
CancellationToken
,
cancellation_token
:
CancellationToken
,
worker_type
:
&
'static
str
,
worker_type
:
&
'static
str
,
...
@@ -103,13 +105,14 @@ where
...
@@ -103,13 +105,14 @@ where
block_size
,
block_size
,
selector
,
selector
,
policy
,
policy
,
prefill_load_estimator
,
));
));
let
(
request_tx
,
request_rx
)
=
mpsc
::
channel
::
<
SchedulingRequest
>
(
1024
);
let
(
request_tx
,
request_rx
)
=
mpsc
::
channel
::
<
SchedulingRequest
>
(
1024
);
let
queue_clone
=
Arc
::
clone
(
&
queue
);
let
queue_clone
=
Arc
::
clone
(
&
queue
);
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
let
mut
request_rx
=
request_rx
;
let
mut
request_rx
=
request_rx
;
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
RECHECK_INTERVAL
);
let
mut
recheck_interval
=
tokio
::
time
::
interval
(
recheck_interval
);
tracing
::
trace!
(
"LocalScheduler background task started"
);
tracing
::
trace!
(
"LocalScheduler background task started"
);
loop
{
loop
{
...
@@ -192,17 +195,18 @@ where
...
@@ -192,17 +195,18 @@ where
}
}
pub
async
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
pub
async
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.add_request
(
req
)
self
.slots
.add_request
(
req
,
Instant
::
now
()
)
}
}
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
())
?
;
self
.slots
.mark_prefill_completed
(
&
request_id
.to_string
(),
Instant
::
now
())
?
;
self
.queue
.update
()
.await
;
self
.queue
.update
()
.await
;
Ok
(())
Ok
(())
}
}
pub
async
fn
free
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
pub
async
fn
free
(
&
self
,
request_id
:
&
str
)
->
Result
<
(),
SequenceError
>
{
self
.slots
.free
(
&
request_id
.to_string
())
?
;
self
.slots
.free
(
&
request_id
.to_string
()
,
Instant
::
now
()
)
?
;
self
.queue
.update
()
.await
;
self
.queue
.update
()
.await
;
Ok
(())
Ok
(())
}
}
...
@@ -231,6 +235,7 @@ where
...
@@ -231,6 +235,7 @@ where
overlaps
:
OverlapScores
,
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
track_prefill_tokens
:
bool
,
)
->
Vec
<
PotentialLoad
>
{
)
->
Vec
<
PotentialLoad
>
{
let
decay_now
=
Instant
::
now
();
let
(
decode_blocks
,
prefill_tokens
)
=
self
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.slots
.potential_blocks_and_tokens_with_prefill_tracking
(
.potential_blocks_and_tokens_with_prefill_tracking
(
...
@@ -238,6 +243,7 @@ where
...
@@ -238,6 +243,7 @@ where
isl_tokens
,
isl_tokens
,
overlaps
,
overlaps
,
track_prefill_tokens
,
track_prefill_tokens
,
decay_now
,
);
);
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
...
@@ -275,15 +281,32 @@ mod tests {
...
@@ -275,15 +281,32 @@ mod tests {
use
super
::
*
;
use
super
::
*
;
use
crate
::
protocols
::
OverlapScores
;
use
crate
::
protocols
::
OverlapScores
;
use
crate
::
scheduling
::
PrefillLoadEstimator
;
use
crate
::
scheduling
::
policy
::
FcfsPolicy
;
use
crate
::
scheduling
::
policy
::
FcfsPolicy
;
use
crate
::
scheduling
::
selector
::
DefaultWorkerSelector
;
use
crate
::
scheduling
::
selector
::
DefaultWorkerSelector
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
struct
FixedPrefillLoadEstimator
{
duration
:
Duration
,
}
impl
PrefillLoadEstimator
for
FixedPrefillLoadEstimator
{
fn
predict_prefill_duration
(
&
self
,
_
batch_size
:
usize
,
_
effective_isl
:
usize
,
_
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
Ok
(
self
.duration
)
}
}
#[allow(clippy::type_complexity)]
#[allow(clippy::type_complexity)]
fn
make_scheduler
(
fn
make_scheduler
(
workers
:
HashMap
<
WorkerId
,
SimpleWorkerConfig
>
,
workers
:
HashMap
<
WorkerId
,
SimpleWorkerConfig
>
,
threshold_frac
:
Option
<
f64
>
,
threshold_frac
:
Option
<
f64
>
,
monitor_worker_configs
:
bool
,
monitor_worker_configs
:
bool
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
(
)
->
(
Arc
<
LocalScheduler
<
NoopSequencePublisher
,
SimpleWorkerConfig
,
FcfsPolicy
>>
,
Arc
<
LocalScheduler
<
NoopSequencePublisher
,
SimpleWorkerConfig
,
FcfsPolicy
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
...
@@ -311,6 +334,8 @@ mod tests {
...
@@ -311,6 +334,8 @@ mod tests {
64
,
64
,
DefaultWorkerSelector
::
new
(
None
,
"test"
),
DefaultWorkerSelector
::
new
(
None
,
"test"
),
FcfsPolicy
,
FcfsPolicy
,
prefill_load_estimator
,
Duration
::
from_secs
(
60
),
true
,
true
,
cancel_token
.clone
(),
cancel_token
.clone
(),
"test"
,
"test"
,
...
@@ -329,7 +354,7 @@ mod tests {
...
@@ -329,7 +354,7 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
let
response
=
scheduler
let
response
=
scheduler
.schedule
(
.schedule
(
...
@@ -366,7 +391,7 @@ mod tests {
...
@@ -366,7 +391,7 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
scheduler
.schedule
(
.schedule
(
...
@@ -389,7 +414,7 @@ mod tests {
...
@@ -389,7 +414,7 @@ mod tests {
assert_eq!
(
assert_eq!
(
slots
slots
.active_tokens
()
.active_tokens
(
Instant
::
now
()
)
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.copied
(),
.copied
(),
Some
(
0
)
Some
(
0
)
...
@@ -408,7 +433,8 @@ mod tests {
...
@@ -408,7 +433,8 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
Some
(
0.5
),
true
,
None
);
scheduler
scheduler
.schedule
(
.schedule
(
...
@@ -466,7 +492,7 @@ mod tests {
...
@@ -466,7 +492,7 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
scheduler
.schedule
(
.schedule
(
...
@@ -511,12 +537,16 @@ mod tests {
...
@@ -511,12 +537,16 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
let
token_seq
=
vec!
[
11
,
22
,
33
,
44
];
let
token_seq
=
vec!
[
11
,
22
,
33
,
44
];
let
overlaps
=
OverlapScores
::
default
();
let
overlaps
=
OverlapScores
::
default
();
let
(
decode_blocks
,
prefill_tokens
)
=
let
(
decode_blocks
,
prefill_tokens
)
=
slots
.potential_blocks_and_tokens
(
slots
.potential_blocks_and_tokens
(
Some
(
&
token_seq
),
128
,
overlaps
.clone
());
Some
(
&
token_seq
),
128
,
overlaps
.clone
(),
Instant
::
now
(),
);
let
mut
expected
:
Vec
<
_
>
=
decode_blocks
let
mut
expected
:
Vec
<
_
>
=
decode_blocks
.keys
()
.keys
()
.map
(|
worker
|
PotentialLoad
{
.map
(|
worker
|
PotentialLoad
{
...
@@ -548,10 +578,51 @@ mod tests {
...
@@ -548,10 +578,51 @@ mod tests {
cancel_token
.cancel
();
cancel_token
.cancel
();
}
}
#[tokio::test(start_paused
=
true
)]
async
fn
test_get_potential_loads_uses_decayed_prefill_tokens
()
{
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
{
max_num_batched_tokens
:
Some
(
256
),
..
Default
::
default
()
},
);
let
estimator
:
Arc
<
dyn
PrefillLoadEstimator
>
=
Arc
::
new
(
FixedPrefillLoadEstimator
{
duration
:
Duration
::
from_secs
(
10
),
});
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
Some
(
estimator
));
scheduler
.schedule
(
Some
(
"req-1"
.to_string
()),
100
,
Some
(
vec!
[
1
,
2
,
3
,
4
]),
OverlapScores
::
default
(),
None
,
true
,
None
,
0.0
,
None
,
None
,
)
.await
.unwrap
();
tokio
::
time
::
advance
(
Duration
::
from_secs
(
6
))
.await
;
let
loads
=
scheduler
.get_potential_loads
(
None
,
0
,
OverlapScores
::
default
(),
true
);
assert_eq!
(
loads
.len
(),
1
);
assert_eq!
(
loads
[
0
]
.potential_prefill_tokens
,
40
);
cancel_token
.cancel
();
}
#[tokio::test]
#[tokio::test]
async
fn
test_register_workers_uses_default_dp_fallback
()
{
async
fn
test_register_workers_uses_default_dp_fallback
()
{
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
HashMap
::
new
(),
None
,
false
);
make_scheduler
(
HashMap
::
new
(),
None
,
false
,
None
);
scheduler
.register_workers
(
&
HashSet
::
from
([
42
]));
scheduler
.register_workers
(
&
HashSet
::
from
([
42
]));
let
loads
=
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
(),
true
);
let
loads
=
scheduler
.get_potential_loads
(
None
,
64
,
OverlapScores
::
default
(),
true
);
...
@@ -567,7 +638,7 @@ mod tests {
...
@@ -567,7 +638,7 @@ mod tests {
async
fn
test_worker_watch_updates_slot_ranges
()
{
async
fn
test_worker_watch_updates_slot_ranges
()
{
let
mut
workers
=
HashMap
::
new
();
let
mut
workers
=
HashMap
::
new
();
workers
.insert
(
0
,
SimpleWorkerConfig
::
default
());
workers
.insert
(
0
,
SimpleWorkerConfig
::
default
());
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
assert_eq!
(
assert_eq!
(
scheduler
scheduler
...
@@ -615,7 +686,7 @@ mod tests {
...
@@ -615,7 +686,7 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
);
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
);
let
(
scheduler
,
_
slots
,
_
cfg_tx
,
cancel_token
)
=
make_scheduler
(
workers
,
None
,
true
,
None
);
scheduler
scheduler
.schedule
(
.schedule
(
...
...
lib/kv-router/src/scheduling/mod.rs
View file @
95a750f4
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
pub
mod
config
;
pub
mod
config
;
mod
local
;
mod
local
;
pub
mod
policy
;
pub
mod
policy
;
pub
mod
prefill_load
;
pub
mod
queue
;
pub
mod
queue
;
pub
mod
selector
;
pub
mod
selector
;
mod
types
;
mod
types
;
pub
use
local
::
LocalScheduler
;
pub
use
local
::
LocalScheduler
;
pub
use
prefill_load
::
PrefillLoadEstimator
;
pub
use
types
::
*
;
pub
use
types
::
*
;
lib/kv-router/src/scheduling/prefill_load.rs
0 → 100644
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
time
::
Duration
;
pub
trait
PrefillLoadEstimator
:
Send
+
Sync
{
fn
predict_prefill_duration
(
&
self
,
batch_size
:
usize
,
effective_isl
:
usize
,
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
;
}
lib/kv-router/src/scheduling/queue.rs
View file @
95a750f4
...
@@ -5,15 +5,16 @@ use std::cmp::Ordering;
...
@@ -5,15 +5,16 @@ use std::cmp::Ordering;
use
std
::
collections
::{
BinaryHeap
,
HashMap
,
HashSet
};
use
std
::
collections
::{
BinaryHeap
,
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
as
AtomicOrdering
};
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
as
AtomicOrdering
};
use
std
::
time
::
Instant
;
use
tokio
::
sync
::
Mutex
;
use
tokio
::
sync
::
Mutex
;
use
tokio
::
sync
::
watch
;
use
tokio
::
sync
::
watch
;
use
tokio
::
time
::
Instant
;
use
super
::
policy
::{
FcfsPolicy
,
SchedulingPolicy
};
use
super
::
policy
::{
FcfsPolicy
,
SchedulingPolicy
};
use
super
::
prefill_load
::
PrefillLoadEstimator
;
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
selector
::{
DefaultWorkerSelector
,
WorkerSelector
};
use
super
::
types
::{
SchedulingRequest
,
SchedulingResponse
};
use
super
::
types
::{
SchedulingRequest
,
SchedulingResponse
};
use
crate
::
protocols
::{
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
protocols
::{
PrefillLoadHint
,
WorkerConfigLike
,
WorkerId
,
WorkerWithDpRank
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequencePublisher
,
SequenceRequest
};
use
crate
::
sequences
::{
ActiveSequencesMultiWorker
,
SequencePublisher
,
SequenceRequest
};
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
...
@@ -68,6 +69,7 @@ pub struct SchedulerQueue<
...
@@ -68,6 +69,7 @@ pub struct SchedulerQueue<
block_size
:
u32
,
block_size
:
u32
,
selector
:
Sel
,
selector
:
Sel
,
policy
:
S
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
}
}
impl
<
impl
<
...
@@ -84,6 +86,7 @@ impl<
...
@@ -84,6 +86,7 @@ impl<
block_size
:
u32
,
block_size
:
u32
,
selector
:
Sel
,
selector
:
Sel
,
policy
:
S
,
policy
:
S
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
Self
{
)
->
Self
{
if
let
Some
(
frac
)
=
threshold_frac
{
if
let
Some
(
frac
)
=
threshold_frac
{
tracing
::
info!
(
"Router queue enabled with threshold fraction {frac}"
);
tracing
::
info!
(
"Router queue enabled with threshold fraction {frac}"
);
...
@@ -98,6 +101,7 @@ impl<
...
@@ -98,6 +101,7 @@ impl<
block_size
,
block_size
,
selector
,
selector
,
policy
,
policy
,
prefill_load_estimator
,
}
}
}
}
...
@@ -133,23 +137,24 @@ impl<
...
@@ -133,23 +137,24 @@ impl<
/// capacity check is skipped.
/// capacity check is skipped.
pub
async
fn
enqueue
(
&
self
,
request
:
SchedulingRequest
)
{
pub
async
fn
enqueue
(
&
self
,
request
:
SchedulingRequest
)
{
let
Some
(
threshold
)
=
self
.threshold_frac
else
{
let
Some
(
threshold
)
=
self
.threshold_frac
else
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
Instant
::
now
()
)
.await
;
return
;
return
;
};
};
if
request
.allowed_worker_ids
.is_some
()
{
if
request
.allowed_worker_ids
.is_some
()
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
Instant
::
now
()
)
.await
;
return
;
return
;
}
}
if
self
.all_workers_busy
(
threshold
,
request
.allowed_worker_ids
.as_ref
())
{
let
decay_now
=
Instant
::
now
();
if
self
.all_workers_busy
(
threshold
,
request
.allowed_worker_ids
.as_ref
(),
decay_now
)
{
tracing
::
debug!
(
"all workers busy, queueing request"
);
tracing
::
debug!
(
"all workers busy, queueing request"
);
let
arrival_offset
=
self
.start_time
.elapsed
();
let
arrival_offset
=
self
.start_time
.elapsed
();
let
key
=
self
.policy
.enqueue_key
(
arrival_offset
,
&
request
);
let
key
=
self
.policy
.enqueue_key
(
arrival_offset
,
&
request
);
self
.pending
.lock
()
.await
.push
(
QueueEntry
{
key
,
request
});
self
.pending
.lock
()
.await
.push
(
QueueEntry
{
key
,
request
});
self
.pending_count
.fetch_add
(
1
,
AtomicOrdering
::
Relaxed
);
self
.pending_count
.fetch_add
(
1
,
AtomicOrdering
::
Relaxed
);
}
else
{
}
else
{
self
.schedule
(
request
)
.await
;
self
.schedule
(
request
,
decay_now
)
.await
;
}
}
}
}
...
@@ -176,7 +181,8 @@ impl<
...
@@ -176,7 +181,8 @@ impl<
}
}
loop
{
loop
{
if
self
.all_workers_busy
(
threshold
,
None
)
{
let
decay_now
=
Instant
::
now
();
if
self
.all_workers_busy
(
threshold
,
None
,
decay_now
)
{
break
;
break
;
}
}
let
Some
(
entry
)
=
self
.pending
.lock
()
.await
.pop
()
else
{
let
Some
(
entry
)
=
self
.pending
.lock
()
.await
.pop
()
else
{
...
@@ -184,13 +190,13 @@ impl<
...
@@ -184,13 +190,13 @@ impl<
};
};
self
.pending_count
.fetch_sub
(
1
,
AtomicOrdering
::
Relaxed
);
self
.pending_count
.fetch_sub
(
1
,
AtomicOrdering
::
Relaxed
);
tracing
::
debug!
(
"scheduling request from pending queue"
);
tracing
::
debug!
(
"scheduling request from pending queue"
);
self
.schedule
(
entry
.request
)
.await
;
self
.schedule
(
entry
.request
,
decay_now
)
.await
;
}
}
}
}
/// Run the full scheduling pipeline for a single request:
/// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request.
/// compute potential load -> select worker -> respond -> book via add_request.
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
)
{
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
,
decay_now
:
Instant
)
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.slots
.potential_blocks_and_tokens_with_prefill_tracking
(
.potential_blocks_and_tokens_with_prefill_tracking
(
...
@@ -198,6 +204,7 @@ impl<
...
@@ -198,6 +204,7 @@ impl<
request
.isl_tokens
,
request
.isl_tokens
,
request
.overlaps
.clone
(),
request
.overlaps
.clone
(),
request
.track_prefill_tokens
,
request
.track_prefill_tokens
,
decay_now
,
);
);
request
.decode_blocks
=
decode_blocks
;
request
.decode_blocks
=
decode_blocks
;
request
.prefill_tokens
=
prefill_tokens
;
request
.prefill_tokens
=
prefill_tokens
;
...
@@ -231,20 +238,66 @@ impl<
...
@@ -231,20 +238,66 @@ impl<
return
;
return
;
};
};
if
let
Err
(
e
)
=
self
.slots
.add_request
(
SequenceRequest
{
let
prefill_load_hint
=
self
.prefill_load_hint_for
(
request
.isl_tokens
,
selection
.overlap_blocks
,
request
.track_prefill_tokens
,
);
if
let
Err
(
e
)
=
self
.slots
.add_request
(
SequenceRequest
{
request_id
:
request_id
.clone
(),
request_id
:
request_id
.clone
(),
token_sequence
:
request
.token_seq
,
token_sequence
:
request
.token_seq
,
isl
:
request
.isl_tokens
,
isl
:
request
.isl_tokens
,
overlap
:
selection
.overlap_blocks
,
overlap
:
selection
.overlap_blocks
,
track_prefill_tokens
:
request
.track_prefill_tokens
,
track_prefill_tokens
:
request
.track_prefill_tokens
,
expected_output_tokens
:
request
.expected_output_tokens
,
expected_output_tokens
:
request
.expected_output_tokens
,
prefill_load_hint
,
worker
:
selection
.worker
,
worker
:
selection
.worker
,
lora_name
:
request
.lora_name
.clone
(),
lora_name
:
request
.lora_name
.clone
(),
})
{
},
decay_now
,
)
{
tracing
::
warn!
(
"Failed to add request {request_id}: {e}"
);
tracing
::
warn!
(
"Failed to add request {request_id}: {e}"
);
}
}
}
}
fn
prefill_load_hint_for
(
&
self
,
isl_tokens
:
usize
,
overlap_blocks
:
u32
,
track_prefill_tokens
:
bool
,
)
->
Option
<
PrefillLoadHint
>
{
if
!
track_prefill_tokens
{
return
None
;
}
let
prefix
=
(
overlap_blocks
as
usize
)
*
(
self
.block_size
as
usize
);
let
effective_isl
=
isl_tokens
.saturating_sub
(
prefix
);
if
effective_isl
==
0
{
return
None
;
}
let
Some
(
estimator
)
=
&
self
.prefill_load_estimator
else
{
return
None
;
};
match
estimator
.predict_prefill_duration
(
1
,
effective_isl
,
prefix
)
{
Ok
(
expected_prefill_duration
)
=>
Some
(
PrefillLoadHint
{
initial_effective_prefill_tokens
:
effective_isl
,
expected_prefill_duration
:
Some
(
expected_prefill_duration
),
}),
Err
(
error
)
=>
{
tracing
::
warn!
(
effective_isl
,
prefix
,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
}
}
/// Number of requests currently parked in the pending queue (lock-free).
/// Number of requests currently parked in the pending queue (lock-free).
pub
fn
pending_count
(
&
self
)
->
usize
{
pub
fn
pending_count
(
&
self
)
->
usize
{
self
.pending_count
.load
(
AtomicOrdering
::
Relaxed
)
self
.pending_count
.load
(
AtomicOrdering
::
Relaxed
)
...
@@ -255,8 +308,13 @@ impl<
...
@@ -255,8 +308,13 @@ impl<
/// otherwise all registered workers are checked.
/// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls
/// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error.
/// through to `schedule`, which returns a proper `NoEndpoints` error.
fn
all_workers_busy
(
&
self
,
threshold
:
f64
,
allowed
:
Option
<&
HashSet
<
WorkerId
>>
)
->
bool
{
fn
all_workers_busy
(
let
active_tokens
=
self
.slots
.active_tokens
();
&
self
,
threshold
:
f64
,
allowed
:
Option
<&
HashSet
<
WorkerId
>>
,
decay_now
:
Instant
,
)
->
bool
{
let
active_tokens
=
self
.slots
.active_tokens
(
decay_now
);
let
configs
=
self
.workers_with_configs
.borrow
();
let
configs
=
self
.workers_with_configs
.borrow
();
let
mut
checked_any
=
false
;
let
mut
checked_any
=
false
;
...
@@ -289,6 +347,7 @@ impl<
...
@@ -289,6 +347,7 @@ impl<
mod
tests
{
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::
watch
;
use
tokio
::
sync
::
watch
;
...
@@ -298,6 +357,25 @@ mod tests {
...
@@ -298,6 +357,25 @@ mod tests {
use
crate
::
sequences
::
ActiveSequencesMultiWorker
;
use
crate
::
sequences
::
ActiveSequencesMultiWorker
;
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
use
crate
::
test_utils
::{
NoopSequencePublisher
,
SimpleWorkerConfig
};
fn
decay_now
()
->
Instant
{
Instant
::
now
()
}
struct
FixedPrefillLoadEstimator
{
duration
:
Duration
,
}
impl
PrefillLoadEstimator
for
FixedPrefillLoadEstimator
{
fn
predict_prefill_duration
(
&
self
,
_
batch_size
:
usize
,
_
effective_isl
:
usize
,
_
prefix
:
usize
,
)
->
anyhow
::
Result
<
Duration
>
{
Ok
(
self
.duration
)
}
}
fn
make_queue
(
fn
make_queue
(
num_workers
:
usize
,
num_workers
:
usize
,
block_size
:
u32
,
block_size
:
u32
,
...
@@ -308,7 +386,7 @@ mod tests {
...
@@ -308,7 +386,7 @@ mod tests {
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
)
{
)
{
let
(
queue
,
slots
,
_
tx
)
=
let
(
queue
,
slots
,
_
tx
)
=
make_queue_with_sender
(
num_workers
,
block_size
,
isl
,
threshold_frac
);
make_queue_with_sender
(
num_workers
,
block_size
,
isl
,
threshold_frac
,
None
);
(
queue
,
slots
)
(
queue
,
slots
)
}
}
...
@@ -318,6 +396,7 @@ mod tests {
...
@@ -318,6 +396,7 @@ mod tests {
block_size
:
u32
,
block_size
:
u32
,
isl
:
usize
,
isl
:
usize
,
threshold_frac
:
Option
<
f64
>
,
threshold_frac
:
Option
<
f64
>
,
prefill_load_estimator
:
Option
<
Arc
<
dyn
PrefillLoadEstimator
>>
,
)
->
(
)
->
(
Arc
<
SchedulerQueue
<
NoopSequencePublisher
,
SimpleWorkerConfig
>>
,
Arc
<
SchedulerQueue
<
NoopSequencePublisher
,
SimpleWorkerConfig
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
Arc
<
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>>
,
...
@@ -354,6 +433,7 @@ mod tests {
...
@@ -354,6 +433,7 @@ mod tests {
block_size
,
block_size
,
selector
,
selector
,
FcfsPolicy
,
FcfsPolicy
,
prefill_load_estimator
,
));
));
(
queue
,
slots
,
cfg_tx
)
(
queue
,
slots
,
cfg_tx
)
...
@@ -409,8 +489,8 @@ mod tests {
...
@@ -409,8 +489,8 @@ mod tests {
let
resp
=
resp
.expect
(
"scheduling failed"
);
let
resp
=
resp
.expect
(
"scheduling failed"
);
assert
!
(
resp
.best_worker.worker_id
<
num_workers
as
u64
);
assert
!
(
resp
.best_worker.worker_id
<
num_workers
as
u64
);
slots
.mark_prefill_completed
(
&
req_id
)
.unwrap
();
slots
.mark_prefill_completed
(
&
req_id
,
decay_now
()
)
.unwrap
();
slots
.free
(
&
req_id
)
.unwrap
();
slots
.free
(
&
req_id
,
decay_now
()
)
.unwrap
();
queue
.update
()
.await
;
queue
.update
()
.await
;
}));
}));
}
}
...
@@ -419,7 +499,7 @@ mod tests {
...
@@ -419,7 +499,7 @@ mod tests {
h
.await
.expect
(
"task panicked"
);
h
.await
.expect
(
"task panicked"
);
}
}
let
active
=
slots
.active_tokens
();
let
active
=
slots
.active_tokens
(
decay_now
()
);
for
(
worker
,
tokens
)
in
&
active
{
for
(
worker
,
tokens
)
in
&
active
{
assert_eq!
(
assert_eq!
(
*
tokens
,
0
,
*
tokens
,
0
,
...
@@ -453,8 +533,8 @@ mod tests {
...
@@ -453,8 +533,8 @@ mod tests {
for
_
in
0
..
num_requests
{
for
_
in
0
..
num_requests
{
queue
.update
()
.await
;
queue
.update
()
.await
;
for
rid
in
&
req_ids
{
for
rid
in
&
req_ids
{
let
_
=
slots
.mark_prefill_completed
(
rid
);
let
_
=
slots
.mark_prefill_completed
(
rid
,
decay_now
()
);
let
_
=
slots
.free
(
rid
);
let
_
=
slots
.free
(
rid
,
decay_now
()
);
}
}
}
}
queue
.update
()
.await
;
queue
.update
()
.await
;
...
@@ -495,8 +575,10 @@ mod tests {
...
@@ -495,8 +575,10 @@ mod tests {
assert_eq!
(
queue
.pending_count
(),
2
);
assert_eq!
(
queue
.pending_count
(),
2
);
// Free the first request and update — should drain one from pending
// Free the first request and update — should drain one from pending
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
())
.unwrap
();
slots
slots
.free
(
&
"req-1"
.to_string
())
.unwrap
();
.mark_prefill_completed
(
&
"req-1"
.to_string
(),
decay_now
())
.unwrap
();
slots
.free
(
&
"req-1"
.to_string
(),
decay_now
())
.unwrap
();
queue
.update
()
.await
;
queue
.update
()
.await
;
// After update, one pending request should have been scheduled
// After update, one pending request should have been scheduled
...
@@ -507,16 +589,43 @@ mod tests {
...
@@ -507,16 +589,43 @@ mod tests {
);
);
// Free req-2 and update to drain remaining
// Free req-2 and update to drain remaining
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-2"
.to_string
());
let
_
=
slots
.free
(
&
"req-2"
.to_string
()
,
decay_now
()
);
queue
.update
()
.await
;
queue
.update
()
.await
;
let
_
=
slots
.mark_prefill_completed
(
&
"req-3"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-3"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-3"
.to_string
());
let
_
=
slots
.free
(
&
"req-3"
.to_string
()
,
decay_now
()
);
queue
.update
()
.await
;
queue
.update
()
.await
;
assert_eq!
(
queue
.pending_count
(),
0
,
"all requests should be drained"
);
assert_eq!
(
queue
.pending_count
(),
0
,
"all requests should be drained"
);
}
}
#[tokio::test(start_paused
=
true
)]
async
fn
test_queue_update_uses_decayed_oldest_prefill_load
()
{
let
estimator
:
Arc
<
dyn
PrefillLoadEstimator
>
=
Arc
::
new
(
FixedPrefillLoadEstimator
{
duration
:
Duration
::
from_secs
(
10
),
});
let
(
queue
,
_
slots
,
_
cfg_tx
)
=
make_queue_with_sender
(
1
,
16
,
100
,
Some
(
0.5
),
Some
(
estimator
));
let
(
req1
,
rx1
)
=
make_request
(
"req-1"
,
100
);
queue
.enqueue
(
req1
)
.await
;
let
_
=
rx1
.await
.unwrap
()
.unwrap
();
let
(
req2
,
mut
rx2
)
=
make_request
(
"req-2"
,
100
);
queue
.enqueue
(
req2
)
.await
;
assert_eq!
(
queue
.pending_count
(),
1
);
tokio
::
time
::
advance
(
Duration
::
from_secs
(
6
))
.await
;
queue
.update
()
.await
;
let
scheduled
=
rx2
.try_recv
()
.expect
(
"queued request should have been scheduled"
);
let
response
=
scheduled
.expect
(
"scheduling returned error"
);
assert_eq!
(
response
.best_worker.worker_id
,
0
);
assert_eq!
(
queue
.pending_count
(),
0
);
}
#[tokio::test]
#[tokio::test]
async
fn
test_no_workers_returns_error
()
{
async
fn
test_no_workers_returns_error
()
{
let
(
queue
,
_
slots
)
=
make_queue
(
0
,
16
,
512
,
None
);
let
(
queue
,
_
slots
)
=
make_queue
(
0
,
16
,
512
,
None
);
...
@@ -542,7 +651,7 @@ mod tests {
...
@@ -542,7 +651,7 @@ mod tests {
let
isl
=
512
;
let
isl
=
512
;
// Start with zero workers (mimics skip_initial_worker_wait=true)
// Start with zero workers (mimics skip_initial_worker_wait=true)
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Routing with no workers must fail
// Routing with no workers must fail
let
(
req_fail
,
rx_fail
)
=
make_request
(
"before-register"
,
isl
);
let
(
req_fail
,
rx_fail
)
=
make_request
(
"before-register"
,
isl
);
...
@@ -590,9 +699,11 @@ mod tests {
...
@@ -590,9 +699,11 @@ mod tests {
// Clean up
// Clean up
slots
slots
.mark_prefill_completed
(
&
"after-register"
.to_string
())
.mark_prefill_completed
(
&
"after-register"
.to_string
(),
decay_now
())
.unwrap
();
slots
.free
(
&
"after-register"
.to_string
(),
decay_now
())
.unwrap
();
.unwrap
();
slots
.free
(
&
"after-register"
.to_string
())
.unwrap
();
}
}
/// Register_workers is additive: calling with a new set does NOT remove old workers.
/// Register_workers is additive: calling with a new set does NOT remove old workers.
...
@@ -601,7 +712,7 @@ mod tests {
...
@@ -601,7 +712,7 @@ mod tests {
let
block_size
=
16
;
let
block_size
=
16
;
let
isl
=
256
;
let
isl
=
256
;
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Register worker 10 in slots and config
// Register worker 10 in slots and config
let
mut
dp1
=
std
::
collections
::
HashMap
::
new
();
let
mut
dp1
=
std
::
collections
::
HashMap
::
new
();
...
@@ -643,8 +754,8 @@ mod tests {
...
@@ -643,8 +754,8 @@ mod tests {
.expect
(
"oneshot dropped"
)
.expect
(
"oneshot dropped"
)
.expect
(
"scheduling failed"
);
.expect
(
"scheduling failed"
);
seen
.insert
(
resp
.best_worker.worker_id
);
seen
.insert
(
resp
.best_worker.worker_id
);
slots
.mark_prefill_completed
(
&
req_id
)
.unwrap
();
slots
.mark_prefill_completed
(
&
req_id
,
decay_now
()
)
.unwrap
();
slots
.free
(
&
req_id
)
.unwrap
();
slots
.free
(
&
req_id
,
decay_now
()
)
.unwrap
();
}
}
assert
!
(
assert
!
(
...
@@ -659,7 +770,7 @@ mod tests {
...
@@ -659,7 +770,7 @@ mod tests {
let
block_size
=
16
;
let
block_size
=
16
;
let
isl
=
256
;
let
isl
=
256
;
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
);
let
(
queue
,
slots
,
cfg_tx
)
=
make_queue_with_sender
(
0
,
block_size
,
isl
,
None
,
None
);
// Register three workers
// Register three workers
let
mut
dp
=
std
::
collections
::
HashMap
::
new
();
let
mut
dp
=
std
::
collections
::
HashMap
::
new
();
...
@@ -712,9 +823,9 @@ mod tests {
...
@@ -712,9 +823,9 @@ mod tests {
resp
.best_worker.worker_id
resp
.best_worker.worker_id
);
);
slots
slots
.mark_prefill_completed
(
&
"filter-0"
.to_string
())
.mark_prefill_completed
(
&
"filter-0"
.to_string
()
,
decay_now
()
)
.unwrap
();
.unwrap
();
slots
.free
(
&
"filter-0"
.to_string
())
.unwrap
();
slots
.free
(
&
"filter-0"
.to_string
()
,
decay_now
()
)
.unwrap
();
}
}
#[tokio::test(flavor
=
"multi_thread"
)]
#[tokio::test(flavor
=
"multi_thread"
)]
...
@@ -727,7 +838,7 @@ mod tests {
...
@@ -727,7 +838,7 @@ mod tests {
let
_
resp1
=
rx1
.await
.unwrap
()
.unwrap
();
let
_
resp1
=
rx1
.await
.unwrap
()
.unwrap
();
assert_eq!
(
assert_eq!
(
slots
slots
.active_tokens
()
.active_tokens
(
decay_now
()
)
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.copied
(),
.copied
(),
Some
(
0
)
Some
(
0
)
...
@@ -738,9 +849,9 @@ mod tests {
...
@@ -738,9 +849,9 @@ mod tests {
let
_
resp2
=
rx2
.await
.unwrap
()
.unwrap
();
let
_
resp2
=
rx2
.await
.unwrap
()
.unwrap
();
assert_eq!
(
queue
.pending_count
(),
0
);
assert_eq!
(
queue
.pending_count
(),
0
);
let
_
=
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-1"
.to_string
());
let
_
=
slots
.free
(
&
"req-1"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
());
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
()
,
decay_now
()
);
let
_
=
slots
.free
(
&
"req-2"
.to_string
());
let
_
=
slots
.free
(
&
"req-2"
.to_string
()
,
decay_now
()
);
}
}
}
}
lib/kv-router/src/sequences/block_tracker.rs
0 → 100644
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dynamo_tokens
::
SequenceHash
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
Weak
};
#[derive(Debug,
Default)]
pub
(
super
)
struct
BlockTracker
{
pub
(
super
)
unique_blocks
:
HashMap
<
SequenceHash
,
Weak
<
()
>>
,
pub
(
super
)
fractional_blocks
:
HashMap
<
SequenceHash
,
f64
>
,
}
impl
BlockTracker
{
pub
(
super
)
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
Arc
<
()
>
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
let
Some
(
rc
)
=
weak
.upgrade
()
{
return
rc
;
}
let
rc
=
Arc
::
new
(());
self
.unique_blocks
.insert
(
*
block
,
Arc
::
downgrade
(
&
rc
));
rc
}
pub
(
super
)
fn
try_remove_block
(
&
mut
self
,
block
:
&
SequenceHash
)
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
weak
.strong_count
()
==
0
{
self
.unique_blocks
.remove
(
block
);
self
.fractional_blocks
.remove
(
block
);
}
}
pub
(
super
)
fn
active_blocks
(
&
self
)
->
usize
{
let
mut
count
=
self
.unique_blocks
.len
()
as
f64
;
for
(
hash
,
frac
)
in
&
self
.fractional_blocks
{
if
self
.unique_blocks
.contains_key
(
hash
)
{
count
=
count
-
1.0
+
frac
;
}
}
count
.round
()
as
usize
}
}
lib/kv-router/src/sequences/mod.rs
View file @
95a750f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
mod
block_tracker
;
pub
mod
multi_worker
;
pub
mod
multi_worker
;
mod
prefill_tracker
;
pub
mod
single
;
pub
mod
single
;
pub
use
multi_worker
::
*
;
pub
use
multi_worker
::
*
;
...
...
lib/kv-router/src/sequences/multi_worker.rs
View file @
95a750f4
...
@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken;
...
@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken;
use
super
::
single
::{
ActiveSequences
,
RequestId
};
use
super
::
single
::{
ActiveSequences
,
RequestId
};
use
crate
::
protocols
::{
use
crate
::
protocols
::{
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
WorkerWithDpRank
,
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
PrefillLoadHint
,
WorkerWithDpRank
,
};
};
// How often we force expire stale requests across all workers. See the comment
// How often we force expire stale requests across all workers. See the comment
...
@@ -93,6 +94,7 @@ pub struct SequenceRequest {
...
@@ -93,6 +94,7 @@ pub struct SequenceRequest {
pub
overlap
:
u32
,
pub
overlap
:
u32
,
pub
track_prefill_tokens
:
bool
,
pub
track_prefill_tokens
:
bool
,
pub
expected_output_tokens
:
Option
<
u32
>
,
pub
expected_output_tokens
:
Option
<
u32
>
,
pub
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
pub
worker
:
WorkerWithDpRank
,
pub
worker
:
WorkerWithDpRank
,
pub
lora_name
:
Option
<
String
>
,
pub
lora_name
:
Option
<
String
>
,
}
}
...
@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return
;
return
;
}
}
// TODO: Publish explicit prompt-load decay timestamps with these events so peer routers
// can mirror the same oldest-prefill anchor instead of approximating from receive time.
let
publisher
=
Arc
::
clone
(
&
self
.publisher
);
let
publisher
=
Arc
::
clone
(
&
self
.publisher
);
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
publisher
.publish_event
(
&
event
)
.await
{
if
let
Err
(
e
)
=
publisher
.publish_event
(
&
event
)
.await
{
...
@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
continue
;
continue
;
}
}
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time.
let
decay_now
=
Instant
::
now
();
match
&
event
.data
{
match
&
event
.data
{
ActiveSequenceEventData
::
AddRequest
{
ActiveSequenceEventData
::
AddRequest
{
token_sequence
,
token_sequence
,
...
@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
overlap
,
track_prefill_tokens
,
track_prefill_tokens
,
expected_output_tokens
,
expected_output_tokens
,
prefill_load_hint
,
}
=>
{
}
=>
{
self
.request_to_worker
self
.request_to_worker
.insert
(
event
.request_id
.clone
(),
event
.worker
);
.insert
(
event
.request_id
.clone
(),
event
.worker
);
...
@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
*
overlap
,
*
overlap
,
*
expected_output_tokens
,
*
expected_output_tokens
,
*
track_prefill_tokens
,
*
track_prefill_tokens
,
*
prefill_load_hint
,
decay_now
,
);
);
}
else
{
}
else
{
tracing
::
warn!
(
tracing
::
warn!
(
...
@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
{
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
table
.slots
[
idx
]
.1
.write
()
.free
(
&
event
.request_id
);
table
.slots
[
idx
]
.1
.write
()
.free
(
&
event
.request_id
,
decay_now
);
}
}
}
}
self
.request_to_lora
.remove
(
&
event
.request_id
);
self
.request_to_lora
.remove
(
&
event
.request_id
);
...
@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
table
.slots
[
idx
]
table
.slots
[
idx
]
.1
.1
.write
()
.write
()
.mark_prefill_completed
(
&
event
.request_id
);
.mark_prefill_completed
(
&
event
.request_id
,
decay_now
);
}
}
}
}
}
}
...
@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
}
}
fn
add_request_local
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
fn
add_request_local
(
&
self
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
let
SequenceRequest
{
let
SequenceRequest
{
request_id
,
request_id
,
token_sequence
,
token_sequence
,
...
@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
overlap
,
track_prefill_tokens
,
track_prefill_tokens
,
expected_output_tokens
,
expected_output_tokens
,
prefill_load_hint
,
worker
,
worker
,
lora_name
,
lora_name
,
}
=
req
;
}
=
req
;
...
@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
,
overlap
,
expected_output_tokens
,
expected_output_tokens
,
track_prefill_tokens
,
track_prefill_tokens
,
prefill_load_hint
,
decay_now
,
)
)
};
};
...
@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
expired_id
);
self
.request_to_lora
.remove
(
expired_id
);
}
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
Ok
(())
Ok
(())
}
}
pub
fn
add_request
(
&
self
,
req
:
SequenceRequest
)
->
Result
<
(),
SequenceError
>
{
pub
fn
add_request
(
&
self
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
self
.spawn_publish_event
(
ActiveSequenceEvent
{
self
.spawn_publish_event
(
ActiveSequenceEvent
{
request_id
:
req
.request_id
.clone
(),
request_id
:
req
.request_id
.clone
(),
worker
:
req
.worker
,
worker
:
req
.worker
,
...
@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap
:
req
.overlap
,
overlap
:
req
.overlap
,
track_prefill_tokens
:
req
.track_prefill_tokens
,
track_prefill_tokens
:
req
.track_prefill_tokens
,
expected_output_tokens
:
req
.expected_output_tokens
,
expected_output_tokens
:
req
.expected_output_tokens
,
prefill_load_hint
:
req
.prefill_load_hint
,
},
},
router_id
:
self
.router_id
,
router_id
:
self
.router_id
,
lora_name
:
req
.lora_name
.clone
(),
lora_name
:
req
.lora_name
.clone
(),
});
});
self
.add_request_local
(
req
)
self
.add_request_local
(
req
,
decay_now
)
}
}
/// Send a mutation to the worker assigned to a request, optionally publishing
/// Send a mutation to the worker assigned to a request, optionally publishing
...
@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn
mutate_request_worker_local
(
fn
mutate_request_worker_local
(
&
self
,
&
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
),
decay_now
:
Instant
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
let
worker
=
self
...
@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get
(
&
worker
)
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
mutate_fn
(
&
mut
seq
,
request_id
);
mutate_fn
(
&
mut
seq
,
request_id
,
decay_now
);
}
}
if
remove_mapping
{
if
remove_mapping
{
...
@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
request_id
);
self
.request_to_lora
.remove
(
request_id
);
}
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
Ok
(())
Ok
(())
}
}
...
@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn
mutate_request_worker
(
fn
mutate_request_worker
(
&
self
,
&
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
event_data
:
ActiveSequenceEventData
,
event_data
:
ActiveSequenceEventData
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
),
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
let
worker
=
self
...
@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
lora_name
,
lora_name
,
});
});
self
.mutate_request_worker_local
(
request_id
,
mutate_fn
,
remove_mapping
)
self
.mutate_request_worker_local
(
request_id
,
decay_now
,
mutate_fn
,
remove_mapping
)
}
}
/// Free all blocks associated with a request.
/// Free all blocks associated with a request.
...
@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
)
->
Result
<
(),
SequenceError
>
{
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
->
Result
<
(),
SequenceError
>
{
if
!
self
.request_to_worker
.contains_key
(
request_id
)
{
if
!
self
.request_to_worker
.contains_key
(
request_id
)
{
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
return
Ok
(());
return
Ok
(());
...
@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.mutate_request_worker
(
self
.mutate_request_worker
(
request_id
,
request_id
,
decay_now
,
ActiveSequenceEventData
::
Free
,
ActiveSequenceEventData
::
Free
,
|
seqs
,
rid
|
{
|
seqs
,
rid
,
decay_now
|
{
seqs
.free
(
rid
);
seqs
.free
(
rid
,
decay_now
);
},
},
true
,
true
,
)
)
...
@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
///
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
/// after the first call (idempotent).
pub
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
RequestId
)
->
Result
<
(),
SequenceError
>
{
pub
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
self
.mutate_request_worker
(
self
.mutate_request_worker
(
request_id
,
request_id
,
decay_now
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
|
seqs
,
rid
|
{
|
seqs
,
rid
,
decay_now
|
{
seqs
.mark_prefill_completed
(
rid
);
seqs
.mark_prefill_completed
(
rid
,
decay_now
);
},
},
false
,
false
,
)
)
...
@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
});
});
}
}
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
,
Instant
::
now
()
);
Ok
(())
Ok
(())
}
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
)
{
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
,
decay_now
:
Instant
)
{
let
(
active_blocks
,
active_tokens
)
=
{
let
(
active_blocks
,
active_tokens
)
=
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
...
@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return
;
return
;
};
};
let
seq
=
table
.slots
[
idx
]
.1
.read
();
let
seq
=
table
.slots
[
idx
]
.1
.read
();
(
seq
.active_blocks
(),
seq
.active_tokens
())
(
seq
.active_blocks
(),
seq
.active_tokens
(
decay_now
))
};
};
self
.publisher
self
.publisher
...
@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
isl
:
usize
,
overlaps
:
OverlapScores
,
overlaps
:
OverlapScores
,
decay_now
:
Instant
,
)
->
(
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlaps
,
true
)
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlaps
,
true
,
decay_now
,
)
}
}
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
...
@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl
:
usize
,
isl
:
usize
,
overlaps
:
OverlapScores
,
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
)
->
(
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
...
@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl
,
isl
,
overlap
,
overlap
,
track_prefill_tokens
,
track_prefill_tokens
,
decay_now
,
);
);
potential_blocks
.insert
(
*
worker
,
blocks
);
potential_blocks
.insert
(
*
worker
,
blocks
);
potential_tokens
.insert
(
*
worker
,
tokens
);
potential_tokens
.insert
(
*
worker
,
tokens
);
...
@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
/// Query all workers for their current number of active tokens.
/// Query all workers for their current number of active tokens.
pub
fn
active_tokens
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
for
(
worker
,
lock
)
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.active_tokens
());
results
.insert
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
));
}
}
results
results
}
}
...
@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Return true if any worker satisfies the provided predicate on active token count.
/// Return true if any worker satisfies the provided predicate on active token count.
pub
fn
any_worker_matches_active_tokens
(
pub
fn
any_worker_matches_active_tokens
(
&
self
,
&
self
,
decay_now
:
Instant
,
mut
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
mut
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
)
->
bool
{
)
->
bool
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
for
(
worker
,
lock
)
in
&
table
.slots
{
for
(
worker
,
lock
)
in
&
table
.slots
{
if
predicate
(
*
worker
,
lock
.read
()
.active_tokens
())
{
if
predicate
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
))
{
return
true
;
return
true
;
}
}
}
}
...
@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self
.request_to_lora
.remove
(
expired_id
);
self
.request_to_lora
.remove
(
expired_id
);
removed_request_count
+=
1
;
removed_request_count
+=
1
;
}
}
self
.publish_active_load_for_worker
(
*
worker
);
self
.publish_active_load_for_worker
(
*
worker
,
now
);
}
}
}
}
let
duration
=
now
.elapsed
();
let
duration
=
now
.elapsed
();
...
@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
time
::
Duration
;
use
super
::
*
;
use
super
::
*
;
use
crate
::
protocols
::{
OverlapScores
,
PrefillLoadHint
};
use
crate
::
test_utils
::
NoopSequencePublisher
;
use
crate
::
test_utils
::
NoopSequencePublisher
;
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
...
@@ -854,20 +896,74 @@ mod tests {
...
@@ -854,20 +896,74 @@ mod tests {
async
fn
add_request_can_skip_prefill_token_tracking
()
{
async
fn
add_request_can_skip_prefill_token_tracking
()
{
let
sequences
=
make_sequences
();
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
decay_now
=
Instant
::
now
();
sequences
sequences
.add_request
(
SequenceRequest
{
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
isl
:
12
,
overlap
:
0
,
overlap
:
0
,
track_prefill_tokens
:
false
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
worker
,
lora_name
:
None
,
lora_name
:
None
,
})
},
decay_now
,
)
.unwrap
();
assert_eq!
(
sequences
.active_tokens
(
decay_now
)
.get
(
&
worker
)
.copied
(),
Some
(
0
)
);
}
#[test]
fn
explicit_decay_time_drives_multi_worker_load_queries_consistently
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
start
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
100
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
Some
(
PrefillLoadHint
{
initial_effective_prefill_tokens
:
100
,
expected_prefill_duration
:
Some
(
Duration
::
from_secs
(
10
)),
}),
worker
,
lora_name
:
None
,
},
start
,
)
.unwrap
();
.unwrap
();
assert_eq!
(
sequences
.active_tokens
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
let
decay_now
=
start
+
Duration
::
from_secs
(
5
);
let
active_tokens
=
sequences
.active_tokens
(
decay_now
);
assert_eq!
(
active_tokens
.get
(
&
worker
)
.copied
(),
Some
(
50
));
let
(
_
,
potential_tokens
)
=
sequences
.potential_blocks_and_tokens_with_prefill_tracking
(
None
,
0
,
OverlapScores
::
default
(),
false
,
decay_now
,
);
assert_eq!
(
potential_tokens
.get
(
&
worker
)
.copied
(),
Some
(
50
));
assert
!
(
sequences
.any_worker_matches_active_tokens
(
decay_now
,
|
candidate
,
tokens
|
{
candidate
==
worker
&&
tokens
==
50
})
);
}
}
}
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment