Unverified Commit 89d67172 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(replay): planner-in-the-loop offline replay (#8187)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent ee9d67e4
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Adapter that drives PlannerStateMachine via the PlannerReplayBridge.
The bridge (Rust, PyO3) runs the offline simulation step-by-step.
This adapter sits between the bridge and the planner state machine:
Bridge.advance_to(tick_ms) → raw metrics dict
Adapter._build_tick_input() → TickInput
StateMachine.on_tick() → PlannerEffects
Adapter → Bridge.apply_scaling(prefill, decode)
Supports both aggregated and disaggregated topologies. No I/O, no runtime
dependencies. Fully deterministic when used with offline replay.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.state_machine import PlannerStateMachine
from dynamo.planner.core.types import (
FpmObservations,
PlannerEffects,
ScheduledTick,
TickDiagnostics,
TickInput,
TrafficObservation,
WorkerCapabilities,
WorkerCounts,
)
logger = logging.getLogger(__name__)
@dataclass
class ScalingEvent:
"""Record of a single scaling decision."""
at_s: float
component: str # "agg", "prefill", or "decode"
from_count: int
to_count: int
reason: Optional[str] = None
@dataclass
class ReplayPlannerReport:
"""Enriched report combining trace metrics and planner diagnostics."""
trace_report: dict[str, Any]
scaling_events: list[ScalingEvent] = field(default_factory=list)
diagnostics_log: list[TickDiagnostics] = field(default_factory=list)
total_ticks: int = 0
def _build_fpm_from_dict(d: dict[str, Any]) -> ForwardPassMetrics:
"""Convert a bridge FPM snapshot dict into a ForwardPassMetrics struct."""
return ForwardPassMetrics(
worker_id=str(d["worker_id"]),
dp_rank=0,
wall_time=d["wall_time"],
scheduled_requests=ScheduledRequestMetrics(
num_prefill_requests=d["num_prefill_requests"],
sum_prefill_tokens=d["sum_prefill_tokens"],
var_prefill_length=d["var_prefill_length"],
sum_prefill_kv_tokens=d["sum_prefill_kv_tokens"],
num_decode_requests=d["num_decode_requests"],
sum_decode_kv_tokens=d["sum_decode_kv_tokens"],
var_decode_kv_tokens=d["var_decode_kv_tokens"],
),
queued_requests=QueuedRequestMetrics(
num_prefill_requests=d["num_queued_prefill"],
sum_prefill_tokens=d["sum_queued_prefill_tokens"],
var_prefill_length=d["var_queued_prefill_length"],
num_decode_requests=d["num_queued_decode"],
sum_decode_kv_tokens=d["sum_queued_decode_kv_tokens"],
var_decode_kv_tokens=d["var_queued_decode_kv_tokens"],
),
)
def _update_fpm_cache(
cache: dict[tuple[str, int], ForwardPassMetrics],
snapshots: list[dict[str, Any]],
active_count: int,
) -> None:
"""Update a last-seen FPM cache with new snapshots and prune removed workers."""
for snap in snapshots:
fpm = _build_fpm_from_dict(snap)
cache[(fpm.worker_id, fpm.dp_rank)] = fpm
# Prune cache down to active_count entries. Workers are removed
# highest-ID-first during scale-down, so keep the lowest IDs.
while len(cache) > active_count:
# Remove the highest worker ID entry
worst_key = max(cache.keys(), key=lambda k: int(k[0]))
del cache[worst_key]
class ReplayPlannerAdapter:
"""Drives the planner state machine using the PlannerReplayBridge.
Supports both ``mode="agg"`` and ``mode="disagg"``.
"""
def __init__(
self,
planner_config: PlannerConfig,
bridge: Any, # PlannerReplayBridge (Rust pyclass)
capabilities: Optional[WorkerCapabilities] = None,
warmup_observations: Optional[list[TrafficObservation]] = None,
) -> None:
self._config = planner_config
self._bridge = bridge
self._sm = PlannerStateMachine(planner_config, capabilities)
self._is_disagg = planner_config.mode == "disagg"
# Last-seen FPM caches (separate for prefill/decode)
self._prefill_fpm_cache: dict[tuple[str, int], ForwardPassMetrics] = {}
self._decode_fpm_cache: dict[tuple[str, int], ForwardPassMetrics] = {}
# Scaling targets — used as `expected` in WorkerCounts
self._scaling_target_prefill: Optional[int] = None
self._scaling_target_decode: Optional[int] = None
if warmup_observations:
self._sm.warm_load_predictors(warmup_observations)
def run(self) -> ReplayPlannerReport:
"""Run the full replay with planner-in-the-loop."""
next_tick = self._sm.initial_tick(0.0)
scaling_events: list[ScalingEvent] = []
diagnostics_log: list[TickDiagnostics] = []
total_ticks = 0
while True:
tick_ms = next_tick.at_s * 1000.0
result = self._bridge.advance_to(tick_ms)
if result["is_done"]:
break
tick_input = self._build_tick_input(next_tick, result)
effects: PlannerEffects = self._sm.on_tick(next_tick, tick_input)
diagnostics_log.append(effects.diagnostics)
total_ticks += 1
# Clear scaling targets once active counts match
active_p = result["active_prefill_count"]
active_d = result["active_decode_count"]
if (
self._scaling_target_prefill is not None
and active_p == self._scaling_target_prefill
):
self._scaling_target_prefill = None
if (
self._scaling_target_decode is not None
and active_d == self._scaling_target_decode
):
self._scaling_target_decode = None
if effects.scale_to is not None:
self._apply_scaling(effects, result, tick_input.now_s, scaling_events)
if effects.next_tick is None:
break
next_tick = effects.next_tick
trace_report = self._bridge.finalize()
return ReplayPlannerReport(
trace_report=trace_report,
scaling_events=scaling_events,
diagnostics_log=diagnostics_log,
total_ticks=total_ticks,
)
def _apply_scaling(
self,
effects: PlannerEffects,
result: dict[str, Any],
now_s: float,
scaling_events: list[ScalingEvent],
) -> None:
"""Apply scaling decisions and record events."""
scale = effects.scale_to
assert scale is not None
current_p = result["active_prefill_count"]
current_d = result["active_decode_count"]
target_p = scale.num_prefill if scale.num_prefill is not None else current_p
target_d = scale.num_decode if scale.num_decode is not None else current_d
if target_p == current_p and target_d == current_d:
return
self._bridge.apply_scaling(target_p, target_d)
if self._is_disagg:
if scale.num_prefill is not None and target_p != current_p:
direction = "scale_up" if target_p > current_p else "scale_down"
logger.info(
"Planner scaling prefill: %d -> %d at t=%.1fs (%s)",
current_p,
target_p,
now_s,
direction,
)
self._scaling_target_prefill = target_p
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="prefill",
from_count=current_p,
to_count=target_p,
reason=direction,
)
)
if scale.num_decode is not None and target_d != current_d:
direction = "scale_up" if target_d > current_d else "scale_down"
logger.info(
"Planner scaling decode: %d -> %d at t=%.1fs (%s)",
current_d,
target_d,
now_s,
direction,
)
self._scaling_target_decode = target_d
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="decode",
from_count=current_d,
to_count=target_d,
reason=direction,
)
)
else:
direction = "scale_up" if target_d > current_d else "scale_down"
logger.info(
"Planner scaling: %d -> %d workers at t=%.1fs (%s)",
current_d,
target_d,
now_s,
direction,
)
self._scaling_target_decode = target_d
scaling_events.append(
ScalingEvent(
at_s=now_s,
component="agg",
from_count=current_d,
to_count=target_d,
reason=direction,
)
)
def _feed_extra_fpm_to_regression(
self,
decode_snaps: list[dict[str, Any]],
prefill_snaps: list[dict[str, Any]],
) -> None:
"""Feed accumulated FPM snapshots to regression, excluding the last
per worker (which will be added by _observe_fpm via fpm_observations).
This avoids double-counting the cached snapshot."""
if not hasattr(self._sm, "_is_easy") or self._sm._is_easy:
return # easy mode has no regression models
if self._sm._is_agg:
# Exclude the last snapshot per worker (it's in the cache and
# will be added by _observe_fpm)
last_idx_per_worker: dict[int, int] = {}
for i, snap in enumerate(decode_snaps):
last_idx_per_worker[snap["worker_id"]] = i
exclude = set(last_idx_per_worker.values())
for i, snap in enumerate(decode_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._agg_regression.add_observation(fpm)
else:
if self._sm._has_prefill:
last_idx: dict[int, int] = {}
for i, snap in enumerate(prefill_snaps):
last_idx[snap["worker_id"]] = i
exclude = set(last_idx.values())
for i, snap in enumerate(prefill_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._prefill_regression.add_observation(fpm)
if self._sm._has_decode:
last_idx = {}
for i, snap in enumerate(decode_snaps):
last_idx[snap["worker_id"]] = i
exclude = set(last_idx.values())
for i, snap in enumerate(decode_snaps):
if i in exclude:
continue
fpm = _build_fpm_from_dict(snap)
if fpm.wall_time > 0.0:
self._sm._decode_regression.add_observation(fpm)
def _build_tick_input(
self, tick: ScheduledTick, result: dict[str, Any]
) -> TickInput:
"""Convert bridge result dict to planner TickInput."""
now_s = result["now_ms"] / 1000.0
worker_counts = None
if tick.need_worker_states:
active_p = result["active_prefill_count"]
active_d = result["active_decode_count"]
expected_p = (
self._scaling_target_prefill
if self._scaling_target_prefill is not None
else active_p
)
expected_d = (
self._scaling_target_decode
if self._scaling_target_decode is not None
else active_d
)
worker_counts = WorkerCounts(
ready_num_prefill=active_p if self._is_disagg else None,
ready_num_decode=active_d,
expected_num_prefill=expected_p if self._is_disagg else None,
expected_num_decode=expected_d,
)
fpm_observations = None
if tick.need_worker_fpm:
prefill_snaps = result.get("prefill_fpm_snapshots", [])
decode_snaps = result.get("decode_fpm_snapshots", [])
_update_fpm_cache(
self._prefill_fpm_cache, prefill_snaps, result["active_prefill_count"]
)
_update_fpm_cache(
self._decode_fpm_cache, decode_snaps, result["active_decode_count"]
)
# In offline replay, we accumulate many FPM snapshots per tick
# (one per engine pass). Feed ALL non-idle snapshots directly to
# the regression models for a representative fit. The last-per-worker
# cache is only used for the FpmObservations dict (worker count
# reconciliation), not as the sole regression input.
prefill_dict = (
dict(self._prefill_fpm_cache) if self._prefill_fpm_cache else None
)
decode_dict = (
dict(self._decode_fpm_cache) if self._decode_fpm_cache else None
)
self._feed_extra_fpm_to_regression(decode_snaps, prefill_snaps)
fpm_observations = FpmObservations(
prefill=prefill_dict,
decode=decode_dict,
)
traffic = None
if tick.need_traffic_metrics:
t = result.get("traffic", {})
duration_s = t.get("duration_s", 0.0)
if duration_s > 0:
traffic = TrafficObservation(
duration_s=duration_s,
num_req=float(t.get("num_req", 0)),
isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0),
)
return TickInput(
now_s=now_s,
traffic=traffic,
worker_counts=worker_counts,
fpm_observations=fpm_observations,
)
...@@ -177,6 +177,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -177,6 +177,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::replay::ReasoningConfig>()?; m.add_class::<llm::replay::ReasoningConfig>()?;
m.add_class::<llm::replay::SglangArgs>()?; m.add_class::<llm::replay::SglangArgs>()?;
m.add_class::<llm::replay::MockEngineArgs>()?; m.add_class::<llm::replay::MockEngineArgs>()?;
m.add_class::<llm::replay::PlannerReplayBridge>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?; m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API m.add_class::<llm::model_card::ModelDeploymentCard>()?; // Internal: only in _internal, not public API
m.add_class::<llm::local_model::ModelRuntimeConfig>()?; m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
......
...@@ -1196,3 +1196,179 @@ fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 { ...@@ -1196,3 +1196,179 @@ fn synthetic_token_id(request_idx: usize, token_idx: usize) -> u32 {
let token = value as u32; let token = value as u32;
if token == 0 { 1 } else { token } if token == 0 { 1 } else { token }
} }
// ---------------------------------------------------------------------------
// Planner-in-the-loop replay bridge
// ---------------------------------------------------------------------------
fn fpm_snapshots_to_json(
snapshots: Vec<(usize, dynamo_mocker::common::protocols::ForwardPassSnapshot)>,
) -> Vec<serde_json::Value> {
snapshots
.into_iter()
.map(|(worker_id, fpm)| {
json!({
"worker_id": worker_id,
"wall_time": fpm.wall_time_secs,
"num_prefill_requests": fpm.num_prefill_requests,
"sum_prefill_tokens": fpm.sum_prefill_tokens,
"var_prefill_length": fpm.var_prefill_length,
"sum_prefill_kv_tokens": fpm.sum_prefill_kv_tokens,
"num_decode_requests": fpm.num_decode_requests,
"sum_decode_kv_tokens": fpm.sum_decode_kv_tokens,
"var_decode_kv_tokens": fpm.var_decode_kv_tokens,
"num_queued_prefill": fpm.num_queued_prefill,
"sum_queued_prefill_tokens": fpm.sum_queued_prefill_tokens,
"var_queued_prefill_length": fpm.var_queued_prefill_length,
"num_queued_decode": fpm.num_queued_decode,
"sum_queued_decode_kv_tokens": fpm.sum_queued_decode_kv_tokens,
"var_queued_decode_kv_tokens": fpm.var_queued_decode_kv_tokens,
})
})
.collect()
}
/// Step-based bridge for driving an offline replay with a Python planner.
///
/// Supports both aggregated and disaggregated topologies. The Python adapter
/// calls `advance_to()` to run the simulation forward, collects FPM/traffic
/// metrics, feeds them to the planner state machine, then calls
/// `apply_scaling()` to resize worker pools.
#[pyclass(unsendable)]
pub struct PlannerReplayBridge {
handle: Option<dynamo_mocker::replay::PlannerReplayHandle>,
}
#[pymethods]
impl PlannerReplayBridge {
/// Create a bridge for an aggregated Mooncake-style JSONL trace replay.
#[new]
#[pyo3(signature = (trace_file, extra_engine_args, num_workers, router_mode="round_robin", router_config=None, arrival_speedup_ratio=1.0, trace_block_size=512))]
fn new(
trace_file: PathBuf,
extra_engine_args: &MockEngineArgs,
num_workers: usize,
router_mode: &str,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<Self> {
let args = extra_engine_args.inner();
let router_mode = parse_replay_router_mode(router_mode)?;
let router_config = load_replay_router_config(router_config);
let handle = dynamo_mocker::replay::PlannerReplayHandle::from_trace_file(
args,
router_config,
None,
&trace_file,
trace_block_size,
num_workers,
arrival_speedup_ratio,
router_mode,
)
.map_err(to_pyerr)?;
Ok(Self {
handle: Some(handle),
})
}
/// Create a bridge for a disaggregated Mooncake-style JSONL trace replay.
#[staticmethod]
#[pyo3(signature = (trace_file, prefill_engine_args, decode_engine_args, num_prefill_workers, num_decode_workers, router_mode="round_robin", router_config=None, arrival_speedup_ratio=1.0, trace_block_size=512))]
#[allow(clippy::too_many_arguments)]
fn create_disagg(
trace_file: PathBuf,
prefill_engine_args: &MockEngineArgs,
decode_engine_args: &MockEngineArgs,
num_prefill_workers: usize,
num_decode_workers: usize,
router_mode: &str,
router_config: Option<KvRouterConfig>,
arrival_speedup_ratio: f64,
trace_block_size: usize,
) -> PyResult<Self> {
let config = dynamo_mocker::replay::OfflineDisaggReplayConfig {
prefill_args: prefill_engine_args.inner(),
decode_args: decode_engine_args.inner(),
num_prefill_workers,
num_decode_workers,
};
let router_mode = parse_replay_router_mode(router_mode)?;
let router_config = load_replay_router_config(router_config);
let handle = dynamo_mocker::replay::PlannerReplayHandle::from_trace_file_disagg(
config,
router_config,
None,
&trace_file,
trace_block_size,
arrival_speedup_ratio,
router_mode,
)
.map_err(to_pyerr)?;
Ok(Self {
handle: Some(handle),
})
}
/// Advance the simulation to `until_ms` simulated time.
///
/// Returns a dict with separate prefill/decode worker counts and FPM snapshots.
fn advance_to(&mut self, py: Python<'_>, until_ms: f64) -> PyResult<PyObject> {
let handle = self
.handle
.as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
let tick_data = handle.advance_to(until_ms).map_err(to_pyerr)?;
let (duration_s, num_req, avg_isl, avg_osl) = tick_data.traffic;
let result = json!({
"now_ms": tick_data.now_ms,
"is_done": tick_data.is_done,
"prefill_fpm_snapshots": fpm_snapshots_to_json(tick_data.prefill_fpm_snapshots),
"decode_fpm_snapshots": fpm_snapshots_to_json(tick_data.decode_fpm_snapshots),
"traffic": {
"duration_s": duration_s,
"num_req": num_req,
"avg_isl": avg_isl,
"avg_osl": avg_osl,
},
"active_prefill_count": tick_data.active_prefill_count,
"active_decode_count": tick_data.active_decode_count,
"total_prefill_count": tick_data.total_prefill_count,
"total_decode_count": tick_data.total_decode_count,
});
pythonize(py, &result)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
/// Apply a scaling decision with separate prefill and decode targets.
/// For agg mode, `target_prefill` is ignored (pass 0).
fn apply_scaling(&mut self, target_prefill: usize, target_decode: usize) -> PyResult<()> {
let handle = self
.handle
.as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
handle
.apply_scaling(target_prefill, target_decode)
.map_err(to_pyerr)
}
/// Finalize the replay and return the trace simulation report.
fn finalize(&mut self, py: Python<'_>) -> PyResult<PyObject> {
let handle = self
.handle
.take()
.ok_or_else(|| PyException::new_err("bridge has already been finalized"))?;
let report = handle.finalize();
pythonize(py, &report)
.map_err(to_pyerr)
.map(|obj| obj.unbind())
}
}
...@@ -1570,6 +1570,37 @@ def run_mocker_synthetic_trace_replay( ...@@ -1570,6 +1570,37 @@ def run_mocker_synthetic_trace_replay(
"""Replay a synthetic mocker workload without requiring a trace file.""" """Replay a synthetic mocker workload without requiring a trace file."""
... ...
class PlannerReplayBridge:
"""Step-based bridge for driving an offline replay with a Python planner."""
def __init__(
self,
trace_file: str | os.PathLike[str],
extra_engine_args: MockEngineArgs,
num_workers: int,
router_mode: str = "round_robin",
router_config: Optional[KvRouterConfig] = None,
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> None: ...
@staticmethod
def create_disagg(
trace_file: str | os.PathLike[str],
prefill_engine_args: MockEngineArgs,
decode_engine_args: MockEngineArgs,
num_prefill_workers: int,
num_decode_workers: int,
router_mode: str = "round_robin",
router_config: Optional[KvRouterConfig] = None,
arrival_speedup_ratio: float = 1.0,
trace_block_size: int = 512,
) -> "PlannerReplayBridge": ...
def advance_to(self, until_ms: float) -> Dict[str, Any]: ...
def apply_scaling(self, target_prefill: int, target_decode: int) -> None: ...
def finalize(self) -> Dict[str, Any]: ...
class Layer: class Layer:
""" """
A KV cache block layer A KV cache block layer
......
...@@ -25,6 +25,7 @@ from dynamo._core import ModelInput as ModelInput ...@@ -25,6 +25,7 @@ from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import PlannerReplayBridge as PlannerReplayBridge
from dynamo._core import PythonAsyncEngine as PythonAsyncEngine from dynamo._core import PythonAsyncEngine as PythonAsyncEngine
from dynamo._core import RadixTree as RadixTree from dynamo._core import RadixTree as RadixTree
from dynamo._core import ReasoningConfig as ReasoningConfig from dynamo._core import ReasoningConfig as ReasoningConfig
......
...@@ -11,7 +11,10 @@ import sys ...@@ -11,7 +11,10 @@ import sys
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Protocol from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from dynamo.planner.core.types import EngineCapabilities
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1") os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
...@@ -106,6 +109,100 @@ def _load_aic_perf_config(args: argparse.Namespace): ...@@ -106,6 +109,100 @@ def _load_aic_perf_config(args: argparse.Namespace):
) )
def _engine_caps(args: MockEngineArgs) -> EngineCapabilities:
"""Derive EngineCapabilities from MockEngineArgs."""
from dynamo.planner.core.types import EngineCapabilities
max_kv_tokens = args.num_gpu_blocks * args.block_size
return EngineCapabilities(
num_gpu=1,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_seqs=args.max_num_seqs,
context_length=max_kv_tokens if max_kv_tokens > 0 else None,
max_kv_tokens=max_kv_tokens if max_kv_tokens > 0 else None,
)
def _run_planner_replay(
trace_file: str,
extra_engine_args: MockEngineArgs | None,
prefill_engine_args: MockEngineArgs | None,
decode_engine_args: MockEngineArgs | None,
router_config: KvRouterConfig | None,
num_workers: int,
num_prefill_workers: int,
num_decode_workers: int,
router_mode: str,
arrival_speedup_ratio: float,
trace_block_size: int,
planner_config_arg: str,
):
"""Run an offline replay with planner-in-the-loop (agg or disagg).
# TODO(jthomson04): SLA-based scaling (optimization_target="sla") with
# disagg mode requires planner_profile_data (NPZ) or AIC-backed engine
# args. The default polynomial perf model does not account for batch
# size in its decode timing, causing the DecodeRegressionModel's
# num_decode_requests coefficient to go negative and reject the fit.
# Fix the polynomial model to incorporate batch_size, or gate disagg
# SLA mode on having a non-polynomial perf model.
"""
from dynamo.llm import PlannerReplayBridge
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.types import WorkerCapabilities
from dynamo.planner.offline.replay_adapter import ReplayPlannerAdapter
planner_config = PlannerConfig.from_config_arg(planner_config_arg)
planner_config.no_operation = True
if planner_config.mode == "agg":
if extra_engine_args is None:
extra_engine_args = MockEngineArgs()
bridge = PlannerReplayBridge(
trace_file=trace_file,
extra_engine_args=extra_engine_args,
num_workers=num_workers,
router_mode=router_mode,
router_config=router_config,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
capabilities = WorkerCapabilities(decode=_engine_caps(extra_engine_args))
elif planner_config.mode == "disagg":
if prefill_engine_args is None or decode_engine_args is None:
raise ValueError(
"disagg planner replay requires --prefill-engine-args and --decode-engine-args"
)
bridge = PlannerReplayBridge.create_disagg(
trace_file=trace_file,
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
num_prefill_workers=num_prefill_workers,
num_decode_workers=num_decode_workers,
router_mode=router_mode,
router_config=router_config,
arrival_speedup_ratio=arrival_speedup_ratio,
trace_block_size=trace_block_size,
)
capabilities = WorkerCapabilities(
prefill=_engine_caps(prefill_engine_args),
decode=_engine_caps(decode_engine_args),
)
else:
raise ValueError(
f"planner-in-the-loop replay supports mode='agg' or 'disagg', got '{planner_config.mode}'"
)
adapter = ReplayPlannerAdapter(
planner_config=planner_config,
bridge=bridge,
capabilities=capabilities,
)
return adapter.run()
def main(argv: Sequence[str] | None = None) -> int: def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(prog="python -m dynamo.replay") parser = argparse.ArgumentParser(prog="python -m dynamo.replay")
parser.add_argument("trace_file", nargs="?") parser.add_argument("trace_file", nargs="?")
...@@ -155,6 +252,10 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -155,6 +252,10 @@ def main(argv: Sequence[str] | None = None) -> int:
"--report-json", "--report-json",
help="path to save the full replay report JSON; defaults to a timestamped file in the current directory", help="path to save the full replay report JSON; defaults to a timestamped file in the current directory",
) )
parser.add_argument(
"--planner-config",
help="path to planner config YAML/JSON or inline JSON; enables planner-in-the-loop replay (offline agg only)",
)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv)) args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None using_trace_file = args.trace_file is not None
...@@ -190,6 +291,43 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -190,6 +291,43 @@ def main(argv: Sequence[str] | None = None) -> int:
except ValueError as exc: except ValueError as exc:
parser.error(str(exc)) parser.error(str(exc))
# Planner-in-the-loop mode
if args.planner_config is not None:
if args.replay_mode != "offline":
parser.error("--planner-config only supports --replay-mode=offline")
if not using_trace_file:
parser.error("--planner-config requires a trace file (not synthetic)")
planner_report = _run_planner_replay(
trace_file=args.trace_file,
extra_engine_args=extra_engine_args,
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
num_workers=args.num_workers,
num_prefill_workers=args.num_prefill_workers,
num_decode_workers=args.num_decode_workers,
router_mode=args.router_mode,
arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size,
planner_config_arg=args.planner_config,
)
report = planner_report.trace_report
if planner_report.scaling_events:
sys.stdout.write("\nScaling events:\n")
for event in planner_report.scaling_events:
sys.stdout.write(
f" t={event.at_s:.1f}s [{event.component}]: "
f"{event.from_count} -> {event.to_count} workers"
f" ({event.reason})\n"
)
report_path = write_report_json(report, args.report_json)
sys.stdout.write(format_report_table(report))
sys.stdout.write("\n")
sys.stdout.write(f"Saved full report to: {report_path}\n")
sys.stdout.write(f"Planner ticks: {planner_report.total_ticks}\n")
return 0
if using_trace_file: if using_trace_file:
report = run_trace_replay( report = run_trace_replay(
args.trace_file, args.trace_file,
......
...@@ -240,7 +240,7 @@ impl PerfModel { ...@@ -240,7 +240,7 @@ impl PerfModel {
/// Predict decode time in milliseconds. /// Predict decode time in milliseconds.
/// ///
/// Callers always pass all parameters; each variant uses what it needs: /// Callers always pass all parameters; each variant uses what it needs:
/// - Polynomial: uses active_kv_tokens /// - Polynomial: uses (active_kv_tokens, total_kv_tokens) as utilization
/// - Interpolated: uses (active_kv_tokens, context_length) /// - Interpolated: uses (active_kv_tokens, context_length)
/// - Aiconfigurator: uses (batch_size, context_length) /// - Aiconfigurator: uses (batch_size, context_length)
pub fn predict_decode_time( pub fn predict_decode_time(
...@@ -248,13 +248,19 @@ impl PerfModel { ...@@ -248,13 +248,19 @@ impl PerfModel {
batch_size: usize, batch_size: usize,
active_kv_tokens: usize, active_kv_tokens: usize,
context_length: usize, context_length: usize,
total_kv_tokens: usize,
) -> f64 { ) -> f64 {
if batch_size == 0 { if batch_size == 0 {
return 0.0; return 0.0;
} }
let time = match self { let time = match self {
PerfModel::Polynomial => { PerfModel::Polynomial => {
let active_perc = active_kv_tokens as f64 / 16384.0; let active_perc = if total_kv_tokens > 0 {
active_kv_tokens as f64 / total_kv_tokens as f64
} else {
tracing::warn!("Total KV tokens is 0, using 1.0 as capacity");
1.0
};
-25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74 -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74
} }
PerfModel::Interpolated { decode_interp, .. } => decode_interp PerfModel::Interpolated { decode_interp, .. } => decode_interp
......
...@@ -6,6 +6,7 @@ mod collector; ...@@ -6,6 +6,7 @@ mod collector;
mod entrypoints; mod entrypoints;
pub(crate) mod offline; pub(crate) mod offline;
mod online; mod online;
mod planner_handle;
mod router_shared; mod router_shared;
mod validate; mod validate;
...@@ -76,6 +77,7 @@ pub use entrypoints::{ ...@@ -76,6 +77,7 @@ pub use entrypoints::{
simulate_trace_requests_with_router_mode, simulate_trace_workload, simulate_trace_requests_with_router_mode, simulate_trace_workload,
simulate_trace_workload_disagg_with_router_mode, simulate_trace_workload_with_router_mode, simulate_trace_workload_disagg_with_router_mode, simulate_trace_workload_with_router_mode,
}; };
pub use planner_handle::{PlannerReplayHandle, PlannerTickData};
pub use validate::validate_replay_args_mode; pub use validate::validate_replay_args_mode;
pub(crate) fn normalize_trace_requests( pub(crate) fn normalize_trace_requests(
......
...@@ -16,11 +16,11 @@ use super::state::OfflineWorkerSnapshot; ...@@ -16,11 +16,11 @@ use super::state::OfflineWorkerSnapshot;
use super::{ use super::{
components::{ components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter, AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission, ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
}, },
state::AggRequestState, state::AggRequestState,
}; };
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, ForwardPassSnapshot, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector}; use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector};
use anyhow::bail; use anyhow::bail;
...@@ -59,7 +59,7 @@ struct AggRuntimeSnapshot { ...@@ -59,7 +59,7 @@ struct AggRuntimeSnapshot {
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct AggRuntimeStats; pub(super) struct AggRuntimeStats;
pub(super) struct AggRuntime { pub(in crate::replay) struct AggRuntime {
now_ms: f64, now_ms: f64,
next_worker_idx: usize, next_worker_idx: usize,
next_event_seq: u64, next_event_seq: u64,
...@@ -71,6 +71,10 @@ pub(super) struct AggRuntime { ...@@ -71,6 +71,10 @@ pub(super) struct AggRuntime {
router: Option<OfflineReplayRouter>, router: Option<OfflineReplayRouter>,
progress: ReplayProgress, progress: ReplayProgress,
stats: AggRuntimeStats, stats: AggRuntimeStats,
/// Forward pass metrics accumulated between planner ticks.
fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic statistics accumulated between planner ticks.
traffic: TrafficAccumulator,
#[cfg(test)] #[cfg(test)]
worker_active_requests: Vec<Vec<Uuid>>, worker_active_requests: Vec<Vec<Uuid>>,
#[cfg(test)] #[cfg(test)]
...@@ -79,7 +83,7 @@ pub(super) struct AggRuntime { ...@@ -79,7 +83,7 @@ pub(super) struct AggRuntime {
impl AggRuntime { impl AggRuntime {
/// Create an aggregated offline runtime seeded from an explicit request queue. /// Create an aggregated offline runtime seeded from an explicit request queue.
pub(super) fn new( pub(in crate::replay) fn new(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
...@@ -99,7 +103,7 @@ impl AggRuntime { ...@@ -99,7 +103,7 @@ impl AggRuntime {
} }
/// Create an aggregated offline runtime whose admissions come from a workload driver. /// Create an aggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload( pub(in crate::replay) fn new_workload(
args: &MockEngineArgs, args: &MockEngineArgs,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
...@@ -139,7 +143,7 @@ impl AggRuntime { ...@@ -139,7 +143,7 @@ impl AggRuntime {
)?), )?),
}; };
let capture_kv_events = router.is_some(); let capture_kv_events = router.is_some();
let engine = EngineComponent::new( let mut engine = EngineComponent::new(
SimulationWorkerStage::Aggregated, SimulationWorkerStage::Aggregated,
EnginePassMode::Visible, EnginePassMode::Visible,
(0..num_workers) (0..num_workers)
...@@ -152,6 +156,7 @@ impl AggRuntime { ...@@ -152,6 +156,7 @@ impl AggRuntime {
}) })
.collect(), .collect(),
); );
engine.set_scaling_args(args, capture_kv_events);
Ok(Self { Ok(Self {
now_ms: 0.0, now_ms: 0.0,
...@@ -168,6 +173,8 @@ impl AggRuntime { ...@@ -168,6 +173,8 @@ impl AggRuntime {
stats: AggRuntimeStats::default(), stats: AggRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
stats: AggRuntimeStats, stats: AggRuntimeStats,
fpm_buffer: Vec::new(),
traffic: TrafficAccumulator::new(),
#[cfg(test)] #[cfg(test)]
worker_active_requests: vec![Vec::new(); num_workers], worker_active_requests: vec![Vec::new(); num_workers],
#[cfg(test)] #[cfg(test)]
...@@ -208,11 +215,13 @@ impl AggRuntime { ...@@ -208,11 +215,13 @@ impl AggRuntime {
} }
} }
/// Pick the next worker in round-robin order. /// Pick the next active worker in round-robin order.
fn next_worker(&mut self) -> usize { fn next_worker(&mut self) -> usize {
let worker_idx = self.next_worker_idx; let active = self.engine.active_worker_ids();
self.next_worker_idx = (self.next_worker_idx + 1) % self.engine.worker_count(); debug_assert!(!active.is_empty(), "no active workers for round-robin");
worker_idx let idx = self.next_worker_idx % active.len();
self.next_worker_idx = idx + 1;
active[idx]
} }
/// Record which worker accepted a request and refresh in-flight stats. /// Record which worker accepted a request and refresh in-flight stats.
...@@ -281,7 +290,10 @@ impl AggRuntime { ...@@ -281,7 +290,10 @@ impl AggRuntime {
); );
if self.router.is_none() { if self.router.is_none() {
self.requests.insert(uuid, AggRequestState::new_running()); self.requests.insert(
uuid,
AggRequestState::new_running(request.tokens.len(), request.max_output_tokens),
);
let worker_idx = self.next_worker(); let worker_idx = self.next_worker();
self.dispatch_to_worker(request, uuid, worker_idx)?; self.dispatch_to_worker(request, uuid, worker_idx)?;
return Ok(uuid); return Ok(uuid);
...@@ -346,9 +358,11 @@ impl AggRuntime { ...@@ -346,9 +358,11 @@ impl AggRuntime {
} }
self.record_router_pending(); self.record_router_pending();
} }
self.requests.remove(&signal.uuid).ok_or_else(|| { let removed_state = self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid) anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?; })?;
self.traffic
.on_request(removed_state.input_tokens, removed_state.output_tokens);
self.admission self.admission
.on_request_completed(signal.uuid, self.now_ms)?; .on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed(); self.progress.inc_completed();
...@@ -465,6 +479,7 @@ impl AggRuntime { ...@@ -465,6 +479,7 @@ impl AggRuntime {
} }
fn handle_engine_effects(&mut self, effects: EngineEffects) -> anyhow::Result<()> { fn handle_engine_effects(&mut self, effects: EngineEffects) -> anyhow::Result<()> {
self.fpm_buffer.extend(effects.fpm_snapshots);
self.apply_router_events(effects.pass_start_kv_events)?; self.apply_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions { for payload in effects.immediate_completions {
let payload = self.engine.on_scheduled_completion(payload)?; let payload = self.engine.on_scheduled_completion(payload)?;
...@@ -496,8 +511,92 @@ impl AggRuntime { ...@@ -496,8 +511,92 @@ impl AggRuntime {
Ok(()) Ok(())
} }
// ------------------------------------------------------------------
// Planner integration: step-based execution
// ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work).
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> anyhow::Result<bool> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
if next_timestamp_ms > until_ms {
break;
}
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
Ok(self.is_done())
}
/// Current simulated time in milliseconds.
pub(in crate::replay) fn now_ms(&self) -> f64 {
self.now_ms
}
/// Number of active (non-pending-removal) workers.
pub(in crate::replay) fn active_worker_count(&self) -> usize {
self.engine.active_worker_ids().len()
}
/// Total worker count including pending-removal.
pub(in crate::replay) fn total_worker_count(&self) -> usize {
self.engine.worker_count()
}
/// Drain accumulated FPM snapshots since the last drain.
pub(in crate::replay) fn drain_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.fpm_buffer)
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
self.traffic.drain(self.now_ms)
}
/// Apply a scaling decision: set the target number of workers.
/// Scale-up is immediate; scale-down removes the worker from the router
/// immediately (so no new requests land on it) and lets it drain in-flight
/// work in the engine.
pub(in crate::replay) fn apply_scaling(&mut self, target_workers: usize) -> anyhow::Result<()> {
let (added, newly_marked) = self.engine.apply_target_count(target_workers);
if let Some(router) = self.router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
Ok(())
}
/// Finalize the replay: finish progress bar, return collector and stats.
pub(in crate::replay::offline) fn finalize(self) -> (TraceCollector, AggRuntimeStats) {
self.progress.finish();
(self.collector, self.stats)
}
/// Finalize the replay and return the simulation report directly.
pub(in crate::replay) fn finalize_report(self) -> crate::replay::TraceSimulationReport {
let (collector, _stats) = self.finalize();
collector.finish()
}
/// Run the aggregated offline replay until all arrivals and worker work are exhausted. /// Run the aggregated offline replay until all arrivals and worker work are exhausted.
pub(super) fn run(mut self) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> { pub(in crate::replay::offline) fn run(
mut self,
) -> anyhow::Result<(TraceCollector, AggRuntimeStats)> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
while !self.is_done() { while !self.is_done() {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::{BTreeMap, BTreeSet};
use anyhow::bail; use anyhow::bail;
use super::super::events::SimulationWorkerStage; use super::super::events::SimulationWorkerStage;
...@@ -9,14 +11,23 @@ use super::super::runtime_utils::WorkerCompletionPayload; ...@@ -9,14 +11,23 @@ use super::super::runtime_utils::WorkerCompletionPayload;
use super::super::state::OfflineWorkerSnapshot; use super::super::state::OfflineWorkerSnapshot;
use super::super::state::OfflineWorkerState; use super::super::state::OfflineWorkerState;
use super::{EngineEffects, EnginePassMode, ScheduledWorkerCompletion}; use super::{EngineEffects, EnginePassMode, ScheduledWorkerCompletion};
use crate::common::protocols::DirectRequest; use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::replay::TraceCollector; use crate::replay::TraceCollector;
use crate::scheduler::RouterEventVisibility; use crate::scheduler::RouterEventVisibility;
pub(in crate::replay::offline) struct EngineComponent { pub(in crate::replay::offline) struct EngineComponent {
stage: SimulationWorkerStage, stage: SimulationWorkerStage,
pass_mode: EnginePassMode, pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>, /// Workers keyed by stable ID (monotonic, never reused).
workers: BTreeMap<usize, OfflineWorkerState>,
/// Counter for generating the next stable worker ID.
next_id: usize,
/// Workers marked for removal — skipped by round-robin, removed when drained.
pending_removal: BTreeSet<usize>,
/// Engine args used to construct new workers during scale-up.
args: MockEngineArgs,
/// Whether new workers should capture KV events (true when a router is present).
capture_kv_events: bool,
} }
impl EngineComponent { impl EngineComponent {
...@@ -25,20 +36,115 @@ impl EngineComponent { ...@@ -25,20 +36,115 @@ impl EngineComponent {
pass_mode: EnginePassMode, pass_mode: EnginePassMode,
workers: Vec<OfflineWorkerState>, workers: Vec<OfflineWorkerState>,
) -> Self { ) -> Self {
let count = workers.len();
let map: BTreeMap<usize, OfflineWorkerState> = workers.into_iter().enumerate().collect();
Self { Self {
stage, stage,
pass_mode, pass_mode,
workers, workers: map,
next_id: count,
pending_removal: BTreeSet::new(),
args: MockEngineArgs::default(),
capture_kv_events: false,
}
}
/// Set the engine args and KV capture flag used when adding workers dynamically.
pub(in crate::replay::offline) fn set_scaling_args(
&mut self,
args: MockEngineArgs,
capture_kv_events: bool,
) {
self.args = args;
self.capture_kv_events = capture_kv_events;
}
/// Add a new worker, returning its stable ID.
pub(in crate::replay::offline) fn add_worker(&mut self) -> usize {
let id = self.next_id;
self.next_id += 1;
let worker = OfflineWorkerState::new(id, self.args.clone(), self.capture_kv_events);
self.workers.insert(id, worker);
id
}
/// Mark a worker for removal. It will be skipped by `drive_ready` and
/// removed once fully drained.
pub(in crate::replay::offline) fn mark_for_removal(&mut self, worker_id: usize) {
self.pending_removal.insert(worker_id);
}
/// Remove all marked workers that have fully drained, returning their IDs.
pub(in crate::replay::offline) fn try_remove_drained(&mut self) -> Vec<usize> {
let mut removed = Vec::new();
self.pending_removal.retain(|&id| {
if let Some(worker) = self.workers.get(&id) {
if worker.is_drained() {
removed.push(id);
return false; // remove from pending set
}
} else {
// Worker already gone
return false;
}
true // keep in pending set
});
for &id in &removed {
self.workers.remove(&id);
}
removed
}
/// Apply a target worker count: add new workers or mark excess for removal.
/// Returns `(added_ids, newly_marked_ids)` so the caller can update the
/// router immediately. Newly marked workers should be removed from the
/// router right away to prevent new requests from landing on them, even
/// though the workers themselves remain in the engine until fully drained.
pub(in crate::replay::offline) fn apply_target_count(
&mut self,
target: usize,
) -> (Vec<usize>, Vec<usize>) {
let active_ids = self.active_worker_ids();
let current = active_ids.len();
let mut added = Vec::new();
let mut newly_marked = Vec::new();
if target > current {
for _ in 0..(target - current) {
added.push(self.add_worker());
}
} else if target < current {
let excess = current - target;
for &id in active_ids.iter().rev().take(excess) {
self.mark_for_removal(id);
newly_marked.push(id);
}
} }
// Clean up any workers that have already fully drained.
self.try_remove_drained();
(added, newly_marked)
}
/// Return stable IDs of all active (non-pending-removal) workers.
pub(in crate::replay::offline) fn active_worker_ids(&self) -> Vec<usize> {
self.workers
.keys()
.filter(|id| !self.pending_removal.contains(id))
.copied()
.collect()
} }
pub(in crate::replay::offline) fn dispatch( pub(in crate::replay::offline) fn dispatch(
&mut self, &mut self,
worker_idx: usize, worker_id: usize,
request: DirectRequest, request: DirectRequest,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.validate_worker_idx(worker_idx)?; let worker = self
self.workers[worker_idx].receive_request(request); .workers
.get_mut(&worker_id)
.ok_or_else(|| anyhow::anyhow!("offline replay selected unknown worker {worker_id}"))?;
worker.receive_request(request);
Ok(()) Ok(())
} }
...@@ -47,8 +153,11 @@ impl EngineComponent { ...@@ -47,8 +153,11 @@ impl EngineComponent {
now_ms: f64, now_ms: f64,
mut collector: Option<&mut TraceCollector>, mut collector: Option<&mut TraceCollector>,
) -> anyhow::Result<EngineEffects> { ) -> anyhow::Result<EngineEffects> {
for worker_idx in 0..self.workers.len() { // Collect worker IDs first to avoid borrow issues.
if !self.workers[worker_idx].is_ready() { let worker_ids: Vec<usize> = self.workers.keys().copied().collect();
for worker_id in worker_ids {
let worker = self.workers.get(&worker_id).unwrap();
if !worker.is_ready() {
continue; continue;
} }
...@@ -57,15 +166,25 @@ impl EngineComponent { ...@@ -57,15 +166,25 @@ impl EngineComponent {
let Some(collector) = collector.as_deref_mut() else { let Some(collector) = collector.as_deref_mut() else {
bail!("offline replay visible engine pass requires a collector"); bail!("offline replay visible engine pass requires a collector");
}; };
self.workers[worker_idx].execute_pass(collector, now_ms) self.workers
.get_mut(&worker_id)
.unwrap()
.execute_pass(collector, now_ms)
} }
EnginePassMode::Hidden => self.workers[worker_idx].execute_hidden_pass(now_ms), EnginePassMode::Hidden => self
.workers
.get_mut(&worker_id)
.unwrap()
.execute_hidden_pass(now_ms),
}; };
let mut effects = EngineEffects { let mut effects = EngineEffects {
admissions: executed.admissions, admissions: executed.admissions,
..EngineEffects::default() ..EngineEffects::default()
}; };
if let Some(fpm) = executed.fpm {
effects.fpm_snapshots.push((worker_id, fpm));
}
let completion_kv_events = let completion_kv_events =
if executed.router_event_visibility == RouterEventVisibility::PassStart { if executed.router_event_visibility == RouterEventVisibility::PassStart {
effects.pass_start_kv_events = executed.kv_events; effects.pass_start_kv_events = executed.kv_events;
...@@ -75,7 +194,7 @@ impl EngineComponent { ...@@ -75,7 +194,7 @@ impl EngineComponent {
}; };
let payload = WorkerCompletionPayload { let payload = WorkerCompletionPayload {
stage: self.stage, stage: self.stage,
worker_idx, worker_idx: worker_id,
completed_requests: executed.completed_requests, completed_requests: executed.completed_requests,
output_signals: executed.output_signals, output_signals: executed.output_signals,
kv_events: completion_kv_events, kv_events: completion_kv_events,
...@@ -86,7 +205,7 @@ impl EngineComponent { ...@@ -86,7 +205,7 @@ impl EngineComponent {
return Ok(effects); return Ok(effects);
} }
self.workers[worker_idx].mark_busy(); self.workers.get_mut(&worker_id).unwrap().mark_busy();
effects effects
.scheduled_completions .scheduled_completions
.push(ScheduledWorkerCompletion { .push(ScheduledWorkerCompletion {
...@@ -110,35 +229,42 @@ impl EngineComponent { ...@@ -110,35 +229,42 @@ impl EngineComponent {
payload.stage payload.stage
); );
} }
self.validate_worker_idx(payload.worker_idx)?; let worker = self.workers.get_mut(&payload.worker_idx).ok_or_else(|| {
self.workers[payload.worker_idx].mark_idle(); anyhow::anyhow!(
self.workers[payload.worker_idx].mark_completed(payload.completed_requests); "offline replay completion for unknown worker {}",
payload.worker_idx
)
})?;
worker.mark_idle();
worker.mark_completed(payload.completed_requests);
// Eagerly clean up drained workers that are pending removal so they
// don't linger indefinitely when no further scaling events trigger
// apply_target_count.
if self.pending_removal.contains(&payload.worker_idx) {
self.try_remove_drained();
}
Ok(payload) Ok(payload)
} }
pub(in crate::replay::offline) fn in_flight(&self) -> usize { pub(in crate::replay::offline) fn in_flight(&self) -> usize {
self.workers.iter().map(OfflineWorkerState::in_flight).sum() self.workers
.values()
.map(OfflineWorkerState::in_flight)
.sum()
} }
pub(in crate::replay::offline) fn is_drained(&self) -> bool { pub(in crate::replay::offline) fn is_drained(&self) -> bool {
self.workers.iter().all(OfflineWorkerState::is_drained) self.workers.values().all(OfflineWorkerState::is_drained)
} }
pub(in crate::replay::offline) fn worker_count(&self) -> usize { pub(in crate::replay::offline) fn worker_count(&self) -> usize {
self.workers.len() self.workers.len()
} }
fn validate_worker_idx(&self, worker_idx: usize) -> anyhow::Result<()> {
if worker_idx >= self.workers.len() {
bail!("offline replay selected unknown worker index {worker_idx}");
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
pub(crate) fn debug_snapshots(&self) -> Vec<OfflineWorkerSnapshot> { pub(crate) fn debug_snapshots(&self) -> Vec<OfflineWorkerSnapshot> {
self.workers self.workers
.iter() .values()
.map(OfflineWorkerState::debug_snapshot) .map(OfflineWorkerState::debug_snapshot)
.collect() .collect()
} }
......
...@@ -11,7 +11,8 @@ pub(in crate::replay::offline) use engine::EngineComponent; ...@@ -11,7 +11,8 @@ pub(in crate::replay::offline) use engine::EngineComponent;
pub(crate) use router::OfflineReplayRouter; pub(crate) use router::OfflineReplayRouter;
#[cfg(test)] #[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot; pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay) use types::ReplayMode;
pub(in crate::replay::offline) use types::{ pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ReplayMode, ScheduledWorkerCompletion, EngineEffects, EnginePassMode, ReadyArrival, ScheduledWorkerCompletion, TrafficAccumulator,
}; };
pub(crate) use types::{RouterEffects, WorkerAdmission}; pub(crate) use types::{RouterEffects, WorkerAdmission};
...@@ -309,6 +309,48 @@ impl OfflineReplayRouter { ...@@ -309,6 +309,48 @@ impl OfflineReplayRouter {
self.pending.len() self.pending.len()
} }
/// Register a new worker with the router, cloning the config from existing workers.
pub(crate) fn add_worker(&mut self, worker_id: usize) -> Result<()> {
let config = self
.workers_with_configs
.values()
.next()
.ok_or_else(|| anyhow!("cannot add worker to router with no existing workers"))?
.clone();
let wid = worker_id as WorkerId;
self.workers_with_configs.insert(wid, config);
// Rebuild the slots with the full worker set
let dp_range: HashMap<u64, (u32, u32)> = self
.workers_with_configs
.keys()
.map(|&id| (id, (0u32, 1u32)))
.collect();
self.slots.update_workers(&dp_range);
// Enable queueing if we now have more than one worker
if self.workers_with_configs.len() > 1 && self.queue_threshold.is_none() {
self.queue_threshold = self.config.router_queue_threshold;
}
Ok(())
}
/// Remove a worker from routing eligibility.
///
/// Only removes the worker from the config map so the selector won't
/// pick it for new requests. The radix tree and active-sequence slots
/// are left intact so that in-flight requests on this worker can still
/// complete (free / mark_prefill_completed) and KV events can still
/// reference existing blocks without "parent block not found" errors.
/// Stale slot and indexer state is harmless — the selector and
/// `all_workers_busy` both skip workers absent from `workers_with_configs`.
pub(crate) fn remove_worker(&mut self, worker_id: usize) -> Result<()> {
let wid = worker_id as WorkerId;
self.workers_with_configs.remove(&wid);
Ok(())
}
#[cfg(test)] #[cfg(test)]
pub(crate) fn debug_snapshot(&self, now_ms: f64) -> OfflineRouterSnapshot { pub(crate) fn debug_snapshot(&self, now_ms: f64) -> OfflineRouterSnapshot {
let decay_now = self.decay_now(now_ms); let decay_now = self.decay_now(now_ms);
......
...@@ -5,12 +5,12 @@ use dynamo_kv_router::protocols::RouterEvent; ...@@ -5,12 +5,12 @@ use dynamo_kv_router::protocols::RouterEvent;
use uuid::Uuid; use uuid::Uuid;
use super::super::runtime_utils::WorkerCompletionPayload; use super::super::runtime_utils::WorkerCompletionPayload;
use crate::common::protocols::DirectRequest; use crate::common::protocols::{DirectRequest, ForwardPassSnapshot};
use crate::loadgen::ReplayRequestHashes; use crate::loadgen::ReplayRequestHashes;
use crate::scheduler::AdmissionEvent; use crate::scheduler::AdmissionEvent;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub(in crate::replay::offline) enum ReplayMode { pub(in crate::replay) enum ReplayMode {
Trace, Trace,
Concurrency { max_in_flight: usize }, Concurrency { max_in_flight: usize },
} }
...@@ -39,6 +39,9 @@ pub(in crate::replay::offline) struct EngineEffects { ...@@ -39,6 +39,9 @@ pub(in crate::replay::offline) struct EngineEffects {
pub(in crate::replay::offline) pass_start_kv_events: Vec<RouterEvent>, pub(in crate::replay::offline) pass_start_kv_events: Vec<RouterEvent>,
pub(in crate::replay::offline) immediate_completions: Vec<WorkerCompletionPayload>, pub(in crate::replay::offline) immediate_completions: Vec<WorkerCompletionPayload>,
pub(in crate::replay::offline) scheduled_completions: Vec<ScheduledWorkerCompletion>, pub(in crate::replay::offline) scheduled_completions: Vec<ScheduledWorkerCompletion>,
/// Forward pass metrics snapshots emitted by workers during this drive cycle,
/// keyed by worker index. Collected for planner integration.
pub(in crate::replay::offline) fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
} }
impl EngineEffects { impl EngineEffects {
...@@ -61,3 +64,57 @@ pub(in crate::replay::offline) struct ReadyArrival { ...@@ -61,3 +64,57 @@ pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) arrival_time_ms: f64, pub(in crate::replay::offline) arrival_time_ms: f64,
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>, pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
} }
/// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window).
#[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64,
num_req: usize,
total_isl: usize,
total_osl: usize,
}
impl TrafficAccumulator {
pub(in crate::replay::offline) fn new() -> Self {
Self {
window_start_ms: 0.0,
num_req: 0,
total_isl: 0,
total_osl: 0,
}
}
/// Record one admitted request.
pub(in crate::replay::offline) fn on_request(
&mut self,
input_tokens: usize,
output_tokens: usize,
) {
self.num_req += 1;
self.total_isl += input_tokens;
self.total_osl += output_tokens;
}
/// Drain the accumulator at the given simulated time, returning
/// (duration_s, num_req, avg_isl, avg_osl) and resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> (f64, usize, f64, f64) {
let duration_s = (now_ms - self.window_start_ms) / 1000.0;
let num_req = self.num_req;
let avg_isl = if num_req > 0 {
self.total_isl as f64 / num_req as f64
} else {
0.0
};
let avg_osl = if num_req > 0 {
self.total_osl as f64 / num_req as f64
} else {
0.0
};
self.window_start_ms = now_ms;
self.num_req = 0;
self.total_isl = 0;
self.total_osl = 0;
(duration_s, num_req, avg_isl, avg_osl)
}
}
...@@ -11,7 +11,7 @@ use uuid::Uuid; ...@@ -11,7 +11,7 @@ use uuid::Uuid;
pub(super) use super::components::ReplayMode; pub(super) use super::components::ReplayMode;
use super::components::{ use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter, AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, WorkerAdmission, ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
}; };
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress; use super::progress::ReplayProgress;
...@@ -22,7 +22,7 @@ use super::runtime_utils::{ ...@@ -22,7 +22,7 @@ use super::runtime_utils::{
#[cfg(test)] #[cfg(test)]
use super::state::DisaggRequestSnapshot; use super::state::DisaggRequestSnapshot;
use super::state::{DisaggPhase, DisaggRequestState}; use super::state::{DisaggPhase, DisaggRequestState};
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal}; use crate::common::protocols::{DirectRequest, ForwardPassSnapshot, MockEngineArgs, OutputSignal};
use crate::loadgen::{ReplayRequestHashes, WorkloadDriver}; use crate::loadgen::{ReplayRequestHashes, WorkloadDriver};
use crate::replay::{ use crate::replay::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector, OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceCollector,
...@@ -60,7 +60,7 @@ pub(super) struct DisaggRuntimeStats { ...@@ -60,7 +60,7 @@ pub(super) struct DisaggRuntimeStats {
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct DisaggRuntimeStats; pub(super) struct DisaggRuntimeStats;
pub(super) struct DisaggRuntime { pub(in crate::replay) struct DisaggRuntime {
now_ms: f64, now_ms: f64,
next_prefill_worker_idx: usize, next_prefill_worker_idx: usize,
next_decode_worker_idx: usize, next_decode_worker_idx: usize,
...@@ -75,11 +75,16 @@ pub(super) struct DisaggRuntime { ...@@ -75,11 +75,16 @@ pub(super) struct DisaggRuntime {
events: BinaryHeap<SimulationEvent>, events: BinaryHeap<SimulationEvent>,
progress: ReplayProgress, progress: ReplayProgress,
stats: DisaggRuntimeStats, stats: DisaggRuntimeStats,
/// Forward pass metrics accumulated between planner ticks, keyed by (stage, worker_idx).
prefill_fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
decode_fpm_buffer: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic statistics accumulated between planner ticks.
traffic: TrafficAccumulator,
} }
impl DisaggRuntime { impl DisaggRuntime {
/// Create a disaggregated offline runtime seeded from an explicit request queue. /// Create a disaggregated offline runtime seeded from an explicit request queue.
pub(super) fn new( pub(in crate::replay) fn new(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
...@@ -97,7 +102,7 @@ impl DisaggRuntime { ...@@ -97,7 +102,7 @@ impl DisaggRuntime {
} }
/// Create a disaggregated offline runtime whose admissions come from a workload driver. /// Create a disaggregated offline runtime whose admissions come from a workload driver.
pub(super) fn new_workload( pub(in crate::replay) fn new_workload(
config: &OfflineDisaggReplayConfig, config: &OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>, router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>, prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
...@@ -147,7 +152,8 @@ impl DisaggRuntime { ...@@ -147,7 +152,8 @@ impl DisaggRuntime {
} }
}; };
let prefill_engine = EngineComponent::new( let prefill_capture_kv = prefill_router.is_some();
let mut prefill_engine = EngineComponent::new(
SimulationWorkerStage::Prefill, SimulationWorkerStage::Prefill,
EnginePassMode::Hidden, EnginePassMode::Hidden,
(0..config.num_prefill_workers) (0..config.num_prefill_workers)
...@@ -155,12 +161,13 @@ impl DisaggRuntime { ...@@ -155,12 +161,13 @@ impl DisaggRuntime {
super::state::OfflineWorkerState::new( super::state::OfflineWorkerState::new(
worker_idx, worker_idx,
config.prefill_args.clone(), config.prefill_args.clone(),
prefill_router.is_some(), prefill_capture_kv,
) )
}) })
.collect(), .collect(),
); );
let decode_engine = EngineComponent::new( prefill_engine.set_scaling_args(config.prefill_args.clone(), prefill_capture_kv);
let mut decode_engine = EngineComponent::new(
SimulationWorkerStage::Decode, SimulationWorkerStage::Decode,
EnginePassMode::Visible, EnginePassMode::Visible,
(0..config.num_decode_workers) (0..config.num_decode_workers)
...@@ -173,6 +180,7 @@ impl DisaggRuntime { ...@@ -173,6 +180,7 @@ impl DisaggRuntime {
}) })
.collect(), .collect(),
); );
decode_engine.set_scaling_args(config.decode_args.clone(), false);
Ok(Self { Ok(Self {
now_ms: 0.0, now_ms: 0.0,
...@@ -192,6 +200,9 @@ impl DisaggRuntime { ...@@ -192,6 +200,9 @@ impl DisaggRuntime {
stats: DisaggRuntimeStats::default(), stats: DisaggRuntimeStats::default(),
#[cfg(not(test))] #[cfg(not(test))]
stats: DisaggRuntimeStats, stats: DisaggRuntimeStats,
prefill_fpm_buffer: Vec::new(),
decode_fpm_buffer: Vec::new(),
traffic: TrafficAccumulator::new(),
}) })
} }
...@@ -209,20 +220,28 @@ impl DisaggRuntime { ...@@ -209,20 +220,28 @@ impl DisaggRuntime {
.map_or(0, OfflineReplayRouter::pending_count) .map_or(0, OfflineReplayRouter::pending_count)
} }
/// Pick the next prefill worker in round-robin order. /// Pick the next active prefill worker in round-robin order.
fn next_prefill_worker(&mut self) -> usize { fn next_prefill_worker(&mut self) -> usize {
let worker_idx = self.next_prefill_worker_idx; let active = self.prefill_engine.active_worker_ids();
self.next_prefill_worker_idx = debug_assert!(
(self.next_prefill_worker_idx + 1) % self.prefill_engine.worker_count(); !active.is_empty(),
worker_idx "no active prefill workers for round-robin"
);
let idx = self.next_prefill_worker_idx % active.len();
self.next_prefill_worker_idx = idx + 1;
active[idx]
} }
/// Pick the next decode worker in round-robin order. /// Pick the next active decode worker in round-robin order.
fn next_decode_worker(&mut self) -> usize { fn next_decode_worker(&mut self) -> usize {
let worker_idx = self.next_decode_worker_idx; let active = self.decode_engine.active_worker_ids();
self.next_decode_worker_idx = debug_assert!(
(self.next_decode_worker_idx + 1) % self.decode_engine.worker_count(); !active.is_empty(),
worker_idx "no active decode workers for round-robin"
);
let idx = self.next_decode_worker_idx % active.len();
self.next_decode_worker_idx = idx + 1;
active[idx]
} }
/// Track the peak number of requests parked in each stage router. /// Track the peak number of requests parked in each stage router.
...@@ -355,7 +374,6 @@ impl DisaggRuntime { ...@@ -355,7 +374,6 @@ impl DisaggRuntime {
request.tokens.len(), request.tokens.len(),
request.max_output_tokens, request.max_output_tokens,
); );
let queued_request = request.clone(); let queued_request = request.clone();
self.requests self.requests
.insert(uuid, DisaggRequestState::new(request, arrival_time_ms)); .insert(uuid, DisaggRequestState::new(request, arrival_time_ms));
...@@ -479,6 +497,11 @@ impl DisaggRuntime { ...@@ -479,6 +497,11 @@ impl DisaggRuntime {
.transition_log .transition_log
.push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid }); .push(DisaggTransition::WorkloadCompleted { uuid: signal.uuid });
} }
let state = self.state(signal.uuid)?;
let original = state.original_request()?;
let input_tokens = original.tokens.len();
let output_tokens = original.max_output_tokens;
self.traffic.on_request(input_tokens, output_tokens);
self.state_mut(signal.uuid)?.mark_done(); self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)] #[cfg(test)]
{ {
...@@ -626,6 +649,7 @@ impl DisaggRuntime { ...@@ -626,6 +649,7 @@ impl DisaggRuntime {
} }
fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> { fn handle_prefill_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.prefill_fpm_buffer.extend(effects.fpm_snapshots);
self.record_prefill_admissions(effects.admissions); self.record_prefill_admissions(effects.admissions);
self.apply_prefill_router_events(effects.pass_start_kv_events)?; self.apply_prefill_router_events(effects.pass_start_kv_events)?;
for payload in effects.immediate_completions { for payload in effects.immediate_completions {
...@@ -651,6 +675,7 @@ impl DisaggRuntime { ...@@ -651,6 +675,7 @@ impl DisaggRuntime {
} }
fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> { fn handle_decode_engine_effects(&mut self, effects: EngineEffects) -> Result<()> {
self.decode_fpm_buffer.extend(effects.fpm_snapshots);
for payload in effects.immediate_completions { for payload in effects.immediate_completions {
let payload = self.decode_engine.on_scheduled_completion(payload)?; let payload = self.decode_engine.on_scheduled_completion(payload)?;
self.process_decode_pass( self.process_decode_pass(
...@@ -693,6 +718,105 @@ impl DisaggRuntime { ...@@ -693,6 +718,105 @@ impl DisaggRuntime {
} }
} }
// ------------------------------------------------------------------
// Planner integration: step-based execution
// ------------------------------------------------------------------
/// Advance the simulation up to `until_ms` simulated time, then pause.
/// Returns `true` if the replay is done (no more work).
pub(in crate::replay) fn advance_to(&mut self, until_ms: f64) -> Result<bool> {
self.drain_current_timestamp()?;
while !self.is_done() {
let Some(next_timestamp_ms) = self.next_timestamp() else {
bail!(
"offline disagg replay reached a dead end with {} in-flight requests remaining",
self.cluster_in_flight()
);
};
if next_timestamp_ms > until_ms {
break;
}
self.now_ms = next_timestamp_ms;
self.drain_current_timestamp()?;
}
Ok(self.is_done())
}
/// Current simulated time in milliseconds.
pub(in crate::replay) fn now_ms(&self) -> f64 {
self.now_ms
}
pub(in crate::replay) fn active_prefill_count(&self) -> usize {
self.prefill_engine.active_worker_ids().len()
}
pub(in crate::replay) fn active_decode_count(&self) -> usize {
self.decode_engine.active_worker_ids().len()
}
pub(in crate::replay) fn total_prefill_count(&self) -> usize {
self.prefill_engine.worker_count()
}
pub(in crate::replay) fn total_decode_count(&self) -> usize {
self.decode_engine.worker_count()
}
/// Drain accumulated prefill FPM snapshots since the last drain.
pub(in crate::replay) fn drain_prefill_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.prefill_fpm_buffer)
}
/// Drain accumulated decode FPM snapshots since the last drain.
pub(in crate::replay) fn drain_decode_fpm(&mut self) -> Vec<(usize, ForwardPassSnapshot)> {
std::mem::take(&mut self.decode_fpm_buffer)
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
self.traffic.drain(self.now_ms)
}
/// Apply a scaling decision with separate prefill and decode targets.
/// Newly marked workers are removed from the router immediately so no
/// new requests land on them while they drain in-flight work.
pub(in crate::replay) fn apply_scaling(
&mut self,
target_prefill: usize,
target_decode: usize,
) -> Result<()> {
let (added, newly_marked) = self.prefill_engine.apply_target_count(target_prefill);
if let Some(router) = self.prefill_router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
let (added, newly_marked) = self.decode_engine.apply_target_count(target_decode);
if let Some(router) = self.decode_router.as_mut() {
for id in added {
router.add_worker(id)?;
}
for id in newly_marked {
router.remove_worker(id)?;
}
}
Ok(())
}
/// Finalize the replay and return the simulation report directly.
pub(in crate::replay) fn finalize_report(self) -> crate::replay::TraceSimulationReport {
self.progress.finish();
self.collector.finish()
}
/// Run the staged offline replay until both prefill and decode pipelines are drained. /// Run the staged offline replay until both prefill and decode pipelines are drained.
pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> { pub(super) fn run(mut self) -> Result<(TraceCollector, DisaggRuntimeStats)> {
self.drain_current_timestamp()?; self.drain_current_timestamp()?;
......
...@@ -19,22 +19,30 @@ pub(crate) struct AggRequestState { ...@@ -19,22 +19,30 @@ pub(crate) struct AggRequestState {
request: Option<DirectRequest>, request: Option<DirectRequest>,
pub(in crate::replay::offline) phase: AggRequestPhase, pub(in crate::replay::offline) phase: AggRequestPhase,
pub(in crate::replay::offline) prefill_completed: bool, pub(in crate::replay::offline) prefill_completed: bool,
pub(in crate::replay::offline) input_tokens: usize,
pub(in crate::replay::offline) output_tokens: usize,
} }
impl AggRequestState { impl AggRequestState {
pub(crate) fn new_queued(request: DirectRequest) -> Self { pub(crate) fn new_queued(request: DirectRequest) -> Self {
let input_tokens = request.tokens.len();
let output_tokens = request.max_output_tokens;
Self { Self {
request: Some(request), request: Some(request),
phase: AggRequestPhase::QueuedAtRouter, phase: AggRequestPhase::QueuedAtRouter,
prefill_completed: false, prefill_completed: false,
input_tokens,
output_tokens,
} }
} }
pub(crate) fn new_running() -> Self { pub(crate) fn new_running(input_tokens: usize, output_tokens: usize) -> Self {
Self { Self {
request: None, request: None,
phase: AggRequestPhase::Running, phase: AggRequestPhase::Running,
prefill_completed: false, prefill_completed: false,
input_tokens,
output_tokens,
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Public handle for driving an offline replay with planner-in-the-loop.
//!
//! Supports both aggregated and disaggregated topologies via [`RuntimeKind`].
//! The Python planner adapter calls [`PlannerReplayHandle::advance_to`] to
//! step the simulation, collects metrics, and calls [`PlannerReplayHandle::apply_scaling`]
//! to resize worker pools.
use std::path::Path;
use std::time::Instant;
use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
use super::offline::agg::AggRuntime;
use super::offline::components::ReplayMode;
use super::offline::disagg::DisaggRuntime;
use super::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport,
};
use crate::common::protocols::{ForwardPassSnapshot, MockEngineArgs};
use crate::loadgen::Trace;
/// Snapshot of metrics collected between planner ticks.
///
/// For aggregated mode, prefill fields are 0 and all data is in decode fields
/// (matching how the planner treats agg as a single decode-stage engine).
pub struct PlannerTickData {
/// Current simulated time in milliseconds.
pub now_ms: f64,
/// Whether the replay has finished (no more work).
pub is_done: bool,
/// Prefill FPM snapshots since last tick: (worker_id, snapshot).
pub prefill_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
/// Decode (or agg) FPM snapshots since last tick: (worker_id, snapshot).
pub decode_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
/// Traffic observation: (duration_s, num_req, avg_isl, avg_osl).
pub traffic: (f64, usize, f64, f64),
/// Active prefill workers (0 for agg mode).
pub active_prefill_count: usize,
/// Active decode workers (or total active for agg mode).
pub active_decode_count: usize,
/// Total prefill workers including pending removal (0 for agg mode).
pub total_prefill_count: usize,
/// Total decode workers including pending removal (or total for agg mode).
pub total_decode_count: usize,
}
#[allow(clippy::large_enum_variant)]
enum RuntimeKind {
Agg(AggRuntime),
Disagg(DisaggRuntime),
}
pub struct PlannerReplayHandle {
runtime: RuntimeKind,
started_at: Instant,
}
impl PlannerReplayHandle {
/// Create a handle for an aggregated trace-file replay.
#[allow(clippy::too_many_arguments)]
pub fn from_trace_file(
args: MockEngineArgs,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
num_workers: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let args = args.normalized()?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let runtime = AggRuntime::new_workload(
&args,
router_config,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(args.block_size)?,
num_workers,
ReplayMode::Trace,
router_mode,
)?;
Ok(Self {
runtime: RuntimeKind::Agg(runtime),
started_at: Instant::now(),
})
}
/// Create a handle for a disaggregated trace-file replay.
pub fn from_trace_file_disagg(
config: OfflineDisaggReplayConfig,
router_config: Option<KvRouterConfig>,
prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
trace_path: &Path,
trace_block_size: usize,
arrival_speedup_ratio: f64,
router_mode: ReplayRouterMode,
) -> Result<Self> {
let config = config.normalized()?;
let trace = Trace::from_mooncake(trace_path, trace_block_size)?
.normalize_session_starts()?
.speed_up_timing(arrival_speedup_ratio)?;
let runtime = DisaggRuntime::new_workload(
&config,
router_config,
prefill_load_estimator,
trace.into_trace_driver_with_block_size(config.decode_args.block_size)?,
ReplayMode::Trace,
router_mode,
)?;
Ok(Self {
runtime: RuntimeKind::Disagg(runtime),
started_at: Instant::now(),
})
}
/// Advance the simulation up to `until_ms`, collect metrics, return tick data.
pub fn advance_to(&mut self, until_ms: f64) -> Result<PlannerTickData> {
match &mut self.runtime {
RuntimeKind::Agg(rt) => {
let is_done = rt.advance_to(until_ms)?;
let fpm = rt.drain_fpm();
let traffic = rt.drain_traffic();
Ok(PlannerTickData {
now_ms: rt.now_ms(),
is_done,
prefill_fpm_snapshots: Vec::new(),
decode_fpm_snapshots: fpm,
traffic,
active_prefill_count: 0,
active_decode_count: rt.active_worker_count(),
total_prefill_count: 0,
total_decode_count: rt.total_worker_count(),
})
}
RuntimeKind::Disagg(rt) => {
let is_done = rt.advance_to(until_ms)?;
let prefill_fpm = rt.drain_prefill_fpm();
let decode_fpm = rt.drain_decode_fpm();
let traffic = rt.drain_traffic();
Ok(PlannerTickData {
now_ms: rt.now_ms(),
is_done,
prefill_fpm_snapshots: prefill_fpm,
decode_fpm_snapshots: decode_fpm,
traffic,
active_prefill_count: rt.active_prefill_count(),
active_decode_count: rt.active_decode_count(),
total_prefill_count: rt.total_prefill_count(),
total_decode_count: rt.total_decode_count(),
})
}
}
}
/// Apply a scaling decision with separate prefill and decode targets.
/// For agg mode, `target_prefill` is ignored.
pub fn apply_scaling(&mut self, target_prefill: usize, target_decode: usize) -> Result<()> {
match &mut self.runtime {
RuntimeKind::Agg(rt) => rt.apply_scaling(target_decode),
RuntimeKind::Disagg(rt) => rt.apply_scaling(target_prefill, target_decode),
}
}
/// Finalize the replay and return the report.
pub fn finalize(self) -> TraceSimulationReport {
let report = match self.runtime {
RuntimeKind::Agg(rt) => rt.finalize_report(),
RuntimeKind::Disagg(rt) => rt.finalize_report(),
};
report.with_wall_time_ms(self.started_at.elapsed().as_secs_f64() * 1000.0)
}
}
...@@ -36,6 +36,7 @@ pub(super) struct SglangConfig { ...@@ -36,6 +36,7 @@ pub(super) struct SglangConfig {
pub(super) decode_speedup_ratio: f64, pub(super) decode_speedup_ratio: f64,
pub(super) worker_type: WorkerType, pub(super) worker_type: WorkerType,
pub(super) block_size: usize, pub(super) block_size: usize,
pub(super) total_kv_tokens: usize,
pub(super) kv_bytes_per_token: Option<usize>, pub(super) kv_bytes_per_token: Option<usize>,
pub(super) kv_transfer_bandwidth: Option<f64>, pub(super) kv_transfer_bandwidth: Option<f64>,
} }
...@@ -83,6 +84,7 @@ impl SglangConfig { ...@@ -83,6 +84,7 @@ impl SglangConfig {
decode_speedup_ratio: args.decode_speedup_ratio, decode_speedup_ratio: args.decode_speedup_ratio,
worker_type: args.worker_type, worker_type: args.worker_type,
block_size: args.block_size, block_size: args.block_size,
total_kv_tokens: args.num_gpu_blocks * args.block_size,
kv_bytes_per_token: args.kv_bytes_per_token, kv_bytes_per_token: args.kv_bytes_per_token,
kv_transfer_bandwidth: args.kv_transfer_bandwidth, kv_transfer_bandwidth: args.kv_transfer_bandwidth,
} }
......
...@@ -140,10 +140,13 @@ pub(super) fn simulate_decode_step( ...@@ -140,10 +140,13 @@ pub(super) fn simulate_decode_step(
.map(SglangRequest::current_sequence_len) .map(SglangRequest::current_sequence_len)
.sum(); .sum();
let avg_context = total_context / running.len(); let avg_context = total_context / running.len();
let decode_time = let active_kv_tokens = total_context.min(config.total_kv_tokens);
config let decode_time = config.perf_model.predict_decode_time(
.perf_model running.len(),
.predict_decode_time(running.len(), total_context, avg_context); active_kv_tokens,
avg_context,
config.total_kv_tokens,
);
let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0); let unscaled_time = Duration::from_secs_f64(decode_time / 1000.0);
let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio; let effective_ratio = config.speedup_ratio * config.decode_speedup_ratio;
let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO { let total_time = if apply_speedup && effective_ratio > 0.0 && unscaled_time > Duration::ZERO {
......
...@@ -657,19 +657,28 @@ impl VllmCore { ...@@ -657,19 +657,28 @@ impl VllmCore {
return (Duration::ZERO, Vec::new()); return (Duration::ZERO, Vec::new());
} }
// For prefill workers, the first decode token is produced as part of
// the prefill forward pass — no separate decode iteration needed.
let (decode_time, decode_end_ms) = if self.args.worker_type == WorkerType::Prefill {
(Duration::ZERO, decode_start_ms)
} else {
let active_kv_tokens = self.kv_manager.num_active_blocks() * self.args.block_size; let active_kv_tokens = self.kv_manager.num_active_blocks() * self.args.block_size;
let total_kv_tokens = self.args.num_gpu_blocks * self.args.block_size;
let total_length = ready let total_length = ready
.iter() .iter()
.filter_map(|uuid| self.state.requests.get(uuid)) .filter_map(|uuid| self.state.requests.get(uuid))
.map(|request| request.sequence.len()) .map(|request| request.sequence.len())
.sum::<usize>(); .sum::<usize>();
let context_length = total_length / ready.len(); let context_length = total_length / ready.len();
let decode_ms = let decode_ms = self.args.perf_model.predict_decode_time(
self.args ready.len(),
.perf_model active_kv_tokens,
.predict_decode_time(ready.len(), active_kv_tokens, context_length); context_length,
let decode_time = scale_decode_time(decode_ms, &self.args); total_kv_tokens,
let decode_end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0; );
let dt = scale_decode_time(decode_ms, &self.args);
(dt, decode_start_ms + dt.as_secs_f64() * 1000.0)
};
let mut output_signals = Vec::with_capacity(ready.len()); let mut output_signals = Vec::with_capacity(ready.len());
for uuid in ready { for uuid in ready {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment