Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
...@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fetch_model, m)?)?; m.add_function(wrap_pyfunction!(fetch_model, m)?)?;
m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?; m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::replay::run_mocker_trace_replay, m)?)?;
m.add_function(wrap_pyfunction!( m.add_function(wrap_pyfunction!(
llm::entrypoint::run_mocker_trace_replay, llm::replay::run_mocker_synthetic_trace_replay,
m m
)?)?; )?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
...@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::EngineType>()?; m.add_class::<llm::entrypoint::EngineType>()?;
m.add_class::<llm::entrypoint::RouterConfig>()?; m.add_class::<llm::entrypoint::RouterConfig>()?;
m.add_class::<llm::entrypoint::KvRouterConfig>()?; m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?; m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
...@@ -31,3 +31,4 @@ pub mod local_model; ...@@ -31,3 +31,4 @@ pub mod local_model;
pub mod lora; pub mod lora;
pub mod model_card; pub mod model_card;
pub mod preprocessor; pub mod preprocessor;
pub mod replay;
...@@ -9,7 +9,6 @@ use std::sync::Arc; ...@@ -9,7 +9,6 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals; use pyo3_async_runtimes::TaskLocals;
use pythonize::pythonize;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig; use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig; use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
...@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE ...@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use dynamo_mocker::common::perf_model::PerfModel; use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback; use super::aic_callback::create_aic_callback;
use dynamo_mocker::common::protocols::MockEngineArgs; use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId; use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
...@@ -58,7 +58,7 @@ impl KvRouterConfig { ...@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
...@@ -77,6 +77,7 @@ impl KvRouterConfig { ...@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>, router_queue_threshold: Option<f64>,
router_event_threads: u32, router_event_threads: u32,
router_enable_cache_control: bool, router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str, router_queue_policy: &str,
remote_indexer_component: Option<String>, remote_indexer_component: Option<String>,
) -> Self { ) -> Self {
...@@ -99,6 +100,7 @@ impl KvRouterConfig { ...@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads, router_event_threads,
router_enable_cache_control, router_enable_cache_control,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| { router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}") panic!("invalid router_queue_policy: {router_queue_policy:?}")
}), }),
...@@ -106,6 +108,13 @@ impl KvRouterConfig { ...@@ -106,6 +108,13 @@ impl KvRouterConfig {
}, },
} }
} }
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
serde_json::from_str::<RsKvRouterConfig>(config_json)
.map(|inner| KvRouterConfig { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
}
} }
#[pyclass] #[pyclass]
...@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs { ...@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
...@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs { ...@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -225,6 +235,7 @@ impl EntrypointArgs { ...@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>, runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
...@@ -272,6 +283,7 @@ impl EntrypointArgs { ...@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path, tls_cert_path,
tls_key_path, tls_key_path,
extra_engine_args, extra_engine_args,
mocker_engine_args,
runtime_config, runtime_config,
namespace, namespace,
namespace_prefix, namespace_prefix,
...@@ -419,8 +431,10 @@ async fn select_engine( ...@@ -419,8 +431,10 @@ async fn select_engine(
} }
} }
EngineType::Mocker => { EngineType::Mocker => {
let mut mocker_args = if let Some(extra_args_path) = args.extra_engine_args { let mut mocker_args = if let Some(mocker_engine_args) = args.mocker_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| { mocker_engine_args.inner()
} else if let Some(extra_args_path) = args.extra_engine_args {
RsMockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
anyhow::anyhow!( anyhow::anyhow!(
"Failed to load mocker args from {:?}: {}", "Failed to load mocker args from {:?}: {}",
extra_args_path, extra_args_path,
...@@ -431,7 +445,7 @@ async fn select_engine( ...@@ -431,7 +445,7 @@ async fn select_engine(
tracing::warn!( tracing::warn!(
"No extra_engine_args specified for mocker engine. Using default mocker args." "No extra_engine_args specified for mocker engine. Using default mocker args."
); );
MockEngineArgs::default() RsMockEngineArgs::default()
}; };
// If aic_backend is set, create Python AIC callback and override perf_model // If aic_backend is set, create Python AIC callback and override perf_model
...@@ -503,84 +517,6 @@ pub fn run_input<'p>( ...@@ -503,84 +517,6 @@ pub fn run_input<'p>(
}) })
} }
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, num_workers=1, replay_concurrency=None))]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<PathBuf>,
num_workers: usize,
replay_concurrency: Option<isize>,
) -> PyResult<PyObject> {
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let mut args = if let Some(ref extra_args_path) = extra_engine_args {
MockEngineArgs::from_json_file(extra_args_path).map_err(|e| {
PyException::new_err(format!(
"Failed to load mocker args from {:?}: {}",
extra_args_path, e
))
})?
} else {
MockEngineArgs::default()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
let report = py.allow_threads(move || {
let replay_concurrency = replay_concurrency
.map(usize::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("replay_concurrency must be at least 1"))?;
if let Some(max_in_flight) = replay_concurrency {
dynamo_mocker::simulation::simulate_concurrency_file(
args,
&trace_file,
max_in_flight,
num_workers,
)
} else {
dynamo_mocker::simulation::simulate_trace_file(args, &trace_file, num_workers)
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
pub fn to_pyerr<E>(err: E) -> PyErr pub fn to_pyerr<E>(err: E) -> PyErr
where where
E: Display, E: Display,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::path::PathBuf;
use std::sync::Arc;
use dynamo_mocker::common::perf_model::PerfModel;
use dynamo_mocker::common::protocols::{
DirectRequest, EngineType as RsMockerEngineType, MockEngineArgs as RsMockEngineArgs,
PreemptionMode as RsPreemptionMode, ReasoningConfig as RsReasoningConfig,
SglangArgs as RsSglangArgs, WorkerType as RsWorkerType,
};
use pyo3::{exceptions::PyException, prelude::*};
use pythonize::pythonize;
use uuid::Uuid;
use super::aic_callback::create_aic_callback;
use super::entrypoint::{KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type {
"vllm" => Ok(RsMockerEngineType::Vllm),
"sglang" => Ok(RsMockerEngineType::Sglang),
other => Err(PyException::new_err(format!(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn parse_worker_type(worker_type: &str) -> PyResult<RsWorkerType> {
match worker_type {
"aggregated" => Ok(RsWorkerType::Aggregated),
"prefill" => Ok(RsWorkerType::Prefill),
"decode" => Ok(RsWorkerType::Decode),
other => Err(PyException::new_err(format!(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn parse_preemption_mode(preemption_mode: &str) -> PyResult<RsPreemptionMode> {
match preemption_mode {
"lifo" => Ok(RsPreemptionMode::Lifo),
"fifo" => Ok(RsPreemptionMode::Fifo),
other => Err(PyException::new_err(format!(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct ReasoningConfig {
inner: RsReasoningConfig,
}
impl ReasoningConfig {
pub fn inner(&self) -> RsReasoningConfig {
self.inner.clone()
}
}
#[pymethods]
impl ReasoningConfig {
#[new]
fn new(
start_thinking_token_id: u32,
end_thinking_token_id: u32,
thinking_ratio: f64,
) -> PyResult<Self> {
let inner = RsReasoningConfig {
start_thinking_token_id,
end_thinking_token_id,
thinking_ratio,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct SglangArgs {
inner: RsSglangArgs,
}
impl SglangArgs {
pub fn inner(&self) -> RsSglangArgs {
self.inner.clone()
}
}
#[pymethods]
impl SglangArgs {
#[new]
#[pyo3(signature = (schedule_policy=None, page_size=None, max_prefill_tokens=None, chunked_prefill_size=None, clip_max_new_tokens=None, schedule_conservativeness=None))]
fn new(
schedule_policy: Option<String>,
page_size: Option<usize>,
max_prefill_tokens: Option<usize>,
chunked_prefill_size: Option<usize>,
clip_max_new_tokens: Option<usize>,
schedule_conservativeness: Option<f64>,
) -> PyResult<Self> {
let inner = RsSglangArgs {
schedule_policy,
page_size,
max_prefill_tokens,
chunked_prefill_size,
clip_max_new_tokens,
schedule_conservativeness,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct MockEngineArgs {
inner: RsMockEngineArgs,
}
impl MockEngineArgs {
pub fn inner(&self) -> RsMockEngineArgs {
self.inner.clone()
}
}
#[pymethods]
impl MockEngineArgs {
#[new]
#[pyo3(signature = (engine_type="vllm", num_gpu_blocks=16384, block_size=0, max_num_seqs=Some(256), max_num_batched_tokens=Some(8192), enable_prefix_caching=true, enable_chunked_prefill=true, speedup_ratio=1.0, decode_speedup_ratio=1.0, dp_size=1, startup_time=None, worker_type="aggregated", aic_backend=None, aic_system=None, aic_backend_version=None, aic_tp_size=None, aic_model_path=None, enable_local_indexer=false, bootstrap_port=None, kv_bytes_per_token=None, kv_transfer_bandwidth=None, reasoning=None, zmq_kv_events_port=None, zmq_replay_port=None, preemption_mode="lifo", router_queue_policy=None, sglang=None))]
#[allow(clippy::too_many_arguments)]
fn new(
engine_type: &str,
num_gpu_blocks: usize,
block_size: usize,
max_num_seqs: Option<usize>,
max_num_batched_tokens: Option<usize>,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
speedup_ratio: f64,
decode_speedup_ratio: f64,
dp_size: u32,
startup_time: Option<f64>,
worker_type: &str,
aic_backend: Option<String>,
aic_system: Option<String>,
aic_backend_version: Option<String>,
aic_tp_size: Option<usize>,
aic_model_path: Option<String>,
enable_local_indexer: bool,
bootstrap_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
kv_transfer_bandwidth: Option<f64>,
reasoning: Option<ReasoningConfig>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
preemption_mode: &str,
router_queue_policy: Option<&str>,
sglang: Option<SglangArgs>,
) -> PyResult<Self> {
let engine_type = parse_mocker_engine_type(engine_type)?;
let worker_type = parse_worker_type(worker_type)?;
let preemption_mode = parse_preemption_mode(preemption_mode)?;
let router_queue_policy = router_queue_policy
.map(|value| {
value.parse().map_err(|e: String| {
PyException::new_err(format!("invalid router_queue_policy {value:?}: {e}"))
})
})
.transpose()?;
let inner = RsMockEngineArgs::builder()
.engine_type(engine_type)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.max_num_seqs(max_num_seqs)
.max_num_batched_tokens(max_num_batched_tokens)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(decode_speedup_ratio)
.dp_size(dp_size)
.startup_time(startup_time)
.worker_type(worker_type)
.aic_backend(aic_backend)
.aic_system(aic_system)
.aic_backend_version(aic_backend_version)
.aic_tp_size(aic_tp_size)
.aic_model_path(aic_model_path)
.enable_local_indexer(enable_local_indexer)
.bootstrap_port(bootstrap_port)
.kv_bytes_per_token(kv_bytes_per_token)
.kv_transfer_bandwidth(kv_transfer_bandwidth)
.reasoning(reasoning.map(|config| config.inner()))
.zmq_kv_events_port(zmq_kv_events_port)
.zmq_replay_port(zmq_replay_port)
.preemption_mode(preemption_mode)
.router_queue_policy(router_queue_policy)
.sglang(sglang.map(|config| config.inner()))
.build()
.map_err(|e| PyException::new_err(format!("Failed to build MockEngineArgs: {e}")))?
.normalized()
.map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs: {e}"))
})?;
Ok(Self { inner })
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
RsMockEngineArgs::from_json_str(config_json)
.map(|inner| Self { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse MockEngineArgs JSON: {e}")))
}
#[getter]
fn block_size(&self) -> usize {
self.inner.block_size
}
#[getter]
fn num_gpu_blocks(&self) -> usize {
self.inner.num_gpu_blocks
}
#[getter]
fn max_num_seqs(&self) -> Option<usize> {
self.inner.max_num_seqs
}
#[getter]
fn max_num_batched_tokens(&self) -> Option<usize> {
self.inner.max_num_batched_tokens
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn dp_size(&self) -> u32 {
self.inner.dp_size
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner.bootstrap_port
}
fn is_prefill(&self) -> bool {
self.inner.is_prefill()
}
fn is_decode(&self) -> bool {
self.inner.is_decode()
}
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None))]
fn with_overrides(
&self,
bootstrap_port: Option<u16>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
) -> PyResult<Self> {
let mut inner = self.inner.clone();
if let Some(port) = bootstrap_port {
inner.bootstrap_port = Some(port);
}
if let Some(port) = zmq_kv_events_port {
inner.zmq_kv_events_port = Some(port);
}
if let Some(port) = zmq_replay_port {
inner.zmq_replay_port = Some(port);
}
if let Some(bytes_per_token) = kv_bytes_per_token {
inner.kv_bytes_per_token = Some(bytes_per_token);
}
inner.normalized().map(|inner| Self { inner }).map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}"))
})
}
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
#[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>,
input_tokens: usize,
output_tokens: usize,
request_count: usize,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
arrival_interval_ms: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
let requests = build_synthetic_requests(
input_tokens,
output_tokens,
request_count,
arrival_interval_ms,
replay_concurrency.is_none(),
)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => {
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
fn load_replay_mocker_args(
py: Python<'_>,
extra_engine_args: Option<MockEngineArgs>,
) -> PyResult<RsMockEngineArgs> {
let mut args = match extra_engine_args {
Some(extra_args) => extra_args.inner(),
None => RsMockEngineArgs::default(),
};
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
Ok(args)
}
fn load_replay_router_config(
router_config: Option<KvRouterConfig>,
) -> Option<dynamo_kv_router::config::KvRouterConfig> {
router_config.map(|config| config.inner())
}
fn parse_replay_router_mode(
router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
match router_mode {
"round_robin" => Ok(dynamo_mocker::replay::ReplayRouterMode::RoundRobin),
"kv_router" => Ok(dynamo_mocker::replay::ReplayRouterMode::KvRouter),
other => Err(PyException::new_err(format!(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'",
other
))),
}
}
fn parse_replay_concurrency(replay_concurrency: Option<isize>) -> anyhow::Result<Option<usize>> {
match replay_concurrency {
Some(value) if value < 1 => anyhow::bail!("replay_concurrency must be at least 1"),
Some(value) => Ok(Some(value as usize)),
None => Ok(None),
}
}
fn build_synthetic_requests(
input_tokens: usize,
output_tokens: usize,
request_count: usize,
arrival_interval_ms: f64,
include_arrival_timestamps: bool,
) -> anyhow::Result<Vec<DirectRequest>> {
if input_tokens == 0 {
anyhow::bail!("input_tokens must be at least 1");
}
if output_tokens == 0 {
anyhow::bail!("output_tokens must be at least 1");
}
if request_count == 0 {
anyhow::bail!("request_count must be at least 1");
}
if !arrival_interval_ms.is_finite() || arrival_interval_ms < 0.0 {
anyhow::bail!(
"arrival_interval_ms must be a finite non-negative number, got {}",
arrival_interval_ms
);
}
let mut requests = Vec::with_capacity(request_count);
for request_idx in 0..request_count {
let tokens = (0..input_tokens)
.map(|token_idx| synthetic_token_id(request_idx, token_idx))
.collect();
requests.push(DirectRequest {
tokens,
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128((request_idx as u128) + 1)),
dp_rank: 0,
arrival_timestamp_ms: include_arrival_timestamps
.then_some(request_idx as f64 * arrival_interval_ms),
});
}
Ok(requests)
}
fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let mut value =
(((request_idx as u64) << 32) ^ (token_idx as u64)).wrapping_add(0x9E37_79B9_7F4A_7C15);
value ^= value >> 30;
value = value.wrapping_mul(0xBF58_476D_1CE4_E5B9);
value ^= value >> 27;
value = value.wrapping_mul(0x94D0_49BB_1331_11EB);
value ^= value >> 31;
let token = value as u32;
if token == 0 { 1 } else { token }
}
...@@ -3,7 +3,17 @@ ...@@ -3,7 +3,17 @@
import asyncio import asyncio
import os import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules # Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
...@@ -1104,9 +1114,10 @@ class KvRouterConfig: ...@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0, router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576, router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = 2.0, router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4, router_event_threads: int = 4,
router_enable_cache_control: bool = False, router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs", router_queue_policy: str = "fcfs",
) -> None: ) -> None:
""" """
...@@ -1132,7 +1143,7 @@ class KvRouterConfig: ...@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0) router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20) router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8) router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 2.0). router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 4.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens. Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints. Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler). Set to None to disable queueing (all requests go directly to the scheduler).
...@@ -1140,12 +1151,111 @@ class KvRouterConfig: ...@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False). cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs"). router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT. "fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT. "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
""" """
... ...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model( async def register_model(
model_input: ModelInput, model_input: ModelInput,
model_type: ModelType, model_type: ModelType,
...@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi ...@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay( def run_mocker_trace_replay(
trace_file: str | os.PathLike[str], trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None, extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1, num_workers: int = 1,
replay_concurrency: Optional[int] = None, replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report.""" """Replay a synthetic mocker workload without requiring a trace file."""
... ...
class Layer: class Layer:
...@@ -1687,6 +1817,7 @@ class EntrypointArgs: ...@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None, tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None, tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None, extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None, runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None, namespace_prefix: Optional[str] = None,
...@@ -1711,7 +1842,8 @@ class EntrypointArgs: ...@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service) http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format) tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format) tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix namespace_prefix: Optional namespace prefix
......
...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig ...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import MockEngineArgs as MockEngineArgs
from dynamo._core import ModelCardInstanceId as ModelCardInstanceId from dynamo._core import ModelCardInstanceId as ModelCardInstanceId
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
...@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType ...@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig
from dynamo._core import RouterConfig as RouterConfig from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode from dynamo._core import RouterMode as RouterMode
from dynamo._core import SglangArgs as SglangArgs
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model from dynamo._core import fetch_model as fetch_model
...@@ -35,7 +38,7 @@ from dynamo._core import make_engine ...@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from dynamo._core import register_model as register_model from dynamo._core import register_model as register_model
from dynamo._core import run_input from dynamo._core import run_input
from dynamo._core import run_kv_indexer as run_kv_indexer from dynamo._core import run_kv_indexer as run_kv_indexer
from dynamo._core import run_mocker_trace_replay from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
from dynamo._core import unregister_model as unregister_model from dynamo._core import unregister_model as unregister_model
from .exceptions import HttpError from .exceptions import HttpError
...@@ -44,3 +47,24 @@ from .exceptions import HttpError ...@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm = fetch_model fetch_llm = fetch_model
register_llm = register_model register_llm = register_model
unregister_llm = unregister_model unregister_llm = unregister_model
def run_mocker_trace_replay(
trace_file,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode="offline",
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.api import run_synthetic_trace_replay, run_trace_replay
__all__ = ["run_synthetic_trace_replay", "run_trace_replay"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.main import main
if __name__ == "__main__":
raise SystemExit(main())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo._core import (
run_mocker_synthetic_trace_replay as _run_mocker_synthetic_trace_replay,
)
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
def run_trace_replay(
trace_file,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
def run_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
arrival_interval_ms=1.0,
):
return _run_mocker_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
arrival_interval_ms=arrival_interval_ms,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import json
import os
import sys
from collections.abc import Sequence
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
parser.add_argument("--extra-engine-args")
parser.add_argument("--router-config")
parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int)
parser.add_argument("--request-count", type=int)
parser.add_argument("--arrival-interval-ms", type=float, default=1.0)
parser.add_argument("--num-workers", type=int, default=1)
parser.add_argument("--replay-concurrency", type=int)
parser.add_argument(
"--replay-mode",
choices=("offline", "online"),
default="offline",
)
parser.add_argument(
"--router-mode",
choices=("round_robin", "kv_router"),
default="round_robin",
)
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
synthetic_args = (args.input_tokens, args.output_tokens, args.request_count)
using_synthetic = any(value is not None for value in synthetic_args)
if using_trace_file == using_synthetic:
parser.error(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if using_synthetic and not all(value is not None for value in synthetic_args):
parser.error(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args = (
MockEngineArgs.from_json(args.extra_engine_args)
if args.extra_engine_args is not None
else None
)
router_config = (
KvRouterConfig.from_json(args.router_config)
if args.router_config is not None
else None
)
if using_trace_file:
report = run_trace_replay(
args.trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
)
else:
report = run_synthetic_trace_replay(
args.input_tokens,
args.output_tokens,
args.request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
arrival_interval_ms=args.arrival_interval_ms,
)
json.dump(report, sys.stdout, indent=2, sort_keys=True)
sys.stdout.write("\n")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import pytest
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.pre_merge,
]
MOONCAKE_TRACE_FIRST20 = """{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def _write_trace_and_args(tmp_path):
trace_path = tmp_path / "trace.jsonl"
records = [
{
"timestamp": 1000.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
{
"timestamp": 1005.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
]
trace_path.write_text(
"\n".join(json.dumps(record) for record in records) + "\n",
encoding="utf-8",
)
return trace_path
def _write_vllm_args(tmp_path):
args_path = tmp_path / "args.json"
args_path.write_text(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
),
encoding="utf-8",
)
return args_path
def _vllm_args():
return MockEngineArgs.from_json(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
)
)
def _write_sglang_args(tmp_path):
args_path = tmp_path / "sglang_args.json"
args_path.write_text(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
),
encoding="utf-8",
)
return args_path
def _sglang_args():
return MockEngineArgs.from_json(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
)
)
def _write_router_config(tmp_path):
config_path = tmp_path / "router_config.json"
config_path.write_text(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
),
encoding="utf-8",
)
return config_path
def _router_config():
return KvRouterConfig.from_json(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
)
)
def _partial_router_config():
return KvRouterConfig(
router_queue_threshold=1.25,
router_event_threads=1,
router_queue_policy="wspt",
)
def _assert_basic_report_counts(report, *, num_requests, input_tokens, output_tokens):
assert report["num_requests"] == num_requests
assert report["completed_requests"] == num_requests
assert report["total_input_tokens"] == num_requests * input_tokens
assert report["total_output_tokens"] == num_requests * output_tokens
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_trace_replay_smoke_matrix(tmp_path, engine_type, replay_mode, router_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_invariant_counts_match(tmp_path, engine_type, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
)
multi_round_robin = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
)
multi_kv_router = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_synthetic_trace_replay_smoke_matrix(
tmp_path, engine_type, replay_mode, router_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_interval_ms=5.0,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_trace_replay_invariant_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
arrival_interval_ms=5.0,
)
multi_round_robin = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
arrival_interval_ms=5.0,
)
multi_kv_router = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
arrival_interval_ms=5.0,
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_concurrency_replay_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
report = run_synthetic_trace_replay(
64,
2,
3,
extra_engine_args=args_path,
num_workers=2,
replay_mode=replay_mode,
replay_concurrency=2,
)
_assert_basic_report_counts(
report,
num_requests=3,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_router_config(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
router_config_path = _router_config()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=router_config_path,
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_router_config_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=_partial_router_config(),
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_extra_engine_args_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
report = run_trace_replay(
trace_path,
extra_engine_args=MockEngineArgs(block_size=64, speedup_ratio=1000.0),
num_workers=1,
replay_mode=replay_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
This diff is collapsed.
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for //! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems. //! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod event_sink;
pub mod indexer; pub mod indexer;
pub mod protocols; pub mod protocols;
pub mod scheduling; pub mod scheduling;
...@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId}; ...@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree; pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed; pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy}; pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer}; pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer; pub use nested_map::PositionalIndexer;
pub use protocols::{ pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId, KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
compute_block_hash_for_seq, WorkerConfigLike, WorkerId, compute_block_hash_for_seq,
}; };
pub use queue::SchedulerQueue; pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree; pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy}; pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse}; pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector}; pub use selector::{DefaultWorkerSelector, WorkerSelector};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use dynamo_tokens::{SequenceHash, Token}; use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -105,6 +107,12 @@ pub trait WorkerConfigLike { ...@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn total_kv_blocks(&self) -> Option<u64>; fn total_kv_blocks(&self) -> Option<u64>;
} }
/// Transport abstraction for publishing batched router-visible KV cache events.
pub trait RouterEventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
/// A worker identifier. /// A worker identifier.
pub type WorkerId = u64; pub type WorkerId = u64;
......
...@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError}; ...@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block}; use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize {
1
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy { pub enum RouterQueuePolicy {
#[default] #[default]
Fcfs, Fcfs,
Lcfs,
Wspt, Wspt,
} }
...@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy { ...@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::Fcfs => f.write_str("fcfs"), Self::Fcfs => f.write_str("fcfs"),
Self::Lcfs => f.write_str("lcfs"),
Self::Wspt => f.write_str("wspt"), Self::Wspt => f.write_str("wspt"),
} }
} }
...@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy { ...@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
match s { match s {
"fcfs" => Ok(Self::Fcfs), "fcfs" => Ok(Self::Fcfs),
"lcfs" => Ok(Self::Lcfs),
"wspt" => Ok(Self::Wspt), "wspt" => Ok(Self::Wspt),
_ => Err(format!( _ => Err(format!(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'" "unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
)), )),
} }
} }
...@@ -58,6 +65,7 @@ pub struct RouterConfigOverride { ...@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters /// KV Router configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(default)]
#[validate(schema(function = "validate_kv_router_config"))] #[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig { pub struct KvRouterConfig {
#[validate(range(min = 0.0))] #[validate(range(min = 0.0))]
...@@ -130,6 +138,13 @@ pub struct KvRouterConfig { ...@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based /// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP). /// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool, pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue. /// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT. /// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT. /// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
...@@ -159,10 +174,11 @@ impl Default for KvRouterConfig { ...@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0, router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default() router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8, router_prune_target_ratio: 0.8,
router_queue_threshold: Some(2.0), router_queue_threshold: Some(4.0),
router_event_threads: 4, router_event_threads: 4,
router_enable_cache_control: false, router_enable_cache_control: false,
skip_initial_worker_wait: false, skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(), router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None, remote_indexer_component: None,
} }
...@@ -237,3 +253,39 @@ impl KvRouterConfig { ...@@ -237,3 +253,39 @@ impl KvRouterConfig {
self.use_kv_events && self.overlap_score_weight > 0.0 self.use_kv_events && self.overlap_score_weight > 0.0
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
{
request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>,
worker_type: &'static str,
}
impl<P, C, S, Sel> LocalScheduler<P, C, S, Sel>
where
P: SequencePublisher + 'static,
C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static,
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Sel,
policy: S,
cancellation_token: CancellationToken,
worker_type: &'static str,
monitor_worker_configs: bool,
) -> Self {
if monitor_worker_configs {
let slots_monitor = Arc::clone(&slots);
let mut monitor_rx = workers_with_configs.clone();
let mut last_workers = monitor_rx.borrow().clone();
let monitor_cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler workers monitoring task started");
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("LocalScheduler worker config watch dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers == last_workers {
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
});
}
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
workers_with_configs,
threshold_frac,
block_size,
selector,
policy,
));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL);
tracing::trace!("LocalScheduler background task started");
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::trace!("LocalScheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("LocalScheduler request channel closed");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = recheck_interval.tick() => {
queue_clone.update().await;
}
}
}
});
Self {
request_tx,
slots,
queue,
worker_type,
}
}
#[expect(clippy::too_many_arguments)]
pub async fn schedule(
&self,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?
}
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
}
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
}
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
}
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
let mut loads = Vec::with_capacity(workers.len());
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::*;
use crate::protocols::OverlapScores;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
#[allow(clippy::type_complexity)]
fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>,
monitor_worker_configs: bool,
) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
watch::Sender<HashMap<WorkerId, SimpleWorkerConfig>>,
CancellationToken,
) {
let dp_range = workers
.iter()
.map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size)))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
64,
dp_range,
false,
0,
"test",
));
let (cfg_tx, cfg_rx) = watch::channel(workers);
let cancel_token = CancellationToken::new();
let scheduler = Arc::new(LocalScheduler::new(
Arc::clone(&slots),
cfg_rx,
threshold_frac,
64,
DefaultWorkerSelector::new(None, "test"),
FcfsPolicy,
cancel_token.clone(),
"test",
monitor_worker_configs,
));
(scheduler, slots, cfg_tx, cancel_token)
}
#[tokio::test]
async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(scheduler.pending_count(), 1);
scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_free_updates_active_state() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
scheduler.free("req-1").await.unwrap();
assert!(scheduler.get_active_lora_counts().is_empty());
cancel_token.cancel();
}
#[tokio::test]
async fn test_get_potential_loads_matches_slots() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
workers.insert(
1,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128),
potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0),
})
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
for (actual, expected) in actual.iter().zip(expected.iter()) {
assert_eq!(actual.worker_id, expected.worker_id);
assert_eq!(actual.dp_rank, expected.dp_rank);
assert_eq!(
actual.potential_prefill_tokens,
expected.potential_prefill_tokens
);
assert_eq!(
actual.potential_decode_blocks,
expected.potential_decode_blocks
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default());
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
assert_eq!(loads[0].dp_rank, 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true);
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len(),
1
);
let mut updated_workers = HashMap::new();
updated_workers.insert(
0,
SimpleWorkerConfig {
data_parallel_size: 2,
..Default::default()
},
);
updated_workers.insert(1, SimpleWorkerConfig::default());
cfg_tx.send(updated_workers).unwrap();
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len()
== 3
{
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
cancel_token.cancel();
}
}
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod config; pub mod config;
mod local;
pub mod policy; pub mod policy;
pub mod queue; pub mod queue;
pub mod selector; pub mod selector;
mod types; mod types;
pub use local::LocalScheduler;
pub use types::*; pub use types::*;
...@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy { ...@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
} }
} }
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub struct LcfsPolicy;
impl SchedulingPolicy for LcfsPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
OrderedFloat(request.priority_jump.max(0.0) + arrival_offset.as_secs_f64())
}
}
/// Weighted Shortest Processing Time (Smith's rule): /// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the /// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL. /// actual prefill cost by subtracting the max KV cache overlap from ISL.
...@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy { ...@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time. /// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy { pub enum RouterSchedulingPolicy {
Fcfs(FcfsPolicy), Fcfs(FcfsPolicy),
Lcfs(LcfsPolicy),
Wspt(WsptPolicy), Wspt(WsptPolicy),
} }
...@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy { ...@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self { pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
match kind { match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy), RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }), RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
} }
} }
...@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy { ...@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key { fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
match self { match self {
Self::Fcfs(p) => p.enqueue_key(arrival_offset, request), Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Lcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Wspt(p) => p.enqueue_key(arrival_offset, request), Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
} }
} }
...@@ -178,6 +196,42 @@ mod tests { ...@@ -178,6 +196,42 @@ mod tests {
assert!(key_b > key_a); assert!(key_b > key_a);
} }
#[test]
fn lcfs_later_arrival_scheduled_first() {
let policy = LcfsPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let early = policy.enqueue_key(Duration::from_secs(1), &req);
let late = policy.enqueue_key(Duration::from_secs(10), &req);
assert!(late > early, "later arrival should have higher key");
}
#[test]
fn lcfs_priority_jump_promotes() {
let policy = LcfsPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 100.0, OverlapScores::default());
let t = Duration::from_secs(10);
let key_normal = policy.enqueue_key(t, &normal);
let key_boosted = policy.enqueue_key(t, &boosted);
assert!(
key_boosted > key_normal,
"priority_jump should produce a higher key"
);
}
#[test]
fn router_scheduling_policy_matches_fcfs_and_lcfs_ordering() {
let req = request_with(512, 0.0, OverlapScores::default());
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
// ---- WSPT policy tests ---- // ---- WSPT policy tests ----
#[test] #[test]
......
...@@ -11,7 +11,7 @@ use tokio::sync::Mutex; ...@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use tokio::sync::watch; use tokio::sync::watch;
use super::policy::{FcfsPolicy, SchedulingPolicy}; use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::selector::WorkerSelector; use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse}; use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank}; use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest}; use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
...@@ -53,6 +53,7 @@ pub struct SchedulerQueue< ...@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P: SequencePublisher, P: SequencePublisher,
C: WorkerConfigLike, C: WorkerConfigLike,
S: SchedulingPolicy = FcfsPolicy, S: SchedulingPolicy = FcfsPolicy,
Sel: WorkerSelector<C> = DefaultWorkerSelector,
> { > {
pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>, pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>,
/// Number of requests currently parked in the pending queue. /// Number of requests currently parked in the pending queue.
...@@ -65,19 +66,23 @@ pub struct SchedulerQueue< ...@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets. /// Reference instant for computing arrival offsets.
start_time: Instant, start_time: Instant,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Sel,
policy: S, policy: S,
} }
impl<P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy> impl<
SchedulerQueue<P, C, S> P: SequencePublisher + 'static,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
> SchedulerQueue<P, C, S, Sel>
{ {
pub fn new( pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>, slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>, workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>, threshold_frac: Option<f64>,
block_size: u32, block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>, selector: Sel,
policy: S, policy: S,
) -> Self { ) -> Self {
if let Some(frac) = threshold_frac { if let Some(frac) = threshold_frac {
...@@ -341,7 +346,7 @@ mod tests { ...@@ -341,7 +346,7 @@ mod tests {
} }
let (cfg_tx, cfg_rx) = watch::channel(configs); let (cfg_tx, cfg_rx) = watch::channel(configs);
let selector = Box::new(DefaultWorkerSelector::new(None, "test")); let selector = DefaultWorkerSelector::new(None, "test");
let queue = Arc::new(SchedulerQueue::new( let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots), Arc::clone(&slots),
cfg_rx, cfg_rx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment