Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
......@@ -8,15 +8,17 @@
//! predictions without knowing about PyO3.
use std::sync::Arc;
use std::time::Duration;
use pyo3::prelude::*;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_mocker::common::perf_model::AicCallback;
/// Wraps a Python AIC InferenceSession for direct calls from Rust.
///
/// The Python object must expose:
/// - `predict_prefill(batch_size, isl, prefix, osl) -> float`
/// - `predict_prefill(batch_size, effective_isl, prefix) -> float`
/// - `predict_decode(batch_size, isl, osl) -> float`
pub(super) struct PyAicCallback {
pub(super) session: PyObject,
......@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback {
unsafe impl Send for PyAicCallback {}
unsafe impl Sync for PyAicCallback {}
impl AicCallback for PyAicCallback {
fn predict_prefill(&self, batch_size: usize, isl: usize, prefix: usize, osl: usize) -> f64 {
impl PyAicCallback {
fn predict_prefill_ms(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> PyResult<f64> {
Python::with_gil(|py| {
self.session
.call_method1(py, "predict_prefill", (batch_size, isl, prefix, osl))
.and_then(|r| r.extract::<f64>(py))
.unwrap_or_else(|e| panic!("AIC predict_prefill failed: {e}"))
.call_method1(py, "predict_prefill", (batch_size, effective_isl, prefix))
.and_then(|result| result.extract::<f64>(py))
})
}
}
impl AicCallback for PyAicCallback {
fn predict_prefill(&self, batch_size: usize, effective_isl: usize, prefix: usize) -> f64 {
self.predict_prefill_ms(batch_size, effective_isl, prefix)
.unwrap_or_else(|e| panic!("AIC predict_prefill failed: {e}"))
}
fn predict_decode(&self, batch_size: usize, isl: usize, osl: usize) -> f64 {
Python::with_gil(|py| {
......@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback {
}
}
impl PrefillLoadEstimator for PyAicCallback {
fn predict_prefill_duration(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> anyhow::Result<Duration> {
let latency_ms = self.predict_prefill_ms(batch_size, effective_isl, prefix)?;
Ok(Duration::from_secs_f64(latency_ms / 1000.0))
}
}
/// Initialize an AIC callback by importing and calling the Python setup function.
///
/// Called once at mocker startup when `--aic-perf-model` is requested.
......@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback(
moe_ep_size: Option<usize>,
attention_dp_size: Option<usize>,
) -> PyResult<Arc<dyn AicCallback>> {
let module = py.import("dynamo.mocker.aic_session")?;
let module = py.import("dynamo._internal.aic")?;
let session = module.call_method1(
"create_session",
(
......@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback(
session: session.into(),
}))
}
pub(super) fn create_aic_prefill_load_estimator(
py: Python<'_>,
backend_name: &str,
system: &str,
model_path: &str,
tp_size: usize,
backend_version: Option<&str>,
) -> PyResult<Arc<dyn PrefillLoadEstimator>> {
let module = py.import("dynamo._internal.aic")?;
let session = module.call_method1(
"create_session",
(backend_name, system, model_path, tp_size, backend_version),
)?;
Ok(Arc::new(PyAicCallback {
session: session.into(),
}))
}
......@@ -10,7 +10,9 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, exceptions::PyValueError, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
use dynamo_kv_router::config::{
KvRouterConfig as RsKvRouterConfig, RouterPrefillLoadModel as RsRouterPrefillLoadModel,
};
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
......@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback;
use super::aic_callback::{create_aic_callback, create_aic_prefill_load_estimator};
use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
......@@ -55,10 +57,76 @@ impl KvRouterConfig {
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct AicPerfConfig {
aic_backend: String,
aic_system: String,
aic_backend_version: Option<String>,
aic_tp_size: usize,
aic_model_path: String,
}
impl AicPerfConfig {
pub(crate) fn backend_name(&self) -> &str {
&self.aic_backend
}
pub(crate) fn system(&self) -> &str {
&self.aic_system
}
pub(crate) fn backend_version(&self) -> Option<&str> {
self.aic_backend_version.as_deref()
}
pub(crate) fn tp_size(&self) -> usize {
self.aic_tp_size
}
pub(crate) fn model_path(&self) -> &str {
&self.aic_model_path
}
}
#[pymethods]
impl AicPerfConfig {
#[new]
#[pyo3(signature = (aic_backend, aic_system, aic_model_path, aic_tp_size=1, aic_backend_version=None))]
fn new(
aic_backend: String,
aic_system: String,
aic_model_path: String,
aic_tp_size: usize,
aic_backend_version: Option<String>,
) -> PyResult<Self> {
if aic_backend.is_empty() {
return Err(PyValueError::new_err("aic_backend must be non-empty"));
}
if aic_system.is_empty() {
return Err(PyValueError::new_err("aic_system must be non-empty"));
}
if aic_model_path.is_empty() {
return Err(PyValueError::new_err("aic_model_path must be non-empty"));
}
if aic_tp_size == 0 {
return Err(PyValueError::new_err("aic_tp_size must be >= 1"));
}
Ok(Self {
aic_backend,
aic_system,
aic_backend_version,
aic_tp_size,
aic_model_path,
})
}
}
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", remote_indexer_component=None))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -70,6 +138,7 @@ impl KvRouterConfig {
router_track_output_blocks: bool,
router_assume_kv_reuse: bool,
router_track_prefill_tokens: bool,
router_prefill_load_model: &str,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
router_ttl_secs: f64,
......@@ -91,6 +160,11 @@ impl KvRouterConfig {
router_track_output_blocks,
router_assume_kv_reuse,
router_track_prefill_tokens,
router_prefill_load_model: router_prefill_load_model
.parse::<RsRouterPrefillLoadModel>()
.unwrap_or_else(|_| {
panic!("invalid router_prefill_load_model: {router_prefill_load_model:?}")
}),
router_snapshot_threshold,
router_reset_states,
router_ttl_secs,
......@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs {
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>,
aic_perf_config: Option<AicPerfConfig>,
}
#[pymethods]
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None, aic_perf_config=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
......@@ -279,6 +354,7 @@ impl EntrypointArgs {
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyObject>,
aic_perf_config: Option<AicPerfConfig>,
) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none())
......@@ -327,6 +403,7 @@ impl EntrypointArgs {
is_prefill,
migration_limit,
chat_engine_factory,
aic_perf_config,
})
}
}
......@@ -467,9 +544,26 @@ async fn select_engine(
EngineType::Dynamic => {
// Convert Python chat engine factory to Rust callback
let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
let prefill_load_estimator = args
.aic_perf_config
.as_ref()
.map(|config| {
Python::with_gil(|py| {
create_aic_prefill_load_estimator(
py,
config.backend_name(),
config.system(),
config.model_path(),
config.tp_size(),
config.backend_version(),
)
})
})
.transpose()?;
RsEngineConfig::Dynamic {
model: Box::new(local_model),
chat_engine_factory,
prefill_load_estimator,
}
}
EngineType::Mocker => {
......
......@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker;
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json;
use super::aic_callback::create_aic_prefill_load_estimator;
use super::entrypoint::AicPerfConfig;
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
depythonize(obj).map_err(to_pyerr)
}
......@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Create ModelManager and use it to create KvRouter (ensures registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
......@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint(
&endpoint.inner,
block_size as u32,
kv_router_config,
prefill_load_estimator,
worker_type,
model_name,
enable_eagle,
......@@ -888,12 +893,29 @@ impl KvRouter {
/// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
/// (contains "prefill") or by `router_track_active_blocks` being disabled.
#[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config))]
#[pyo3(signature = (endpoint, block_size, kv_router_config, aic_perf_config=None))]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: &super::entrypoint::KvRouterConfig,
aic_perf_config: Option<&AicPerfConfig>,
) -> PyResult<Self> {
let prefill_load_estimator = aic_perf_config
.map(|config| {
Python::with_gil(|py| {
create_aic_prefill_load_estimator(
py,
config.backend_name(),
config.system(),
config.model_path(),
config.tp_size(),
config.backend_version(),
)
})
})
.transpose()
.map_err(to_pyerr)?;
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
let client = endpoint.inner.client().await.map_err(to_pyerr)?;
......@@ -916,6 +938,7 @@ impl KvRouter {
endpoint,
block_size,
Some(kv_router_config.inner()),
prefill_load_estimator,
)
.await?;
......
......@@ -19,8 +19,8 @@ use pythonize::pythonize;
use serde_json::json;
use uuid::Uuid;
use super::aic_callback::create_aic_callback;
use super::entrypoint::{KvRouterConfig, to_pyerr};
use super::aic_callback::{create_aic_callback, create_aic_prefill_load_estimator};
use super::entrypoint::{AicPerfConfig, KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type {
......@@ -526,7 +526,7 @@ impl MockEngineArgs {
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))]
#[pyo3(signature = (trace_file, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, aic_perf_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, trace_block_size=512))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay(
py: Python<'_>,
......@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay(
prefill_engine_args: Option<MockEngineArgs>,
decode_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
num_workers: usize,
num_prefill_workers: usize,
num_decode_workers: usize,
......@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay(
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<PyObject> {
let args_selection = load_replay_args_selection(
py,
......@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay(
num_prefill_workers,
num_decode_workers,
)?;
let router_mode = parse_replay_router_mode(router_mode)?;
let prefill_load_estimator = load_replay_prefill_load_estimator(
py,
router_mode,
router_config.as_ref(),
aic_perf_config,
)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
......@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
max_in_flight,
num_workers,
router_mode,
......@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_file_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
num_workers,
arrival_speedup_ratio,
router_mode,
......@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
max_in_flight,
num_workers,
router_mode,
......@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
num_workers,
arrival_speedup_ratio,
router_mode,
......@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_file_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
max_in_flight,
router_mode,
)
......@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_file_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
&trace_file,
trace_block_size,
arrival_speedup_ratio,
router_mode,
)
......@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay(
}
#[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0, turns_per_session=1, shared_prefix_ratio=0.0, num_prefix_groups=0, inter_turn_delay_ms=0.0))]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, aic_perf_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0, turns_per_session=1, shared_prefix_ratio=0.0, num_prefix_groups=0, inter_turn_delay_ms=0.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>,
......@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay(
prefill_engine_args: Option<MockEngineArgs>,
decode_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
num_workers: usize,
num_prefill_workers: usize,
num_decode_workers: usize,
......@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay(
num_prefill_workers,
num_decode_workers,
)?;
let router_mode = parse_replay_router_mode(router_mode)?;
let prefill_load_estimator = load_replay_prefill_load_estimator(
py,
router_mode,
router_config.as_ref(),
aic_perf_config,
)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let block_size = match &args_selection {
ReplayArgsSelection::Aggregated(args) => args.block_size.max(1),
ReplayArgsSelection::Disagg(config) => config.prefill_args.block_size.max(1),
......@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_workload_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
max_in_flight,
num_workers,
......@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_workload_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
num_workers,
router_mode,
......@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_workload_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
max_in_flight,
num_workers,
......@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_live_workload_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
num_workers,
router_mode,
......@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", Some(max_in_flight)) => dynamo_mocker::replay::simulate_concurrency_workload_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
max_in_flight,
router_mode,
......@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", None) => dynamo_mocker::replay::simulate_trace_workload_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
trace,
router_mode,
),
......@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
max_in_flight,
num_workers,
......@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
num_workers,
arrival_speedup_ratio,
......@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
max_in_flight,
num_workers,
......@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
*args,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
num_workers,
arrival_speedup_ratio,
......@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_requests_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
max_in_flight,
router_mode,
......@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_requests_disagg_with_router_mode(
*config,
router_config.clone(),
prefill_load_estimator.clone(),
requests,
arrival_speedup_ratio,
router_mode,
......@@ -970,6 +1009,57 @@ fn load_replay_router_config(
router_config.map(|config| config.inner())
}
fn load_replay_prefill_load_estimator(
py: Python<'_>,
router_mode: dynamo_mocker::replay::ReplayRouterMode,
router_config: Option<&KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
) -> PyResult<Option<dynamo_mocker::replay::ReplayPrefillLoadEstimator>> {
if router_mode != dynamo_mocker::replay::ReplayRouterMode::KvRouter {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_mode='kv_router'",
));
}
return Ok(None);
}
let Some(router_config) = router_config else {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_config with router_prefill_load_model='aic'",
));
}
return Ok(None);
};
let router_config = router_config.inner();
if !router_config.router_prefill_load_model.is_enabled() {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_prefill_load_model='aic'",
));
}
return Ok(None);
}
let Some(aic_perf_config) = aic_perf_config else {
return Err(PyException::new_err(
"router_prefill_load_model='aic' requires aic_perf_config",
));
};
create_aic_prefill_load_estimator(
py,
aic_perf_config.backend_name(),
aic_perf_config.system(),
aic_perf_config.model_path(),
aic_perf_config.tp_size(),
aic_perf_config.backend_version(),
)
.map(Some)
}
fn parse_replay_router_mode(
router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
......
......@@ -1159,6 +1159,17 @@ class RouterConfig:
"""
...
class AicPerfConfig:
def __init__(
self,
aic_backend: str,
aic_system: str,
aic_model_path: str,
aic_tp_size: int = 1,
aic_backend_version: Optional[str] = None,
) -> None:
...
class KvRouterConfig:
"""Values for KV router"""
......@@ -1172,6 +1183,8 @@ class KvRouterConfig:
router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True,
router_track_prefill_tokens: bool = True,
router_prefill_load_model: str = "none",
router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False,
router_ttl_secs: float = 120.0,
......@@ -1199,6 +1212,10 @@ class KvRouterConfig:
sequence length (agent_hints.osl in nvext).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes.
router_track_prefill_tokens: Include prompt-side prefill tokens in active load accounting (default: True).
router_prefill_load_model: Prompt-side prefill load model (default: "none").
"none" keeps static prompt load accounting.
"aic" decays the oldest active prefill request using AIC-predicted duration.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
......@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
......@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay(
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
......@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1,
num_prefill_workers: int = 1,
num_decode_workers: int = 1,
......@@ -1779,6 +1799,7 @@ class KvRouter:
endpoint: Endpoint,
block_size: int,
kv_router_config: KvRouterConfig,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
"""
Create a new KvRouter instance.
......@@ -1787,6 +1808,7 @@ class KvRouter:
endpoint: The endpoint to connect to for routing requests
block_size: The KV cache block size
kv_router_config: Configuration for the KV router
aic_perf_config: Optional AIC perf-model config for effective prefill load tracking
"""
...
......@@ -1998,6 +2020,7 @@ class EntrypointArgs:
is_prefill: bool = False,
migration_limit: int = 0,
chat_engine_factory: Optional[Callable] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None:
"""
Create EntrypointArgs.
......@@ -2024,6 +2047,7 @@ class EntrypointArgs:
is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled)
chat_engine_factory: Optional Python chat completions engine factory callback
aic_perf_config: Optional AIC perf-model configuration for default KV routing
"""
...
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared AIC session helpers used by internal Dynamo integrations."""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
DEFAULT_BACKEND_VERSIONS = {
"vllm": "0.12.0",
"sglang": "0.5.6.post2",
}
DEFAULT_STATIC_STRIDE = 32
def resolve_backend_version(backend_name: str, backend_version: str | None) -> str:
"""Return the pinned backend version used for AIC perf lookups."""
if backend_version is not None:
return backend_version
return DEFAULT_BACKEND_VERSIONS.get(backend_name, DEFAULT_BACKEND_VERSIONS["vllm"])
def _load_aiconfigurator():
try:
from aiconfigurator.sdk import config
from aiconfigurator.sdk.backends.factory import get_backend
from aiconfigurator.sdk.inference_session import InferenceSession
from aiconfigurator.sdk.models import get_model
from aiconfigurator.sdk.perf_database import (
get_database,
get_supported_databases,
)
except (
ImportError
) as exc: # pragma: no cover - exercised in integration environments
raise RuntimeError(
"aiconfigurator is required for AIC perf modeling but is not installed"
) from exc
return {
"config": config,
"get_backend": get_backend,
"InferenceSession": InferenceSession,
"get_model": get_model,
"get_database": get_database,
"get_supported_databases": get_supported_databases,
}
class AicSession:
"""Wrap an AIC InferenceSession with direct prefill/decode predictors."""
def __init__(
self,
backend_name: str,
system: str,
model_path: str,
tp_size: int,
backend_version: str | None = None,
moe_tp_size: int | None = None,
moe_ep_size: int | None = None,
attention_dp_size: int | None = None,
):
aic = _load_aiconfigurator()
version = resolve_backend_version(backend_name, backend_version)
database = aic["get_database"](
system=system, backend=backend_name, version=version
)
if database is None:
supported = (
aic["get_supported_databases"]().get(system, {}).get(backend_name, [])
)
supported_versions = ", ".join(supported) if supported else "<none>"
raise RuntimeError(
"AIC perf database not found for "
f"system={system!r}, backend={backend_name!r}, version={version!r}. "
f"Supported versions for this system/backend: {supported_versions}"
)
model_config = aic["config"].ModelConfig(
tp_size=tp_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size,
attention_dp_size=attention_dp_size or 1,
)
model = aic["get_model"](
model_path=model_path,
model_config=model_config,
backend_name=backend_name,
)
backend = aic["get_backend"](backend_name)
self._session = aic["InferenceSession"](
model=model, database=database, backend=backend
)
self._database = database
self._model = model
self._model_name = getattr(model, "model_name", None) or model_path
logger.info(
"AIC session initialized: backend=%s, system=%s, model=%s, tp=%d",
backend_name,
system,
model_path,
tp_size,
)
def _predict_context_latency(
self, batch_size: int, effective_isl: int, prefix: int
) -> float:
if effective_isl <= 0:
raise ValueError(
f"effective_isl must be positive, got effective_isl={effective_isl}"
)
total_latency = 0.0
for op in self._model.context_ops:
op_name = getattr(op, "_name", "")
x = batch_size if "logits_gemm" in op_name else batch_size * effective_isl
result = op.query(
self._database,
x=x,
batch_size=batch_size,
beam_width=1,
s=effective_isl,
prefix=prefix,
model_name=self._model_name,
seq_imbalance_correction_scale=1.0,
)
total_latency += float(result)
return total_latency
def _predict_generation_latency(self, batch_size: int, isl: int, osl: int) -> float:
if osl <= 1:
return 0.0
effective_batch_size = batch_size * (self._model._nextn + 1)
total_latency = 0.0
for step in range(0, osl - 1, DEFAULT_STATIC_STRIDE):
step_latency = 0.0
for op in self._model.generation_ops:
result = op.query(
self._database,
x=effective_batch_size,
batch_size=effective_batch_size,
beam_width=1,
s=isl + step + 1,
model_name=self._model_name,
gen_seq_imbalance_correction_scale=1.0,
)
step_latency += float(result)
repeat_count = min(DEFAULT_STATIC_STRIDE, osl - 1 - step)
total_latency += step_latency * repeat_count
return total_latency
def predict_prefill(
self, batch_size: int, effective_isl: int, prefix: int
) -> float:
"""Predict prefill latency in ms from uncached tokens and cached prefix."""
return self._predict_context_latency(batch_size, effective_isl, prefix)
def predict_decode(self, batch_size: int, isl: int, osl: int) -> float:
"""Predict decode (generation) latency in ms."""
return self._predict_generation_latency(batch_size, isl, osl)
def create_session(
backend_name: str,
system: str,
model_path: str,
tp_size: int,
backend_version: str | None = None,
moe_tp_size: int | None = None,
moe_ep_size: int | None = None,
attention_dp_size: int | None = None,
) -> AicSession:
"""Factory function called from Rust via PyO3."""
return AicSession(
backend_name,
system,
model_path,
tp_size,
backend_version,
moe_tp_size,
moe_ep_size,
attention_dp_size,
)
......@@ -5,6 +5,7 @@
import logging
from dynamo._core import AicPerfConfig as AicPerfConfig
from dynamo._core import EngineType
from dynamo._core import EntrypointArgs as EntrypointArgs
from dynamo._core import FpmEventRelay as FpmEventRelay
......@@ -57,6 +58,7 @@ def run_mocker_trace_replay(
replay_concurrency=None,
router_mode="round_robin",
arrival_speedup_ratio=1.0,
trace_block_size=512,
):
return _run_mocker_trace_replay(
trace_file,
......@@ -67,4 +69,5 @@ def run_mocker_trace_replay(
replay_mode="offline",
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
......@@ -14,6 +14,7 @@ def run_trace_replay(
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=1,
num_prefill_workers=1,
num_decode_workers=1,
......@@ -21,6 +22,7 @@ def run_trace_replay(
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
trace_block_size=512,
):
return _run_mocker_trace_replay(
trace_file,
......@@ -28,6 +30,7 @@ def run_trace_replay(
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=num_workers,
num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers,
......@@ -35,6 +38,7 @@ def run_trace_replay(
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
......@@ -47,6 +51,7 @@ def run_synthetic_trace_replay(
prefill_engine_args=None,
decode_engine_args=None,
router_config=None,
aic_perf_config=None,
num_workers=1,
num_prefill_workers=1,
num_decode_workers=1,
......@@ -68,6 +73,7 @@ def run_synthetic_trace_replay(
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=num_workers,
num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers,
......
......@@ -15,7 +15,7 @@ from typing import Protocol
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.llm import AicPerfConfig, KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
from dynamo.replay.reporting import format_report_table, write_report_json
......@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None):
return MockEngineArgs.from_json(json.dumps(raw))
def _load_aic_perf_config(args: argparse.Namespace):
values = {
"aic_backend": args.aic_backend,
"aic_system": args.aic_system,
"aic_model_path": args.aic_model_path,
"aic_backend_version": args.aic_backend_version,
"aic_tp_size": args.aic_tp_size,
}
if not any(value is not None for value in values.values()):
return None
missing = [
name
for name in ("aic_backend", "aic_system", "aic_model_path")
if values[name] is None
]
if missing:
missing_flags = ", ".join(f"--{name.replace('_', '-')}" for name in missing)
raise ValueError(f"AIC replay modeling requires {missing_flags}")
return AicPerfConfig(
aic_backend=values["aic_backend"],
aic_system=values["aic_system"],
aic_model_path=values["aic_model_path"],
aic_tp_size=values["aic_tp_size"] or 1,
aic_backend_version=values["aic_backend_version"],
)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
......@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int:
parser.add_argument("--prefill-engine-args")
parser.add_argument("--decode-engine-args")
parser.add_argument("--router-config")
parser.add_argument("--aic-backend")
parser.add_argument("--aic-system")
parser.add_argument("--aic-backend-version")
parser.add_argument("--aic-tp-size", type=int)
parser.add_argument("--aic-model-path")
parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int)
parser.add_argument(
......@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int:
default="round_robin",
)
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
parser.add_argument(
"--trace-block-size",
type=int,
default=512,
help="tokens represented by each hash_id in the trace file; only used for file replay",
)
parser.add_argument(
"--report-json",
help="path to save the full replay report JSON; defaults to a timestamped file in the current directory",
......@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int:
if args.router_config is not None
else None
)
try:
aic_perf_config = _load_aic_perf_config(args)
except ValueError as exc:
parser.error(str(exc))
if using_trace_file:
report = run_trace_replay(
......@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers,
......@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int:
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size,
)
else:
report = run_synthetic_trace_replay(
......@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers,
......
......@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode):
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_supports_distinct_trace_and_engine_block_sizes(
tmp_path, replay_mode
):
trace_path = tmp_path / "trace_block_size_split.jsonl"
trace_path.write_text(
'{"timestamp":1000.0,"input_length":128,"output_length":2,"hash_ids":[101]}\n',
encoding="utf-8",
)
report = run_trace_replay(
trace_path,
extra_engine_args=_vllm_args(),
num_workers=1,
replay_mode=replay_mode,
trace_block_size=512,
)
_assert_basic_report_counts(
report,
num_requests=1,
input_tokens=128,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
......
......@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use config::{KvRouterConfig, RouterConfigOverride, RouterPrefillLoadModel, RouterQueuePolicy};
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
......@@ -50,6 +50,7 @@ pub use protocols::{
pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler;
pub use scheduling::PrefillLoadEstimator;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector};
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::time::Duration;
use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap;
......@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent {
pub lora_name: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrefillLoadHint {
pub initial_effective_prefill_tokens: usize,
pub expected_prefill_duration: Option<Duration>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData {
AddRequest {
......@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData {
#[serde(default = "default_track_prefill_tokens")]
track_prefill_tokens: bool,
expected_output_tokens: Option<u32>,
#[serde(default)]
prefill_load_hint: Option<PrefillLoadHint>,
},
Free,
MarkPrefillCompleted,
......
......@@ -4,6 +4,7 @@
use std::env::{self, VarError};
use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use derive_builder::Builder;
use rand::Rng;
......@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy {
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterPrefillLoadModel {
#[default]
None,
Aic,
}
impl fmt::Display for RouterPrefillLoadModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.write_str("none"),
Self::Aic => f.write_str("aic"),
}
}
}
impl FromStr for RouterPrefillLoadModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Self::None),
"aic" => Ok(Self::Aic),
_ => Err(format!(
"unknown prefill load model: {s:?}, expected 'none' or 'aic'"
)),
}
}
}
impl RouterPrefillLoadModel {
pub fn is_enabled(self) -> bool {
!matches!(self, Self::None)
}
}
impl FromStr for RouterQueuePolicy {
type Err = String;
......@@ -124,6 +162,9 @@ pub struct KvRouterConfig {
#[serde(default = "default_track_prefill_tokens")]
pub router_track_prefill_tokens: bool,
/// Optional model for estimating effective prompt-side prefill load over time.
pub router_prefill_load_model: RouterPrefillLoadModel,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>,
......@@ -183,6 +224,7 @@ impl Default for KvRouterConfig {
router_track_output_blocks: false,
router_assume_kv_reuse: true,
router_track_prefill_tokens: default_track_prefill_tokens(),
router_prefill_load_model: RouterPrefillLoadModel::default(),
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
......@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr
"router_track_output_blocks requires router_track_active_blocks=true",
));
}
if config.router_prefill_load_model.is_enabled() && !config.router_track_prefill_tokens {
return Err(ValidationError::new(
"router_prefill_load_model requires router_track_prefill_tokens=true",
));
}
if config.router_prefill_load_model.is_enabled()
&& !matches!(config.router_queue_policy, RouterQueuePolicy::Fcfs)
{
return Err(ValidationError::new(
"router_prefill_load_model currently requires router_queue_policy='fcfs'",
));
}
Ok(())
}
impl KvRouterConfig {
pub fn router_queue_recheck_interval(&self) -> Duration {
const DEFAULT_RECHECK_INTERVAL: Duration = Duration::from_secs(60);
const PREFILL_LOAD_RECHECK_INTERVAL: Duration = Duration::from_millis(100);
if self.router_prefill_load_model.is_enabled() && self.router_queue_threshold.is_some() {
return PREFILL_LOAD_RECHECK_INTERVAL;
}
DEFAULT_RECHECK_INTERVAL
}
pub fn assume_kv_reuse(&self, config_override: Option<&RouterConfigOverride>) -> bool {
config_override
.and_then(|cfg| cfg.assume_kv_reuse)
......@@ -288,28 +353,6 @@ mod tests {
use super::*;
use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_tracking_prefill_tokens() {
assert!(KvRouterConfig::default().router_track_prefill_tokens);
}
#[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default();
......@@ -343,17 +386,6 @@ mod tests {
assert_ne!(without_mm, with_mm);
}
#[test]
fn router_config_override_serde_round_trip_preserves_track_prefill_tokens() {
let serialized = serde_json::to_string(&RouterConfigOverride {
track_prefill_tokens: Some(false),
..Default::default()
})
.unwrap();
let deserialized: RouterConfigOverride = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.track_prefill_tokens, Some(false));
}
#[test]
fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() {
let config = KvRouterConfig::default();
......
......@@ -6,9 +6,11 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
......@@ -18,8 +20,6 @@ use crate::sequences::{
};
use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where
P: SequencePublisher,
......@@ -49,6 +49,8 @@ where
block_size: u32,
selector: Sel,
policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
recheck_interval: Duration,
track_prefill_tokens_default: bool,
cancellation_token: CancellationToken,
worker_type: &'static str,
......@@ -103,13 +105,14 @@ where
block_size,
selector,
policy,
prefill_load_estimator,
));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL);
let mut recheck_interval = tokio::time::interval(recheck_interval);
tracing::trace!("LocalScheduler background task started");
loop {
......@@ -192,17 +195,18 @@ where
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req)
self.slots.add_request(req, Instant::now())
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.mark_prefill_completed(&request_id.to_string())?;
self.slots
.mark_prefill_completed(&request_id.to_string(), Instant::now())?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string())?;
self.slots.free(&request_id.to_string(), Instant::now())?;
self.queue.update().await;
Ok(())
}
......@@ -231,6 +235,7 @@ where
overlaps: OverlapScores,
track_prefill_tokens: bool,
) -> Vec<PotentialLoad> {
let decay_now = Instant::now();
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens_with_prefill_tracking(
......@@ -238,6 +243,7 @@ where
isl_tokens,
overlaps,
track_prefill_tokens,
decay_now,
);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
......@@ -275,15 +281,32 @@ mod tests {
use super::*;
use crate::protocols::OverlapScores;
use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
#[allow(clippy::type_complexity)]
fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>,
monitor_worker_configs: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
......@@ -311,6 +334,8 @@ mod tests {
64,
DefaultWorkerSelector::new(None, "test"),
FcfsPolicy,
prefill_load_estimator,
Duration::from_secs(60),
true,
cancel_token.clone(),
"test",
......@@ -329,7 +354,7 @@ mod tests {
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let response = scheduler
.schedule(
......@@ -366,7 +391,7 @@ mod tests {
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler
.schedule(
......@@ -389,7 +414,7 @@ mod tests {
assert_eq!(
slots
.active_tokens()
.active_tokens(Instant::now())
.get(&WorkerWithDpRank::new(0, 0))
.copied(),
Some(0)
......@@ -408,7 +433,8 @@ mod tests {
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true);
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
scheduler
.schedule(
......@@ -466,7 +492,7 @@ mod tests {
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler
.schedule(
......@@ -511,12 +537,16 @@ mod tests {
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone());
let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens(
Some(&token_seq),
128,
overlaps.clone(),
Instant::now(),
);
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
......@@ -548,10 +578,51 @@ mod tests {
cancel_token.cancel();
}
#[tokio::test(start_paused = true)]
async fn test_get_potential_loads_uses_decayed_prefill_tokens() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let estimator: Arc<dyn PrefillLoadEstimator> = Arc::new(FixedPrefillLoadEstimator {
duration: Duration::from_secs(10),
});
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(workers, None, true, Some(estimator));
scheduler
.schedule(
Some("req-1".to_string()),
100,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40);
cancel_token.cancel();
}
#[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false);
make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true);
......@@ -567,7 +638,7 @@ mod tests {
async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
assert_eq!(
scheduler
......@@ -615,7 +686,7 @@ mod tests {
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler
.schedule(
......
......@@ -4,9 +4,11 @@
pub mod config;
mod local;
pub mod policy;
pub mod prefill_load;
pub mod queue;
pub mod selector;
mod types;
pub use local::LocalScheduler;
pub use prefill_load::PrefillLoadEstimator;
pub use types::*;
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
pub trait PrefillLoadEstimator: Send + Sync {
fn predict_prefill_duration(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> anyhow::Result<Duration>;
}
......@@ -5,15 +5,16 @@ use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::time::Instant;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio::time::Instant;
use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::protocols::{PrefillLoadHint, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
......@@ -68,6 +69,7 @@ pub struct SchedulerQueue<
block_size: u32,
selector: Sel,
policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
}
impl<
......@@ -84,6 +86,7 @@ impl<
block_size: u32,
selector: Sel,
policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> Self {
if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}");
......@@ -98,6 +101,7 @@ impl<
block_size,
selector,
policy,
prefill_load_estimator,
}
}
......@@ -133,23 +137,24 @@ impl<
/// capacity check is skipped.
pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else {
self.schedule(request).await;
self.schedule(request, Instant::now()).await;
return;
};
if request.allowed_worker_ids.is_some() {
self.schedule(request).await;
self.schedule(request, Instant::now()).await;
return;
}
if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref()) {
let decay_now = Instant::now();
if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref(), decay_now) {
tracing::debug!("all workers busy, queueing request");
let arrival_offset = self.start_time.elapsed();
let key = self.policy.enqueue_key(arrival_offset, &request);
self.pending.lock().await.push(QueueEntry { key, request });
self.pending_count.fetch_add(1, AtomicOrdering::Relaxed);
} else {
self.schedule(request).await;
self.schedule(request, decay_now).await;
}
}
......@@ -176,7 +181,8 @@ impl<
}
loop {
if self.all_workers_busy(threshold, None) {
let decay_now = Instant::now();
if self.all_workers_busy(threshold, None, decay_now) {
break;
}
let Some(entry) = self.pending.lock().await.pop() else {
......@@ -184,13 +190,13 @@ impl<
};
self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed);
tracing::debug!("scheduling request from pending queue");
self.schedule(entry.request).await;
self.schedule(entry.request, decay_now).await;
}
}
/// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) {
async fn schedule(&self, mut request: SchedulingRequest, decay_now: Instant) {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens_with_prefill_tracking(
......@@ -198,6 +204,7 @@ impl<
request.isl_tokens,
request.overlaps.clone(),
request.track_prefill_tokens,
decay_now,
);
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
......@@ -231,20 +238,66 @@ impl<
return;
};
if let Err(e) = self.slots.add_request(SequenceRequest {
let prefill_load_hint = self.prefill_load_hint_for(
request.isl_tokens,
selection.overlap_blocks,
request.track_prefill_tokens,
);
if let Err(e) = self.slots.add_request(
SequenceRequest {
request_id: request_id.clone(),
token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
prefill_load_hint,
worker: selection.worker,
lora_name: request.lora_name.clone(),
}) {
},
decay_now,
) {
tracing::warn!("Failed to add request {request_id}: {e}");
}
}
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
}
}
/// Number of requests currently parked in the pending queue (lock-free).
pub fn pending_count(&self) -> usize {
self.pending_count.load(AtomicOrdering::Relaxed)
......@@ -255,8 +308,13 @@ impl<
/// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error.
fn all_workers_busy(&self, threshold: f64, allowed: Option<&HashSet<WorkerId>>) -> bool {
let active_tokens = self.slots.active_tokens();
fn all_workers_busy(
&self,
threshold: f64,
allowed: Option<&HashSet<WorkerId>>,
decay_now: Instant,
) -> bool {
let active_tokens = self.slots.active_tokens(decay_now);
let configs = self.workers_with_configs.borrow();
let mut checked_any = false;
......@@ -289,6 +347,7 @@ impl<
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
......@@ -298,6 +357,25 @@ mod tests {
use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
fn decay_now() -> Instant {
Instant::now()
}
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
fn make_queue(
num_workers: usize,
block_size: u32,
......@@ -308,7 +386,7 @@ mod tests {
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
) {
let (queue, slots, _tx) =
make_queue_with_sender(num_workers, block_size, isl, threshold_frac);
make_queue_with_sender(num_workers, block_size, isl, threshold_frac, None);
(queue, slots)
}
......@@ -318,6 +396,7 @@ mod tests {
block_size: u32,
isl: usize,
threshold_frac: Option<f64>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> (
Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
......@@ -354,6 +433,7 @@ mod tests {
block_size,
selector,
FcfsPolicy,
prefill_load_estimator,
));
(queue, slots, cfg_tx)
......@@ -409,8 +489,8 @@ mod tests {
let resp = resp.expect("scheduling failed");
assert!(resp.best_worker.worker_id < num_workers as u64);
slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).unwrap();
slots.mark_prefill_completed(&req_id, decay_now()).unwrap();
slots.free(&req_id, decay_now()).unwrap();
queue.update().await;
}));
}
......@@ -419,7 +499,7 @@ mod tests {
h.await.expect("task panicked");
}
let active = slots.active_tokens();
let active = slots.active_tokens(decay_now());
for (worker, tokens) in &active {
assert_eq!(
*tokens, 0,
......@@ -453,8 +533,8 @@ mod tests {
for _ in 0..num_requests {
queue.update().await;
for rid in &req_ids {
let _ = slots.mark_prefill_completed(rid);
let _ = slots.free(rid);
let _ = slots.mark_prefill_completed(rid, decay_now());
let _ = slots.free(rid, decay_now());
}
}
queue.update().await;
......@@ -495,8 +575,10 @@ mod tests {
assert_eq!(queue.pending_count(), 2);
// Free the first request and update — should drain one from pending
slots.mark_prefill_completed(&"req-1".to_string()).unwrap();
slots.free(&"req-1".to_string()).unwrap();
slots
.mark_prefill_completed(&"req-1".to_string(), decay_now())
.unwrap();
slots.free(&"req-1".to_string(), decay_now()).unwrap();
queue.update().await;
// After update, one pending request should have been scheduled
......@@ -507,16 +589,43 @@ mod tests {
);
// Free req-2 and update to drain remaining
let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string());
let _ = slots.mark_prefill_completed(&"req-2".to_string(), decay_now());
let _ = slots.free(&"req-2".to_string(), decay_now());
queue.update().await;
let _ = slots.mark_prefill_completed(&"req-3".to_string());
let _ = slots.free(&"req-3".to_string());
let _ = slots.mark_prefill_completed(&"req-3".to_string(), decay_now());
let _ = slots.free(&"req-3".to_string(), decay_now());
queue.update().await;
assert_eq!(queue.pending_count(), 0, "all requests should be drained");
}
#[tokio::test(start_paused = true)]
async fn test_queue_update_uses_decayed_oldest_prefill_load() {
let estimator: Arc<dyn PrefillLoadEstimator> = Arc::new(FixedPrefillLoadEstimator {
duration: Duration::from_secs(10),
});
let (queue, _slots, _cfg_tx) =
make_queue_with_sender(1, 16, 100, Some(0.5), Some(estimator));
let (req1, rx1) = make_request("req-1", 100);
queue.enqueue(req1).await;
let _ = rx1.await.unwrap().unwrap();
let (req2, mut rx2) = make_request("req-2", 100);
queue.enqueue(req2).await;
assert_eq!(queue.pending_count(), 1);
tokio::time::advance(Duration::from_secs(6)).await;
queue.update().await;
let scheduled = rx2
.try_recv()
.expect("queued request should have been scheduled");
let response = scheduled.expect("scheduling returned error");
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(queue.pending_count(), 0);
}
#[tokio::test]
async fn test_no_workers_returns_error() {
let (queue, _slots) = make_queue(0, 16, 512, None);
......@@ -542,7 +651,7 @@ mod tests {
let isl = 512;
// Start with zero workers (mimics skip_initial_worker_wait=true)
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None);
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Routing with no workers must fail
let (req_fail, rx_fail) = make_request("before-register", isl);
......@@ -590,9 +699,11 @@ mod tests {
// Clean up
slots
.mark_prefill_completed(&"after-register".to_string())
.mark_prefill_completed(&"after-register".to_string(), decay_now())
.unwrap();
slots
.free(&"after-register".to_string(), decay_now())
.unwrap();
slots.free(&"after-register".to_string()).unwrap();
}
/// Register_workers is additive: calling with a new set does NOT remove old workers.
......@@ -601,7 +712,7 @@ mod tests {
let block_size = 16;
let isl = 256;
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None);
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Register worker 10 in slots and config
let mut dp1 = std::collections::HashMap::new();
......@@ -643,8 +754,8 @@ mod tests {
.expect("oneshot dropped")
.expect("scheduling failed");
seen.insert(resp.best_worker.worker_id);
slots.mark_prefill_completed(&req_id).unwrap();
slots.free(&req_id).unwrap();
slots.mark_prefill_completed(&req_id, decay_now()).unwrap();
slots.free(&req_id, decay_now()).unwrap();
}
assert!(
......@@ -659,7 +770,7 @@ mod tests {
let block_size = 16;
let isl = 256;
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None);
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Register three workers
let mut dp = std::collections::HashMap::new();
......@@ -712,9 +823,9 @@ mod tests {
resp.best_worker.worker_id
);
slots
.mark_prefill_completed(&"filter-0".to_string())
.mark_prefill_completed(&"filter-0".to_string(), decay_now())
.unwrap();
slots.free(&"filter-0".to_string()).unwrap();
slots.free(&"filter-0".to_string(), decay_now()).unwrap();
}
#[tokio::test(flavor = "multi_thread")]
......@@ -727,7 +838,7 @@ mod tests {
let _resp1 = rx1.await.unwrap().unwrap();
assert_eq!(
slots
.active_tokens()
.active_tokens(decay_now())
.get(&WorkerWithDpRank::new(0, 0))
.copied(),
Some(0)
......@@ -738,9 +849,9 @@ mod tests {
let _resp2 = rx2.await.unwrap().unwrap();
assert_eq!(queue.pending_count(), 0);
let _ = slots.mark_prefill_completed(&"req-1".to_string());
let _ = slots.free(&"req-1".to_string());
let _ = slots.mark_prefill_completed(&"req-2".to_string());
let _ = slots.free(&"req-2".to_string());
let _ = slots.mark_prefill_completed(&"req-1".to_string(), decay_now());
let _ = slots.free(&"req-1".to_string(), decay_now());
let _ = slots.mark_prefill_completed(&"req-2".to_string(), decay_now());
let _ = slots.free(&"req-2".to_string(), decay_now());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::SequenceHash;
use std::collections::HashMap;
use std::sync::{Arc, Weak};
#[derive(Debug, Default)]
pub(super) struct BlockTracker {
pub(super) unique_blocks: HashMap<SequenceHash, Weak<()>>,
pub(super) fractional_blocks: HashMap<SequenceHash, f64>,
}
impl BlockTracker {
pub(super) fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc
}
pub(super) fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
}
pub(super) fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
count = count - 1.0 + frac;
}
}
count.round() as usize
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod block_tracker;
pub mod multi_worker;
mod prefill_tracker;
pub mod single;
pub use multi_worker::*;
......
......@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken;
use super::single::{ActiveSequences, RequestId};
use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank,
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint,
WorkerWithDpRank,
};
// How often we force expire stale requests across all workers. See the comment
......@@ -93,6 +94,7 @@ pub struct SequenceRequest {
pub overlap: u32,
pub track_prefill_tokens: bool,
pub expected_output_tokens: Option<u32>,
pub prefill_load_hint: Option<PrefillLoadHint>,
pub worker: WorkerWithDpRank,
pub lora_name: Option<String>,
}
......@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return;
}
// TODO: Publish explicit prompt-load decay timestamps with these events so peer routers
// can mirror the same oldest-prefill anchor instead of approximating from receive time.
let publisher = Arc::clone(&self.publisher);
tokio::spawn(async move {
if let Err(e) = publisher.publish_event(&event).await {
......@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
continue;
}
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time.
let decay_now = Instant::now();
match &event.data {
ActiveSequenceEventData::AddRequest {
token_sequence,
......@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
} => {
self.request_to_worker
.insert(event.request_id.clone(), event.worker);
......@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
*overlap,
*expected_output_tokens,
*track_prefill_tokens,
*prefill_load_hint,
decay_now,
);
} else {
tracing::warn!(
......@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
{
let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) {
table.slots[idx].1.write().free(&event.request_id);
table.slots[idx].1.write().free(&event.request_id, decay_now);
}
}
self.request_to_lora.remove(&event.request_id);
......@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
table.slots[idx]
.1
.write()
.mark_prefill_completed(&event.request_id);
.mark_prefill_completed(&event.request_id, decay_now);
}
}
}
......@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
fn add_request_local(&self, req: SequenceRequest) -> Result<(), SequenceError> {
fn add_request_local(
&self,
req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
let SequenceRequest {
request_id,
token_sequence,
......@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap,
track_prefill_tokens,
expected_output_tokens,
prefill_load_hint,
worker,
lora_name,
} = req;
......@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap,
expected_output_tokens,
track_prefill_tokens,
prefill_load_hint,
decay_now,
)
};
......@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(expired_id);
}
self.publish_active_load_for_worker(worker);
self.publish_active_load_for_worker(worker, decay_now);
Ok(())
}
pub fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
pub fn add_request(
&self,
req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
self.spawn_publish_event(ActiveSequenceEvent {
request_id: req.request_id.clone(),
worker: req.worker,
......@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens,
prefill_load_hint: req.prefill_load_hint,
},
router_id: self.router_id,
lora_name: req.lora_name.clone(),
});
self.add_request_local(req)
self.add_request_local(req, decay_now)
}
/// Send a mutation to the worker assigned to a request, optionally publishing
......@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn mutate_request_worker_local(
&self,
request_id: &RequestId,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
decay_now: Instant,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
......@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write();
mutate_fn(&mut seq, request_id);
mutate_fn(&mut seq, request_id, decay_now);
}
if remove_mapping {
......@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(request_id);
}
self.publish_active_load_for_worker(worker);
self.publish_active_load_for_worker(worker, decay_now);
Ok(())
}
......@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn mutate_request_worker(
&self,
request_id: &RequestId,
decay_now: Instant,
event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
......@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
lora_name,
});
self.mutate_request_worker_local(request_id, mutate_fn, remove_mapping)
self.mutate_request_worker_local(request_id, decay_now, mutate_fn, remove_mapping)
}
/// Free all blocks associated with a request.
......@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
pub fn free(&self, request_id: &RequestId, decay_now: Instant) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
......@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.mutate_request_worker(
request_id,
decay_now,
ActiveSequenceEventData::Free,
|seqs, rid| {
seqs.free(rid);
|seqs, rid, decay_now| {
seqs.free(rid, decay_now);
},
true,
)
......@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent).
pub fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<(), SequenceError> {
pub fn mark_prefill_completed(
&self,
request_id: &RequestId,
decay_now: Instant,
) -> Result<(), SequenceError> {
self.mutate_request_worker(
request_id,
decay_now,
ActiveSequenceEventData::MarkPrefillCompleted,
|seqs, rid| {
seqs.mark_prefill_completed(rid);
|seqs, rid, decay_now| {
seqs.mark_prefill_completed(rid, decay_now);
},
false,
)
......@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
});
}
self.publish_active_load_for_worker(worker);
self.publish_active_load_for_worker(worker, Instant::now());
Ok(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank, decay_now: Instant) {
let (active_blocks, active_tokens) = {
let table = self.workers.read();
let Some(&idx) = table.index.get(&worker) else {
......@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return;
};
let seq = table.slots[idx].1.read();
(seq.active_blocks(), seq.active_tokens())
(seq.active_blocks(), seq.active_tokens(decay_now))
};
self.publisher
......@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: OverlapScores,
decay_now: Instant,
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlaps, true)
self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlaps,
true,
decay_now,
)
}
pub fn potential_blocks_and_tokens_with_prefill_tracking(
......@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl: usize,
overlaps: OverlapScores,
track_prefill_tokens: bool,
decay_now: Instant,
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
......@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl,
overlap,
track_prefill_tokens,
decay_now,
);
potential_blocks.insert(*worker, blocks);
potential_tokens.insert(*worker, tokens);
......@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
/// Query all workers for their current number of active tokens.
pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
pub fn active_tokens(&self, decay_now: Instant) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read();
let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots {
results.insert(*worker, lock.read().active_tokens());
results.insert(*worker, lock.read().active_tokens(decay_now));
}
results
}
......@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Return true if any worker satisfies the provided predicate on active token count.
pub fn any_worker_matches_active_tokens(
&self,
decay_now: Instant,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool {
let table = self.workers.read();
for (worker, lock) in &table.slots {
if predicate(*worker, lock.read().active_tokens()) {
if predicate(*worker, lock.read().active_tokens(decay_now)) {
return true;
}
}
......@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(expired_id);
removed_request_count += 1;
}
self.publish_active_load_for_worker(*worker);
self.publish_active_load_for_worker(*worker, now);
}
}
let duration = now.elapsed();
......@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::time::Duration;
use super::*;
use crate::protocols::{OverlapScores, PrefillLoadHint};
use crate::test_utils::NoopSequencePublisher;
fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
......@@ -854,20 +896,74 @@ mod tests {
async fn add_request_can_skip_prefill_token_tracking() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let decay_now = Instant::now();
sequences
.add_request(SequenceRequest {
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
prefill_load_hint: None,
worker,
lora_name: None,
})
},
decay_now,
)
.unwrap();
assert_eq!(
sequences.active_tokens(decay_now).get(&worker).copied(),
Some(0)
);
}
#[test]
fn explicit_decay_time_drives_multi_worker_load_queries_consistently() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let start = Instant::now();
sequences
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 100,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: Some(PrefillLoadHint {
initial_effective_prefill_tokens: 100,
expected_prefill_duration: Some(Duration::from_secs(10)),
}),
worker,
lora_name: None,
},
start,
)
.unwrap();
assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0));
let decay_now = start + Duration::from_secs(5);
let active_tokens = sequences.active_tokens(decay_now);
assert_eq!(active_tokens.get(&worker).copied(), Some(50));
let (_, potential_tokens) = sequences.potential_blocks_and_tokens_with_prefill_tracking(
None,
0,
OverlapScores::default(),
false,
decay_now,
);
assert_eq!(potential_tokens.get(&worker).copied(), Some(50));
assert!(
sequences.any_worker_matches_active_tokens(decay_now, |candidate, tokens| {
candidate == worker && tokens == 50
})
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment