Unverified Commit b7fe46b1 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add multi-worker replay and router startup fixes (#7553)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82794761
......@@ -149,8 +149,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fetch_model, m)?)?;
m.add_function(wrap_pyfunction!(run_kv_indexer, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::replay::run_mocker_trace_replay, m)?)?;
m.add_function(wrap_pyfunction!(
llm::entrypoint::run_mocker_trace_replay,
llm::replay::run_mocker_synthetic_trace_replay,
m
)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
......@@ -165,6 +166,9 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::EngineType>()?;
m.add_class::<llm::entrypoint::RouterConfig>()?;
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
......@@ -31,3 +31,4 @@ pub mod local_model;
pub mod lora;
pub mod model_card;
pub mod preprocessor;
pub mod replay;
......@@ -9,7 +9,6 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use pythonize::pythonize;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
......@@ -25,7 +24,8 @@ use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingE
use dynamo_mocker::common::perf_model::PerfModel;
use super::aic_callback::create_aic_callback;
use dynamo_mocker::common::protocols::MockEngineArgs;
use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
use dynamo_runtime::protocols::EndpointId;
......@@ -58,7 +58,7 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_enable_cache_control=false, min_initial_workers=1, router_queue_policy="fcfs", remote_indexer_component=None))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
......@@ -77,6 +77,7 @@ impl KvRouterConfig {
router_queue_threshold: Option<f64>,
router_event_threads: u32,
router_enable_cache_control: bool,
min_initial_workers: usize,
router_queue_policy: &str,
remote_indexer_component: Option<String>,
) -> Self {
......@@ -99,6 +100,7 @@ impl KvRouterConfig {
router_event_threads,
router_enable_cache_control,
skip_initial_worker_wait: false,
min_initial_workers,
router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
panic!("invalid router_queue_policy: {router_queue_policy:?}")
}),
......@@ -106,6 +108,13 @@ impl KvRouterConfig {
},
}
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
serde_json::from_str::<RsKvRouterConfig>(config_json)
.map(|inner| KvRouterConfig { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
}
}
#[pyclass]
......@@ -196,6 +205,7 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>,
namespace_prefix: Option<String>,
......@@ -208,7 +218,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
......@@ -225,6 +235,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
mocker_engine_args: Option<PyMockEngineArgs>,
runtime_config: Option<ModelRuntimeConfig>,
namespace: Option<String>,
namespace_prefix: Option<String>,
......@@ -272,6 +283,7 @@ impl EntrypointArgs {
tls_cert_path,
tls_key_path,
extra_engine_args,
mocker_engine_args,
runtime_config,
namespace,
namespace_prefix,
......@@ -419,8 +431,10 @@ async fn select_engine(
}
}
EngineType::Mocker => {
let mut mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
let mut mocker_args = if let Some(mocker_engine_args) = args.mocker_engine_args {
mocker_engine_args.inner()
} else if let Some(extra_args_path) = args.extra_engine_args {
RsMockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
anyhow::anyhow!(
"Failed to load mocker args from {:?}: {}",
extra_args_path,
......@@ -431,7 +445,7 @@ async fn select_engine(
tracing::warn!(
"No extra_engine_args specified for mocker engine. Using default mocker args."
);
MockEngineArgs::default()
RsMockEngineArgs::default()
};
// If aic_backend is set, create Python AIC callback and override perf_model
......@@ -503,84 +517,6 @@ pub fn run_input<'p>(
})
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, num_workers=1, replay_concurrency=None))]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<PathBuf>,
num_workers: usize,
replay_concurrency: Option<isize>,
) -> PyResult<PyObject> {
// Load args before allow_threads so we can use the GIL for AIC callback creation.
let mut args = if let Some(ref extra_args_path) = extra_engine_args {
MockEngineArgs::from_json_file(extra_args_path).map_err(|e| {
PyException::new_err(format!(
"Failed to load mocker args from {:?}: {}",
extra_args_path, e
))
})?
} else {
MockEngineArgs::default()
};
// Create AIC callback if requested (requires GIL, must be done before allow_threads).
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
let report = py.allow_threads(move || {
let replay_concurrency = replay_concurrency
.map(usize::try_from)
.transpose()
.map_err(|_| anyhow::anyhow!("replay_concurrency must be at least 1"))?;
if let Some(max_in_flight) = replay_concurrency {
dynamo_mocker::simulation::simulate_concurrency_file(
args,
&trace_file,
max_in_flight,
num_workers,
)
} else {
dynamo_mocker::simulation::simulate_trace_file(args, &trace_file, num_workers)
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
pub fn to_pyerr<E>(err: E) -> PyErr
where
E: Display,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::path::PathBuf;
use std::sync::Arc;
use dynamo_mocker::common::perf_model::PerfModel;
use dynamo_mocker::common::protocols::{
DirectRequest, EngineType as RsMockerEngineType, MockEngineArgs as RsMockEngineArgs,
PreemptionMode as RsPreemptionMode, ReasoningConfig as RsReasoningConfig,
SglangArgs as RsSglangArgs, WorkerType as RsWorkerType,
};
use pyo3::{exceptions::PyException, prelude::*};
use pythonize::pythonize;
use uuid::Uuid;
use super::aic_callback::create_aic_callback;
use super::entrypoint::{KvRouterConfig, to_pyerr};
fn parse_mocker_engine_type(engine_type: &str) -> PyResult<RsMockerEngineType> {
match engine_type {
"vllm" => Ok(RsMockerEngineType::Vllm),
"sglang" => Ok(RsMockerEngineType::Sglang),
other => Err(PyException::new_err(format!(
"engine_type must be either 'vllm' or 'sglang', got '{other}'"
))),
}
}
fn parse_worker_type(worker_type: &str) -> PyResult<RsWorkerType> {
match worker_type {
"aggregated" => Ok(RsWorkerType::Aggregated),
"prefill" => Ok(RsWorkerType::Prefill),
"decode" => Ok(RsWorkerType::Decode),
other => Err(PyException::new_err(format!(
"worker_type must be one of 'aggregated', 'prefill', or 'decode', got '{other}'"
))),
}
}
fn parse_preemption_mode(preemption_mode: &str) -> PyResult<RsPreemptionMode> {
match preemption_mode {
"lifo" => Ok(RsPreemptionMode::Lifo),
"fifo" => Ok(RsPreemptionMode::Fifo),
other => Err(PyException::new_err(format!(
"preemption_mode must be either 'lifo' or 'fifo', got '{other}'"
))),
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub struct ReasoningConfig {
inner: RsReasoningConfig,
}
impl ReasoningConfig {
pub fn inner(&self) -> RsReasoningConfig {
self.inner.clone()
}
}
#[pymethods]
impl ReasoningConfig {
#[new]
fn new(
start_thinking_token_id: u32,
end_thinking_token_id: u32,
thinking_ratio: f64,
) -> PyResult<Self> {
let inner = RsReasoningConfig {
start_thinking_token_id,
end_thinking_token_id,
thinking_ratio,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct SglangArgs {
inner: RsSglangArgs,
}
impl SglangArgs {
pub fn inner(&self) -> RsSglangArgs {
self.inner.clone()
}
}
#[pymethods]
impl SglangArgs {
#[new]
#[pyo3(signature = (schedule_policy=None, page_size=None, max_prefill_tokens=None, chunked_prefill_size=None, clip_max_new_tokens=None, schedule_conservativeness=None))]
fn new(
schedule_policy: Option<String>,
page_size: Option<usize>,
max_prefill_tokens: Option<usize>,
chunked_prefill_size: Option<usize>,
clip_max_new_tokens: Option<usize>,
schedule_conservativeness: Option<f64>,
) -> PyResult<Self> {
let inner = RsSglangArgs {
schedule_policy,
page_size,
max_prefill_tokens,
chunked_prefill_size,
clip_max_new_tokens,
schedule_conservativeness,
};
Ok(Self { inner })
}
}
#[pyclass]
#[derive(Clone, Debug, Default)]
pub struct MockEngineArgs {
inner: RsMockEngineArgs,
}
impl MockEngineArgs {
pub fn inner(&self) -> RsMockEngineArgs {
self.inner.clone()
}
}
#[pymethods]
impl MockEngineArgs {
#[new]
#[pyo3(signature = (engine_type="vllm", num_gpu_blocks=16384, block_size=0, max_num_seqs=Some(256), max_num_batched_tokens=Some(8192), enable_prefix_caching=true, enable_chunked_prefill=true, speedup_ratio=1.0, decode_speedup_ratio=1.0, dp_size=1, startup_time=None, worker_type="aggregated", aic_backend=None, aic_system=None, aic_backend_version=None, aic_tp_size=None, aic_model_path=None, enable_local_indexer=false, bootstrap_port=None, kv_bytes_per_token=None, kv_transfer_bandwidth=None, reasoning=None, zmq_kv_events_port=None, zmq_replay_port=None, preemption_mode="lifo", router_queue_policy=None, sglang=None))]
#[allow(clippy::too_many_arguments)]
fn new(
engine_type: &str,
num_gpu_blocks: usize,
block_size: usize,
max_num_seqs: Option<usize>,
max_num_batched_tokens: Option<usize>,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
speedup_ratio: f64,
decode_speedup_ratio: f64,
dp_size: u32,
startup_time: Option<f64>,
worker_type: &str,
aic_backend: Option<String>,
aic_system: Option<String>,
aic_backend_version: Option<String>,
aic_tp_size: Option<usize>,
aic_model_path: Option<String>,
enable_local_indexer: bool,
bootstrap_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
kv_transfer_bandwidth: Option<f64>,
reasoning: Option<ReasoningConfig>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
preemption_mode: &str,
router_queue_policy: Option<&str>,
sglang: Option<SglangArgs>,
) -> PyResult<Self> {
let engine_type = parse_mocker_engine_type(engine_type)?;
let worker_type = parse_worker_type(worker_type)?;
let preemption_mode = parse_preemption_mode(preemption_mode)?;
let router_queue_policy = router_queue_policy
.map(|value| {
value.parse().map_err(|e: String| {
PyException::new_err(format!("invalid router_queue_policy {value:?}: {e}"))
})
})
.transpose()?;
let inner = RsMockEngineArgs::builder()
.engine_type(engine_type)
.num_gpu_blocks(num_gpu_blocks)
.block_size(block_size)
.max_num_seqs(max_num_seqs)
.max_num_batched_tokens(max_num_batched_tokens)
.enable_prefix_caching(enable_prefix_caching)
.enable_chunked_prefill(enable_chunked_prefill)
.speedup_ratio(speedup_ratio)
.decode_speedup_ratio(decode_speedup_ratio)
.dp_size(dp_size)
.startup_time(startup_time)
.worker_type(worker_type)
.aic_backend(aic_backend)
.aic_system(aic_system)
.aic_backend_version(aic_backend_version)
.aic_tp_size(aic_tp_size)
.aic_model_path(aic_model_path)
.enable_local_indexer(enable_local_indexer)
.bootstrap_port(bootstrap_port)
.kv_bytes_per_token(kv_bytes_per_token)
.kv_transfer_bandwidth(kv_transfer_bandwidth)
.reasoning(reasoning.map(|config| config.inner()))
.zmq_kv_events_port(zmq_kv_events_port)
.zmq_replay_port(zmq_replay_port)
.preemption_mode(preemption_mode)
.router_queue_policy(router_queue_policy)
.sglang(sglang.map(|config| config.inner()))
.build()
.map_err(|e| PyException::new_err(format!("Failed to build MockEngineArgs: {e}")))?
.normalized()
.map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs: {e}"))
})?;
Ok(Self { inner })
}
#[staticmethod]
fn from_json(config_json: &str) -> PyResult<Self> {
RsMockEngineArgs::from_json_str(config_json)
.map(|inner| Self { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse MockEngineArgs JSON: {e}")))
}
#[getter]
fn block_size(&self) -> usize {
self.inner.block_size
}
#[getter]
fn num_gpu_blocks(&self) -> usize {
self.inner.num_gpu_blocks
}
#[getter]
fn max_num_seqs(&self) -> Option<usize> {
self.inner.max_num_seqs
}
#[getter]
fn max_num_batched_tokens(&self) -> Option<usize> {
self.inner.max_num_batched_tokens
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn dp_size(&self) -> u32 {
self.inner.dp_size
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner.bootstrap_port
}
fn is_prefill(&self) -> bool {
self.inner.is_prefill()
}
fn is_decode(&self) -> bool {
self.inner.is_decode()
}
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None))]
fn with_overrides(
&self,
bootstrap_port: Option<u16>,
zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>,
kv_bytes_per_token: Option<usize>,
) -> PyResult<Self> {
let mut inner = self.inner.clone();
if let Some(port) = bootstrap_port {
inner.bootstrap_port = Some(port);
}
if let Some(port) = zmq_kv_events_port {
inner.zmq_kv_events_port = Some(port);
}
if let Some(port) = zmq_replay_port {
inner.zmq_replay_port = Some(port);
}
if let Some(bytes_per_token) = kv_bytes_per_token {
inner.kv_bytes_per_token = Some(bytes_per_token);
}
inner.normalized().map(|inner| Self { inner }).map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}"))
})
}
}
#[pyfunction]
#[pyo3(signature = (trace_file, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_trace_replay(
py: Python<'_>,
trace_file: PathBuf,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => dynamo_mocker::replay::simulate_trace_live_file_with_router_mode(
args,
router_config.clone(),
&trace_file,
num_workers,
arrival_speedup_ratio,
router_mode,
),
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
#[pyfunction]
#[pyo3(signature = (input_tokens, output_tokens, request_count, extra_engine_args=None, router_config=None, num_workers=1, replay_concurrency=None, replay_mode="offline", router_mode="round_robin", arrival_speedup_ratio=1.0, arrival_interval_ms=1.0))]
#[allow(clippy::too_many_arguments)]
pub fn run_mocker_synthetic_trace_replay(
py: Python<'_>,
input_tokens: usize,
output_tokens: usize,
request_count: usize,
extra_engine_args: Option<MockEngineArgs>,
router_config: Option<KvRouterConfig>,
num_workers: usize,
replay_concurrency: Option<isize>,
replay_mode: &str,
router_mode: &str,
arrival_speedup_ratio: f64,
arrival_interval_ms: f64,
) -> PyResult<PyObject> {
let args = load_replay_mocker_args(py, extra_engine_args)?;
let router_config = load_replay_router_config(router_config);
let replay_mode = replay_mode.to_owned();
let router_mode = parse_replay_router_mode(router_mode)?;
let report = py.allow_threads(move || {
let replay_concurrency = parse_replay_concurrency(replay_concurrency)?;
let requests = build_synthetic_requests(
input_tokens,
output_tokens,
request_count,
arrival_interval_ms,
replay_concurrency.is_none(),
)?;
match (replay_mode.as_str(), replay_concurrency) {
("offline", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("offline", None) => dynamo_mocker::replay::simulate_trace_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
),
("online", Some(max_in_flight)) => {
dynamo_mocker::replay::simulate_concurrency_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
max_in_flight,
num_workers,
router_mode,
)
}
("online", None) => {
dynamo_mocker::replay::simulate_trace_live_requests_with_router_mode(
args,
router_config.clone(),
requests,
num_workers,
arrival_speedup_ratio,
router_mode,
)
}
(other, _) => anyhow::bail!(
"replay_mode must be either 'offline' or 'online', got '{}'",
other
),
}
});
let report = report.map_err(to_pyerr)?;
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
fn load_replay_mocker_args(
py: Python<'_>,
extra_engine_args: Option<MockEngineArgs>,
) -> PyResult<RsMockEngineArgs> {
let mut args = match extra_engine_args {
Some(extra_args) => extra_args.inner(),
None => RsMockEngineArgs::default(),
};
if let Some(ref backend_name) = args.aic_backend.clone() {
let backend = backend_name.clone();
let system = args.aic_system.as_deref().unwrap_or("h200_sxm").to_string();
let model_name = args
.aic_model_path
.clone()
.ok_or_else(|| PyException::new_err("--aic-perf-model requires --model-path"))?;
let backend_version = args.aic_backend_version.clone();
let tp_size = args.aic_tp_size.unwrap_or(1);
let callback = create_aic_callback(
py,
&backend,
&system,
&model_name,
tp_size,
backend_version.as_deref(),
)
.map_err(|e| {
PyException::new_err(format!(
"Failed to create AIC callback (--aic-perf-model was requested): {}",
e
))
})?;
tracing::info!(
"AIC perf model: backend={}, gpu={}, model={}, version={:?}",
backend,
system,
model_name,
backend_version
);
args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
}
Ok(args)
}
fn load_replay_router_config(
router_config: Option<KvRouterConfig>,
) -> Option<dynamo_kv_router::config::KvRouterConfig> {
router_config.map(|config| config.inner())
}
fn parse_replay_router_mode(
router_mode: &str,
) -> PyResult<dynamo_mocker::replay::ReplayRouterMode> {
match router_mode {
"round_robin" => Ok(dynamo_mocker::replay::ReplayRouterMode::RoundRobin),
"kv_router" => Ok(dynamo_mocker::replay::ReplayRouterMode::KvRouter),
other => Err(PyException::new_err(format!(
"router_mode must be either 'round_robin' or 'kv_router', got '{}'",
other
))),
}
}
fn parse_replay_concurrency(replay_concurrency: Option<isize>) -> anyhow::Result<Option<usize>> {
match replay_concurrency {
Some(value) if value < 1 => anyhow::bail!("replay_concurrency must be at least 1"),
Some(value) => Ok(Some(value as usize)),
None => Ok(None),
}
}
fn build_synthetic_requests(
input_tokens: usize,
output_tokens: usize,
request_count: usize,
arrival_interval_ms: f64,
include_arrival_timestamps: bool,
) -> anyhow::Result<Vec<DirectRequest>> {
if input_tokens == 0 {
anyhow::bail!("input_tokens must be at least 1");
}
if output_tokens == 0 {
anyhow::bail!("output_tokens must be at least 1");
}
if request_count == 0 {
anyhow::bail!("request_count must be at least 1");
}
if !arrival_interval_ms.is_finite() || arrival_interval_ms < 0.0 {
anyhow::bail!(
"arrival_interval_ms must be a finite non-negative number, got {}",
arrival_interval_ms
);
}
let mut requests = Vec::with_capacity(request_count);
for request_idx in 0..request_count {
let tokens = (0..input_tokens)
.map(|token_idx| synthetic_token_id(request_idx, token_idx))
.collect();
requests.push(DirectRequest {
tokens,
max_output_tokens: output_tokens,
uuid: Some(Uuid::from_u128((request_idx as u128) + 1)),
dp_rank: 0,
arrival_timestamp_ms: include_arrival_timestamps
.then_some(request_idx as f64 * arrival_interval_ms),
});
}
Ok(requests)
}
fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let mut value =
(((request_idx as u64) << 32) ^ (token_idx as u64)).wrapping_add(0x9E37_79B9_7F4A_7C15);
value ^= value >> 30;
value = value.wrapping_mul(0xBF58_476D_1CE4_E5B9);
value ^= value >> 27;
value = value.wrapping_mul(0x94D0_49BB_1331_11EB);
value ^= value >> 31;
let token = value as u32;
if token == 0 { 1 } else { token }
}
......@@ -3,7 +3,17 @@
import asyncio
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
)
# Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
......@@ -1104,9 +1114,10 @@ class KvRouterConfig:
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = 2.0,
router_queue_threshold: Optional[float] = 4.0,
router_event_threads: int = 4,
router_enable_cache_control: bool = False,
min_initial_workers: int = 1,
router_queue_policy: str = "fcfs",
) -> None:
"""
......@@ -1132,7 +1143,7 @@ class KvRouterConfig:
router_ttl_secs: TTL for blocks in seconds when not using KV events (default: 120.0)
router_max_tree_size: Maximum tree size before pruning (default: 1048576, which is 2^20)
router_prune_target_ratio: Target size ratio after pruning (default: 0.8)
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 2.0).
router_queue_threshold: Queue threshold fraction for prefill token capacity (default: 4.0).
Requests are queued if all workers exceed this fraction of max_num_batched_tokens.
Enables priority scheduling via request priority hints.
Set to None to disable queueing (all requests go directly to the scheduler).
......@@ -1140,12 +1151,111 @@ class KvRouterConfig:
When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
min_initial_workers: Minimum number of discovered workers required before
router startup continues (default: 1). Ignored when
skip_initial_worker_wait is enabled.
router_queue_policy: Scheduling policy for the router queue (default: "fcfs").
"fcfs": first-come first-served with priority bumps — optimizes tail TTFT.
"lcfs": last-come first-served with priority bumps — intentionally worsens tail behavior for policy comparisons.
"wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
"""
...
@staticmethod
def from_json(config_json: str) -> "KvRouterConfig":
...
class ReasoningConfig:
def __init__(
self,
start_thinking_token_id: int,
end_thinking_token_id: int,
thinking_ratio: float,
) -> None:
...
class SglangArgs:
def __init__(
self,
schedule_policy: Optional[str] = None,
page_size: Optional[int] = None,
max_prefill_tokens: Optional[int] = None,
chunked_prefill_size: Optional[int] = None,
clip_max_new_tokens: Optional[int] = None,
schedule_conservativeness: Optional[float] = None,
) -> None:
...
class MockEngineArgs:
def __init__(
self,
engine_type: str = "vllm",
num_gpu_blocks: int = 16384,
block_size: int = 0,
max_num_seqs: Optional[int] = 256,
max_num_batched_tokens: Optional[int] = 8192,
enable_prefix_caching: bool = True,
enable_chunked_prefill: bool = True,
speedup_ratio: float = 1.0,
decode_speedup_ratio: float = 1.0,
dp_size: int = 1,
startup_time: Optional[float] = None,
worker_type: str = "aggregated",
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_local_indexer: bool = False,
bootstrap_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
kv_transfer_bandwidth: Optional[float] = None,
reasoning: Optional[ReasoningConfig] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
preemption_mode: str = "lifo",
router_queue_policy: Optional[str] = None,
sglang: Optional[SglangArgs] = None,
) -> None:
...
@staticmethod
def from_json(config_json: str) -> "MockEngineArgs":
...
@property
def block_size(self) -> int: ...
@property
def num_gpu_blocks(self) -> int: ...
@property
def max_num_seqs(self) -> Optional[int]: ...
@property
def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_local_indexer(self) -> bool: ...
@property
def dp_size(self) -> int: ...
@property
def bootstrap_port(self) -> Optional[int]: ...
def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ...
def with_overrides(
self,
bootstrap_port: Optional[int] = None,
zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None,
) -> "MockEngineArgs": ...
async def register_model(
model_input: ModelInput,
model_type: ModelType,
......@@ -1249,11 +1359,31 @@ async def run_input(runtime: DistributedRuntime, input: str, engine_config: Engi
def run_mocker_trace_replay(
trace_file: str | os.PathLike[str],
extra_engine_args: Optional[str | os.PathLike[str]] = None,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report for aggregated vLLM or SGLang configs."""
...
def run_mocker_synthetic_trace_replay(
input_tokens: int,
output_tokens: int,
request_count: int,
extra_engine_args: Optional[MockEngineArgs] = None,
router_config: Optional[KvRouterConfig] = None,
num_workers: int = 1,
replay_concurrency: Optional[int] = None,
replay_mode: Literal["offline", "online"] = "offline",
router_mode: Literal["round_robin", "kv_router"] = "round_robin",
arrival_speedup_ratio: float = 1.0,
arrival_interval_ms: float = 1.0,
) -> Dict[str, Any]:
"""Replay a mocker trace file and return the simulation report."""
"""Replay a synthetic mocker workload without requiring a trace file."""
...
class Layer:
......@@ -1687,6 +1817,7 @@ class EntrypointArgs:
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
mocker_engine_args: Optional[MockEngineArgs] = None,
runtime_config: Optional[ModelRuntimeConfig] = None,
namespace: Optional[str] = None,
namespace_prefix: Optional[str] = None,
......@@ -1711,7 +1842,8 @@ class EntrypointArgs:
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
extra_engine_args: Optional path to mocker engine arguments JSON
mocker_engine_args: Typed mocker engine arguments
runtime_config: Optional runtime configuration for discovery registration
namespace: Dynamo namespace for model discovery scoping
namespace_prefix: Optional namespace prefix
......
......@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import MockEngineArgs as MockEngineArgs
from dynamo._core import ModelCardInstanceId as ModelCardInstanceId
from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
......@@ -25,8 +26,10 @@ from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig
from dynamo._core import RouterConfig as RouterConfig
from dynamo._core import RouterMode as RouterMode
from dynamo._core import SglangArgs as SglangArgs
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import compute_block_hash_for_seq as compute_block_hash_for_seq
from dynamo._core import fetch_model as fetch_model
......@@ -35,7 +38,7 @@ from dynamo._core import make_engine
from dynamo._core import register_model as register_model
from dynamo._core import run_input
from dynamo._core import run_kv_indexer as run_kv_indexer
from dynamo._core import run_mocker_trace_replay
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
from dynamo._core import unregister_model as unregister_model
from .exceptions import HttpError
......@@ -44,3 +47,24 @@ from .exceptions import HttpError
fetch_llm = fetch_model
register_llm = register_model
unregister_llm = unregister_model
def run_mocker_trace_replay(
trace_file,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode="offline",
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.api import run_synthetic_trace_replay, run_trace_replay
__all__ = ["run_synthetic_trace_replay", "run_trace_replay"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.replay.main import main
if __name__ == "__main__":
raise SystemExit(main())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo._core import (
run_mocker_synthetic_trace_replay as _run_mocker_synthetic_trace_replay,
)
from dynamo._core import run_mocker_trace_replay as _run_mocker_trace_replay
def run_trace_replay(
trace_file,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
):
return _run_mocker_trace_replay(
trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
)
def run_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
*,
extra_engine_args=None,
router_config=None,
num_workers=1,
replay_concurrency=None,
replay_mode="offline",
router_mode="round_robin",
arrival_speedup_ratio=1.0,
arrival_interval_ms=1.0,
):
return _run_mocker_synthetic_trace_replay(
input_tokens,
output_tokens,
request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=num_workers,
replay_concurrency=replay_concurrency,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_speedup_ratio=arrival_speedup_ratio,
arrival_interval_ms=arrival_interval_ms,
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import json
import os
import sys
from collections.abc import Sequence
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
parser.add_argument("--extra-engine-args")
parser.add_argument("--router-config")
parser.add_argument("--input-tokens", type=int)
parser.add_argument("--output-tokens", type=int)
parser.add_argument("--request-count", type=int)
parser.add_argument("--arrival-interval-ms", type=float, default=1.0)
parser.add_argument("--num-workers", type=int, default=1)
parser.add_argument("--replay-concurrency", type=int)
parser.add_argument(
"--replay-mode",
choices=("offline", "online"),
default="offline",
)
parser.add_argument(
"--router-mode",
choices=("round_robin", "kv_router"),
default="round_robin",
)
parser.add_argument("--arrival-speedup-ratio", type=float, default=1.0)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
synthetic_args = (args.input_tokens, args.output_tokens, args.request_count)
using_synthetic = any(value is not None for value in synthetic_args)
if using_trace_file == using_synthetic:
parser.error(
"provide either trace_file or all of --input-tokens/--output-tokens/--request-count"
)
if using_synthetic and not all(value is not None for value in synthetic_args):
parser.error(
"synthetic replay requires --input-tokens, --output-tokens, and --request-count"
)
extra_engine_args = (
MockEngineArgs.from_json(args.extra_engine_args)
if args.extra_engine_args is not None
else None
)
router_config = (
KvRouterConfig.from_json(args.router_config)
if args.router_config is not None
else None
)
if using_trace_file:
report = run_trace_replay(
args.trace_file,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
)
else:
report = run_synthetic_trace_replay(
args.input_tokens,
args.output_tokens,
args.request_count,
extra_engine_args=extra_engine_args,
router_config=router_config,
num_workers=args.num_workers,
replay_concurrency=args.replay_concurrency,
replay_mode=args.replay_mode,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
arrival_interval_ms=args.arrival_interval_ms,
)
json.dump(report, sys.stdout, indent=2, sort_keys=True)
sys.stdout.write("\n")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import pytest
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.parallel,
pytest.mark.pre_merge,
]
MOONCAKE_TRACE_FIRST20 = """{"timestamp": 0, "input_length": 6755, "output_length": 500, "hash_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]}
{"timestamp": 0, "input_length": 7319, "output_length": 490, "hash_ids": [0, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]}
{"timestamp": 0, "input_length": 7234, "output_length": 794, "hash_ids": [0, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]}
{"timestamp": 0, "input_length": 2287, "output_length": 316, "hash_ids": [0, 42, 43, 44, 45]}
{"timestamp": 0, "input_length": 9013, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]}
{"timestamp": 0, "input_length": 6506, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 64]}
{"timestamp": 0, "input_length": 4824, "output_length": 173, "hash_ids": [0, 65, 66, 67, 68, 69, 70, 71, 72, 73]}
{"timestamp": 0, "input_length": 3119, "output_length": 20, "hash_ids": [74, 75, 76, 77, 78, 79, 80]}
{"timestamp": 0, "input_length": 23090, "output_length": 453, "hash_ids": [0, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]}
{"timestamp": 0, "input_length": 3135, "output_length": 19, "hash_ids": [74, 75, 76, 77, 78, 126, 127]}
{"timestamp": 0, "input_length": 26874, "output_length": 458, "hash_ids": [0, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179]}
{"timestamp": 0, "input_length": 10487, "output_length": 402, "hash_ids": [0, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]}
{"timestamp": 0, "input_length": 17448, "output_length": 610, "hash_ids": [0, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233]}
{"timestamp": 0, "input_length": 6253, "output_length": 3, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 234]}
{"timestamp": 0, "input_length": 6725, "output_length": 32, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 235, 236]}
{"timestamp": 3052, "input_length": 13538, "output_length": 71, "hash_ids": [0, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262]}
{"timestamp": 3052, "input_length": 87162, "output_length": 402, "hash_ids": [0, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432]}
{"timestamp": 3052, "input_length": 6166, "output_length": 24, "hash_ids": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 433]}
{"timestamp": 3052, "input_length": 6320, "output_length": 548, "hash_ids": [0, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445]}
{"timestamp": 3052, "input_length": 2007, "output_length": 354, "hash_ids": [0, 446, 447, 448]}
"""
def _write_trace_and_args(tmp_path):
trace_path = tmp_path / "trace.jsonl"
records = [
{
"timestamp": 1000.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
{
"timestamp": 1005.0,
"input_length": 64,
"output_length": 2,
"hash_ids": [101],
},
]
trace_path.write_text(
"\n".join(json.dumps(record) for record in records) + "\n",
encoding="utf-8",
)
return trace_path
def _write_vllm_args(tmp_path):
args_path = tmp_path / "args.json"
args_path.write_text(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
),
encoding="utf-8",
)
return args_path
def _vllm_args():
return MockEngineArgs.from_json(
json.dumps(
{
"block_size": 64,
"speedup_ratio": 1000.0,
}
)
)
def _write_sglang_args(tmp_path):
args_path = tmp_path / "sglang_args.json"
args_path.write_text(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
),
encoding="utf-8",
)
return args_path
def _sglang_args():
return MockEngineArgs.from_json(
json.dumps(
{
"engine_type": "sglang",
"num_gpu_blocks": 512,
"block_size": 64,
"speedup_ratio": 1000.0,
"sglang": {
"page_size": 64,
},
}
)
)
def _write_router_config(tmp_path):
config_path = tmp_path / "router_config.json"
config_path.write_text(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
),
encoding="utf-8",
)
return config_path
def _router_config():
return KvRouterConfig.from_json(
json.dumps(
{
"router_queue_threshold": 1.25,
"router_event_threads": 1,
"router_queue_policy": "wspt",
"router_temperature": 0.0,
"overlap_score_weight": 1.0,
"use_kv_events": True,
"durable_kv_events": False,
"router_replica_sync": False,
"router_track_active_blocks": True,
"router_track_output_blocks": False,
"router_assume_kv_reuse": True,
"router_snapshot_threshold": 1000000,
"router_reset_states": False,
"router_ttl_secs": 120.0,
"router_max_tree_size": 1048576,
"router_prune_target_ratio": 0.8,
"router_enable_cache_control": False,
"skip_initial_worker_wait": False,
"min_initial_workers": 1,
"remote_indexer_component": None,
}
)
)
def _partial_router_config():
return KvRouterConfig(
router_queue_threshold=1.25,
router_event_threads=1,
router_queue_policy="wspt",
)
def _assert_basic_report_counts(report, *, num_requests, input_tokens, output_tokens):
assert report["num_requests"] == num_requests
assert report["completed_requests"] == num_requests
assert report["total_input_tokens"] == num_requests * input_tokens
assert report["total_output_tokens"] == num_requests * output_tokens
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_trace_replay_smoke_matrix(tmp_path, engine_type, replay_mode, router_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_invariant_counts_match(tmp_path, engine_type, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
)
multi_round_robin = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
)
multi_kv_router = run_trace_replay(
trace_path,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
@pytest.mark.parametrize("router_mode", ["round_robin", "kv_router"])
def test_run_synthetic_trace_replay_smoke_matrix(
tmp_path, engine_type, replay_mode, router_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
num_workers = 1 if router_mode == "round_robin" else 2
report = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=num_workers,
replay_mode=replay_mode,
router_mode=router_mode,
arrival_interval_ms=5.0,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_trace_replay_invariant_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
single = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=1,
replay_mode=replay_mode,
arrival_interval_ms=5.0,
)
multi_round_robin = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="round_robin",
arrival_interval_ms=5.0,
)
multi_kv_router = run_synthetic_trace_replay(
64,
2,
2,
extra_engine_args=args_path,
num_workers=4,
replay_mode=replay_mode,
router_mode="kv_router",
arrival_interval_ms=5.0,
)
for field in (
"num_requests",
"completed_requests",
"total_input_tokens",
"total_output_tokens",
):
assert single[field] == multi_round_robin[field]
assert single[field] == multi_kv_router[field]
@pytest.mark.parametrize("engine_type", ["vllm", "sglang"])
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_synthetic_concurrency_replay_counts_match(
tmp_path, engine_type, replay_mode
):
args_path = _vllm_args() if engine_type == "vllm" else _sglang_args()
report = run_synthetic_trace_replay(
64,
2,
3,
extra_engine_args=args_path,
num_workers=2,
replay_mode=replay_mode,
replay_concurrency=2,
)
_assert_basic_report_counts(
report,
num_requests=3,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_router_config(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
router_config_path = _router_config()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=router_config_path,
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_router_config_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
args_path = _vllm_args()
report = run_trace_replay(
trace_path,
extra_engine_args=args_path,
router_config=_partial_router_config(),
num_workers=2,
replay_mode=replay_mode,
router_mode="kv_router",
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
@pytest.mark.parametrize("replay_mode", ["offline", "online"])
def test_run_trace_replay_accepts_partial_extra_engine_args_json(tmp_path, replay_mode):
trace_path = _write_trace_and_args(tmp_path)
report = run_trace_replay(
trace_path,
extra_engine_args=MockEngineArgs(block_size=64, speedup_ratio=1000.0),
num_workers=1,
replay_mode=replay_mode,
)
_assert_basic_report_counts(
report,
num_requests=2,
input_tokens=64,
output_tokens=2,
)
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport abstraction for publishing batched KV cache events.
//!
//! Implementations handle the actual delivery mechanism (NATS event plane,
//! JetStream durable queue, direct indexer application, etc.). The trait lives
//! in this crate so that the batching processor and other routing logic can be
//! written generically; runtime-specific impls stay in `lib/llm`.
use std::future::Future;
use crate::protocols::RouterEvent;
/// Transport abstraction for publishing batched KV cache events.
pub trait EventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
This diff is collapsed.
......@@ -6,7 +6,6 @@
//! This crate provides the core radix tree implementation and protocols for
//! efficient KV cache lookup and routing in distributed LLM inference systems.
pub mod event_sink;
pub mod indexer;
pub mod protocols;
pub mod scheduling;
......@@ -41,15 +40,15 @@ pub use self::sequence::{ActiveSequences, RequestId};
pub use concurrent_radix_tree::ConcurrentRadixTree;
pub use concurrent_radix_tree_compressed::ConcurrentRadixTreeCompressed;
pub use config::{KvRouterConfig, RouterConfigOverride, RouterQueuePolicy};
pub use event_sink::EventSink;
pub use indexer::{MaybeError, SyncIndexer, ThreadPoolIndexer};
pub use nested_map::PositionalIndexer;
pub use protocols::{
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, WorkerConfigLike, WorkerId,
compute_block_hash_for_seq,
KvCacheEventError, LocalBlockHash, OverlapScores, RouterEvent, RouterEventSink,
WorkerConfigLike, WorkerId, compute_block_hash_for_seq,
};
pub use queue::SchedulerQueue;
pub use radix_tree::RadixTree;
pub use scheduling::LocalScheduler;
pub use scheduling::policy::{FcfsPolicy, RouterSchedulingPolicy, SchedulingPolicy, WsptPolicy};
pub use scheduling::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
pub use selector::{DefaultWorkerSelector, WorkerSelector};
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use dynamo_tokens::{SequenceHash, Token};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
......@@ -105,6 +107,12 @@ pub trait WorkerConfigLike {
fn total_kv_blocks(&self) -> Option<u64>;
}
/// Transport abstraction for publishing batched router-visible KV cache events.
pub trait RouterEventSink: Send + Sync {
fn publish_event(&self, event: &RouterEvent)
-> impl Future<Output = anyhow::Result<()>> + Send;
}
/// A worker identifier.
pub type WorkerId = u64;
......
......@@ -11,11 +11,16 @@ use validator::{Validate, ValidationError};
use crate::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
const fn default_min_initial_workers() -> usize {
1
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
#[default]
Fcfs,
Lcfs,
Wspt,
}
......@@ -23,6 +28,7 @@ impl fmt::Display for RouterQueuePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fcfs => f.write_str("fcfs"),
Self::Lcfs => f.write_str("lcfs"),
Self::Wspt => f.write_str("wspt"),
}
}
......@@ -34,9 +40,10 @@ impl FromStr for RouterQueuePolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"fcfs" => Ok(Self::Fcfs),
"lcfs" => Ok(Self::Lcfs),
"wspt" => Ok(Self::Wspt),
_ => Err(format!(
"unknown queue policy: {s:?}, expected 'fcfs' or 'wspt'"
"unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
)),
}
}
......@@ -58,6 +65,7 @@ pub struct RouterConfigOverride {
/// KV Router configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(default)]
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
......@@ -130,6 +138,13 @@ pub struct KvRouterConfig {
/// When true, the router starts immediately without waiting for discovery-based
/// workers and workers are provided externally per-request (e.g., EPP).
pub skip_initial_worker_wait: bool,
/// Minimum number of workers that must be discovered before router startup continues.
/// Default: 1. Ignored when skip_initial_worker_wait=true.
#[serde(default = "default_min_initial_workers")]
#[validate(range(min = 1))]
pub min_initial_workers: usize,
/// Scheduling policy for the router queue.
/// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
/// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
......@@ -159,10 +174,11 @@ impl Default for KvRouterConfig {
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
router_queue_threshold: Some(2.0),
router_queue_threshold: Some(4.0),
router_event_threads: 4,
router_enable_cache_control: false,
skip_initial_worker_wait: false,
min_initial_workers: default_min_initial_workers(),
router_queue_policy: RouterQueuePolicy::default(),
remote_indexer_component: None,
}
......@@ -237,3 +253,39 @@ impl KvRouterConfig {
self.use_kv_events && self.overlap_score_weight > 0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn router_queue_policy_display_and_parse_support_lcfs() {
assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
assert_eq!(
"lcfs".parse::<RouterQueuePolicy>().unwrap(),
RouterQueuePolicy::Lcfs
);
}
#[test]
fn router_queue_policy_serde_round_trip_supports_lcfs() {
let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
assert_eq!(serialized, "\"lcfs\"");
let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
}
#[test]
fn kv_router_config_defaults_to_one_initial_worker() {
assert_eq!(KvRouterConfig::default().min_initial_workers, 1);
}
#[test]
fn kv_router_config_rejects_zero_initial_workers() {
let cfg = KvRouterConfig {
min_initial_workers: 0,
..KvRouterConfig::default()
};
assert!(cfg.validate().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use super::policy::{RouterSchedulingPolicy, SchedulingPolicy};
use super::queue::SchedulerQueue;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse};
use crate::protocols::{OverlapScores, WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{
ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
};
use dynamo_tokens::SequenceHash;
const RECHECK_INTERVAL: Duration = Duration::from_secs(60);
pub struct LocalScheduler<P, C, S = RouterSchedulingPolicy, Sel = DefaultWorkerSelector>
where
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
{
request_tx: mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker<P>>,
queue: Arc<SchedulerQueue<P, C, S, Sel>>,
worker_type: &'static str,
}
impl<P, C, S, Sel> LocalScheduler<P, C, S, Sel>
where
P: SequencePublisher + 'static,
C: WorkerConfigLike + Clone + PartialEq + Send + Sync + 'static,
S: SchedulingPolicy + 'static,
Sel: WorkerSelector<C> + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Sel,
policy: S,
cancellation_token: CancellationToken,
worker_type: &'static str,
monitor_worker_configs: bool,
) -> Self {
if monitor_worker_configs {
let slots_monitor = Arc::clone(&slots);
let mut monitor_rx = workers_with_configs.clone();
let mut last_workers = monitor_rx.borrow().clone();
let monitor_cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("LocalScheduler workers monitoring task started");
loop {
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("LocalScheduler workers monitoring task shutting down");
break;
}
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("LocalScheduler worker config watch dropped, shutting down");
break;
}
}
}
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers == last_workers {
continue;
}
let dp_range: HashMap<WorkerId, (u32, u32)> = current_workers
.iter()
.map(|(&id, cfg)| {
(
id,
(cfg.data_parallel_start_rank(), cfg.data_parallel_size()),
)
})
.collect();
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
});
}
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
workers_with_configs,
threshold_frac,
block_size,
selector,
policy,
));
let (request_tx, request_rx) = mpsc::channel::<SchedulingRequest>(1024);
let queue_clone = Arc::clone(&queue);
tokio::spawn(async move {
let mut request_rx = request_rx;
let mut recheck_interval = tokio::time::interval(RECHECK_INTERVAL);
tracing::trace!("LocalScheduler background task started");
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::trace!("LocalScheduler background task shutting down");
break;
}
request = request_rx.recv() => {
let Some(request) = request else {
tracing::warn!("LocalScheduler request channel closed");
break;
};
tracing::trace!("received request to be scheduled");
queue_clone.enqueue(request).await;
}
_ = recheck_interval.tick() => {
queue_clone.update().await;
}
}
}
});
Self {
request_tx,
slots,
queue,
worker_type,
}
}
#[expect(clippy::too_many_arguments)]
pub async fn schedule(
&self,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
router_config_override: Option<&super::config::RouterConfigOverride>,
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
expected_output_tokens: Option<u32>,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<SchedulingResponse, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(),
router_config_override: router_config_override.cloned(),
update_states,
lora_name,
priority_jump,
expected_output_tokens,
allowed_worker_ids,
resp_tx: Some(resp_tx),
};
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?
}
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
self.queue.register_workers(worker_ids);
}
pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
self.slots.add_request(req).await
}
pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots
.mark_prefill_completed(&request_id.to_string())
.await?;
self.queue.update().await;
Ok(())
}
pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
self.slots.free(&request_id.to_string()).await?;
self.queue.update().await;
Ok(())
}
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
}
pub fn worker_type(&self) -> &'static str {
self.worker_type
}
pub fn add_output_block(
&self,
request_id: &str,
decay_fraction: Option<f64>,
) -> Result<(), SequenceError> {
self.slots
.add_output_block(&request_id.to_string(), decay_fraction)
}
pub fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize,
overlaps: OverlapScores,
) -> Vec<PotentialLoad> {
let (decode_blocks, prefill_tokens) =
self.slots
.potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
workers.extend(decode_blocks.keys().copied());
workers.extend(prefill_tokens.keys().copied());
let mut loads = Vec::with_capacity(workers.len());
for worker in workers {
loads.push(PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens
.get(&worker)
.copied()
.unwrap_or(isl_tokens),
potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
});
}
loads
}
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
self.slots.get_active_lora_counts()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::*;
use crate::protocols::OverlapScores;
use crate::scheduling::policy::FcfsPolicy;
use crate::scheduling::selector::DefaultWorkerSelector;
use crate::test_utils::{NoopSequencePublisher, SimpleWorkerConfig};
#[allow(clippy::type_complexity)]
fn make_scheduler(
workers: HashMap<WorkerId, SimpleWorkerConfig>,
threshold_frac: Option<f64>,
monitor_worker_configs: bool,
) -> (
Arc<LocalScheduler<NoopSequencePublisher, SimpleWorkerConfig, FcfsPolicy>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
watch::Sender<HashMap<WorkerId, SimpleWorkerConfig>>,
CancellationToken,
) {
let dp_range = workers
.iter()
.map(|(&id, cfg)| (id, (cfg.data_parallel_start_rank, cfg.data_parallel_size)))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
64,
dp_range,
false,
0,
"test",
));
let (cfg_tx, cfg_rx) = watch::channel(workers);
let cancel_token = CancellationToken::new();
let scheduler = Arc::new(LocalScheduler::new(
Arc::clone(&slots),
cfg_rx,
threshold_frac,
64,
DefaultWorkerSelector::new(None, "test"),
FcfsPolicy,
cancel_token.clone(),
"test",
monitor_worker_configs,
));
(scheduler, slots, cfg_tx, cancel_token)
}
#[tokio::test]
async fn test_schedule_books_request_into_active_sequences() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let response = scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(response.best_worker.worker_id, 0);
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
cancel_token.cancel();
}
#[tokio::test]
async fn test_mark_prefill_completed_drains_pending_queue() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, Some(0.5), true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
.unwrap();
let queued = {
let scheduler = Arc::clone(&scheduler);
tokio::spawn(async move {
scheduler
.schedule(
Some("req-2".to_string()),
64,
Some(vec![5, 6, 7, 8]),
OverlapScores::default(),
None,
true,
None,
0.0,
None,
None,
)
.await
})
};
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(scheduler.pending_count(), 1);
scheduler.mark_prefill_completed("req-1").await.unwrap();
queued.await.unwrap().unwrap();
assert_eq!(scheduler.pending_count(), 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_free_updates_active_state() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(64),
..Default::default()
},
);
let (scheduler, _slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
scheduler
.schedule(
Some("req-1".to_string()),
64,
Some(vec![1, 2, 3, 4]),
OverlapScores::default(),
None,
true,
Some("adapter-a".to_string()),
0.0,
None,
None,
)
.await
.unwrap();
assert_eq!(
scheduler.get_active_lora_counts(),
HashMap::from([(String::from("adapter-a"), 1)])
);
scheduler.free("req-1").await.unwrap();
assert!(scheduler.get_active_lora_counts().is_empty());
cancel_token.cancel();
}
#[tokio::test]
async fn test_get_potential_loads_matches_slots() {
let mut workers = HashMap::new();
workers.insert(
0,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
workers.insert(
1,
SimpleWorkerConfig {
max_num_batched_tokens: Some(256),
..Default::default()
},
);
let (scheduler, slots, _cfg_tx, cancel_token) = make_scheduler(workers, None, true);
let token_seq = vec![11, 22, 33, 44];
let overlaps = OverlapScores::default();
let (decode_blocks, prefill_tokens) =
slots.potential_blocks_and_tokens(Some(&token_seq), 128, overlaps.clone());
let mut expected: Vec<_> = decode_blocks
.keys()
.map(|worker| PotentialLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
potential_prefill_tokens: prefill_tokens.get(worker).copied().unwrap_or(128),
potential_decode_blocks: decode_blocks.get(worker).copied().unwrap_or(0),
})
.collect();
expected.sort_by_key(|load| (load.worker_id, load.dp_rank));
let mut actual = scheduler.get_potential_loads(Some(token_seq), 128, overlaps);
actual.sort_by_key(|load| (load.worker_id, load.dp_rank));
assert_eq!(actual.len(), expected.len());
for (actual, expected) in actual.iter().zip(expected.iter()) {
assert_eq!(actual.worker_id, expected.worker_id);
assert_eq!(actual.dp_rank, expected.dp_rank);
assert_eq!(
actual.potential_prefill_tokens,
expected.potential_prefill_tokens
);
assert_eq!(
actual.potential_decode_blocks,
expected.potential_decode_blocks
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_register_workers_uses_default_dp_fallback() {
let (scheduler, _slots, _cfg_tx, cancel_token) =
make_scheduler(HashMap::new(), None, false);
scheduler.register_workers(&HashSet::from([42]));
let loads = scheduler.get_potential_loads(None, 64, OverlapScores::default());
assert_eq!(loads.len(), 1);
assert_eq!(loads[0].worker_id, 42);
assert_eq!(loads[0].dp_rank, 0);
cancel_token.cancel();
}
#[tokio::test]
async fn test_worker_watch_updates_slot_ranges() {
let mut workers = HashMap::new();
workers.insert(0, SimpleWorkerConfig::default());
let (scheduler, _slots, cfg_tx, cancel_token) = make_scheduler(workers, None, true);
assert_eq!(
scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len(),
1
);
let mut updated_workers = HashMap::new();
updated_workers.insert(
0,
SimpleWorkerConfig {
data_parallel_size: 2,
..Default::default()
},
);
updated_workers.insert(1, SimpleWorkerConfig::default());
cfg_tx.send(updated_workers).unwrap();
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if scheduler
.get_potential_loads(None, 64, OverlapScores::default())
.len()
== 3
{
break;
}
tokio::task::yield_now().await;
}
})
.await
.unwrap();
cancel_token.cancel();
}
}
......@@ -2,9 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
pub mod config;
mod local;
pub mod policy;
pub mod queue;
pub mod selector;
mod types;
pub use local::LocalScheduler;
pub use types::*;
......@@ -43,6 +43,21 @@ impl SchedulingPolicy for FcfsPolicy {
}
}
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub struct LcfsPolicy;
impl SchedulingPolicy for LcfsPolicy {
type Key = OrderedFloat<f64>;
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
OrderedFloat(request.priority_jump.max(0.0) + arrival_offset.as_secs_f64())
}
}
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
/// actual prefill cost by subtracting the max KV cache overlap from ISL.
......@@ -73,6 +88,7 @@ impl SchedulingPolicy for WsptPolicy {
/// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy {
Fcfs(FcfsPolicy),
Lcfs(LcfsPolicy),
Wspt(WsptPolicy),
}
......@@ -80,6 +96,7 @@ impl RouterSchedulingPolicy {
pub fn new(kind: RouterQueuePolicy, block_size: usize) -> Self {
match kind {
RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy { block_size }),
}
}
......@@ -91,6 +108,7 @@ impl SchedulingPolicy for RouterSchedulingPolicy {
fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
match self {
Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Lcfs(p) => p.enqueue_key(arrival_offset, request),
Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
}
}
......@@ -178,6 +196,42 @@ mod tests {
assert!(key_b > key_a);
}
#[test]
fn lcfs_later_arrival_scheduled_first() {
let policy = LcfsPolicy;
let req = request_with(512, 0.0, OverlapScores::default());
let early = policy.enqueue_key(Duration::from_secs(1), &req);
let late = policy.enqueue_key(Duration::from_secs(10), &req);
assert!(late > early, "later arrival should have higher key");
}
#[test]
fn lcfs_priority_jump_promotes() {
let policy = LcfsPolicy;
let normal = request_with(512, 0.0, OverlapScores::default());
let boosted = request_with(512, 100.0, OverlapScores::default());
let t = Duration::from_secs(10);
let key_normal = policy.enqueue_key(t, &normal);
let key_boosted = policy.enqueue_key(t, &boosted);
assert!(
key_boosted > key_normal,
"priority_jump should produce a higher key"
);
}
#[test]
fn router_scheduling_policy_matches_fcfs_and_lcfs_ordering() {
let req = request_with(512, 0.0, OverlapScores::default());
let early = Duration::from_secs(1);
let late = Duration::from_secs(10);
let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs, 16);
assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));
let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs, 16);
assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
}
// ---- WSPT policy tests ----
#[test]
......
......@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use tokio::sync::watch;
use super::policy::{FcfsPolicy, SchedulingPolicy};
use super::selector::WorkerSelector;
use super::selector::{DefaultWorkerSelector, WorkerSelector};
use super::types::{SchedulingRequest, SchedulingResponse};
use crate::protocols::{WorkerConfigLike, WorkerId, WorkerWithDpRank};
use crate::sequences::{ActiveSequencesMultiWorker, SequencePublisher, SequenceRequest};
......@@ -53,6 +53,7 @@ pub struct SchedulerQueue<
P: SequencePublisher,
C: WorkerConfigLike,
S: SchedulingPolicy = FcfsPolicy,
Sel: WorkerSelector<C> = DefaultWorkerSelector,
> {
pending: Mutex<BinaryHeap<QueueEntry<S::Key>>>,
/// Number of requests currently parked in the pending queue.
......@@ -65,19 +66,23 @@ pub struct SchedulerQueue<
/// Reference instant for computing arrival offsets.
start_time: Instant,
block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>,
selector: Sel,
policy: S,
}
impl<P: SequencePublisher + 'static, C: WorkerConfigLike, S: SchedulingPolicy>
SchedulerQueue<P, C, S>
impl<
P: SequencePublisher + 'static,
C: WorkerConfigLike,
S: SchedulingPolicy,
Sel: WorkerSelector<C>,
> SchedulerQueue<P, C, S, Sel>
{
pub fn new(
slots: Arc<ActiveSequencesMultiWorker<P>>,
workers_with_configs: watch::Receiver<HashMap<WorkerId, C>>,
threshold_frac: Option<f64>,
block_size: u32,
selector: Box<dyn WorkerSelector<C> + Send + Sync>,
selector: Sel,
policy: S,
) -> Self {
if let Some(frac) = threshold_frac {
......@@ -341,7 +346,7 @@ mod tests {
}
let (cfg_tx, cfg_rx) = watch::channel(configs);
let selector = Box::new(DefaultWorkerSelector::new(None, "test"));
let selector = DefaultWorkerSelector::new(None, "test");
let queue = Arc::new(SchedulerQueue::new(
Arc::clone(&slots),
cfg_rx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment