Unverified Commit 95a750f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore(replay): refactor offline components into cleaner lanes (#7866)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 210bbf5d
...@@ -8,15 +8,17 @@ ...@@ -8,15 +8,17 @@
//! predictions without knowing about PyO3. //! predictions without knowing about PyO3.
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use pyo3::prelude::*; use pyo3::prelude::*;
use dynamo_kv_router::PrefillLoadEstimator;
use dynamo_mocker::common::perf_model::AicCallback; use dynamo_mocker::common::perf_model::AicCallback;
/// Wraps a Python AIC InferenceSession for direct calls from Rust. /// Wraps a Python AIC InferenceSession for direct calls from Rust.
/// ///
/// The Python object must expose: /// The Python object must expose:
/// - `predict_prefill(batch_size, isl, prefix, osl) -> float` /// - `predict_prefill(batch_size, effective_isl, prefix) -> float`
/// - `predict_decode(batch_size, isl, osl) -> float` /// - `predict_decode(batch_size, isl, osl) -> float`
pub(super) struct PyAicCallback { pub(super) struct PyAicCallback {
pub(super) session: PyObject, pub(super) session: PyObject,
...@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback { ...@@ -26,15 +28,26 @@ pub(super) struct PyAicCallback {
unsafe impl Send for PyAicCallback {} unsafe impl Send for PyAicCallback {}
unsafe impl Sync for PyAicCallback {} unsafe impl Sync for PyAicCallback {}
impl AicCallback for PyAicCallback { impl PyAicCallback {
fn predict_prefill(&self, batch_size: usize, isl: usize, prefix: usize, osl: usize) -> f64 { fn predict_prefill_ms(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> PyResult<f64> {
Python::with_gil(|py| { Python::with_gil(|py| {
self.session self.session
.call_method1(py, "predict_prefill", (batch_size, isl, prefix, osl)) .call_method1(py, "predict_prefill", (batch_size, effective_isl, prefix))
.and_then(|r| r.extract::<f64>(py)) .and_then(|result| result.extract::<f64>(py))
.unwrap_or_else(|e| panic!("AIC predict_prefill failed: {e}"))
}) })
} }
}
impl AicCallback for PyAicCallback {
fn predict_prefill(&self, batch_size: usize, effective_isl: usize, prefix: usize) -> f64 {
self.predict_prefill_ms(batch_size, effective_isl, prefix)
.unwrap_or_else(|e| panic!("AIC predict_prefill failed: {e}"))
}
fn predict_decode(&self, batch_size: usize, isl: usize, osl: usize) -> f64 { fn predict_decode(&self, batch_size: usize, isl: usize, osl: usize) -> f64 {
Python::with_gil(|py| { Python::with_gil(|py| {
...@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback { ...@@ -46,6 +59,18 @@ impl AicCallback for PyAicCallback {
} }
} }
impl PrefillLoadEstimator for PyAicCallback {
fn predict_prefill_duration(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> anyhow::Result<Duration> {
let latency_ms = self.predict_prefill_ms(batch_size, effective_isl, prefix)?;
Ok(Duration::from_secs_f64(latency_ms / 1000.0))
}
}
/// Initialize an AIC callback by importing and calling the Python setup function. /// Initialize an AIC callback by importing and calling the Python setup function.
/// ///
/// Called once at mocker startup when `--aic-perf-model` is requested. /// Called once at mocker startup when `--aic-perf-model` is requested.
...@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback( ...@@ -61,7 +86,7 @@ pub(super) fn create_aic_callback(
moe_ep_size: Option<usize>, moe_ep_size: Option<usize>,
attention_dp_size: Option<usize>, attention_dp_size: Option<usize>,
) -> PyResult<Arc<dyn AicCallback>> { ) -> PyResult<Arc<dyn AicCallback>> {
let module = py.import("dynamo.mocker.aic_session")?; let module = py.import("dynamo._internal.aic")?;
let session = module.call_method1( let session = module.call_method1(
"create_session", "create_session",
( (
...@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback( ...@@ -79,3 +104,21 @@ pub(super) fn create_aic_callback(
session: session.into(), session: session.into(),
})) }))
} }
pub(super) fn create_aic_prefill_load_estimator(
py: Python<'_>,
backend_name: &str,
system: &str,
model_path: &str,
tp_size: usize,
backend_version: Option<&str>,
) -> PyResult<Arc<dyn PrefillLoadEstimator>> {
let module = py.import("dynamo._internal.aic")?;
let session = module.call_method1(
"create_session",
(backend_name, system, model_path, tp_size, backend_version),
)?;
Ok(Arc::new(PyAicCallback {
session: session.into(),
}))
}
...@@ -10,7 +10,9 @@ use std::sync::Arc; ...@@ -10,7 +10,9 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, exceptions::PyValueError, prelude::*}; use pyo3::{exceptions::PyException, exceptions::PyValueError, prelude::*};
use pyo3_async_runtimes::TaskLocals; use pyo3_async_runtimes::TaskLocals;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig; use dynamo_kv_router::config::{
KvRouterConfig as RsKvRouterConfig, RouterPrefillLoadModel as RsRouterPrefillLoadModel,
};
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig; use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
use dynamo_llm::entrypoint::ChatEngineFactoryCallback; use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
...@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard; ...@@ -23,7 +25,7 @@ use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use dynamo_mocker::common::perf_model::PerfModel; use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback; use super::aic_callback::{create_aic_callback, create_aic_prefill_load_estimator};
use super::replay::MockEngineArgs as PyMockEngineArgs; use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs; use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId; use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
...@@ -55,10 +57,76 @@ impl KvRouterConfig { ...@@ -55,10 +57,76 @@ impl KvRouterConfig {
} }
} }
#[pyclass]
#[derive(Clone, Debug)]
pub struct AicPerfConfig {
aic_backend: String,
aic_system: String,
aic_backend_version: Option<String>,
aic_tp_size: usize,
aic_model_path: String,
}
impl AicPerfConfig {
pub(crate) fn backend_name(&self) -> &str {
&self.aic_backend
}
pub(crate) fn system(&self) -> &str {
&self.aic_system
}
pub(crate) fn backend_version(&self) -> Option<&str> {
self.aic_backend_version.as_deref()
}
pub(crate) fn tp_size(&self) -> usize {
self.aic_tp_size
}
pub(crate) fn model_path(&self) -> &str {
&self.aic_model_path
}
}
#[pymethods]
impl AicPerfConfig {
#[new]
#[pyo3(signature = (aic_backend, aic_system, aic_model_path, aic_tp_size=1, aic_backend_version=None))]
fn new(
aic_backend: String,
aic_system: String,
aic_model_path: String,
aic_tp_size: usize,
aic_backend_version: Option<String>,
) -> PyResult<Self> {
if aic_backend.is_empty() {
return Err(PyValueError::new_err("aic_backend must be non-empty"));
}
if aic_system.is_empty() {
return Err(PyValueError::new_err("aic_system must be non-empty"));
}
if aic_model_path.is_empty() {
return Err(PyValueError::new_err("aic_model_path must be non-empty"));
}
if aic_tp_size == 0 {
return Err(PyValueError::new_err("aic_tp_size must be >= 1"));
}
Ok(Self {
aic_backend,
aic_system,
aic_backend_version,
aic_tp_size,
aic_model_path,
})
}
}
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", remote_indexer_component=None))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_prefill_load_model="none", router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -70,6 +138,7 @@ impl KvRouterConfig { ...@@ -70,6 +138,7 @@ impl KvRouterConfig {
router_track_output_blocks: bool, router_track_output_blocks: bool,
router_assume_kv_reuse: bool, router_assume_kv_reuse: bool,
router_track_prefill_tokens: bool, router_track_prefill_tokens: bool,
router_prefill_load_model: &str,
router_snapshot_threshold: Option<u32>, router_snapshot_threshold: Option<u32>,
router_reset_states: bool, router_reset_states: bool,
router_ttl_secs: f64, router_ttl_secs: f64,
...@@ -91,6 +160,11 @@ impl KvRouterConfig { ...@@ -91,6 +160,11 @@ impl KvRouterConfig {
router_track_output_blocks, router_track_output_blocks,
router_assume_kv_reuse, router_assume_kv_reuse,
router_track_prefill_tokens, router_track_prefill_tokens,
router_prefill_load_model: router_prefill_load_model
.parse::<RsRouterPrefillLoadModel>()
.unwrap_or_else(|_| {
panic!("invalid router_prefill_load_model: {router_prefill_load_model:?}")
}),
router_snapshot_threshold, router_snapshot_threshold,
router_reset_states, router_reset_states,
router_ttl_secs, router_ttl_secs,
...@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs { ...@@ -249,13 +323,14 @@ pub(crate) struct EntrypointArgs {
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>, chat_engine_factory: Option<PyEngineFactory>,
aic_perf_config: Option<AicPerfConfig>,
} }
#[pymethods] #[pymethods]
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None, aic_perf_config=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -279,6 +354,7 @@ impl EntrypointArgs { ...@@ -279,6 +354,7 @@ impl EntrypointArgs {
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyObject>, chat_engine_factory: Option<PyObject>,
aic_perf_config: Option<AicPerfConfig>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from); let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none()) if (tls_cert_path.is_some() && tls_key_path.is_none())
...@@ -327,6 +403,7 @@ impl EntrypointArgs { ...@@ -327,6 +403,7 @@ impl EntrypointArgs {
is_prefill, is_prefill,
migration_limit, migration_limit,
chat_engine_factory, chat_engine_factory,
aic_perf_config,
}) })
} }
} }
...@@ -467,9 +544,26 @@ async fn select_engine( ...@@ -467,9 +544,26 @@ async fn select_engine(
EngineType::Dynamic => { EngineType::Dynamic => {
// Convert Python chat engine factory to Rust callback // Convert Python chat engine factory to Rust callback
let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback); let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
let prefill_load_estimator = args
.aic_perf_config
.as_ref()
.map(|config| {
Python::with_gil(|py| {
create_aic_prefill_load_estimator(
py,
config.backend_name(),
config.system(),
config.model_path(),
config.tp_size(),
config.backend_version(),
)
})
})
.transpose()?;
RsEngineConfig::Dynamic { RsEngineConfig::Dynamic {
model: Box::new(local_model), model: Box::new(local_model),
chat_engine_factory, chat_engine_factory,
prefill_load_estimator,
} }
} }
EngineType::Mocker => { EngineType::Mocker => {
......
...@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker; ...@@ -30,6 +30,9 @@ use llm_rs::protocols::common::timing::RequestTracker;
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json; use serde_json::json;
use super::aic_callback::create_aic_prefill_load_estimator;
use super::entrypoint::AicPerfConfig;
fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> { fn depythonize_block_mm_infos(obj: &Bound<'_, PyAny>) -> PyResult<Vec<Option<BlockExtraInfo>>> {
depythonize(obj).map_err(to_pyerr) depythonize(obj).map_err(to_pyerr)
} }
...@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint( ...@@ -703,6 +706,7 @@ async fn create_kv_router_from_endpoint(
endpoint: &Endpoint, endpoint: &Endpoint,
block_size: usize, block_size: usize,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> { ) -> Result<Arc<llm_rs::kv_router::KvRouter>, PyErr> {
// Create ModelManager and use it to create KvRouter (ensures registration) // Create ModelManager and use it to create KvRouter (ensures registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new()); let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
...@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint( ...@@ -766,6 +770,7 @@ async fn create_kv_router_from_endpoint(
&endpoint.inner, &endpoint.inner,
block_size as u32, block_size as u32,
kv_router_config, kv_router_config,
prefill_load_estimator,
worker_type, worker_type,
model_name, model_name,
enable_eagle, enable_eagle,
...@@ -888,12 +893,29 @@ impl KvRouter { ...@@ -888,12 +893,29 @@ impl KvRouter {
/// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component /// Note: Worker type for Prometheus metrics is inferred from the endpoint name/component
/// (contains "prefill") or by `router_track_active_blocks` being disabled. /// (contains "prefill") or by `router_track_active_blocks` being disabled.
#[new] #[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config))] #[pyo3(signature = (endpoint, block_size, kv_router_config, aic_perf_config=None))]
fn new( fn new(
endpoint: &Endpoint, endpoint: &Endpoint,
block_size: usize, block_size: usize,
kv_router_config: &super::entrypoint::KvRouterConfig, kv_router_config: &super::entrypoint::KvRouterConfig,
aic_perf_config: Option<&AicPerfConfig>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let prefill_load_estimator = aic_perf_config
.map(|config| {
Python::with_gil(|py| {
create_aic_prefill_load_estimator(
py,
config.backend_name(),
config.system(),
config.model_path(),
config.tp_size(),
config.backend_version(),
)
})
})
.transpose()
.map_err(to_pyerr)?;
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move { runtime.block_on(async move {
let client = endpoint.inner.client().await.map_err(to_pyerr)?; let client = endpoint.inner.client().await.map_err(to_pyerr)?;
...@@ -916,6 +938,7 @@ impl KvRouter { ...@@ -916,6 +938,7 @@ impl KvRouter {
endpoint, endpoint,
block_size, block_size,
Some(kv_router_config.inner()), Some(kv_router_config.inner()),
prefill_load_estimator,
) )
.await?; .await?;
......
...@@ -19,8 +19,8 @@ use pythonize::pythonize; ...@@ -19,8 +19,8 @@ use pythonize::pythonize;
use serde_json::json; use serde_json::json;
use uuid::Uuid; use uuid::Uuid;
use super::aic_callback::create_aic_callback; use super::aic_callback::{create_aic_callback, create_aic_prefill_load_estimator};
use super::entrypoint::{KvRouterConfig, to_pyerr}; use super::entrypoint::{AicPerfConfig, KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> { fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type { match engine_type {
...@@ -526,7 +526,7 @@ impl MockEngineArgs { ...@@ -526,7 +526,7 @@ impl MockEngineArgs {
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))] #[pyo3(signature = (trace_file, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, aic_perf_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, trace_block_size=512))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay( pub fn run_mocker_trace_replay(
py: Python<'_>, py: Python<'_>,
...@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay( ...@@ -535,6 +535,7 @@ pub fn run_mocker_trace_replay(
prefill_engine_args: Option<MockEngineArgs>, prefill_engine_args: Option<MockEngineArgs>,
decode_engine_args: Option<MockEngineArgs>, decode_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
num_workers: usize, num_workers: usize,
num_prefill_workers: usize, num_prefill_workers: usize,
num_decode_workers: usize, num_decode_workers: usize,
...@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay( ...@@ -542,6 +543,7 @@ pub fn run_mocker_trace_replay(
replay_mode: &str, replay_mode: &str,
router_mode: &str, router_mode: &str,
arrival_speedup_ratio: f64, arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<PyObject> { ) -> PyResult<PyObject> {
let args_selection = load_replay_args_selection( let args_selection = load_replay_args_selection(
py, py,
...@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay( ...@@ -552,9 +554,15 @@ pub fn run_mocker_trace_replay(
num_prefill_workers, num_prefill_workers,
num_decode_workers, num_decode_workers,
)?; )?;
let router_mode = parse_replay_router_mode(router_mode)?;
let prefill_load_estimator = load_replay_prefill_load_estimator(
py,
router_mode,
router_config.as_ref(),
aic_perf_config,
)?;
let router_config = load_replay_router_config(router_config); let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned(); let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || { let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?; let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
...@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay( ...@@ -565,7 +573,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode( dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
max_in_flight, max_in_flight,
num_workers, num_workers,
router_mode, router_mode,
...@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay( ...@@ -575,7 +585,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_file_with_router_mode( dynamo_mocker::replay::simulate_trace_file_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
router_mode, router_mode,
...@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay( ...@@ -585,7 +597,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode( dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
max_in_flight, max_in_flight,
num_workers, num_workers,
router_mode, router_mode,
...@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay( ...@@ -595,7 +609,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_live_file_with_router_mode( dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
router_mode, router_mode,
...@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay( ...@@ -613,7 +629,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_concurrency_file_disagg_with_router_mode( dynamo_mocker::replay::simulate_concurrency_file_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
max_in_flight, max_in_flight,
router_mode, router_mode,
) )
...@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay( ...@@ -622,7 +640,9 @@ pub fn run_mocker_trace_replay(
dynamo_mocker::replay::simulate_trace_file_disagg_with_router_mode( dynamo_mocker::replay::simulate_trace_file_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
&trace_file, &trace_file,
trace_block_size,
arrival_speedup_ratio, arrival_speedup_ratio,
router_mode, router_mode,
) )
...@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay( ...@@ -642,7 +662,7 @@ pub fn run_mocker_trace_replay(
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0, turns_per_session=1, shared_prefix_ratio=0.0, num_prefix_groups=0, inter_turn_delay_ms=0.0))] #[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, prefill_engine_args=None, decode_engine_args=None, router_config=None, aic_perf_config=None, num_workers=1, num_prefill_workers=1, num_decode_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0, turns_per_session=1, shared_prefix_ratio=0.0, num_prefix_groups=0, inter_turn_delay_ms=0.0))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay( pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>, py: Python<'_>,
...@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -653,6 +673,7 @@ pub fn run_mocker_synthetic_trace_replay(
prefill_engine_args: Option<MockEngineArgs>, prefill_engine_args: Option<MockEngineArgs>,
decode_engine_args: Option<MockEngineArgs>, decode_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
num_workers: usize, num_workers: usize,
num_prefill_workers: usize, num_prefill_workers: usize,
num_decode_workers: usize, num_decode_workers: usize,
...@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -675,9 +696,15 @@ pub fn run_mocker_synthetic_trace_replay(
num_prefill_workers, num_prefill_workers,
num_decode_workers, num_decode_workers,
)?; )?;
let router_mode = parse_replay_router_mode(router_mode)?;
let prefill_load_estimator = load_replay_prefill_load_estimator(
py,
router_mode,
router_config.as_ref(),
aic_perf_config,
)?;
let router_config = load_replay_router_config(router_config); let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned(); let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let block_size = match &args_selection { let block_size = match &args_selection {
ReplayArgsSelection::Aggregated(args) => args.block_size.max(1), ReplayArgsSelection::Aggregated(args) => args.block_size.max(1),
ReplayArgsSelection::Disagg(config) => config.prefill_args.block_size.max(1), ReplayArgsSelection::Disagg(config) => config.prefill_args.block_size.max(1),
...@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -712,6 +739,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_workload_with_router_mode( dynamo_mocker::replay::simulate_concurrency_workload_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -722,6 +750,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_workload_with_router_mode( dynamo_mocker::replay::simulate_trace_workload_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
num_workers, num_workers,
router_mode, router_mode,
...@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -731,6 +760,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_workload_with_router_mode( dynamo_mocker::replay::simulate_concurrency_live_workload_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -741,6 +771,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_live_workload_with_router_mode( dynamo_mocker::replay::simulate_trace_live_workload_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
num_workers, num_workers,
router_mode, router_mode,
...@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -756,6 +787,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", Some(max_in_flight)) => dynamo_mocker::replay::simulate_concurrency_workload_disagg_with_router_mode( ("offline", Some(max_in_flight)) => dynamo_mocker::replay::simulate_concurrency_workload_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
max_in_flight, max_in_flight,
router_mode, router_mode,
...@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -763,6 +795,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", None) => dynamo_mocker::replay::simulate_trace_workload_disagg_with_router_mode( ("offline", None) => dynamo_mocker::replay::simulate_trace_workload_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
trace, trace,
router_mode, router_mode,
), ),
...@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -793,6 +826,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode( dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -802,6 +836,7 @@ pub fn run_mocker_synthetic_trace_replay(
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode( ("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -811,6 +846,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode( dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
max_in_flight, max_in_flight,
num_workers, num_workers,
...@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -821,6 +857,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode( dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
*args, *args,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
num_workers, num_workers,
arrival_speedup_ratio, arrival_speedup_ratio,
...@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -838,6 +875,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_concurrency_requests_disagg_with_router_mode( dynamo_mocker::replay::simulate_concurrency_requests_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
max_in_flight, max_in_flight,
router_mode, router_mode,
...@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay( ...@@ -847,6 +885,7 @@ pub fn run_mocker_synthetic_trace_replay(
dynamo_mocker::replay::simulate_trace_requests_disagg_with_router_mode( dynamo_mocker::replay::simulate_trace_requests_disagg_with_router_mode(
*config, *config,
router_config.clone(), router_config.clone(),
prefill_load_estimator.clone(),
requests, requests,
arrival_speedup_ratio, arrival_speedup_ratio,
router_mode, router_mode,
...@@ -970,6 +1009,57 @@ fn load_replay_router_config( ...@@ -970,6 +1009,57 @@ fn load_replay_router_config(
router_config.map(|config| config.inner()) router_config.map(|config| config.inner())
} }
fn load_replay_prefill_load_estimator(
py: Python<'_>,
router_mode: dynamo_mocker::replay::ReplayRouterMode,
router_config: Option<&KvRouterConfig>,
aic_perf_config: Option<&AicPerfConfig>,
) -> PyResult<Option<dynamo_mocker::replay::ReplayPrefillLoadEstimator>> {
if router_mode != dynamo_mocker::replay::ReplayRouterMode::KvRouter {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_mode='kv_router'",
));
}
return Ok(None);
}
let Some(router_config) = router_config else {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_config with router_prefill_load_model='aic'",
));
}
return Ok(None);
};
let router_config = router_config.inner();
if !router_config.router_prefill_load_model.is_enabled() {
if aic_perf_config.is_some() {
return Err(PyException::new_err(
"aic_perf_config requires router_prefill_load_model='aic'",
));
}
return Ok(None);
}
let Some(aic_perf_config) = aic_perf_config else {
return Err(PyException::new_err(
"router_prefill_load_model='aic' requires aic_perf_config",
));
};
create_aic_prefill_load_estimator(
py,
aic_perf_config.backend_name(),
aic_perf_config.system(),
aic_perf_config.model_path(),
aic_perf_config.tp_size(),
aic_perf_config.backend_version(),
)
.map(Some)
}
fn parse_replay_router_mode( fn parse_replay_router_mode(
router_mode: &str, router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> { ) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
......
...@@ -1159,6 +1159,17 @@ class RouterConfig: ...@@ -1159,6 +1159,17 @@ class RouterConfig:
""" """
... ...
class AicPerfConfig:
def __init__(
self,
aic_backend: str,
aic_system: str,
aic_model_path: str,
aic_tp_size: int = 1,
aic_backend_version: Optional[str] = None,
) -> None:
...
class KvRouterConfig: class KvRouterConfig:
"""Values for KV router""" """Values for KV router"""
...@@ -1172,6 +1183,8 @@ class KvRouterConfig: ...@@ -1172,6 +1183,8 @@ class KvRouterConfig:
router_track_active_blocks: bool = True, router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False, router_track_output_blocks: bool = False,
router_assume_kv_reuse: bool = True, router_assume_kv_reuse: bool = True,
router_track_prefill_tokens: bool = True,
router_prefill_load_model: str = "none",
router_snapshot_threshold: Optional[int] = 1000000, router_snapshot_threshold: Optional[int] = 1000000,
router_reset_states: bool = False, router_reset_states: bool = False,
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
...@@ -1199,6 +1212,10 @@ class KvRouterConfig: ...@@ -1199,6 +1212,10 @@ class KvRouterConfig:
sequence length (agent_hints.osl in nvext). sequence length (agent_hints.osl in nvext).
router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True). router_assume_kv_reuse: Assume KV cache reuse when tracking active blocks (default: True).
When True, computes actual block hashes. When False, generates random hashes. When True, computes actual block hashes. When False, generates random hashes.
router_track_prefill_tokens: Include prompt-side prefill tokens in active load accounting (default: True).
router_prefill_load_model: Prompt-side prefill load model (default: "none").
"none" keeps static prompt load accounting.
"aic" decays the oldest active prefill request using AIC-predicted duration.
router_snapshot_threshold: Number of messages before snapshot (default: 1000000) router_snapshot_threshold: Number of messages before snapshot (default: 1000000)
router_reset_states: Reset router state on startup (default: False) router_reset_states: Reset router state on startup (default: False)
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
...@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay( ...@@ -1516,6 +1533,7 @@ def run_mocker_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None, prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None, decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None, router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1, num_workers: int = 1,
num_prefill_workers: int = 1, num_prefill_workers: int = 1,
num_decode_workers: int = 1, num_decode_workers: int = 1,
...@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay( ...@@ -1523,6 +1541,7 @@ def run_mocker_trace_replay(
replay_mode: Literal["offline", "online"] = "offline", replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin", router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0, arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs.""" """Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
... ...
...@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay( ...@@ -1535,6 +1554,7 @@ def run_mocker_synthetic_trace_replay(
prefill_engine_args: Optional[MockEngineArgs] = None, prefill_engine_args: Optional[MockEngineArgs] = None,
decode_engine_args: Optional[MockEngineArgs] = None, decode_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None, router_config: Optional[KvRouterConfig] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
num_workers: int = 1, num_workers: int = 1,
num_prefill_workers: int = 1, num_prefill_workers: int = 1,
num_decode_workers: int = 1, num_decode_workers: int = 1,
...@@ -1779,6 +1799,7 @@ class KvRouter: ...@@ -1779,6 +1799,7 @@ class KvRouter:
endpoint: Endpoint, endpoint: Endpoint,
block_size: int, block_size: int,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None: ) -> None:
""" """
Create a new KvRouter instance. Create a new KvRouter instance.
...@@ -1787,6 +1808,7 @@ class KvRouter: ...@@ -1787,6 +1808,7 @@ class KvRouter:
endpoint: The endpoint to connect to for routing requests endpoint: The endpoint to connect to for routing requests
block_size: The KV cache block size block_size: The KV cache block size
kv_router_config: Configuration for the KV router kv_router_config: Configuration for the KV router
aic_perf_config: Optional AIC perf-model config for effective prefill load tracking
""" """
... ...
...@@ -1998,6 +2020,7 @@ class EntrypointArgs: ...@@ -1998,6 +2020,7 @@ class EntrypointArgs:
is_prefill: bool = False, is_prefill: bool = False,
migration_limit: int = 0, migration_limit: int = 0,
chat_engine_factory: Optional[Callable] = None, chat_engine_factory: Optional[Callable] = None,
aic_perf_config: Optional[AicPerfConfig] = None,
) -> None: ) -> None:
""" """
Create EntrypointArgs. Create EntrypointArgs.
...@@ -2024,6 +2047,7 @@ class EntrypointArgs: ...@@ -2024,6 +2047,7 @@ class EntrypointArgs:
is_prefill: Whether this is a prefill worker is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled) migration_limit: Maximum number of request migrations (0=disabled)
chat_engine_factory: Optional Python chat completions engine factory callback chat_engine_factory: Optional Python chat completions engine factory callback
aic_perf_config: Optional AIC perf-model configuration for default KV routing
""" """
... ...
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared AIC session helpers used by internal Dynamo integrations."""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
DEFAULT_BACKEND_VERSIONS = {
"vllm": "0.12.0",
"sglang": "0.5.6.post2",
}
DEFAULT_STATIC_STRIDE = 32
def resolve_backend_version(backend_name: str, backend_version: str | None) -> str:
"""Return the pinned backend version used for AIC perf lookups."""
if backend_version is not None:
return backend_version
return DEFAULT_BACKEND_VERSIONS.get(backend_name, DEFAULT_BACKEND_VERSIONS["vllm"])
def _load_aiconfigurator():
try:
from aiconfigurator.sdk import config
from aiconfigurator.sdk.backends.factory import get_backend
from aiconfigurator.sdk.inference_session import InferenceSession
from aiconfigurator.sdk.models import get_model
from aiconfigurator.sdk.perf_database import (
get_database,
get_supported_databases,
)
except (
ImportError
) as exc: # pragma: no cover - exercised in integration environments
raise RuntimeError(
"aiconfigurator is required for AIC perf modeling but is not installed"
) from exc
return {
"config": config,
"get_backend": get_backend,
"InferenceSession": InferenceSession,
"get_model": get_model,
"get_database": get_database,
"get_supported_databases": get_supported_databases,
}
class AicSession:
"""Wrap an AIC InferenceSession with direct prefill/decode predictors."""
def __init__(
self,
backend_name: str,
system: str,
model_path: str,
tp_size: int,
backend_version: str | None = None,
moe_tp_size: int | None = None,
moe_ep_size: int | None = None,
attention_dp_size: int | None = None,
):
aic = _load_aiconfigurator()
version = resolve_backend_version(backend_name, backend_version)
database = aic["get_database"](
system=system, backend=backend_name, version=version
)
if database is None:
supported = (
aic["get_supported_databases"]().get(system, {}).get(backend_name, [])
)
supported_versions = ", ".join(supported) if supported else "<none>"
raise RuntimeError(
"AIC perf database not found for "
f"system={system!r}, backend={backend_name!r}, version={version!r}. "
f"Supported versions for this system/backend: {supported_versions}"
)
model_config = aic["config"].ModelConfig(
tp_size=tp_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size,
attention_dp_size=attention_dp_size or 1,
)
model = aic["get_model"](
model_path=model_path,
model_config=model_config,
backend_name=backend_name,
)
backend = aic["get_backend"](backend_name)
self._session = aic["InferenceSession"](
model=model, database=database, backend=backend
)
self._database = database
self._model = model
self._model_name = getattr(model, "model_name", None) or model_path
logger.info(
"AIC session initialized: backend=%s, system=%s, model=%s, tp=%d",
backend_name,
system,
model_path,
tp_size,
)
def _predict_context_latency(
self, batch_size: int, effective_isl: int, prefix: int
) -> float:
if effective_isl <= 0:
raise ValueError(
f"effective_isl must be positive, got effective_isl={effective_isl}"
)
total_latency = 0.0
for op in self._model.context_ops:
op_name = getattr(op, "_name", "")
x = batch_size if "logits_gemm" in op_name else batch_size * effective_isl
result = op.query(
self._database,
x=x,
batch_size=batch_size,
beam_width=1,
s=effective_isl,
prefix=prefix,
model_name=self._model_name,
seq_imbalance_correction_scale=1.0,
)
total_latency += float(result)
return total_latency
def _predict_generation_latency(self, batch_size: int, isl: int, osl: int) -> float:
if osl <= 1:
return 0.0
effective_batch_size = batch_size * (self._model._nextn + 1)
total_latency = 0.0
for step in range(0, osl - 1, DEFAULT_STATIC_STRIDE):
step_latency = 0.0
for op in self._model.generation_ops:
result = op.query(
self._database,
x=effective_batch_size,
batch_size=effective_batch_size,
beam_width=1,
s=isl + step + 1,
model_name=self._model_name,
gen_seq_imbalance_correction_scale=1.0,
)
step_latency += float(result)
repeat_count = min(DEFAULT_STATIC_STRIDE, osl - 1 - step)
total_latency += step_latency * repeat_count
return total_latency
def predict_prefill(
self, batch_size: int, effective_isl: int, prefix: int
) -> float:
"""Predict prefill latency in ms from uncached tokens and cached prefix."""
return self._predict_context_latency(batch_size, effective_isl, prefix)
def predict_decode(self, batch_size: int, isl: int, osl: int) -> float:
"""Predict decode (generation) latency in ms."""
return self._predict_generation_latency(batch_size, isl, osl)
def create_session(
backend_name: str,
system: str,
model_path: str,
tp_size: int,
backend_version: str | None = None,
moe_tp_size: int | None = None,
moe_ep_size: int | None = None,
attention_dp_size: int | None = None,
) -> AicSession:
"""Factory function called from Rust via PyO3."""
return AicSession(
backend_name,
system,
model_path,
tp_size,
backend_version,
moe_tp_size,
moe_ep_size,
attention_dp_size,
)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import logging import logging
from dynamo._core import AicPerfConfig as AicPerfConfig
from dynamo._core import EngineType from dynamo._core import EngineType
from dynamo._core import EntrypointArgs as EntrypointArgs from dynamo._core import EntrypointArgs as EntrypointArgs
from dynamo._core import FpmEventRelay as FpmEventRelay from dynamo._core import FpmEventRelay as FpmEventRelay
...@@ -57,6 +58,7 @@ def run_mocker_trace_replay( ...@@ -57,6 +58,7 @@ def run_mocker_trace_replay(
replay_concurrency=None, replay_concurrency=None,
router_mode="round_robin", router_mode="round_robin",
arrival_speedup_ratio=1.0, arrival_speedup_ratio=1.0,
trace_block_size=512,
): ):
return _run_mocker_trace_replay( return _run_mocker_trace_replay(
trace_file, trace_file,
...@@ -67,4 +69,5 @@ def run_mocker_trace_replay( ...@@ -67,4 +69,5 @@ def run_mocker_trace_replay(
replay_mode="offline", replay_mode="offline",
router_mode=router_mode, router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio, arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
) )
...@@ -14,6 +14,7 @@ def run_trace_replay( ...@@ -14,6 +14,7 @@ def run_trace_replay(
prefill_engine_args=None, prefill_engine_args=None,
decode_engine_args=None, decode_engine_args=None,
router_config=None, router_config=None,
aic_perf_config=None,
num_workers=1, num_workers=1,
num_prefill_workers=1, num_prefill_workers=1,
num_decode_workers=1, num_decode_workers=1,
...@@ -21,6 +22,7 @@ def run_trace_replay( ...@@ -21,6 +22,7 @@ def run_trace_replay(
replay_mode="offline", replay_mode="offline",
router_mode="round_robin", router_mode="round_robin",
arrival_speedup_ratio=1.0, arrival_speedup_ratio=1.0,
trace_block_size=512,
): ):
return _run_mocker_trace_replay( return _run_mocker_trace_replay(
trace_file, trace_file,
...@@ -28,6 +30,7 @@ def run_trace_replay( ...@@ -28,6 +30,7 @@ def run_trace_replay(
prefill_engine_args=prefill_engine_args, prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args, decode_engine_args=decode_engine_args,
router_config=router_config, router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=num_workers, num_workers=num_workers,
num_prefill_workers=num_prefill_workers, num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers, num_decode_workers=num_decode_workers,
...@@ -35,6 +38,7 @@ def run_trace_replay( ...@@ -35,6 +38,7 @@ def run_trace_replay(
replay_mode=replay_mode, replay_mode=replay_mode,
router_mode=router_mode, router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio, arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
) )
...@@ -47,6 +51,7 @@ def run_synthetic_trace_replay( ...@@ -47,6 +51,7 @@ def run_synthetic_trace_replay(
prefill_engine_args=None, prefill_engine_args=None,
decode_engine_args=None, decode_engine_args=None,
router_config=None, router_config=None,
aic_perf_config=None,
num_workers=1, num_workers=1,
num_prefill_workers=1, num_prefill_workers=1,
num_decode_workers=1, num_decode_workers=1,
...@@ -68,6 +73,7 @@ def run_synthetic_trace_replay( ...@@ -68,6 +73,7 @@ def run_synthetic_trace_replay(
prefill_engine_args=prefill_engine_args, prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args, decode_engine_args=decode_engine_args,
router_config=router_config, router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=num_workers, num_workers=num_workers,
num_prefill_workers=num_prefill_workers, num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers, num_decode_workers=num_decode_workers,
......
...@@ -15,7 +15,7 @@ from typing import Protocol ...@@ -15,7 +15,7 @@ from typing import Protocol
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1") os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs from dynamo.llm import AicPerfConfig, KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
from dynamo.replay.reporting import format_report_table, write_report_json from dynamo.replay.reporting import format_report_table, write_report_json
...@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None): ...@@ -72,6 +72,35 @@ def _load_engine_args(raw_args: str | None):
return MockEngineArgs.from_json(json.dumps(raw)) return MockEngineArgs.from_json(json.dumps(raw))
def _load_aic_perf_config(args: argparse.Namespace):
values = {
"aic_backend": args.aic_backend,
"aic_system": args.aic_system,
"aic_model_path": args.aic_model_path,
"aic_backend_version": args.aic_backend_version,
"aic_tp_size": args.aic_tp_size,
}
if not any(value is not None for value in values.values()):
return None
missing = [
name
for name in ("aic_backend", "aic_system", "aic_model_path")
if values[name] is None
]
if missing:
missing_flags = ", ".join(f"--{name.replace('_', '-')}" for name in missing)
raise ValueError(f"AIC replay modeling requires {missing_flags}")
return AicPerfConfig(
aic_backend=values["aic_backend"],
aic_system=values["aic_system"],
aic_model_path=values["aic_model_path"],
aic_tp_size=values["aic_tp_size"] or 1,
aic_backend_version=values["aic_backend_version"],
)
def main(argv: Sequence[str] | None = None) -> int: def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay") parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?") parser.add_argument("trace_file", nargs="?")
...@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -79,6 +108,11 @@ def main(argv: Sequence[str] | None = None) -> int:
parser.add_argument("--prefill-engine-args") parser.add_argument("--prefill-engine-args")
parser.add_argument("--decode-engine-args") parser.add_argument("--decode-engine-args")
parser.add_argument("--router-config") parser.add_argument("--router-config")
parser.add_argument("--aic-backend")
parser.add_argument("--aic-system")
parser.add_argument("--aic-backend-version")
parser.add_argument("--aic-tp-size", type=int)
parser.add_argument("--aic-model-path")
parser.add_argument("--input-tokens", type=int) parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int) parser.add_argument("--output-tokens", type=int)
parser.add_argument( parser.add_argument(
...@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -106,6 +140,12 @@ def main(argv: Sequence[str] | None = None) -> int:
default="round_robin", default="round_robin",
) )
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0) parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
parser.add_argument(
"--trace-block-size",
type=int,
default=512,
help="tokens represented by each hash_id in the trace file; only used for file replay",
)
parser.add_argument( parser.add_argument(
"--report-json", "--report-json",
help="path to save the full replay report JSON; defaults to a timestamped file in the current directory", help="path to save the full replay report JSON; defaults to a timestamped file in the current directory",
...@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -140,6 +180,10 @@ def main(argv: Sequence[str] | None = None) -> int:
if args.router_config is not None if args.router_config is not None
else None else None
) )
try:
aic_perf_config = _load_aic_perf_config(args)
except ValueError as exc:
parser.error(str(exc))
if using_trace_file: if using_trace_file:
report = run_trace_replay( report = run_trace_replay(
...@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -148,6 +192,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args=prefill_engine_args, prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args, decode_engine_args=decode_engine_args,
router_config=router_config, router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=args.num_workers, num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers, num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers, num_decode_workers=args.num_decode_workers,
...@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -155,6 +200,7 @@ def main(argv: Sequence[str] | None = None) -> int:
replay_mode=args.replay_mode, replay_mode=args.replay_mode,
router_mode=args.router_mode, router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio, arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size,
) )
else: else:
report = run_synthetic_trace_replay( report = run_synthetic_trace_replay(
...@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -165,6 +211,7 @@ def main(argv: Sequence[str] | None = None) -> int:
prefill_engine_args=prefill_engine_args, prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args, decode_engine_args=decode_engine_args,
router_config=router_config, router_config=router_config,
aic_perf_config=aic_perf_config,
num_workers=args.num_workers, num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers, num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers, num_decode_workers=args.num_decode_workers,
......
...@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode): ...@@ -125,6 +125,32 @@ def test_run_trace_replay_supports_multiturn_sessions(tmp_path, replay_mode):
) )
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_supports_distinct_trace_and_engine_block_sizes(
tmp_path, replay_mode
):
trace_path = tmp_path / "trace_block_size_split.jsonl"
trace_path.write_text(
'{"timestamp":1000.0,"input_length":128,"output_length":2,"hash_ids":[101]}\n',
encoding="utf-8",
)
report = run_trace_replay(
trace_path,
extra_engine_args=_vllm_args(),
num_workers=1,
replay_mode=replay_mode,
trace_block_size=512,
)
_assert_basic_report_counts(
report,
num_requests=1,
input_tokens=128,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"]) @pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"]) @pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"]) @pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
......
...@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{ ...@@ -40,7 +40,7 @@ pub use self::multi_worker_sequence::{
pub use self::sequence::{ActiveSequences, RequestId}; pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed; pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterPrefillLoadModel, RouterQueuePolicy};
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
...@@ -50,6 +50,7 @@ pub use protocols::{ ...@@ -50,6 +50,7 @@ pub use protocols::{
pub use queue::SchedulerQueue; pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree; pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler; pub use scheduling::LocalScheduler;
pub use scheduling::PrefillLoadEstimator;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy}; pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector}; pub use selector::{DefaultWorkerSelector, WorkerSelector};
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::future::Future; use std::future::Future;
use std::time::Duration;
use dynamo_tokens::{SequenceHash, Token}; use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
...@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent { ...@@ -429,6 +430,12 @@ pub struct ActiveSequenceEvent {
pub lora_name: Option<String>, pub lora_name: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrefillLoadHint {
pub initial_effective_prefill_tokens: usize,
pub expected_prefill_duration: Option<Duration>,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ActiveSequenceEventData { pub enum ActiveSequenceEventData {
AddRequest { AddRequest {
...@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData { ...@@ -438,6 +445,8 @@ pub enum ActiveSequenceEventData {
#[serde(default = "default_track_prefill_tokens")] #[serde(default = "default_track_prefill_tokens")]
track_prefill_tokens: bool, track_prefill_tokens: bool,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
#[serde(default)]
prefill_load_hint: Option<PrefillLoadHint>,
}, },
Free, Free,
MarkPrefillCompleted, MarkPrefillCompleted,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
use std::env::{self, VarError}; use std::env::{self, VarError};
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use std::time::Duration;
use derive_builder::Builder; use derive_builder::Builder;
use rand::Rng; use rand::Rng;
...@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy { ...@@ -53,6 +54,43 @@ impl fmt::Display for RouterQueuePolicy {
} }
} }
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterPrefillLoadModel {
#[default]
None,
Aic,
}
impl fmt::Display for RouterPrefillLoadModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.write_str("none"),
Self::Aic => f.write_str("aic"),
}
}
}
impl FromStr for RouterPrefillLoadModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Self::None),
"aic" => Ok(Self::Aic),
_ => Err(format!(
"unknown prefill load model: {s:?}, expected 'none' or 'aic'"
)),
}
}
}
impl RouterPrefillLoadModel {
pub fn is_enabled(self) -> bool {
!matches!(self, Self::None)
}
}
impl FromStr for RouterQueuePolicy { impl FromStr for RouterQueuePolicy {
type Err = String; type Err = String;
...@@ -124,6 +162,9 @@ pub struct KvRouterConfig { ...@@ -124,6 +162,9 @@ pub struct KvRouterConfig {
#[serde(default = "default_track_prefill_tokens")] #[serde(default = "default_track_prefill_tokens")]
pub router_track_prefill_tokens: bool, pub router_track_prefill_tokens: bool,
/// Optional model for estimating effective prompt-side prefill load over time.
pub router_prefill_load_model: RouterPrefillLoadModel,
/// Threshold for triggering snapshots. If None, no snapshots will be performed. /// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))] #[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>, pub router_snapshot_threshold: Option<u32>,
...@@ -183,6 +224,7 @@ impl Default for KvRouterConfig { ...@@ -183,6 +224,7 @@ impl Default for KvRouterConfig {
router_track_output_blocks: false, router_track_output_blocks: false,
router_assume_kv_reuse: true, router_assume_kv_reuse: true,
router_track_prefill_tokens: default_track_prefill_tokens(), router_track_prefill_tokens: default_track_prefill_tokens(),
router_prefill_load_model: RouterPrefillLoadModel::default(),
router_snapshot_threshold: Some(1000000), router_snapshot_threshold: Some(1000000),
router_reset_states: false, router_reset_states: false,
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
...@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr ...@@ -214,10 +256,33 @@ fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationEr
"router_track_output_blocks requires router_track_active_blocks=true", "router_track_output_blocks requires router_track_active_blocks=true",
)); ));
} }
if config.router_prefill_load_model.is_enabled() && !config.router_track_prefill_tokens {
return Err(ValidationError::new(
"router_prefill_load_model requires router_track_prefill_tokens=true",
));
}
if config.router_prefill_load_model.is_enabled()
&& !matches!(config.router_queue_policy, RouterQueuePolicy::Fcfs)
{
return Err(ValidationError::new(
"router_prefill_load_model currently requires router_queue_policy='fcfs'",
));
}
Ok(()) Ok(())
} }
impl KvRouterConfig { impl KvRouterConfig {
pub fn router_queue_recheck_interval(&self) -> Duration {
const DEFAULT_RECHECK_INTERVAL: Duration = Duration::from_secs(60);
const PREFILL_LOAD_RECHECK_INTERVAL: Duration = Duration::from_millis(100);
if self.router_prefill_load_model.is_enabled() && self.router_queue_threshold.is_some() {
return PREFILL_LOAD_RECHECK_INTERVAL;
}
DEFAULT_RECHECK_INTERVAL
}
pub fn assume_kv_reuse(&self, config_override: Option<&RouterConfigOverride>) -> bool { pub fn assume_kv_reuse(&self, config_override: Option<&RouterConfigOverride>) -> bool {
config_override config_override
.and_then(|cfg| cfg.assume_kv_reuse) .and_then(|cfg| cfg.assume_kv_reuse)
...@@ -288,28 +353,6 @@ mod tests { ...@@ -288,28 +353,6 @@ mod tests {
use super::*; use super::*;
use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo}; use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_tracking_prefill_tokens() {
assert!(KvRouterConfig::default().router_track_prefill_tokens);
}
#[test] #[test]
fn compute_seq_hashes_for_tracking_uses_mm_hashes() { fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
let cfg = KvRouterConfig::default(); let cfg = KvRouterConfig::default();
...@@ -343,17 +386,6 @@ mod tests { ...@@ -343,17 +386,6 @@ mod tests {
assert_ne!(without_mm, with_mm); assert_ne!(without_mm, with_mm);
} }
#[test]
fn router_config_override_serde_round_trip_preserves_track_prefill_tokens() {
let serialized = serde_json::to_string(&RouterConfigOverride {
track_prefill_tokens: Some(false),
..Default::default()
})
.unwrap();
let deserialized: RouterConfigOverride = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.track_prefill_tokens, Some(false));
}
#[test] #[test]
fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() { fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() {
let config = KvRouterConfig::default(); let config = KvRouterConfig::default();
......
...@@ -6,9 +6,11 @@ use std::sync::Arc; ...@@ -6,9 +6,11 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tokio::time::Instant;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy}; use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::queue::SchedulerQueue; use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
...@@ -18,8 +20,6 @@ use crate::sequences::{ ...@@ -18,8 +20,6 @@ use crate::sequences::{
}; };
use dynamo_tokens::SequenceHash; use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector> pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where where
P: SequencePublisher, P: SequencePublisher,
...@@ -49,6 +49,8 @@ where ...@@ -49,6 +49,8 @@ where
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
policy: S, policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
recheck_interval: Duration,
track_prefill_tokens_default: bool, track_prefill_tokens_default: bool,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
worker_type: &'static str, worker_type: &'static str,
...@@ -103,13 +105,14 @@ where ...@@ -103,13 +105,14 @@ where
block_size, block_size,
selector, selector,
policy, policy,
prefill_load_estimator,
)); ));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue); let queue_clone = Arc::clone(&queue);
tokio::spawn(async move { tokio::spawn(async move {
let mut request_rx = request_rx; let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL); let mut recheck_interval = tokio::time::interval(recheck_interval);
tracing::trace!("LocalScheduler background task started"); tracing::trace!("LocalScheduler background task started");
loop { loop {
...@@ -192,17 +195,18 @@ where ...@@ -192,17 +195,18 @@ where
} }
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req) self.slots.add_request(req, Instant::now())
} }
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.mark_prefill_completed(&request_id.to_string())?; self.slots
.mark_prefill_completed(&request_id.to_string(), Instant::now())?;
self.queue.update().await; self.queue.update().await;
Ok(()) Ok(())
} }
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string())?; self.slots.free(&request_id.to_string(), Instant::now())?;
self.queue.update().await; self.queue.update().await;
Ok(()) Ok(())
} }
...@@ -231,6 +235,7 @@ where ...@@ -231,6 +235,7 @@ where
overlaps: OverlapScores, overlaps: OverlapScores,
track_prefill_tokens: bool, track_prefill_tokens: bool,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
let decay_now = Instant::now();
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self
.slots .slots
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
...@@ -238,6 +243,7 @@ where ...@@ -238,6 +243,7 @@ where
isl_tokens, isl_tokens,
overlaps, overlaps,
track_prefill_tokens, track_prefill_tokens,
decay_now,
); );
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new(); let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
...@@ -275,15 +281,32 @@ mod tests { ...@@ -275,15 +281,32 @@ mod tests {
use super::*; use super::*;
use crate::protocols::OverlapScores; use crate::protocols::OverlapScores;
use crate::scheduling::PrefillLoadEstimator;
use crate::scheduling::policy::FcfsPolicy; use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector; use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn make_scheduler( fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>, workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
monitor_worker_configs: bool, monitor_worker_configs: bool,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> ( ) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>, Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>, Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
...@@ -311,6 +334,8 @@ mod tests { ...@@ -311,6 +334,8 @@ mod tests {
64, 64,
DefaultWorkerSelector::new(None, "test"), DefaultWorkerSelector::new(None, "test"),
FcfsPolicy, FcfsPolicy,
prefill_load_estimator,
Duration::from_secs(60),
true, true,
cancel_token.clone(), cancel_token.clone(),
"test", "test",
...@@ -329,7 +354,7 @@ mod tests { ...@@ -329,7 +354,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let response = scheduler let response = scheduler
.schedule( .schedule(
...@@ -366,7 +391,7 @@ mod tests { ...@@ -366,7 +391,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler scheduler
.schedule( .schedule(
...@@ -389,7 +414,7 @@ mod tests { ...@@ -389,7 +414,7 @@ mod tests {
assert_eq!( assert_eq!(
slots slots
.active_tokens() .active_tokens(Instant::now())
.get(&WorkerWithDpRank::new(0, 0)) .get(&WorkerWithDpRank::new(0, 0))
.copied(), .copied(),
Some(0) Some(0)
...@@ -408,7 +433,8 @@ mod tests { ...@@ -408,7 +433,8 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true); let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(workers, Some(0.5), true, None);
scheduler scheduler
.schedule( .schedule(
...@@ -466,7 +492,7 @@ mod tests { ...@@ -466,7 +492,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler scheduler
.schedule( .schedule(
...@@ -511,12 +537,16 @@ mod tests { ...@@ -511,12 +537,16 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
let token_seq = vec![11, 22, 33, 44]; let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default(); let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) = let (decode_blocks, prefill_tokens) = slots.potential_blocks_and_tokens(
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone()); Some(&token_seq),
128,
overlaps.clone(),
Instant::now(),
);
let mut expected: Vec<_> = decode_blocks let mut expected: Vec<_> = decode_blocks
.keys() .keys()
.map(|worker| PotentialLoad { .map(|worker| PotentialLoad {
...@@ -548,10 +578,51 @@ mod tests { ...@@ -548,10 +578,51 @@ mod tests {
cancel_token.cancel(); cancel_token.cancel();
} }
#[tokio::test(start_paused = true)]
async fn test_get_potential_loads_uses_decayed_prefill_tokens() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let estimator: Arc<dyn PrefillLoadEstimator> = Arc::new(FixedPrefillLoadEstimator {
duration: Duration::from_secs(10),
});
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(workers, None, true, Some(estimator));
scheduler
.schedule(
Some("req-1".to_string()),
100,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
tokio::time::advance(Duration::from_secs(6)).await;
let loads = scheduler.get_potential_loads(None, 0, OverlapScores::default(), true);
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].potential_prefill_tokens, 40);
cancel_token.cancel();
}
#[tokio::test] #[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() { async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) = let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false); make_scheduler(HashMap::new(), None, false, None);
scheduler.register_workers(&HashSet::from([42])); scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true); let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default(), true);
...@@ -567,7 +638,7 @@ mod tests { ...@@ -567,7 +638,7 @@ mod tests {
async fn test_worker_watch_updates_slot_ranges() { async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new(); let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default()); workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
assert_eq!( assert_eq!(
scheduler scheduler
...@@ -615,7 +686,7 @@ mod tests { ...@@ -615,7 +686,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
); );
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true); let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true, None);
scheduler scheduler
.schedule( .schedule(
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
pub mod config; pub mod config;
mod local; mod local;
pub mod policy; pub mod policy;
pub mod prefill_load;
pub mod queue; pub mod queue;
pub mod selector; pub mod selector;
mod types; mod types;
pub use local::LocalScheduler; pub use local::LocalScheduler;
pub use prefill_load::PrefillLoadEstimator;
pub use types::*; pub use types::*;
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
pub trait PrefillLoadEstimator: Send + Sync {
fn predict_prefill_duration(
&self,
batch_size: usize,
effective_isl: usize,
prefix: usize,
) -> anyhow::Result<Duration>;
}
...@@ -5,15 +5,16 @@ use std::cmp::Ordering; ...@@ -5,15 +5,16 @@ use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet}; use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::time::Instant;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::watch; use tokio::sync::watch;
use tokio::time::Instant;
use super::policy::{FcfsPolicy, SchedulingPolicy}; use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::prefill_load::PrefillLoadEstimator;
use super::selector::{DefaultWorkerSelector, WorkerSelector}; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse}; use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{PrefillLoadHint, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker) /// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
...@@ -68,6 +69,7 @@ pub struct SchedulerQueue< ...@@ -68,6 +69,7 @@ pub struct SchedulerQueue<
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
policy: S, policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
} }
impl< impl<
...@@ -84,6 +86,7 @@ impl< ...@@ -84,6 +86,7 @@ impl<
block_size: u32, block_size: u32,
selector: Sel, selector: Sel,
policy: S, policy: S,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> Self { ) -> Self {
if let Some(frac) = threshold_frac { if let Some(frac) = threshold_frac {
tracing::info!("Router queue enabled with threshold fraction {frac}"); tracing::info!("Router queue enabled with threshold fraction {frac}");
...@@ -98,6 +101,7 @@ impl< ...@@ -98,6 +101,7 @@ impl<
block_size, block_size,
selector, selector,
policy, policy,
prefill_load_estimator,
} }
} }
...@@ -133,23 +137,24 @@ impl< ...@@ -133,23 +137,24 @@ impl<
/// capacity check is skipped. /// capacity check is skipped.
pub async fn enqueue(&self, request: SchedulingRequest) { pub async fn enqueue(&self, request: SchedulingRequest) {
let Some(threshold) = self.threshold_frac else { let Some(threshold) = self.threshold_frac else {
self.schedule(request).await; self.schedule(request, Instant::now()).await;
return; return;
}; };
if request.allowed_worker_ids.is_some() { if request.allowed_worker_ids.is_some() {
self.schedule(request).await; self.schedule(request, Instant::now()).await;
return; return;
} }
if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref()) { let decay_now = Instant::now();
if self.all_workers_busy(threshold, request.allowed_worker_ids.as_ref(), decay_now) {
tracing::debug!("all workers busy, queueing request"); tracing::debug!("all workers busy, queueing request");
let arrival_offset = self.start_time.elapsed(); let arrival_offset = self.start_time.elapsed();
let key = self.policy.enqueue_key(arrival_offset, &request); let key = self.policy.enqueue_key(arrival_offset, &request);
self.pending.lock().await.push(QueueEntry { key, request }); self.pending.lock().await.push(QueueEntry { key, request });
self.pending_count.fetch_add(1, AtomicOrdering::Relaxed); self.pending_count.fetch_add(1, AtomicOrdering::Relaxed);
} else { } else {
self.schedule(request).await; self.schedule(request, decay_now).await;
} }
} }
...@@ -176,7 +181,8 @@ impl< ...@@ -176,7 +181,8 @@ impl<
} }
loop { loop {
if self.all_workers_busy(threshold, None) { let decay_now = Instant::now();
if self.all_workers_busy(threshold, None, decay_now) {
break; break;
} }
let Some(entry) = self.pending.lock().await.pop() else { let Some(entry) = self.pending.lock().await.pop() else {
...@@ -184,13 +190,13 @@ impl< ...@@ -184,13 +190,13 @@ impl<
}; };
self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed); self.pending_count.fetch_sub(1, AtomicOrdering::Relaxed);
tracing::debug!("scheduling request from pending queue"); tracing::debug!("scheduling request from pending queue");
self.schedule(entry.request).await; self.schedule(entry.request, decay_now).await;
} }
} }
/// Run the full scheduling pipeline for a single request: /// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request. /// compute potential load -> select worker -> respond -> book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) { async fn schedule(&self, mut request: SchedulingRequest, decay_now: Instant) {
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self
.slots .slots
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
...@@ -198,6 +204,7 @@ impl< ...@@ -198,6 +204,7 @@ impl<
request.isl_tokens, request.isl_tokens,
request.overlaps.clone(), request.overlaps.clone(),
request.track_prefill_tokens, request.track_prefill_tokens,
decay_now,
); );
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
...@@ -231,20 +238,66 @@ impl< ...@@ -231,20 +238,66 @@ impl<
return; return;
}; };
if let Err(e) = self.slots.add_request(SequenceRequest { let prefill_load_hint = self.prefill_load_hint_for(
request_id: request_id.clone(), request.isl_tokens,
token_sequence: request.token_seq, selection.overlap_blocks,
isl: request.isl_tokens, request.track_prefill_tokens,
overlap: selection.overlap_blocks, );
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens, if let Err(e) = self.slots.add_request(
worker: selection.worker, SequenceRequest {
lora_name: request.lora_name.clone(), request_id: request_id.clone(),
}) { token_sequence: request.token_seq,
isl: request.isl_tokens,
overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens,
prefill_load_hint,
worker: selection.worker,
lora_name: request.lora_name.clone(),
},
decay_now,
) {
tracing::warn!("Failed to add request {request_id}: {e}"); tracing::warn!("Failed to add request {request_id}: {e}");
} }
} }
fn prefill_load_hint_for(
&self,
isl_tokens: usize,
overlap_blocks: u32,
track_prefill_tokens: bool,
) -> Option<PrefillLoadHint> {
if !track_prefill_tokens {
return None;
}
let prefix = (overlap_blocks as usize) * (self.block_size as usize);
let effective_isl = isl_tokens.saturating_sub(prefix);
if effective_isl == 0 {
return None;
}
let Some(estimator) = &self.prefill_load_estimator else {
return None;
};
match estimator.predict_prefill_duration(1, effective_isl, prefix) {
Ok(expected_prefill_duration) => Some(PrefillLoadHint {
initial_effective_prefill_tokens: effective_isl,
expected_prefill_duration: Some(expected_prefill_duration),
}),
Err(error) => {
tracing::warn!(
effective_isl,
prefix,
"failed to predict prefill duration for active load tracking: {error}"
);
None
}
}
}
/// Number of requests currently parked in the pending queue (lock-free). /// Number of requests currently parked in the pending queue (lock-free).
pub fn pending_count(&self) -> usize { pub fn pending_count(&self) -> usize {
self.pending_count.load(AtomicOrdering::Relaxed) self.pending_count.load(AtomicOrdering::Relaxed)
...@@ -255,8 +308,13 @@ impl< ...@@ -255,8 +308,13 @@ impl<
/// otherwise all registered workers are checked. /// otherwise all registered workers are checked.
/// Returns false when no eligible workers exist so the request falls /// Returns false when no eligible workers exist so the request falls
/// through to `schedule`, which returns a proper `NoEndpoints` error. /// through to `schedule`, which returns a proper `NoEndpoints` error.
fn all_workers_busy(&self, threshold: f64, allowed: Option<&HashSet<WorkerId>>) -> bool { fn all_workers_busy(
let active_tokens = self.slots.active_tokens(); &self,
threshold: f64,
allowed: Option<&HashSet<WorkerId>>,
decay_now: Instant,
) -> bool {
let active_tokens = self.slots.active_tokens(decay_now);
let configs = self.workers_with_configs.borrow(); let configs = self.workers_with_configs.borrow();
let mut checked_any = false; let mut checked_any = false;
...@@ -289,6 +347,7 @@ impl< ...@@ -289,6 +347,7 @@ impl<
mod tests { mod tests {
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch; use tokio::sync::watch;
...@@ -298,6 +357,25 @@ mod tests { ...@@ -298,6 +357,25 @@ mod tests {
use crate::sequences::ActiveSequencesMultiWorker; use crate::sequences::ActiveSequencesMultiWorker;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig}; use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
fn decay_now() -> Instant {
Instant::now()
}
struct FixedPrefillLoadEstimator {
duration: Duration,
}
impl PrefillLoadEstimator for FixedPrefillLoadEstimator {
fn predict_prefill_duration(
&self,
_batch_size: usize,
_effective_isl: usize,
_prefix: usize,
) -> anyhow::Result<Duration> {
Ok(self.duration)
}
}
fn make_queue( fn make_queue(
num_workers: usize, num_workers: usize,
block_size: u32, block_size: u32,
...@@ -308,7 +386,7 @@ mod tests { ...@@ -308,7 +386,7 @@ mod tests {
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>, Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
) { ) {
let (queue, slots, _tx) = let (queue, slots, _tx) =
make_queue_with_sender(num_workers, block_size, isl, threshold_frac); make_queue_with_sender(num_workers, block_size, isl, threshold_frac, None);
(queue, slots) (queue, slots)
} }
...@@ -318,6 +396,7 @@ mod tests { ...@@ -318,6 +396,7 @@ mod tests {
block_size: u32, block_size: u32,
isl: usize, isl: usize,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
) -> ( ) -> (
Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>, Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>, Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
...@@ -354,6 +433,7 @@ mod tests { ...@@ -354,6 +433,7 @@ mod tests {
block_size, block_size,
selector, selector,
FcfsPolicy, FcfsPolicy,
prefill_load_estimator,
)); ));
(queue, slots, cfg_tx) (queue, slots, cfg_tx)
...@@ -409,8 +489,8 @@ mod tests { ...@@ -409,8 +489,8 @@ mod tests {
let resp = resp.expect("scheduling failed"); let resp = resp.expect("scheduling failed");
assert!(resp.best_worker.worker_id < num_workers as u64); assert!(resp.best_worker.worker_id < num_workers as u64);
slots.mark_prefill_completed(&req_id).unwrap(); slots.mark_prefill_completed(&req_id, decay_now()).unwrap();
slots.free(&req_id).unwrap(); slots.free(&req_id, decay_now()).unwrap();
queue.update().await; queue.update().await;
})); }));
} }
...@@ -419,7 +499,7 @@ mod tests { ...@@ -419,7 +499,7 @@ mod tests {
h.await.expect("task panicked"); h.await.expect("task panicked");
} }
let active = slots.active_tokens(); let active = slots.active_tokens(decay_now());
for (worker, tokens) in &active { for (worker, tokens) in &active {
assert_eq!( assert_eq!(
*tokens, 0, *tokens, 0,
...@@ -453,8 +533,8 @@ mod tests { ...@@ -453,8 +533,8 @@ mod tests {
for _ in 0..num_requests { for _ in 0..num_requests {
queue.update().await; queue.update().await;
for rid in &req_ids { for rid in &req_ids {
let _ = slots.mark_prefill_completed(rid); let _ = slots.mark_prefill_completed(rid, decay_now());
let _ = slots.free(rid); let _ = slots.free(rid, decay_now());
} }
} }
queue.update().await; queue.update().await;
...@@ -495,8 +575,10 @@ mod tests { ...@@ -495,8 +575,10 @@ mod tests {
assert_eq!(queue.pending_count(), 2); assert_eq!(queue.pending_count(), 2);
// Free the first request and update — should drain one from pending // Free the first request and update — should drain one from pending
slots.mark_prefill_completed(&"req-1".to_string()).unwrap(); slots
slots.free(&"req-1".to_string()).unwrap(); .mark_prefill_completed(&"req-1".to_string(), decay_now())
.unwrap();
slots.free(&"req-1".to_string(), decay_now()).unwrap();
queue.update().await; queue.update().await;
// After update, one pending request should have been scheduled // After update, one pending request should have been scheduled
...@@ -507,16 +589,43 @@ mod tests { ...@@ -507,16 +589,43 @@ mod tests {
); );
// Free req-2 and update to drain remaining // Free req-2 and update to drain remaining
let _ = slots.mark_prefill_completed(&"req-2".to_string()); let _ = slots.mark_prefill_completed(&"req-2".to_string(), decay_now());
let _ = slots.free(&"req-2".to_string()); let _ = slots.free(&"req-2".to_string(), decay_now());
queue.update().await; queue.update().await;
let _ = slots.mark_prefill_completed(&"req-3".to_string()); let _ = slots.mark_prefill_completed(&"req-3".to_string(), decay_now());
let _ = slots.free(&"req-3".to_string()); let _ = slots.free(&"req-3".to_string(), decay_now());
queue.update().await; queue.update().await;
assert_eq!(queue.pending_count(), 0, "all requests should be drained"); assert_eq!(queue.pending_count(), 0, "all requests should be drained");
} }
#[tokio::test(start_paused = true)]
async fn test_queue_update_uses_decayed_oldest_prefill_load() {
let estimator: Arc<dyn PrefillLoadEstimator> = Arc::new(FixedPrefillLoadEstimator {
duration: Duration::from_secs(10),
});
let (queue, _slots, _cfg_tx) =
make_queue_with_sender(1, 16, 100, Some(0.5), Some(estimator));
let (req1, rx1) = make_request("req-1", 100);
queue.enqueue(req1).await;
let _ = rx1.await.unwrap().unwrap();
let (req2, mut rx2) = make_request("req-2", 100);
queue.enqueue(req2).await;
assert_eq!(queue.pending_count(), 1);
tokio::time::advance(Duration::from_secs(6)).await;
queue.update().await;
let scheduled = rx2
.try_recv()
.expect("queued request should have been scheduled");
let response = scheduled.expect("scheduling returned error");
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(queue.pending_count(), 0);
}
#[tokio::test] #[tokio::test]
async fn test_no_workers_returns_error() { async fn test_no_workers_returns_error() {
let (queue, _slots) = make_queue(0, 16, 512, None); let (queue, _slots) = make_queue(0, 16, 512, None);
...@@ -542,7 +651,7 @@ mod tests { ...@@ -542,7 +651,7 @@ mod tests {
let isl = 512; let isl = 512;
// Start with zero workers (mimics skip_initial_worker_wait=true) // Start with zero workers (mimics skip_initial_worker_wait=true)
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Routing with no workers must fail // Routing with no workers must fail
let (req_fail, rx_fail) = make_request("before-register", isl); let (req_fail, rx_fail) = make_request("before-register", isl);
...@@ -590,9 +699,11 @@ mod tests { ...@@ -590,9 +699,11 @@ mod tests {
// Clean up // Clean up
slots slots
.mark_prefill_completed(&"after-register".to_string()) .mark_prefill_completed(&"after-register".to_string(), decay_now())
.unwrap();
slots
.free(&"after-register".to_string(), decay_now())
.unwrap(); .unwrap();
slots.free(&"after-register".to_string()).unwrap();
} }
/// Register_workers is additive: calling with a new set does NOT remove old workers. /// Register_workers is additive: calling with a new set does NOT remove old workers.
...@@ -601,7 +712,7 @@ mod tests { ...@@ -601,7 +712,7 @@ mod tests {
let block_size = 16; let block_size = 16;
let isl = 256; let isl = 256;
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Register worker 10 in slots and config // Register worker 10 in slots and config
let mut dp1 = std::collections::HashMap::new(); let mut dp1 = std::collections::HashMap::new();
...@@ -643,8 +754,8 @@ mod tests { ...@@ -643,8 +754,8 @@ mod tests {
.expect("oneshot dropped") .expect("oneshot dropped")
.expect("scheduling failed"); .expect("scheduling failed");
seen.insert(resp.best_worker.worker_id); seen.insert(resp.best_worker.worker_id);
slots.mark_prefill_completed(&req_id).unwrap(); slots.mark_prefill_completed(&req_id, decay_now()).unwrap();
slots.free(&req_id).unwrap(); slots.free(&req_id, decay_now()).unwrap();
} }
assert!( assert!(
...@@ -659,7 +770,7 @@ mod tests { ...@@ -659,7 +770,7 @@ mod tests {
let block_size = 16; let block_size = 16;
let isl = 256; let isl = 256;
let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None); let (queue, slots, cfg_tx) = make_queue_with_sender(0, block_size, isl, None, None);
// Register three workers // Register three workers
let mut dp = std::collections::HashMap::new(); let mut dp = std::collections::HashMap::new();
...@@ -712,9 +823,9 @@ mod tests { ...@@ -712,9 +823,9 @@ mod tests {
resp.best_worker.worker_id resp.best_worker.worker_id
); );
slots slots
.mark_prefill_completed(&"filter-0".to_string()) .mark_prefill_completed(&"filter-0".to_string(), decay_now())
.unwrap(); .unwrap();
slots.free(&"filter-0".to_string()).unwrap(); slots.free(&"filter-0".to_string(), decay_now()).unwrap();
} }
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
...@@ -727,7 +838,7 @@ mod tests { ...@@ -727,7 +838,7 @@ mod tests {
let _resp1 = rx1.await.unwrap().unwrap(); let _resp1 = rx1.await.unwrap().unwrap();
assert_eq!( assert_eq!(
slots slots
.active_tokens() .active_tokens(decay_now())
.get(&WorkerWithDpRank::new(0, 0)) .get(&WorkerWithDpRank::new(0, 0))
.copied(), .copied(),
Some(0) Some(0)
...@@ -738,9 +849,9 @@ mod tests { ...@@ -738,9 +849,9 @@ mod tests {
let _resp2 = rx2.await.unwrap().unwrap(); let _resp2 = rx2.await.unwrap().unwrap();
assert_eq!(queue.pending_count(), 0); assert_eq!(queue.pending_count(), 0);
let _ = slots.mark_prefill_completed(&"req-1".to_string()); let _ = slots.mark_prefill_completed(&"req-1".to_string(), decay_now());
let _ = slots.free(&"req-1".to_string()); let _ = slots.free(&"req-1".to_string(), decay_now());
let _ = slots.mark_prefill_completed(&"req-2".to_string()); let _ = slots.mark_prefill_completed(&"req-2".to_string(), decay_now());
let _ = slots.free(&"req-2".to_string()); let _ = slots.free(&"req-2".to_string(), decay_now());
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_tokens::SequenceHash;
use std::collections::HashMap;
use std::sync::{Arc, Weak};
#[derive(Debug, Default)]
pub(super) struct BlockTracker {
pub(super) unique_blocks: HashMap<SequenceHash, Weak<()>>,
pub(super) fractional_blocks: HashMap<SequenceHash, f64>,
}
impl BlockTracker {
pub(super) fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
if let Some(weak) = self.unique_blocks.get(block)
&& let Some(rc) = weak.upgrade()
{
return rc;
}
let rc = Arc::new(());
self.unique_blocks.insert(*block, Arc::downgrade(&rc));
rc
}
pub(super) fn try_remove_block(&mut self, block: &SequenceHash) {
if let Some(weak) = self.unique_blocks.get(block)
&& weak.strong_count() == 0
{
self.unique_blocks.remove(block);
self.fractional_blocks.remove(block);
}
}
pub(super) fn active_blocks(&self) -> usize {
let mut count = self.unique_blocks.len() as f64;
for (hash, frac) in &self.fractional_blocks {
if self.unique_blocks.contains_key(hash) {
count = count - 1.0 + frac;
}
}
count.round() as usize
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
mod block_tracker;
pub mod multi_worker; pub mod multi_worker;
mod prefill_tracker;
pub mod single; pub mod single;
pub use multi_worker::*; pub use multi_worker::*;
......
...@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken; ...@@ -20,7 +20,8 @@ use tokio_util::sync::CancellationToken;
use super::single::{ActiveSequences, RequestId}; use super::single::{ActiveSequences, RequestId};
use crate::protocols::{ use crate::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, WorkerWithDpRank, ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, OverlapScores, PrefillLoadHint,
WorkerWithDpRank,
}; };
// How often we force expire stale requests across all workers. See the comment // How often we force expire stale requests across all workers. See the comment
...@@ -93,6 +94,7 @@ pub struct SequenceRequest { ...@@ -93,6 +94,7 @@ pub struct SequenceRequest {
pub overlap: u32, pub overlap: u32,
pub track_prefill_tokens: bool, pub track_prefill_tokens: bool,
pub expected_output_tokens: Option<u32>, pub expected_output_tokens: Option<u32>,
pub prefill_load_hint: Option<PrefillLoadHint>,
pub worker: WorkerWithDpRank, pub worker: WorkerWithDpRank,
pub lora_name: Option<String>, pub lora_name: Option<String>,
} }
...@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -177,6 +179,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return; return;
} }
// TODO: Publish explicit prompt-load decay timestamps with these events so peer routers
// can mirror the same oldest-prefill anchor instead of approximating from receive time.
let publisher = Arc::clone(&self.publisher); let publisher = Arc::clone(&self.publisher);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = publisher.publish_event(&event).await { if let Err(e) = publisher.publish_event(&event).await {
...@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -228,6 +232,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
continue; continue;
} }
// TODO: ActiveSequenceEvent does not carry prompt-load decay timestamps yet.
// Peer routers still approximate decay anchoring with local receive time.
let decay_now = Instant::now();
match &event.data { match &event.data {
ActiveSequenceEventData::AddRequest { ActiveSequenceEventData::AddRequest {
token_sequence, token_sequence,
...@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -235,6 +242,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap, overlap,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint,
} => { } => {
self.request_to_worker self.request_to_worker
.insert(event.request_id.clone(), event.worker); .insert(event.request_id.clone(), event.worker);
...@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -253,6 +261,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
*overlap, *overlap,
*expected_output_tokens, *expected_output_tokens,
*track_prefill_tokens, *track_prefill_tokens,
*prefill_load_hint,
decay_now,
); );
} else { } else {
tracing::warn!( tracing::warn!(
...@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -267,7 +277,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
{ {
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&worker) { if let Some(&idx) = table.index.get(&worker) {
table.slots[idx].1.write().free(&event.request_id); table.slots[idx].1.write().free(&event.request_id, decay_now);
} }
} }
self.request_to_lora.remove(&event.request_id); self.request_to_lora.remove(&event.request_id);
...@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -281,7 +291,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
table.slots[idx] table.slots[idx]
.1 .1
.write() .write()
.mark_prefill_completed(&event.request_id); .mark_prefill_completed(&event.request_id, decay_now);
} }
} }
} }
...@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -381,7 +391,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
} }
fn add_request_local(&self, req: SequenceRequest) -> Result<(), SequenceError> { fn add_request_local(
&self,
req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
let SequenceRequest { let SequenceRequest {
request_id, request_id,
token_sequence, token_sequence,
...@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -389,6 +403,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap, overlap,
track_prefill_tokens, track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
prefill_load_hint,
worker, worker,
lora_name, lora_name,
} = req; } = req;
...@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -435,6 +450,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap, overlap,
expected_output_tokens, expected_output_tokens,
track_prefill_tokens, track_prefill_tokens,
prefill_load_hint,
decay_now,
) )
}; };
...@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -443,12 +460,16 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(expired_id); self.request_to_lora.remove(expired_id);
} }
self.publish_active_load_for_worker(worker); self.publish_active_load_for_worker(worker, decay_now);
Ok(()) Ok(())
} }
pub fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> { pub fn add_request(
&self,
req: SequenceRequest,
decay_now: Instant,
) -> Result<(), SequenceError> {
self.spawn_publish_event(ActiveSequenceEvent { self.spawn_publish_event(ActiveSequenceEvent {
request_id: req.request_id.clone(), request_id: req.request_id.clone(),
worker: req.worker, worker: req.worker,
...@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -458,11 +479,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlap: req.overlap, overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens, track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens, expected_output_tokens: req.expected_output_tokens,
prefill_load_hint: req.prefill_load_hint,
}, },
router_id: self.router_id, router_id: self.router_id,
lora_name: req.lora_name.clone(), lora_name: req.lora_name.clone(),
}); });
self.add_request_local(req) self.add_request_local(req, decay_now)
} }
/// Send a mutation to the worker assigned to a request, optionally publishing /// Send a mutation to the worker assigned to a request, optionally publishing
...@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -470,7 +492,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn mutate_request_worker_local( fn mutate_request_worker_local(
&self, &self,
request_id: &RequestId, request_id: &RequestId,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId), decay_now: Instant,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool, remove_mapping: bool,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
let worker = self let worker = self
...@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -488,7 +511,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get(&worker) .get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?; .ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write(); let mut seq = table.slots[idx].1.write();
mutate_fn(&mut seq, request_id); mutate_fn(&mut seq, request_id, decay_now);
} }
if remove_mapping { if remove_mapping {
...@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -496,7 +519,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(request_id); self.request_to_lora.remove(request_id);
} }
self.publish_active_load_for_worker(worker); self.publish_active_load_for_worker(worker, decay_now);
Ok(()) Ok(())
} }
...@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -504,8 +527,9 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
fn mutate_request_worker( fn mutate_request_worker(
&self, &self,
request_id: &RequestId, request_id: &RequestId,
decay_now: Instant,
event_data: ActiveSequenceEventData, event_data: ActiveSequenceEventData,
mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId), mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId, Instant),
remove_mapping: bool, remove_mapping: bool,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
let worker = self let worker = self
...@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -528,7 +552,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
lora_name, lora_name,
}); });
self.mutate_request_worker_local(request_id, mutate_fn, remove_mapping) self.mutate_request_worker_local(request_id, decay_now, mutate_fn, remove_mapping)
} }
/// Free all blocks associated with a request. /// Free all blocks associated with a request.
...@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -539,7 +563,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// This also performs the underlying prefill-complete cleanup via /// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call /// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request. /// [`Self::mark_prefill_completed`] before freeing a completed request.
pub fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> { pub fn free(&self, request_id: &RequestId, decay_now: Instant) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) { if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)"); tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(()); return Ok(());
...@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -547,9 +571,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.mutate_request_worker( self.mutate_request_worker(
request_id, request_id,
decay_now,
ActiveSequenceEventData::Free, ActiveSequenceEventData::Free,
|seqs, rid| { |seqs, rid, decay_now| {
seqs.free(rid); seqs.free(rid, decay_now);
}, },
true, true,
) )
...@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -559,12 +584,17 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// ///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op /// Note: Calling this multiple times for the same request is allowed and will be a no-op
/// after the first call (idempotent). /// after the first call (idempotent).
pub fn mark_prefill_completed(&self, request_id: &RequestId) -> Result<(), SequenceError> { pub fn mark_prefill_completed(
&self,
request_id: &RequestId,
decay_now: Instant,
) -> Result<(), SequenceError> {
self.mutate_request_worker( self.mutate_request_worker(
request_id, request_id,
decay_now,
ActiveSequenceEventData::MarkPrefillCompleted, ActiveSequenceEventData::MarkPrefillCompleted,
|seqs, rid| { |seqs, rid, decay_now| {
seqs.mark_prefill_completed(rid); seqs.mark_prefill_completed(rid, decay_now);
}, },
false, false,
) )
...@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -605,13 +635,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}); });
} }
self.publish_active_load_for_worker(worker); self.publish_active_load_for_worker(worker, Instant::now());
Ok(()) Ok(())
} }
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics. /// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) { fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank, decay_now: Instant) {
let (active_blocks, active_tokens) = { let (active_blocks, active_tokens) = {
let table = self.workers.read(); let table = self.workers.read();
let Some(&idx) = table.index.get(&worker) else { let Some(&idx) = table.index.get(&worker) else {
...@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -619,7 +649,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
return; return;
}; };
let seq = table.slots[idx].1.read(); let seq = table.slots[idx].1.read();
(seq.active_blocks(), seq.active_tokens()) (seq.active_blocks(), seq.active_tokens(decay_now))
}; };
self.publisher self.publisher
...@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -674,11 +704,18 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
decay_now: Instant,
) -> ( ) -> (
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
) { ) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlaps, true) self.potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlaps,
true,
decay_now,
)
} }
pub fn potential_blocks_and_tokens_with_prefill_tracking( pub fn potential_blocks_and_tokens_with_prefill_tracking(
...@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -687,6 +724,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl: usize, isl: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
track_prefill_tokens: bool, track_prefill_tokens: bool,
decay_now: Instant,
) -> ( ) -> (
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
...@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -712,6 +750,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
isl, isl,
overlap, overlap,
track_prefill_tokens, track_prefill_tokens,
decay_now,
); );
potential_blocks.insert(*worker, blocks); potential_blocks.insert(*worker, blocks);
potential_tokens.insert(*worker, tokens); potential_tokens.insert(*worker, tokens);
...@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -741,11 +780,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
} }
/// Query all workers for their current number of active tokens. /// Query all workers for their current number of active tokens.
pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> { pub fn active_tokens(&self, decay_now: Instant) -> HashMap<WorkerWithDpRank, usize> {
let table = self.workers.read(); let table = self.workers.read();
let mut results = HashMap::with_capacity(table.slots.len()); let mut results = HashMap::with_capacity(table.slots.len());
for (worker, lock) in &table.slots { for (worker, lock) in &table.slots {
results.insert(*worker, lock.read().active_tokens()); results.insert(*worker, lock.read().active_tokens(decay_now));
} }
results results
} }
...@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -753,11 +792,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Return true if any worker satisfies the provided predicate on active token count. /// Return true if any worker satisfies the provided predicate on active token count.
pub fn any_worker_matches_active_tokens( pub fn any_worker_matches_active_tokens(
&self, &self,
decay_now: Instant,
mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool, mut predicate: impl FnMut(WorkerWithDpRank, usize) -> bool,
) -> bool { ) -> bool {
let table = self.workers.read(); let table = self.workers.read();
for (worker, lock) in &table.slots { for (worker, lock) in &table.slots {
if predicate(*worker, lock.read().active_tokens()) { if predicate(*worker, lock.read().active_tokens(decay_now)) {
return true; return true;
} }
} }
...@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -792,7 +832,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
self.request_to_lora.remove(expired_id); self.request_to_lora.remove(expired_id);
removed_request_count += 1; removed_request_count += 1;
} }
self.publish_active_load_for_worker(*worker); self.publish_active_load_for_worker(*worker, now);
} }
} }
let duration = now.elapsed(); let duration = now.elapsed();
...@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -835,8 +875,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration;
use super::*; use super::*;
use crate::protocols::{OverlapScores, PrefillLoadHint};
use crate::test_utils::NoopSequencePublisher; use crate::test_utils::NoopSequencePublisher;
fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> { fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
...@@ -854,20 +896,74 @@ mod tests { ...@@ -854,20 +896,74 @@ mod tests {
async fn add_request_can_skip_prefill_token_tracking() { async fn add_request_can_skip_prefill_token_tracking() {
let sequences = make_sequences(); let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0); let worker = WorkerWithDpRank::new(1, 0);
let decay_now = Instant::now();
sequences sequences
.add_request(SequenceRequest { .add_request(
request_id: "req-1".to_string(), SequenceRequest {
token_sequence: Some(vec![1, 2, 3]), request_id: "req-1".to_string(),
isl: 12, token_sequence: Some(vec![1, 2, 3]),
overlap: 0, isl: 12,
track_prefill_tokens: false, overlap: 0,
expected_output_tokens: None, track_prefill_tokens: false,
worker, expected_output_tokens: None,
lora_name: None, prefill_load_hint: None,
}) worker,
lora_name: None,
},
decay_now,
)
.unwrap();
assert_eq!(
sequences.active_tokens(decay_now).get(&worker).copied(),
Some(0)
);
}
#[test]
fn explicit_decay_time_drives_multi_worker_load_queries_consistently() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
let start = Instant::now();
sequences
.add_request(
SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 100,
overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None,
prefill_load_hint: Some(PrefillLoadHint {
initial_effective_prefill_tokens: 100,
expected_prefill_duration: Some(Duration::from_secs(10)),
}),
worker,
lora_name: None,
},
start,
)
.unwrap(); .unwrap();
assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0)); let decay_now = start + Duration::from_secs(5);
let active_tokens = sequences.active_tokens(decay_now);
assert_eq!(active_tokens.get(&worker).copied(), Some(50));
let (_, potential_tokens) = sequences.potential_blocks_and_tokens_with_prefill_tracking(
None,
0,
OverlapScores::default(),
false,
decay_now,
);
assert_eq!(potential_tokens.get(&worker).copied(), Some(50));
assert!(
sequences.any_worker_matches_active_tokens(decay_now, |candidate, tokens| {
candidate == worker && tokens == 50
})
);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment