Unverified Commit 89d67172 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(replay): planner-in-the-loop offline replay (#8187)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent ee9d67e4
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Adapter that drives PlannerStateMachine via the PlannerReplayBridge.
The bridge (Rust, PyO3) runs the offline simulation step-by-step.
This adapter sits between the bridge and the planner state machine:
Bridge.advance_to(tick_ms) → raw metrics dict
Adapter._build_tick_input() → TickInput
StateMachine.on_tick() → PlannerEffects
Adapter → Bridge.apply_scaling(prefill, decode)
Supports both aggregated and disaggregated topologies. No I/O, no runtime
dependencies. Fully deterministic when used with offline replay.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
FpmObservations,
PlannerEffects,
ScheduledTick,
TickDiagnostics,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
logger = logging.getLogger(__name__)
@dataclass
class ScalingEvent:
"""Record of a single scaling decision."""
at_s: float
component: str # "agg", "prefill", or "decode"
from_count: int
to_count: int
reason: Optional[str] = None
@dataclass
class ReplayPlannerReport:
"""Enriched report combining trace metrics and planner diagnostics."""
trace_report: dict[str, Any]
scaling_events: list[ScalingEvent] = field(default_factory=list)
diagnostics_log: list[TickDiagnostics] = field(default_factory=list)
total_ticks: int = 0
def _build_fpm_from_dict(d: dict[str, Any]) -> ForwardPassMetrics:
"""Convert a bridge FPM snapshot dict into a ForwardPassMetrics struct."""
return ForwardPassMetrics(
worker_id=str(d["worker_id"]),
dp_rank=0,
wall_time=d["wall_time"],
scheduled_requests=ScheduledRequestMetrics(
num_prefill_requests=d["num_prefill_requests"],
sum_prefill_tokens=d["sum_prefill_tokens"],
var_prefill_length=d["var_prefill_length"],
sum_prefill_kv_tokens=d["sum_prefill_kv_tokens"],
num_decode_requests=d["num_decode_requests"],
sum_decode_kv_tokens=d["sum_decode_kv_tokens"],
var_decode_kv_tokens=d["var_decode_kv_tokens"],
),
queued_requests=QueuedRequestMetrics(
num_prefill_requests=d["num_queued_prefill"],
sum_prefill_tokens=d["sum_queued_prefill_tokens"],
var_prefill_length=d["var_queued_prefill_length"],
num_decode_requests=d["num_queued_decode"],
sum_decode_kv_tokens=d["sum_queued_decode_kv_tokens"],
var_decode_kv_tokens=d["var_queued_decode_kv_tokens"],
),
)
def _update_fpm_cache(
cache: dict[tuple[str, int], ForwardPassMetrics],
snapshots: list[dict[str, Any]],
active_count: int,
) -> None:
"""Update a last-seen FPM cache with new snapshots and prune removed workers."""
for snap in snapshots:
fpm = _build_fpm_from_dict(snap)
cache[(fpm.worker_id, fpm.dp_rank)] = fpm
# Prune cache down to active_count entries. Workers are removed
# highest-ID-first during scale-down, so keep the lowest IDs.
while len(cache) > active_count:
# Remove the highest worker ID entry
worst_key = max(cache.keys(), key=lambda k: int(k[0]))
del cache[worst_key]
class ReplayPlannerAdapter:
"""Drives the planner state machine using the PlannerReplayBridge.
Supports both ``mode="agg"`` and ``mode="disagg"``.
"""
def __init__(
self,
planner_config: PlannerConfig,
bridge: Any, # PlannerReplayBridge (Rust pyclass)
capabilities: Optional[WorkerCapabilities] = None,
warmup_observations: Optional[list[TrafficObservation]] = None,
) -> None:
self._config = planner_config
self._bridge = bridge
self._sm = PlannerStateMachine(planner_config, capabilities)
self._is_disagg = planner_config.mode == "disagg"
# Last-seen FPM caches (separate for prefill/decode)
self._prefill_fpm_cache: dict[tuple[str, int], ForwardPassMetrics] = {}
self._decode_fpm_cache: dict[tuple[str, int], ForwardPassMetrics] = {}
# Scaling targets — used as `expected` in WorkerCounts
self._scaling_target_prefill: Optional[int] = None
self._scaling_target_decode: Optional[int] = None
if warmup_observations:
self._sm.warm_load_predictors(warmup_observations)
def run(self) -> ReplayPlannerReport:
"""Run the full replay with planner-in-the-loop."""
next_tick = self._sm.initial_tick(0.0)
scaling_events: list[ScalingEvent] = []
diagnostics_log: list[TickDiagnostics] = []
total_ticks = 0
while True:
tick_ms = next_tick.at_s * 1000.0
result = self._bridge.advance_to(tick_ms)
if result["is_done"]:
break
tick_input = self._build_tick_input(next_tick, result)
effects: PlannerEffects = self._sm.on_tick(next_tick, tick_input)
diagnostics_log.append(effects.diagnostics)
total_ticks += 1
# Clear scaling targets once active counts match
active_p = result["active_prefill_count"]
active_d = result["active_decode_count"]
if (
self._scaling_target_prefill is not None
and active_p == self._scaling_target_prefill
):
self._scaling_target_prefill = None
if (
self._scaling_target_decode is not None
and active_d == self._scaling_target_decode
):
self._scaling_target_decode = None
if effects.scale_to is not None:
self._apply_scaling(effects, result, tick_input.now_s, scaling_events)
if effects.next_tick is None:
break
next_tick = effects.next_tick
trace_report = self._bridge.finalize()
return ReplayPlannerReport(
trace_report=trace_report,
scaling_events=scaling_events,
diagnostics_log=diagnostics_log,
total_ticks=total_ticks,
)
def _apply_scaling(
self,
effects: PlannerEffects,
result: dict[str, Any],
now_s: float,
scaling_events: list[ScalingEvent],
) -> None:
"""Apply scaling decisions and record events."""
scale = effects.scale_to
assert scale is not None
current_p = result["active_prefill_count"]
current_d = result["active_decode_count"]
target_p = scale.num_prefill if scale.num_prefill is not None else current_p
target_d = scale.num_decode if scale.num_decode is not None else current_d
if target_p == current_p and target_d == current_d:
return
self._bridge.apply_scaling(target_p, target_d)
if self._is_disagg:
if scale.num_prefill is not None and target_p != current_p:
direction = "scale_up" if target_p > current_p else "scale_down"
logger.info(
"Planner scaling prefill: %d -> %d at t=%.1fs (%s)",
current_p,
target_p,
now_s,
direction,
)
self._scaling_target_prefill = target_p
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="prefill",
from_count=current_p,
to_count=target_p,
reason=direction,
)
)
if scale.num_decode is not None and target_d != current_d:
direction = "scale_up" if target_d > current_d else "scale_down"
logger.info(
"Planner scaling decode: %d -> %d at t=%.1fs (%s)",
current_d,
target_d,
now_s,
direction,
)
self._scaling_target_decode = target_d
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="decode",
from_count=current_d,
to_count=target_d,
reason=direction,
)
)
else:
direction = "scale_up" if target_d > current_d else "scale_down"
logger.info(
"Planner scaling: %d -> %d workers at t=%.1fs (%s)",
current_d,
target_d,
now_s,
direction,
)
self._scaling_target_decode = target_d
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="agg",
from_count=current_d,
to_count=target_d,
reason=direction,
)
)
def _feed_extra_fpm_to_regression(
self,
decode_snaps: list[dict[str, Any]],
prefill_snaps: list[dict[str, Any]],
) -> None:
"""Feed accumulated FPM snapshots to regression, excluding the last
per worker (which will be added by _observe_fpm via fpm_observations).
This avoids double-counting the cached snapshot."""
if not hasattr(self._sm, "_is_easy") or self._sm._is_easy:
return # easy mode has no regression models
if self._sm._is_agg:
# Exclude the last snapshot per worker (it's in the cache and
# will be added by _observe_fpm)
last_idx_per_worker: dict[int, int] = {}
for i, snap in enumerate(decode_snaps):
last_idx_per_worker[snap["worker_id"]] = i
exclude = set(last_idx_per_worker.values())
for i, snap in enumerate(decode_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._agg_regression.add_observation(fpm)
else:
if self._sm._has_prefill:
last_idx: dict[int, int] = {}
for i, snap in enumerate(prefill_snaps):
last_idx[snap["worker_id"]] = i
exclude = set(last_idx.values())
for i, snap in enumerate(prefill_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._prefill_regression.add_observation(fpm)
if self._sm._has_decode:
last_idx = {}
for i, snap in enumerate(decode_snaps):
last_idx[snap["worker_id"]] = i
exclude = set(last_idx.values())
for i, snap in enumerate(decode_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._decode_regression.add_observation(fpm)
def _build_tick_input(
self, tick: ScheduledTick, result: dict[str, Any]
) -> TickInput:
"""Convert bridge result dict to planner TickInput."""
now_s = result["now_ms"] / 1000.0
worker_counts = None
if tick.need_worker_states:
active_p = result["active_prefill_count"]
active_d = result["active_decode_count"]
expected_p = (
self._scaling_target_prefill
if self._scaling_target_prefill is not None
else active_p
)
expected_d = (
self._scaling_target_decode
if self._scaling_target_decode is not None
else active_d
)
worker_counts = WorkerCounts(
ready_num_prefill=active_p if self._is_disagg else None,
ready_num_decode=active_d,
expected_num_prefill=expected_p if self._is_disagg else None,
expected_num_decode=expected_d,
)
fpm_observations = None
if tick.need_worker_fpm:
prefill_snaps = result.get("prefill_fpm_snapshots", [])
decode_snaps = result.get("decode_fpm_snapshots", [])
_update_fpm_cache(
self._prefill_fpm_cache, prefill_snaps, result["active_prefill_count"]
)
_update_fpm_cache(
self._decode_fpm_cache, decode_snaps, result["active_decode_count"]
)
# In offline replay, we accumulate many FPM snapshots per tick
# (one per engine pass). Feed ALL non-idle snapshots directly to
# the regression models for a representative fit. The last-per-worker
# cache is only used for the FpmObservations dict (worker count
# reconciliation), not as the sole regression input.
prefill_dict = (
dict(self._prefill_fpm_cache) if self._prefill_fpm_cache else None
)
decode_dict = (
dict(self._decode_fpm_cache) if self._decode_fpm_cache else None
)
self._feed_extra_fpm_to_regression(decode_snaps, prefill_snaps)
fpm_observations = FpmObservations(
prefill=prefill_dict,
decode=decode_dict,
)
traffic = None
if tick.need_traffic_metrics:
t = result.get("traffic", {})
duration_s = t.get("duration_s", 0.0)
if duration_s > 0:
traffic = TrafficObservation(
duration_s=duration_s,
num_req=float(t.get("num_req", 0)),
isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0),
)
return TickInput(
now_s=now_s,
traffic=traffic,
worker_counts=worker_counts,
fpm_observations=fpm_observations,
)
......@@ -177,6 +177,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::replay::PlannerReplayBridge>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
......@@ -1196,3 +1196,179 @@ fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let token = value as u32;
if token == 0 { 1 } else { token }
}
// ---------------------------------------------------------------------------
// Planner-in-the-loop replay bridge
// ---------------------------------------------------------------------------
fn fpm_snapshots_to_json(
snapshots: Vec<(usize, dynamo_mocker::common::protocols::ForwardPassSnapshot)>,
) -> Vec<serde_json::Value> {
snapshots
.into_iter()
.map(|(worker_id, fpm)| {
json!({
"worker_id": worker_id,
"wall_time": fpm.wall_time_secs,
"num_prefill_requests": fpm.num_prefill_requests,
"sum_prefill_tokens": fpm.sum_prefill_tokens,
"var_prefill_length": fpm.var_prefill_length,
"sum_prefill_kv_tokens": fpm.sum_prefill_kv_tokens,
"num_decode_requests": fpm.num_decode_requests,
"sum_decode_kv_tokens": fpm.sum_decode_kv_tokens,
"var_decode_kv_tokens": fpm.var_decode_kv_tokens,
"num_queued_prefill": fpm.num_queued_prefill,
"sum_queued_prefill_tokens": fpm.sum_queued_prefill_tokens,
"var_queued_prefill_length": fpm.var_queued_prefill_length,
"num_queued_decode": fpm.num_queued_decode,
"sum_queued_decode_kv_tokens": fpm.sum_queued_decode_kv_tokens,
"var_queued_decode_kv_tokens": fpm.var_queued_decode_kv_tokens,
})
})
.collect()
}
/// Step-based bridge for driving an offline replay with a Python planner.
///
/// Supports both aggregated and disaggregated topologies. The Python adapter
/// calls `advance_to()` to run the simulation forward, collects FPM/traffic
/// metrics, feeds them to the planner state machine, then calls
/// `apply_scaling()` to resize worker pools.
#[pyclass(unsendable)]
pub struct PlannerReplayBridge {
handle: Option<dynamo_mocker::replay::PlannerReplayHandle>,
}
#[pymethods]
impl PlannerReplayBridge {
/// Create a bridge for an aggregated Mooncake-style JSONL trace replay.
#[new]
#[pyo3(signature = (trace_file, extra_engine_args, num_workers, router_mode="round_robin", router_config=None, arrival_speedup_ratio=1.0, trace_block_size=512))]
fn new(
trace_file: PathBuf,
extra_engine_args: &MockEngineArgs,
num_workers: usize,
router_mode: &str,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<Self> {
let args = extra_engine_args.inner();
let router_mode = parse_replay_router_mode(router_mode)?;
let router_config = load_replay_router_config(router_config);
let handle = dynamo_mocker::replay::PlannerReplayHandle::from_trace_file(
args,
router_config,
None,
&trace_file,
trace_block_size,
num_workers,
arrival_speedup_ratio,
router_mode,
)
.map_err(to_pyerr)?;
Ok(Self {
handle: Some(handle),
})
}
/// Create a bridge for a disaggregated Mooncake-style JSONL trace replay.
#[staticmethod]
#[pyo3(signature = (trace_file, prefill_engine_args, decode_engine_args, num_prefill_workers, num_decode_workers, router_mode="round_robin", router_config=None, arrival_speedup_ratio=1.0, trace_block_size=512))]
#[allow(clippy::too_many_arguments)]
fn create_disagg(
trace_file: PathBuf,
prefill_engine_args: &MockEngineArgs,
decode_engine_args: &MockEngineArgs,
num_prefill_workers: usize,
num_decode_workers: usize,
router_mode: &str,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<Self> {
let config = dynamo_mocker::replay::OfflineDisaggReplayConfig {
prefill_args: prefill_engine_args.inner(),
decode_args: decode_engine_args.inner(),
num_prefill_workers,
num_decode_workers,
};
let router_mode = parse_replay_router_mode(router_mode)?;
let router_config = load_replay_router_config(router_config);
let handle = dynamo_mocker::replay::PlannerReplayHandle::from_trace_file_disagg(
config,
router_config,
None,
&trace_file,
trace_block_size,
arrival_speedup_ratio,
router_mode,
)
.map_err(to_pyerr)?;
Ok(Self {
handle: Some(handle),
})
}
/// Advance the simulation to `until_ms` simulated time.
///
/// Returns a dict with separate prefill/decode worker counts and FPM snapshots.
fn advance_to(&mut self, py: Python<'_>, until_ms: f64) -> PyResult<PyObject> {
let handle = self
.handle
.as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
let tick_data = handle.advance_to(until_ms).map_err(to_pyerr)?;
let (duration_s, num_req, avg_isl, avg_osl) = tick_data.traffic;
let result = json!({
"now_ms": tick_data.now_ms,
"is_done": tick_data.is_done,
"prefill_fpm_snapshots": fpm_snapshots_to_json(tick_data.prefill_fpm_snapshots),
"decode_fpm_snapshots": fpm_snapshots_to_json(tick_data.decode_fpm_snapshots),
"traffic": {
"duration_s": duration_s,
"num_req": num_req,
"avg_isl": avg_isl,
"avg_osl": avg_osl,
},
"active_prefill_count": tick_data.active_prefill_count,
"active_decode_count": tick_data.active_decode_count,
"total_prefill_count": tick_data.total_prefill_count,
"total_decode_count": tick_data.total_decode_count,
});
pythonize(py, &result)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
/// Apply a scaling decision with separate prefill and decode targets.
/// For agg mode, `target_prefill` is ignored (pass 0).
fn apply_scaling(&mut self, target_prefill: usize, target_decode: usize) -> PyResult<()> {
let handle = self
.handle
.as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
handle
.apply_scaling(target_prefill, target_decode)
.map_err(to_pyerr)
}
/// Finalize the replay and return the trace simulation report.
fn finalize(&mut self, py: Python<'_>) -> PyResult<PyObject> {
let handle = self
.handle
.take()
.ok_or_else(|| PyException::new_err("bridge has already been finalized"))?;
let report = handle.finalize();
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
}
......@@ -1570,6 +1570,37 @@ def run_mocker_synthetic_trace_replay(
"""Replay a synthetic mocker workload without requiring a trace file."""
...
class PlannerReplayBridge:
"""Step-based bridge for driving an offline replay with a Python planner."""
def __init__(
self,
trace_file: str | os.PathLike[str],
extra_engine_args: MockEngineArgs,
num_workers: int,
router_mode: str = "round_robin",
router_config: Optional[KvRouterConfig] = None,
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> None: ...
@staticmethod
def create_disagg(
trace_file: str | os.PathLike[str],
prefill_engine_args: MockEngineArgs,
decode_engine_args: MockEngineArgs,
num_prefill_workers: int,
num_decode_workers: int,
router_mode: str = "round_robin",
router_config: Optional[KvRouterConfig] = None,
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> "PlannerReplayBridge": ...
def advance_to(self, until_ms: float) -> Dict[str, Any]: ...
def apply_scaling(self, target_prefill: int, target_decode: int) -> None: ...
def finalize(self) -> Dict[str, Any]: ...
class Layer:
"""
A KV cache block layer
......
......@@ -25,6 +25,7 @@ from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PlannerReplayBridge as PlannerReplayBridge
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig
......
......@@ -11,7 +11,10 @@ import sys
from collections.abc import Sequence
from pathlib import Path
from types import SimpleNamespace
from typing import Protocol
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from dynamo.planner.core.types import EngineCapabilities
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
......@@ -106,6 +109,100 @@ def _load_aic_perf_config(args: argparse.Namespace):
)
def _engine_caps(args: MockEngineArgs) -> EngineCapabilities:
"""Derive EngineCapabilities from MockEngineArgs."""
from dynamo.planner.core.types import EngineCapabilities
max_kv_tokens = args.num_gpu_blocks * args.block_size
return EngineCapabilities(
num_gpu=1,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_seqs=args.max_num_seqs,
context_length=max_kv_tokens if max_kv_tokens > 0 else None,
max_kv_tokens=max_kv_tokens if max_kv_tokens > 0 else None,
)
def _run_planner_replay(
trace_file: str,
extra_engine_args: MockEngineArgs | None,
prefill_engine_args: MockEngineArgs | None,
decode_engine_args: MockEngineArgs | None,
router_config: KvRouterConfig | None,
num_workers: int,
num_prefill_workers: int,
num_decode_workers: int,
router_mode: str,
arrival_speedup_ratio: float,
trace_block_size: int,
planner_config_arg: str,
):
"""Run an offline replay with planner-in-the-loop (agg or disagg).
# TODO(jthomson04): SLA-based scaling (optimization_target="sla") with
# disagg mode requires planner_profile_data (NPZ) or AIC-backed engine
# args. The default polynomial perf model does not account for batch
# size in its decode timing, causing the DecodeRegressionModel's
# num_decode_requests coefficient to go negative and reject the fit.
# Fix the polynomial model to incorporate batch_size, or gate disagg
# SLA mode on having a non-polynomial perf model.
"""
from dynamo.llm import PlannerReplayBridge
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.types import WorkerCapabilities
from dynamo.planner.offline.replay_adapter import ReplayPlannerAdapter
planner_config = PlannerConfig.from_config_arg(planner_config_arg)
planner_config.no_operation = True
if planner_config.mode == "agg":
if extra_engine_args is None:
extra_engine_args = MockEngineArgs()
bridge = PlannerReplayBridge(
trace_file=trace_file,
extra_engine_args=extra_engine_args,
num_workers=num_workers,
router_mode=router_mode,
router_config=router_config,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
capabilities = WorkerCapabilities(decode=_engine_caps(extra_engine_args))
elif planner_config.mode == "disagg":
if prefill_engine_args is None or decode_engine_args is None:
raise ValueError(
"disagg planner replay requires --prefill-engine-args and --decode-engine-args"
)
bridge = PlannerReplayBridge.create_disagg(
trace_file=trace_file,
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers,
router_mode=router_mode,
router_config=router_config,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
capabilities = WorkerCapabilities(
prefill=_engine_caps(prefill_engine_args),
decode=_engine_caps(decode_engine_args),
)
else:
raise ValueError(
f"planner-in-the-loop replay supports mode='agg' or 'disagg', got '{planner_config.mode}'"
)
adapter = ReplayPlannerAdapter(
planner_config=planner_config,
bridge=bridge,
capabilities=capabilities,
)
return adapter.run()
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?")
......@@ -155,6 +252,10 @@ def main(argv: Sequence[str] | None = None) -> int:
"--report-json",
help="path to save the full replay report JSON; defaults to a timestamped file in the current directory",
)
parser.add_argument(
"--planner-config",
help="path to planner config YAML/JSON or inline JSON; enables planner-in-the-loop replay (offline agg only)",
)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
......@@ -190,6 +291,43 @@ def main(argv: Sequence[str] | None = None) -> int:
except ValueError as exc:
parser.error(str(exc))
# Planner-in-the-loop mode
if args.planner_config is not None:
if args.replay_mode != "offline":
parser.error("--planner-config only supports --replay-mode=offline")
if not using_trace_file:
parser.error("--planner-config requires a trace file (not synthetic)")
planner_report = _run_planner_replay(
trace_file=args.trace_file,
extra_engine_args=extra_engine_args,
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size,
planner_config_arg=args.planner_config,
)
report = planner_report.trace_report
if planner_report.scaling_events:
sys.stdout.write("\nScaling events:\n")
for event in planner_report.scaling_events:
sys.stdout.write(
f" t={event.at_s:.1f}s [{event.component}]: "
f"{event.from_count} -> {event.to_count} workers"
f" ({event.reason})\n"
)
report_path = write_report_json(report, args.report_json)
sys.stdout.write(format_report_table(report))
sys.stdout.write("\n")
sys.stdout.write(f"Saved full report to: {report_path}\n")
sys.stdout.write(f"Planner ticks: {planner_report.total_ticks}\n")
return 0
if using_trace_file:
report = run_trace_replay(
args.trace_file,
......
......@@ -240,7 +240,7 @@ impl PerfModel {
/// Predict decode time in milliseconds.
///
/// Callers always pass all parameters; each variant uses what it needs:
/// - Polynomial: uses active_kv_tokens
/// - Polynomial: uses (active_kv_tokens, total_kv_tokens) as utilization
/// - Interpolated: uses (active_kv_tokens, context_length)
/// - Aiconfigurator: uses (batch_size, context_length)
pub fn predict_decode_time(
......@@ -248,13 +248,19 @@ impl PerfModel {
batch_size: usize,
active_kv_tokens: usize,
context_length: usize,
total_kv_tokens: usize,
) -> f64 {
if batch_size == 0 {
return 0.0;
}
let time = match self {
PerfModel::Polynomial => {
let active_perc = active_kv_tokens as f64 / 16384.0;
let active_perc = if total_kv_tokens > 0 {
active_kv_tokens as f64 / total_kv_tokens as f64
} else {
tracing::warn!("Total KV tokens is 0, using 1.0 as capacity");
1.0
};
-25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
}
PerfModel::Interpolated { decode_interp, .. } => decode_interp
......
......@@ -6,6 +6,7 @@ mod collector;
mod entrypoints;
pub(crate) mod offline;
mod online;
mod planner_handle;
mod router_shared;
mod validate;
......@@ -76,6 +77,7 @@ pub use entrypoints::{
simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_disagg_with_router_mode, simulate_trace_workload_with_router_mode,
};
pub use planner_handle::{PlannerReplayHandle, PlannerTickData};
pub use validate::validate_replay_args_mode;
pub(crate) fn normalize_trace_requests(
......
......@@ -16,11 +16,11 @@ use super::state::OfflineWorkerSnapshot;
use super::{
components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
},
state::AggRequestState,
};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::common::protocols::{DirectRequest, ForwardPassSnapshot, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector};
use anyhow::bail;
......@@ -59,7 +59,7 @@ struct AggRuntimeSnapshot {
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct AggRuntimeStats;
pub(super) struct AggRuntime {
pub(in crate::replay) struct AggRuntime {
now_ms: f64,
next_worker_idx: usize,
next_event_seq: u64,
......@@ -71,6 +71,10 @@ pub(super) struct AggRuntime {
router: Option<OfflineReplayRouter>,
progress: ReplayProgress,
stats: AggRuntimeStats,
/// Forward pass metrics accumulated between planner ticks.
fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic statistics accumulated between planner ticks.
traffic: TrafficAccumulator,
#[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>,
#[cfg(test)]
......@@ -79,7 +83,7 @@ pub(super) struct AggRuntime {
impl AggRuntime {
/// Create an aggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
pub(in crate::replay) fn new(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
......@@ -99,7 +103,7 @@ impl AggRuntime {
}
/// Create an aggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
pub(in crate::replay) fn new_workload(
args: &MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
......@@ -139,7 +143,7 @@ impl AggRuntime {
)?),
};
let capture_kv_events = router.is_some();
let engine = EngineComponent::new(
let mut engine = EngineComponent::new(
SimulationWorkerStage::Aggregated,
EnginePassMode::Visible,
(0..num_workers)
......@@ -152,6 +156,7 @@ impl AggRuntime {
})
.collect(),
);
engine.set_scaling_args(args, capture_kv_events);
Ok(Self {
now_ms: 0.0,
......@@ -168,6 +173,8 @@ impl AggRuntime {
stats: AggRuntimeStats::default(),
#[cfg(not(test))]
stats: AggRuntimeStats,
fpm_buffer: Vec::new(),
traffic: TrafficAccumulator::new(),
#[cfg(test)]
worker_active_requests: vec![Vec::new(); num_workers],
#[cfg(test)]
......@@ -208,11 +215,13 @@ impl AggRuntime {
}
}
/// Pick the next worker in round-robin order.
/// Pick the next active worker in round-robin order.
fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx;
self.next_worker_idx = (self.next_worker_idx + 1) % self.engine.worker_count();
worker_idx
let active = self.engine.active_worker_ids();
debug_assert!(!active.is_empty(), "no active workers for round-robin");
let idx = self.next_worker_idx % active.len();
self.next_worker_idx = idx + 1;
active[idx]
}
/// Record which worker accepted a request and refresh in-flight stats.
......@@ -281,7 +290,10 @@ impl AggRuntime {
);
if self.router.is_none() {
self.requests.insert(uuid, AggRequestState::new_running());
self.requests.insert(
uuid,
AggRequestState::new_running(request.tokens.len(), request.max_output_tokens),
);
let worker_idx = self.next_worker();
self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid);
......@@ -346,9 +358,11 @@ impl AggRuntime {
}
self.record_router_pending();
}
self.requests.remove(&signal.uuid).ok_or_else(|| {
let removed_state = self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?;
self.traffic
.on_request(removed_state.input_tokens, removed_state.output_tokens);
self.admission
.on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed();
......@@ -465,6 +479,7 @@ impl AggRuntime {
}
fn handle_engine_effects(&mut self, effects: EngineEffects) -> anyhow::Result<()> {
self.fpm_buffer.extend(effects.fpm_snapshots);
self.apply_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions {
let payload = self.engine.on_scheduled_completion(payload)?;
......@@ -496,8 +511,92 @@ impl AggRuntime {
Ok(())
}
// ------------------------------------------------------------------
// Planner integration: step-based execution
// ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work).
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> anyhow::Result<bool> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
if next_timestamp_ms > until_ms {
break;
}
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
Ok(self.is_done())
}
/// Current simulated time in milliseconds.
pub(in crate::replay) fn now_ms(&self) -> f64 {
self.now_ms
}
/// Number of active (non-pending-removal) workers.
pub(in crate::replay) fn active_worker_count(&self) -> usize {
self.engine.active_worker_ids().len()
}
/// Total worker count including pending-removal.
pub(in crate::replay) fn total_worker_count(&self) -> usize {
self.engine.worker_count()
}
/// Drain accumulated FPM snapshots since the last drain.
pub(in crate::replay) fn drain_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.fpm_buffer)
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
self.traffic.drain(self.now_ms)
}
/// Apply a scaling decision: set the target number of workers.
/// Scale-up is immediate; scale-down removes the worker from the router
/// immediately (so no new requests land on it) and lets it drain in-flight
/// work in the engine.
pub(in crate::replay) fn apply_scaling(&mut self, target_workers: usize) -> anyhow::Result<()> {
let (added, newly_marked) = self.engine.apply_target_count(target_workers);
if let Some(router) = self.router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
Ok(())
}
/// Finalize the replay: finish progress bar, return collector and stats.
pub(in crate::replay::offline) fn finalize(self) -> (TraceCollector, AggRuntimeStats) {
self.progress.finish();
(self.collector, self.stats)
}
/// Finalize the replay and return the simulation report directly.
pub(in crate::replay) fn finalize_report(self) -> crate::replay::TraceSimulationReport {
let (collector, _stats) = self.finalize();
collector.finish()
}
/// Run the aggregated offline replay until all arrivals and worker work are exhausted.
pub(super) fn run(mut self) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> {
pub(in crate::replay::offline) fn run(
mut self,
) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> {
self.drain_current_timestamp()?;
while !self.is_done() {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::{BTreeMap, BTreeSet};
use anyhow::bail;
use super::super::events::SimulationWorkerStage;
......@@ -9,14 +11,23 @@ use super::super::runtime_utils::WorkerCompletionPayload;
use super::super::state::OfflineWorkerSnapshot;
use super::super::state::OfflineWorkerState;
use super::{EngineEffects, EnginePassMode, ScheduledWorkerCompletion};
use crate::common::protocols::DirectRequest;
use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::replay::TraceCollector;
use crate::scheduler::RouterEventVisibility;
pub(in crate::replay::offline) struct EngineComponent {
stage: SimulationWorkerStage,
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
/// Workers keyed by stable ID (monotonic, never reused).
workers: BTreeMap<usize, OfflineWorkerState>,
/// Counter for generating the next stable worker ID.
next_id: usize,
/// Workers marked for removal — skipped by round-robin, removed when drained.
pending_removal: BTreeSet<usize>,
/// Engine args used to construct new workers during scale-up.
args: MockEngineArgs,
/// Whether new workers should capture KV events (true when a router is present).
capture_kv_events: bool,
}
impl EngineComponent {
......@@ -25,20 +36,115 @@ impl EngineComponent {
pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>,
) -> Self {
let count = workers.len();
let map: BTreeMap<usize, OfflineWorkerState> = workers.into_iter().enumerate().collect();
Self {
stage,
pass_mode,
workers,
workers: map,
next_id: count,
pending_removal: BTreeSet::new(),
args: MockEngineArgs::default(),
capture_kv_events: false,
}
}
/// Set the engine args and KV capture flag used when adding workers dynamically.
pub(in crate::replay::offline) fn set_scaling_args(
&mut self,
args: MockEngineArgs,
capture_kv_events: bool,
) {
self.args = args;
self.capture_kv_events = capture_kv_events;
}
/// Add a new worker, returning its stable ID.
pub(in crate::replay::offline) fn add_worker(&mut self) -> usize {
let id = self.next_id;
self.next_id += 1;
let worker = OfflineWorkerState::new(id, self.args.clone(), self.capture_kv_events);
self.workers.insert(id, worker);
id
}
/// Mark a worker for removal. It will be skipped by `drive_ready` and
/// removed once fully drained.
pub(in crate::replay::offline) fn mark_for_removal(&mut self, worker_id: usize) {
self.pending_removal.insert(worker_id);
}
/// Remove all marked workers that have fully drained, returning their IDs.
pub(in crate::replay::offline) fn try_remove_drained(&mut self) -> Vec<usize> {
let mut removed = Vec::new();
self.pending_removal.retain(|&id| {
if let Some(worker) = self.workers.get(&id) {
if worker.is_drained() {
removed.push(id);
return false; // remove from pending set
}
} else {
// Worker already gone
return false;
}
true // keep in pending set
});
for &id in &removed {
self.workers.remove(&id);
}
removed
}
/// Apply a target worker count: add new workers or mark excess for removal.
/// Returns `(added_ids, newly_marked_ids)` so the caller can update the
/// router immediately. Newly marked workers should be removed from the
/// router right away to prevent new requests from landing on them, even
/// though the workers themselves remain in the engine until fully drained.
pub(in crate::replay::offline) fn apply_target_count(
&mut self,
target: usize,
) -> (Vec<usize>, Vec<usize>) {
let active_ids = self.active_worker_ids();
let current = active_ids.len();
let mut added = Vec::new();
let mut newly_marked = Vec::new();
if target > current {
for _ in 0..(target - current) {
added.push(self.add_worker());
}
} else if target < current {
let excess = current - target;
for &id in active_ids.iter().rev().take(excess) {
self.mark_for_removal(id);
newly_marked.push(id);
}
}
// Clean up any workers that have already fully drained.
self.try_remove_drained();
(added, newly_marked)
}
/// Return stable IDs of all active (non-pending-removal) workers.
pub(in crate::replay::offline) fn active_worker_ids(&self) -> Vec<usize> {
self.workers
.keys()
.filter(|id| !self.pending_removal.contains(id))
.copied()
.collect()
}
pub(in crate::replay::offline) fn dispatch(
&mut self,
worker_idx: usize,
worker_id: usize,
request: DirectRequest,
) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?;
self.workers[worker_idx].receive_request(request);
let worker = self
.workers
.get_mut(&worker_id)
.ok_or_else(|| anyhow::anyhow!("offline replay selected unknown worker {worker_id}"))?;
worker.receive_request(request);
Ok(())
}
......@@ -47,8 +153,11 @@ impl EngineComponent {
now_ms: f64,
mut collector: Option<&mut TraceCollector>,
) -> anyhow::Result<EngineEffects> {
for worker_idx in 0..self.workers.len() {
if !self.workers[worker_idx].is_ready() {
// Collect worker IDs first to avoid borrow issues.
let worker_ids: Vec<usize> = self.workers.keys().copied().collect();
for worker_id in worker_ids {
let worker = self.workers.get(&worker_id).unwrap();
if !worker.is_ready() {
continue;
}
......@@ -57,15 +166,25 @@ impl EngineComponent {
let Some(collector) = collector.as_deref_mut() else {
bail!("offline replay visible engine pass requires a collector");
};
self.workers[worker_idx].execute_pass(collector, now_ms)
self.workers
.get_mut(&worker_id)
.unwrap()
.execute_pass(collector, now_ms)
}
EnginePassMode::Hidden => self.workers[worker_idx].execute_hidden_pass(now_ms),
EnginePassMode::Hidden => self
.workers
.get_mut(&worker_id)
.unwrap()
.execute_hidden_pass(now_ms),
};
let mut effects = EngineEffects {
admissions: executed.admissions,
..EngineEffects::default()
};
if let Some(fpm) = executed.fpm {
effects.fpm_snapshots.push((worker_id, fpm));
}
let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart {
effects.pass_start_kv_events = executed.kv_events;
......@@ -75,7 +194,7 @@ impl EngineComponent {
};
let payload = WorkerCompletionPayload {
stage: self.stage,
worker_idx,
worker_idx: worker_id,
completed_requests: executed.completed_requests,
output_signals: executed.output_signals,
kv_events: completion_kv_events,
......@@ -86,7 +205,7 @@ impl EngineComponent {
return Ok(effects);
}
self.workers[worker_idx].mark_busy();
self.workers.get_mut(&worker_id).unwrap().mark_busy();
effects
.scheduled_completions
.push(ScheduledWorkerCompletion {
......@@ -110,35 +229,42 @@ impl EngineComponent {
payload.stage
);
}
self.validate_worker_idx(payload.worker_idx)?;
self.workers[payload.worker_idx].mark_idle();
self.workers[payload.worker_idx].mark_completed(payload.completed_requests);
let worker = self.workers.get_mut(&payload.worker_idx).ok_or_else(|| {
anyhow::anyhow!(
"offline replay completion for unknown worker {}",
payload.worker_idx
)
})?;
worker.mark_idle();
worker.mark_completed(payload.completed_requests);
// Eagerly clean up drained workers that are pending removal so they
// don't linger indefinitely when no further scaling events trigger
// apply_target_count.
if self.pending_removal.contains(&payload.worker_idx) {
self.try_remove_drained();
}
Ok(payload)
}
pub(in crate::replay::offline) fn in_flight(&self) -> usize {
self.workers.iter().map(OfflineWorkerState::in_flight).sum()
self.workers
.values()
.map(OfflineWorkerState::in_flight)
.sum()
}
pub(in crate::replay::offline) fn is_drained(&self) -> bool {
self.workers.iter().all(OfflineWorkerState::is_drained)
self.workers.values().all(OfflineWorkerState::is_drained)
}
pub(in crate::replay::offline) fn worker_count(&self) -> usize {
self.workers.len()
}
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
#[cfg(test)]
pub(crate) fn debug_snapshots(&self) -> Vec<OfflineWorkerSnapshot> {
self.workers
.iter()
.values()
.map(OfflineWorkerState::debug_snapshot)
.collect()
}
......
......@@ -11,7 +11,8 @@ pub(in crate::replay::offline) use engine::EngineComponent;
pub(crate) use router::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay) use types::ReplayMode;
pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ReplayMode, ScheduledWorkerCompletion,
EngineEffects, EnginePassMode, ReadyArrival, ScheduledWorkerCompletion, TrafficAccumulator,
};
pub(crate) use types::{RouterEffects, WorkerAdmission};
......@@ -309,6 +309,48 @@ impl OfflineReplayRouter {
self.pending.len()
}
/// Register a new worker with the router, cloning the config from existing workers.
pub(crate) fn add_worker(&mut self, worker_id: usize) -> Result<()> {
let config = self
.workers_with_configs
.values()
.next()
.ok_or_else(|| anyhow!("cannot add worker to router with no existing workers"))?
.clone();
let wid = worker_id as WorkerId;
self.workers_with_configs.insert(wid, config);
// Rebuild the slots with the full worker set
let dp_range: HashMap<u64, (u32, u32)> = self
.workers_with_configs
.keys()
.map(|&id| (id, (0u32, 1u32)))
.collect();
self.slots.update_workers(&dp_range);
// Enable queueing if we now have more than one worker
if self.workers_with_configs.len() > 1 && self.queue_threshold.is_none() {
self.queue_threshold = self.config.router_queue_threshold;
}
Ok(())
}
/// Remove a worker from routing eligibility.
///
/// Only removes the worker from the config map so the selector won't
/// pick it for new requests. The radix tree and active-sequence slots
/// are left intact so that in-flight requests on this worker can still
/// complete (free / mark_prefill_completed) and KV events can still
/// reference existing blocks without "parent block not found" errors.
/// Stale slot and indexer state is harmless — the selector and
/// `all_workers_busy` both skip workers absent from `workers_with_configs`.
pub(crate) fn remove_worker(&mut self, worker_id: usize) -> Result<()> {
let wid = worker_id as WorkerId;
self.workers_with_configs.remove(&wid);
Ok(())
}
#[cfg(test)]
pub(crate) fn debug_snapshot(&self, now_ms: f64) -> OfflineRouterSnapshot {
let decay_now = self.decay_now(now_ms);
......
......@@ -5,12 +5,12 @@ use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid;
use super::super::runtime_utils::WorkerCompletionPayload;
use crate::common::protocols::DirectRequest;
use crate::common::protocols::{DirectRequest, ForwardPassSnapshot};
use crate::loadgen::ReplayRequestHashes;
use crate::scheduler::AdmissionEvent;
#[derive(Debug, Clone, Copy)]
pub(in crate::replay::offline) enum ReplayMode {
pub(in crate::replay) enum ReplayMode {
Trace,
Concurrency { max_in_flight: usize },
}
......@@ -39,6 +39,9 @@ pub(in crate::replay::offline) struct EngineEffects {
pub(in crate::replay::offline) pass_start_kv_events: Vec<RouterEvent>,
pub(in crate::replay::offline) immediate_completions: Vec<WorkerCompletionPayload>,
pub(in crate::replay::offline) scheduled_completions: Vec<ScheduledWorkerCompletion>,
/// Forward pass metrics snapshots emitted by workers during this drive cycle,
/// keyed by worker index. Collected for planner integration.
pub(in crate::replay::offline) fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
}
impl EngineEffects {
......@@ -61,3 +64,57 @@ pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) arrival_time_ms: f64,
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
}
/// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window).
#[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64,
num_req: usize,
total_isl: usize,
total_osl: usize,
}
impl TrafficAccumulator {
pub(in crate::replay::offline) fn new() -> Self {
Self {
window_start_ms: 0.0,
num_req: 0,
total_isl: 0,
total_osl: 0,
}
}
/// Record one admitted request.
pub(in crate::replay::offline) fn on_request(
&mut self,
input_tokens: usize,
output_tokens: usize,
) {
self.num_req += 1;
self.total_isl += input_tokens;
self.total_osl += output_tokens;
}
/// Drain the accumulator at the given simulated time, returning
/// (duration_s, num_req, avg_isl, avg_osl) and resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> (f64, usize, f64, f64) {
let duration_s = (now_ms - self.window_start_ms) / 1000.0;
let num_req = self.num_req;
let avg_isl = if num_req > 0 {
self.total_isl as f64 / num_req as f64
} else {
0.0
};
let avg_osl = if num_req > 0 {
self.total_osl as f64 / num_req as f64
} else {
0.0
};
self.window_start_ms = now_ms;
self.num_req = 0;
self.total_isl = 0;
self.total_osl = 0;
(duration_s, num_req, avg_isl, avg_osl)
}
}
......@@ -11,7 +11,7 @@ use uuid::Uuid;
pub(super) use super::components::ReplayMode;
use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
};
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
......@@ -22,7 +22,7 @@ use super::runtime_utils::{
#[cfg(test)]
use super::state::DisaggRequestSnapshot;
use super::state::{DisaggPhase, DisaggRequestState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::common::protocols::{DirectRequest, ForwardPassSnapshot, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector,
......@@ -60,7 +60,7 @@ pub(super) struct DisaggRuntimeStats {
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct DisaggRuntimeStats;
pub(super) struct DisaggRuntime {
pub(in crate::replay) struct DisaggRuntime {
now_ms: f64,
next_prefill_worker_idx: usize,
next_decode_worker_idx: usize,
......@@ -75,11 +75,16 @@ pub(super) struct DisaggRuntime {
events: BinaryHeap<SimulationEvent>,
progress: ReplayProgress,
stats: DisaggRuntimeStats,
/// Forward pass metrics accumulated between planner ticks, keyed by (stage, worker_idx).
prefill_fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
decode_fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic statistics accumulated between planner ticks.
traffic: TrafficAccumulator,
}
impl DisaggRuntime {
/// Create a disaggregated offline runtime seeded from an explicit request queue.
pub(super) fn new(
pub(in crate::replay) fn new(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
......@@ -97,7 +102,7 @@ impl DisaggRuntime {
}
/// Create a disaggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload(
pub(in crate::replay) fn new_workload(
config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
......@@ -147,7 +152,8 @@ impl DisaggRuntime {
}
};
let prefill_engine = EngineComponent::new(
let prefill_capture_kv = prefill_router.is_some();
let mut prefill_engine = EngineComponent::new(
SimulationWorkerStage::Prefill,
EnginePassMode::Hidden,
(0..config.num_prefill_workers)
......@@ -155,12 +161,13 @@ impl DisaggRuntime {
super::state::OfflineWorkerState::new(
worker_idx,
config.prefill_args.clone(),
prefill_router.is_some(),
prefill_capture_kv,
)
})
.collect(),
);
let decode_engine = EngineComponent::new(
prefill_engine.set_scaling_args(config.prefill_args.clone(), prefill_capture_kv);
let mut decode_engine = EngineComponent::new(
SimulationWorkerStage::Decode,
EnginePassMode::Visible,
(0..config.num_decode_workers)
......@@ -173,6 +180,7 @@ impl DisaggRuntime {
})
.collect(),
);
decode_engine.set_scaling_args(config.decode_args.clone(), false);
Ok(Self {
now_ms: 0.0,
......@@ -192,6 +200,9 @@ impl DisaggRuntime {
stats: DisaggRuntimeStats::default(),
#[cfg(not(test))]
stats: DisaggRuntimeStats,
prefill_fpm_buffer: Vec::new(),
decode_fpm_buffer: Vec::new(),
traffic: TrafficAccumulator::new(),
})
}
......@@ -209,20 +220,28 @@ impl DisaggRuntime {
.map_or(0, OfflineReplayRouter::pending_count)
}
/// Pick the next prefill worker in round-robin order.
/// Pick the next active prefill worker in round-robin order.
fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx;
self.next_prefill_worker_idx =
(self.next_prefill_worker_idx + 1) % self.prefill_engine.worker_count();
worker_idx
let active = self.prefill_engine.active_worker_ids();
debug_assert!(
!active.is_empty(),
"no active prefill workers for round-robin"
);
let idx = self.next_prefill_worker_idx % active.len();
self.next_prefill_worker_idx = idx + 1;
active[idx]
}
/// Pick the next decode worker in round-robin order.
/// Pick the next active decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx;
self.next_decode_worker_idx =
(self.next_decode_worker_idx + 1) % self.decode_engine.worker_count();
worker_idx
let active = self.decode_engine.active_worker_ids();
debug_assert!(
!active.is_empty(),
"no active decode workers for round-robin"
);
let idx = self.next_decode_worker_idx % active.len();
self.next_decode_worker_idx = idx + 1;
active[idx]
}
/// Track the peak number of requests parked in each stage router.
......@@ -355,7 +374,6 @@ impl DisaggRuntime {
request.tokens.len(),
request.max_output_tokens,
);
let queued_request = request.clone();
self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
......@@ -479,6 +497,11 @@ impl DisaggRuntime {
.transition_log
.push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid });
}
let state = self.state(signal.uuid)?;
let original = state.original_request()?;
let input_tokens = original.tokens.len();
let output_tokens = original.max_output_tokens;
self.traffic.on_request(input_tokens, output_tokens);
self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)]
{
......@@ -626,6 +649,7 @@ impl DisaggRuntime {
}
fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.prefill_fpm_buffer.extend(effects.fpm_snapshots);
self.record_prefill_admissions(effects.admissions);
self.apply_prefill_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions {
......@@ -651,6 +675,7 @@ impl DisaggRuntime {
}
fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.decode_fpm_buffer.extend(effects.fpm_snapshots);
for payload in effects.immediate_completions {
let payload = self.decode_engine.on_scheduled_completion(payload)?;
self.process_decode_pass(
......@@ -693,6 +718,105 @@ impl DisaggRuntime {
}
}
// ------------------------------------------------------------------
// Planner integration: step-based execution
// ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work).
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> Result<bool> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline disagg replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
if next_timestamp_ms > until_ms {
break;
}
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
Ok(self.is_done())
}
/// Current simulated time in milliseconds.
pub(in crate::replay) fn now_ms(&self) -> f64 {
self.now_ms
}
pub(in crate::replay) fn active_prefill_count(&self) -> usize {
self.prefill_engine.active_worker_ids().len()
}
pub(in crate::replay) fn active_decode_count(&self) -> usize {
self.decode_engine.active_worker_ids().len()
}
pub(in crate::replay) fn total_prefill_count(&self) -> usize {
self.prefill_engine.worker_count()
}
pub(in crate::replay) fn total_decode_count(&self) -> usize {
self.decode_engine.worker_count()
}
/// Drain accumulated prefill FPM snapshots since the last drain.
pub(in crate::replay) fn drain_prefill_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.prefill_fpm_buffer)
}
/// Drain accumulated decode FPM snapshots since the last drain.
pub(in crate::replay) fn drain_decode_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.decode_fpm_buffer)
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
self.traffic.drain(self.now_ms)
}
/// Apply a scaling decision with separate prefill and decode targets.
/// Newly marked workers are removed from the router immediately so no
/// new requests land on them while they drain in-flight work.
pub(in crate::replay) fn apply_scaling(
&mut self,
target_prefill: usize,
target_decode: usize,
) -> Result<()> {
let (added, newly_marked) = self.prefill_engine.apply_target_count(target_prefill);
if let Some(router) = self.prefill_router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
let (added, newly_marked) = self.decode_engine.apply_target_count(target_decode);
if let Some(router) = self.decode_router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
Ok(())
}
/// Finalize the replay and return the simulation report directly.
pub(in crate::replay) fn finalize_report(self) -> crate::replay::TraceSimulationReport {
self.progress.finish();
self.collector.finish()
}
/// Run the staged offline replay until both prefill and decode pipelines are drained.
pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?;
......
......@@ -19,22 +19,30 @@ pub(crate) struct AggRequestState {
request: Option<DirectRequest>,
pub(in crate::replay::offline) phase: AggRequestPhase,
pub(in crate::replay::offline) prefill_completed: bool,
pub(in crate::replay::offline) input_tokens: usize,
pub(in crate::replay::offline) output_tokens: usize,
}
impl AggRequestState {
pub(crate) fn new_queued(request: DirectRequest) -> Self {
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
Self {
request: Some(request),
phase: AggRequestPhase::QueuedAtRouter,
prefill_completed: false,
input_tokens,
output_tokens,
}
}
pub(crate) fn new_running() -> Self {
pub(crate) fn new_running(input_tokens: usize, output_tokens: usize) -> Self {
Self {
request: None,
phase: AggRequestPhase::Running,
prefill_completed: false,
input_tokens,
output_tokens,
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Public handle for driving an offline replay with planner-in-the-loop.
//!
//! Supports both aggregated and disaggregated topologies via [`RuntimeKind`].
//! The Python planner adapter calls [`PlannerReplayHandle::advance_to`] to
//! step the simulation, collects metrics, and calls [`PlannerReplayHandle::apply_scaling`]
//! to resize worker pools.
use std::path::Path;
use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
use super::offline::agg::AggRuntime;
use super::offline::components::ReplayMode;
use super::offline::disagg::DisaggRuntime;
use super::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport,
};
use crate::common::protocols::{ForwardPassSnapshot, MockEngineArgs};
use crate::loadgen::Trace;
/// Snapshot of metrics collected between planner ticks.
///
/// For aggregated mode, prefill fields are 0 and all data is in decode fields
/// (matching how the planner treats agg as a single decode-stage engine).
pub struct PlannerTickData {
/// Current simulated time in milliseconds.
pub now_ms: f64,
/// Whether the replay has finished (no more work).
pub is_done: bool,
/// Prefill FPM snapshots since last tick: (worker_id, snapshot).
pub prefill_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
/// Decode (or agg) FPM snapshots since last tick: (worker_id, snapshot).
pub decode_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic observation: (duration_s, num_req, avg_isl, avg_osl).
pub traffic: (f64, usize, f64, f64),
/// Active prefill workers (0 for agg mode).
pub active_prefill_count: usize,
/// Active decode workers (or total active for agg mode).
pub active_decode_count: usize,
/// Total prefill workers including pending removal (0 for agg mode).
pub total_prefill_count: usize,
/// Total decode workers including pending removal (or total for agg mode).
pub total_decode_count: usize,
}
#[allow(clippy::large_enum_variant)]
enum RuntimeKind {
Agg(AggRuntime),
Disagg(DisaggRuntime),
}
pub struct PlannerReplayHandle {
runtime: RuntimeKind,
started_at: Instant,
}
impl PlannerReplayHandle {
/// Create a handle for an aggregated trace-file replay.
#[allow(clippy::too_many_arguments)]
pub fn from_trace_file(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let args = args.normalized()?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let runtime = AggRuntime::new_workload(
&args,
router_config,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(args.block_size)?,
num_workers,
ReplayMode::Trace,
router_mode,
)?;
Ok(Self {
runtime: RuntimeKind::Agg(runtime),
started_at: Instant::now(),
})
}
/// Create a handle for a disaggregated trace-file replay.
pub fn from_trace_file_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let config = config.normalized()?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let runtime = DisaggRuntime::new_workload(
&config,
router_config,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(config.decode_args.block_size)?,
ReplayMode::Trace,
router_mode,
)?;
Ok(Self {
runtime: RuntimeKind::Disagg(runtime),
started_at: Instant::now(),
})
}
/// Advance the simulation up to `until_ms`, collect metrics, return tick data.
pub fn advance_to(&mut self, until_ms: f64) -> Result<PlannerTickData> {
match &mut self.runtime {
RuntimeKind::Agg(rt) => {
let is_done = rt.advance_to(until_ms)?;
let fpm = rt.drain_fpm();
let traffic = rt.drain_traffic();
Ok(PlannerTickData {
now_ms: rt.now_ms(),
is_done,
prefill_fpm_snapshots: Vec::new(),
decode_fpm_snapshots: fpm,
traffic,
active_prefill_count: 0,
active_decode_count: rt.active_worker_count(),
total_prefill_count: 0,
total_decode_count: rt.total_worker_count(),
})
}
RuntimeKind::Disagg(rt) => {
let is_done = rt.advance_to(until_ms)?;
let prefill_fpm = rt.drain_prefill_fpm();
let decode_fpm = rt.drain_decode_fpm();
let traffic = rt.drain_traffic();
Ok(PlannerTickData {
now_ms: rt.now_ms(),
is_done,
prefill_fpm_snapshots: prefill_fpm,
decode_fpm_snapshots: decode_fpm,
traffic,
active_prefill_count: rt.active_prefill_count(),
active_decode_count: rt.active_decode_count(),
total_prefill_count: rt.total_prefill_count(),
total_decode_count: rt.total_decode_count(),
})
}
}
}
/// Apply a scaling decision with separate prefill and decode targets.
/// For agg mode, `target_prefill` is ignored.
pub fn apply_scaling(&mut self, target_prefill: usize, target_decode: usize) -> Result<()> {
match &mut self.runtime {
RuntimeKind::Agg(rt) => rt.apply_scaling(target_decode),
RuntimeKind::Disagg(rt) => rt.apply_scaling(target_prefill, target_decode),
}
}
/// Finalize the replay and return the report.
pub fn finalize(self) -> TraceSimulationReport {
let report = match self.runtime {
RuntimeKind::Agg(rt) => rt.finalize_report(),
RuntimeKind::Disagg(rt) => rt.finalize_report(),
};
report.with_wall_time_ms(self.started_at.elapsed().as_secs_f64() * 1000.0)
}
}
......@@ -36,6 +36,7 @@ pub(super) struct SglangConfig {
pub(super) decode_speedup_ratio: f64,
pub(super) worker_type: WorkerType,
pub(super) block_size: usize,
pub(super) total_kv_tokens: usize,
pub(super) kv_bytes_per_token: Option<usize>,
pub(super) kv_transfer_bandwidth: Option<f64>,
}
......@@ -83,6 +84,7 @@ impl SglangConfig {
decode_speedup_ratio: args.decode_speedup_ratio,
worker_type: args.worker_type,
block_size: args.block_size,
total_kv_tokens: args.num_gpu_blocks * args.block_size,
kv_bytes_per_token: args.kv_bytes_per_token,
kv_transfer_bandwidth: args.kv_transfer_bandwidth,
}
......
......@@ -140,10 +140,13 @@ pub(super) fn simulate_decode_step(
.map(SglangRequest::current_sequence_len)
.sum();
let avg_context = total_context / running.len();
let decode_time =
config
.perf_model
.predict_decode_time(running.len(), total_context, avg_context);
let active_kv_tokens = total_context.min(config.total_kv_tokens);
let decode_time = config.perf_model.predict_decode_time(
running.len(),
active_kv_tokens,
avg_context,
config.total_kv_tokens,
);
let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0);
let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
......
......@@ -657,19 +657,28 @@ impl VllmCore {
return (Duration::ZERO, Vec::new());
}
// For prefill workers, the first decode token is produced as part of
// the prefill forward pass — no separate decode iteration needed.
let (decode_time, decode_end_ms) = if self.args.worker_type == WorkerType::Prefill {
(Duration::ZERO, decode_start_ms)
} else {
let active_kv_tokens = self.kv_manager.num_active_blocks() * self.args.block_size;
let total_kv_tokens = self.args.num_gpu_blocks * self.args.block_size;
let total_length = ready
.iter()
.filter_map(|uuid| self.state.requests.get(uuid))
.map(|request| request.sequence.len())
.sum::<usize>();
let context_length = total_length / ready.len();
let decode_ms =
self.args
.perf_model
.predict_decode_time(ready.len(), active_kv_tokens, context_length);
let decode_time = scale_decode_time(decode_ms, &self.args);
let decode_end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
let decode_ms = self.args.perf_model.predict_decode_time(
ready.len(),
active_kv_tokens,
context_length,
total_kv_tokens,
);
let dt = scale_decode_time(decode_ms, &self.args);
(dt, decode_start_ms + dt.as_secs_f64() * 1000.0)
};
let mut output_signals = Vec::with_capacity(ready.len());
for uuid in ready {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment