Unverified Commit 0c7068dd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix(replay/planner): planner-in-the-loop replay diagnostics + planner scaling logics fixes (#8280)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 4e7c6afd
......@@ -41,6 +41,8 @@ class LoadScalingMixin:
_diag_estimated_ttft_ms: Optional[float]
_diag_estimated_itl_ms: Optional[float]
_diag_load_reason: Optional[str]
_diag_load_reason_prefill: Optional[str]
_diag_load_reason_decode: Optional[str]
def _advance_load(self, obs: FpmObservations) -> Optional[ScalingDecision]:
if not self._config.enable_load_scaling:
......@@ -122,16 +124,22 @@ class LoadScalingMixin:
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
self._diag_load_reason = "no_fpm_data"
self._diag_load_reason_prefill = "no_fpm_data"
self._diag_load_reason_decode = "no_fpm_data"
return None
if p_stats and not self._reconcile_fpm_worker_count(
p_stats, self._num_p_workers, "prefill"
):
self._diag_load_reason = "worker_count_mismatch"
self._diag_load_reason_prefill = "worker_count_mismatch"
self._diag_load_reason_decode = "worker_count_mismatch"
return None
if d_stats and not self._reconcile_fpm_worker_count(
d_stats, self._num_d_workers, "decode"
):
self._diag_load_reason = "worker_count_mismatch"
self._diag_load_reason_prefill = "worker_count_mismatch"
self._diag_load_reason_decode = "worker_count_mismatch"
return None
easy = self._config.optimization_target != "sla"
......@@ -157,30 +165,61 @@ class LoadScalingMixin:
final_p = p_desired if p_desired is not None else self._num_p_workers
final_d = d_desired if d_desired is not None else self._num_d_workers
if final_p == self._num_p_workers and final_d == self._num_d_workers:
logger.info("Load-based scaling: no scaling needed")
self._diag_load_reason = "no_change"
return None
# Enforce bounds first so "no change" comparison is against the
# post-floor target, not the raw load decision. Otherwise a load
# decision of "no change" would skip the floor and let replicas
# stay below a throughput-scaling lower bound that was raised on
# a previous (or same) tick.
original_p, original_d = final_p, final_d
# Apply throughput floor first and track the post-floor value so we
# can attribute later lifts to their real source -- throughput
# capping is a distinct diagnostic from min_endpoint / global-budget
# lifts, which should not be labelled "scale_down_capped_by_throughput".
if self._config.enable_throughput_scaling:
final_p = max(final_p, self._throughput_lower_bound_p)
final_d = max(final_d, self._throughput_lower_bound_d)
post_floor_p, post_floor_d = final_p, final_d
final_p = max(final_p, self._config.min_endpoint)
final_d = max(final_d, self._config.min_endpoint)
final_p, final_d = self._apply_global_budget(final_p, final_d)
if (final_p > original_p or final_d > original_d) and (
original_p < self._num_p_workers or original_d < self._num_d_workers
):
self._diag_load_reason = "scale_down_capped_by_throughput"
elif final_p > self._num_p_workers or final_d > self._num_d_workers:
self._diag_load_reason = "scale_up"
elif final_p < self._num_p_workers or final_d < self._num_d_workers:
self._diag_load_reason = "scale_down"
else:
self._diag_load_reason = "no_change"
# Per-component reasons
def _reason(final: int, original: int, post_floor: int, current: int) -> str:
# Only classify as throughput-capped when the throughput floor
# itself lifted the load decision; later min_endpoint / budget
# adjustments don't count.
floor_capped = post_floor > original and original < current
if final > current:
return "scale_up"
if final < current:
return (
"scale_down_capped_by_throughput" if floor_capped else "scale_down"
)
return "scale_down_capped_by_throughput" if floor_capped else "no_change"
self._diag_load_reason_prefill = _reason(
final_p, original_p, post_floor_p, self._num_p_workers
)
self._diag_load_reason_decode = _reason(
final_d, original_d, post_floor_d, self._num_d_workers
)
# Aggregate reason: prioritise "most interesting" across components.
_PRIORITY = {
"scale_up": 4,
"scale_down_capped_by_throughput": 3,
"scale_down": 2,
"no_change": 1,
}
self._diag_load_reason = max(
(self._diag_load_reason_prefill, self._diag_load_reason_decode),
key=lambda r: _PRIORITY.get(r or "", 0),
)
if final_p == self._num_p_workers and final_d == self._num_d_workers:
logger.info("Load-based scaling: no scaling needed")
return None
logger.info(
f"Load-based disagg scaling: prefill {self._num_p_workers}->{final_p}, "
......@@ -270,9 +309,10 @@ class LoadScalingMixin:
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
self._diag_load_reason = "no_change"
return None
# Load scaling sees "no change" -- but the throughput floor may
# still require scaling up, so keep processing rather than
# returning early.
desired = num_workers
original_desired = desired
desired = max(desired, self._config.min_endpoint)
......@@ -280,15 +320,23 @@ class LoadScalingMixin:
desired = max(desired, self._throughput_lower_bound_d)
desired = self._apply_single_budget(desired, "decode")
# Preserve "load wanted to scale down but floor lifted it" as a
# distinct diagnostic reason even when the net result is no change.
floor_capped = desired > original_desired and original_desired < num_workers
if desired == num_workers:
logger.info("Agg scaling: no scaling needed")
self._diag_load_reason = (
"scale_down_capped_by_throughput" if floor_capped else "no_change"
)
return None
if desired < num_workers:
if desired > original_desired:
self._diag_load_reason = "scale_down_capped_by_throughput"
else:
self._diag_load_reason = "scale_down"
elif desired > num_workers:
self._diag_load_reason = (
"scale_down_capped_by_throughput" if floor_capped else "scale_down"
)
else: # desired > num_workers (equality returned above)
self._diag_load_reason = "scale_up"
else:
self._diag_load_reason = "no_change"
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
return ScalingDecision(num_decode=desired)
......
......@@ -119,6 +119,8 @@ class AggRegressionModel(_BaseRegressionModel):
max_num_batched_tokens: int,
ttft_sla: float,
itl_sla: float,
max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None,
) -> tuple[float, float, float]:
"""Find the maximum agg engine request rate under both SLA targets.
......@@ -129,19 +131,28 @@ class AggRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
Args:
isl: average input sequence length (tokens).
osl: average output sequence length (tokens).
max_num_batched_tokens: per-iteration token budget.
ttft_sla: TTFT target in milliseconds.
itl_sla: ITL target in milliseconds.
Returns:
(engine_rps, actual_ttft_ms, actual_itl_ms) -- 0 rps
signals an error (model not fitted or invalid input);
positive rps is the best achievable rate with the
predicted TTFT/ITL. If SLAs are violated, a warning
is logged but the rate is still returned.
The upper bound for the batch-size sweep is the smallest of:
1. KV cache capacity: ``max_kv_tokens / (isl + osl/2)``
2. ``max_num_seqs`` (engine concurrency limit)
3. The prefill/decode rate-balance point (steady state). For a
batch of size ``x``:
- Decode egress rate: ``x / osl`` requests finish per iter
(x concurrent streams, each taking osl decode iters).
- Prefill admission rate: ``(max_num_batched_tokens - x) / isl``
requests admitted per iter (the budget left after decode
takes one slot per in-flight request, divided by isl tokens
per new request).
Steady state requires admission >= egress:
``(max_num_batched_tokens - x) / isl >= x / osl``,
which simplifies to
``isl / (max_num_batched_tokens - x) <= osl``
(the check implemented below), or equivalently
``x <= osl * max_num_batched_tokens / (isl + osl)``.
Above this, prefill becomes the bottleneck and TTFT grows
unbounded.
The caller guarantees ``osl > 0`` and ``max_num_batched_tokens > 0``
via the early-return validation above.
"""
if (
not self._ensure_fitted()
......@@ -152,7 +163,34 @@ class AggRegressionModel(_BaseRegressionModel):
return (0.0, 0.0, 0.0)
avg_ctx = isl + osl / 2.0
max_bs = max(1, int(max_num_batched_tokens / max(1, avg_ctx))) * 2
# KV cache cap
kv_cap = (
max(1, int(max_kv_tokens / max(1.0, avg_ctx)))
if max_kv_tokens and max_kv_tokens > 0
else 1024 # large fallback when capability not known
)
# Concurrency cap
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
# Prefill/decode balance cap via binary search within [1, min(kv_cap, seq_cap)]
# For each candidate x, check: isl / (max_num_batched_tokens - x) <= osl
hard_cap = min(kv_cap, seq_cap, max_num_batched_tokens - 1)
def _prefill_balanced(x: int) -> bool:
prefill_budget = max_num_batched_tokens - x
if prefill_budget <= 0:
return False
return isl / prefill_budget <= osl
lo, hi = 1, max(1, hard_cap)
while lo < hi:
mid = (lo + hi + 1) // 2
if _prefill_balanced(mid):
lo = mid
else:
hi = mid - 1
max_bs = lo
best_rps = 0.0
best_ttft_ms = 0.0
......
......@@ -186,6 +186,28 @@ class _BaseRegressionModel:
def _gather_observations(self) -> list[tuple[Union[float, list[float]], float]]:
return [obs for bucket in self._buckets.values() for obs in bucket]
# Feature indices whose coefficients are allowed to go slightly negative
# under noisy fits without rejecting the whole model. Used when a feature
# has weak signal and can flip sign under multicollinearity without
# violating physical monotonicity (e.g. num_decode_requests vs
# sum_decode_kv_tokens: total KV dominates wall time, so a small negative
# batch-size coefficient is numerical noise, not "more requests → less
# work"). Subclasses override via ``_relaxable_feature_indices``.
#
# Most features should stay non-negative: both AggRegressionModel and
# PrefillRegressionModel operate on token counts (sum_prefill_tokens,
# sum_decode_kv_tokens) that directly drive GPU compute and therefore
# must have positive coefficients. Only DecodeRegressionModel relaxes
# index 0 (num_decode_requests), which is a weaker secondary feature.
_relaxable_feature_indices: frozenset[int] = frozenset()
# Coefficients within this band of 0 are treated as numerical noise
# when a feature is marked relaxable. Anything more negative than this
# implies the regression is learning an inverted relationship and is
# rejected (or clipped, for relaxable features) so the planner does not
# scale on physically impossible predictions.
_RELAXABLE_NEG_TOLERANCE = 1e-6
def _fit(self) -> bool:
observations = self._gather_observations()
if len(observations) < self.min_observations:
......@@ -199,13 +221,37 @@ class _BaseRegressionModel:
# Negative coefficients mean "more load → less compute time", which
# is physically impossible. Reject the fit so callers see the model
# as not ready rather than making inverted scaling decisions.
if np.any(self._model.coef_ < 0):
# Exception: features in ``_relaxable_feature_indices`` may go
# slightly negative due to multicollinearity / noise; accept tiny
# (within-tolerance) negatives as-is, and clamp larger relaxable
# negatives to 0 so predictions remain monotone in that feature.
coef = self._model.coef_
neg_mask = coef < 0
if np.any(neg_mask):
non_relaxable_negs = [
i
for i in range(len(coef))
if neg_mask[i] and i not in self._relaxable_feature_indices
]
if non_relaxable_negs:
logger.warning(
f"Regression produced negative coefficients {self._model.coef_.tolist()}, "
f"Regression produced negative coefficients {coef.tolist()}, "
"model rejected — scaling will be skipped until more data arrives"
)
self._is_fitted = False
return False
# Any negatives remaining here are on relaxable features. Clamp
# those that exceed the noise tolerance so the model never
# predicts lower wall time for higher values of that feature.
large_negs = neg_mask & (coef < -self._RELAXABLE_NEG_TOLERANCE)
if np.any(large_negs):
logger.debug(
"Clamped large negative relaxable coefficients at indices "
"%s from %s to 0",
[i for i in range(len(coef)) if large_negs[i]],
coef.tolist(),
)
coef[large_negs] = 0.0
self._is_fitted = True
return True
......
......@@ -18,7 +18,20 @@ logger = logging.getLogger(__name__)
class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from decode batch composition."""
"""Predict per-iteration wall time from decode batch composition.
Features: ``[num_decode_requests, sum_decode_kv_tokens]``. The
``sum_decode_kv_tokens`` feature dominates wall time via attention
compute, while ``num_decode_requests`` has a weaker secondary effect
from linear-layer work. Under multicollinearity (both features scale
with batch size), the ``num_decode_requests`` coefficient can flip
sign under noisy fits; we accept the small negative value since
``sum_decode_kv_tokens`` keeps the overall prediction monotone.
"""
# num_decode_requests (index 0) is relaxable; sum_decode_kv_tokens (index 1)
# must remain non-negative.
_relaxable_feature_indices = frozenset({0})
def __init__(
self,
......@@ -69,7 +82,12 @@ class DecodeRegressionModel(_BaseRegressionModel):
return self._predict_2d(num_req, total_kv)
def find_best_engine_decode_rps(
self, itl: float, context_length: float, osl: float
self,
itl: float,
context_length: float,
osl: float,
max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None,
) -> tuple[float, float]:
"""Find the maximum decode engine request rate within an ITL target.
......@@ -81,6 +99,12 @@ class DecodeRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
The upper bound of the sweep is the smallest of:
- ``max_kv_tokens / context_length`` -- KV cache capacity
- ``max_num_seqs`` -- engine concurrency limit
Falls back to ``_max_observed_kv / context_length`` (or 256) if
neither capability is provided.
Returns:
(engine_rps, actual_itl_ms) -- 0 rps signals an error
(model not fitted or invalid input); positive rps is
......@@ -89,11 +113,14 @@ class DecodeRegressionModel(_BaseRegressionModel):
if not self._ensure_fitted() or context_length <= 0 or osl <= 0 or itl <= 0:
return (0.0, 0.0)
max_batch = (
max(1, int(self._max_observed_kv / context_length))
if self._max_observed_kv > 0
else 256
)
if max_kv_tokens and max_kv_tokens > 0:
kv_cap = max(1, int(max_kv_tokens / context_length))
elif self._max_observed_kv > 0:
kv_cap = max(1, int(self._max_observed_kv / context_length))
else:
kv_cap = 256
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
max_batch = max(1, min(kv_cap, seq_cap))
lo, hi = 1, max_batch
best_bs, best_wt = 1, self._predict_2d(1, context_length)
......
......@@ -80,14 +80,28 @@ class PrefillRegressionModel(_BaseRegressionModel):
return total_time
def find_best_engine_prefill_rps(
self, ttft_sla: float, isl: float
self,
ttft_sla: float,
isl: float,
max_num_batched_tokens: Optional[int] = None,
) -> tuple[float, float]:
"""Find prefill engine request rate under a TTFT target.
Predicts wall_time for a single prefill at the given ISL.
If the predicted TTFT exceeds the SLA, logs a warning but
still returns the best achievable rate so the caller can
scale based on load matching.
Predicts wall_time for a single prefill at the given ISL and
derives engine_rps = 1 / wt. This formula assumes the regression
scales roughly linearly with sum_prefill_tokens: under that
assumption batching multiple prefills (each with ISL tokens) gives
the same engine_rps as one-request-at-a-time, because a batch of
B requests has wt ≈ k·B·ISL, so rate = B/wt = 1/(k·ISL).
If ISL exceeds max_num_batched_tokens, a single request must be
chunked across multiple forward passes. We compute wall_time as
ceil(ISL / MBT) * wt(MBT-sized chunk) to stay within the model's
training domain.
If the predicted TTFT exceeds the SLA, logs a warning but still
returns the best achievable rate so the caller can scale based
on load matching.
Returns:
(engine_rps, actual_ttft_ms) -- 0 rps signals an error
......@@ -96,7 +110,28 @@ class PrefillRegressionModel(_BaseRegressionModel):
"""
if not self._ensure_fitted() or isl <= 0:
return (0.0, 0.0)
# Chunk long prefills so we stay within the regression's training
# domain: a single forward pass never processes more than
# max_num_batched_tokens tokens. At the boundary isl ==
# max_num_batched_tokens, the `else` branch handles it as a single
# pass (no chunking needed); strict `>` inequality is deliberate.
if (
max_num_batched_tokens
and max_num_batched_tokens > 0
and isl > max_num_batched_tokens
):
num_chunks = math.ceil(isl / max_num_batched_tokens)
# remainder is the size of the final (possibly partial) chunk.
# Invariant: remainder ∈ (0, max_num_batched_tokens] by
# construction of num_chunks via math.ceil.
remainder = isl - (num_chunks - 1) * max_num_batched_tokens
wt = (num_chunks - 1) * self._predict_wall_time(
float(max_num_batched_tokens)
) + self._predict_wall_time(remainder)
else:
wt = self._predict_wall_time(isl)
actual_ttft_ms = wt * 1000.0
engine_rps = 1.0 / wt
if actual_ttft_ms > ttft_sla:
......
......@@ -116,6 +116,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_engine_rps_decode: Optional[float] = None
self._diag_load_reason: Optional[str] = None
self._diag_throughput_reason: Optional[str] = None
self._diag_load_reason_prefill: Optional[str] = None
self._diag_load_reason_decode: Optional[str] = None
self._diag_throughput_reason_prefill: Optional[str] = None
self._diag_throughput_reason_decode: Optional[str] = None
# ------------------------------------------------------------------
# Public API
......@@ -170,6 +174,24 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
if tick_input.worker_counts is not None:
self._update_inventory(tick_input.worker_counts)
# Run throughput scaling first so any updated lower bound is visible
# to the load scaling pass on a combined tick. Otherwise load scaling
# reads the stale bound, potentially deciding to scale below the new
# floor set in this same tick.
#
# We always advance _next_throughput_s on a throughput tick, even if
# no traffic was available, so the planner keeps the throughput
# cadence stable rather than re-firing back-to-back ticks whenever
# traffic is temporarily absent.
throughput_decision = None
if tick.run_throughput_scaling:
if tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
throughput_decision = self._advance_throughput(tick_input.traffic)
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
if tick.run_load_scaling:
if tick_input.fpm_observations is not None:
if not self._is_easy:
......@@ -179,16 +201,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
effects.scale_to = load_decision
self._next_load_s = tick_input.now_s + self._config.load_adjustment_interval
if tick.run_throughput_scaling:
if tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
throughput_decision = self._advance_throughput(tick_input.traffic)
if throughput_decision is not None:
if effects.scale_to is None:
# Load scaling has precedence when it produced a decision; otherwise
# fall back to the throughput-scaling decision.
if effects.scale_to is None and throughput_decision is not None:
effects.scale_to = throughput_decision
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
effects.diagnostics = self._build_diagnostics()
effects.next_tick = self._next_scheduled_tick()
......@@ -204,6 +220,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_engine_rps_decode = None
self._diag_load_reason = None
self._diag_throughput_reason = None
self._diag_load_reason_prefill = None
self._diag_load_reason_decode = None
self._diag_throughput_reason_prefill = None
self._diag_throughput_reason_decode = None
def _build_diagnostics(self) -> TickDiagnostics:
return TickDiagnostics(
......@@ -214,8 +234,14 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
predicted_osl=self._diag_predicted_osl,
engine_rps_prefill=self._diag_engine_rps_prefill,
engine_rps_decode=self._diag_engine_rps_decode,
throughput_lower_bound_prefill=self._throughput_lower_bound_p,
throughput_lower_bound_decode=self._throughput_lower_bound_d,
load_decision_reason=self._diag_load_reason,
throughput_decision_reason=self._diag_throughput_reason,
load_decision_reason_prefill=self._diag_load_reason_prefill,
load_decision_reason_decode=self._diag_load_reason_decode,
throughput_decision_reason_prefill=self._diag_throughput_reason_prefill,
throughput_decision_reason_decode=self._diag_throughput_reason_decode,
)
# ------------------------------------------------------------------
......
......@@ -29,6 +29,8 @@ class ThroughputScalingMixin:
_diag_engine_rps_prefill: Optional[float]
_diag_engine_rps_decode: Optional[float]
_diag_throughput_reason: Optional[str]
_diag_throughput_reason_prefill: Optional[str]
_diag_throughput_reason_decode: Optional[str]
def _advance_throughput(
self, traffic: TrafficObservation
......@@ -104,9 +106,24 @@ class ThroughputScalingMixin:
) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl)
num_d = self._compute_decode_replicas(demand_rps, isl, osl)
# _compute_* sets _diag_throughput_reason = "model_not_ready" when
# the regression isn't fit yet. If one side is not ready, the other
# side's computation was still valid but its decision is blocked,
# so we label it "partner_not_ready" to keep per-component
# diagnostics consistent with the aggregate reason.
if num_p is None or num_d is None:
self._diag_throughput_reason_prefill = (
"model_not_ready" if num_p is None else "partner_not_ready"
)
self._diag_throughput_reason_decode = (
"model_not_ready" if num_d is None else "partner_not_ready"
)
return None
reason = "set_lower_bound" if self._config.enable_load_scaling else "scale"
self._diag_throughput_reason_prefill = reason
self._diag_throughput_reason_decode = reason
if self._config.enable_load_scaling:
self._throughput_lower_bound_p = num_p
self._throughput_lower_bound_d = num_d
......@@ -140,6 +157,8 @@ class ThroughputScalingMixin:
max_num_batched_tokens=max_tokens,
ttft_sla=self._config.ttft,
itl_sla=self._config.itl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None,
)
if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling")
......@@ -171,8 +190,11 @@ class ThroughputScalingMixin:
def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
p_caps = self._capabilities.prefill
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft, isl=isl
ttft_sla=self._config.ttft,
isl=isl,
max_num_batched_tokens=p_caps.max_num_batched_tokens if p_caps else None,
)
if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling")
......@@ -194,10 +216,13 @@ class ThroughputScalingMixin:
def _compute_decode_replicas(
self, demand_rps: float, isl: float, osl: float
) -> Optional[int]:
d_caps = self._capabilities.decode
engine_rps, itl_ms = self._decode_regression.find_best_engine_decode_rps(
itl=self._config.itl,
context_length=isl + osl / 2,
osl=osl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None,
)
if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling")
......
......@@ -112,9 +112,20 @@ class TickDiagnostics:
engine_rps_prefill: Optional[float] = None
engine_rps_decode: Optional[float] = None
# Throughput-scaling: lower bound on replicas
throughput_lower_bound_prefill: Optional[int] = None
throughput_lower_bound_decode: Optional[int] = None
# Scaling decision reasons (set by the mixin that ran)
# Aggregate reasons (agg mode, or combined disagg).
load_decision_reason: Optional[str] = None
throughput_decision_reason: Optional[str] = None
# Per-component reasons (populated in disagg mode for separate
# prefill / decode decision timelines).
load_decision_reason_prefill: Optional[str] = None
load_decision_reason_decode: Optional[str] = None
throughput_decision_reason_prefill: Optional[str] = None
throughput_decision_reason_decode: Optional[str] = None
@dataclass
......
......@@ -66,11 +66,19 @@ class TickSnapshot:
engine_rps_decode: Optional[float] = None
load_decision_reason: Optional[str] = None
throughput_decision_reason: Optional[str] = None
load_decision_reason_prefill: Optional[str] = None
load_decision_reason_decode: Optional[str] = None
throughput_decision_reason_prefill: Optional[str] = None
throughput_decision_reason_decode: Optional[str] = None
# Per-engine FPM queue depths
prefill_engines: list[PerEngineFpm] = field(default_factory=list)
decode_engines: list[PerEngineFpm] = field(default_factory=list)
# Throughput lower bound
throughput_lower_bound_prefill: Optional[int] = None
throughput_lower_bound_decode: Optional[int] = None
# Scaling decision
scale_to_prefill: Optional[int] = None
scale_to_decode: Optional[int] = None
......@@ -88,6 +96,7 @@ class DiagnosticsRecorder:
"""
config: PlannerConfig
max_kv_tokens: Optional[int] = None
_snapshots: list[TickSnapshot] = field(default_factory=list)
_last_report_s: float = 0.0
_report_count: int = 0
......@@ -177,6 +186,12 @@ class DiagnosticsRecorder:
engine_rps_decode=diag.engine_rps_decode,
load_decision_reason=diag.load_decision_reason,
throughput_decision_reason=diag.throughput_decision_reason,
load_decision_reason_prefill=diag.load_decision_reason_prefill,
load_decision_reason_decode=diag.load_decision_reason_decode,
throughput_decision_reason_prefill=diag.throughput_decision_reason_prefill,
throughput_decision_reason_decode=diag.throughput_decision_reason_decode,
throughput_lower_bound_prefill=diag.throughput_lower_bound_prefill,
throughput_lower_bound_decode=diag.throughput_lower_bound_decode,
prefill_engines=prefill_engines,
decode_engines=decode_engines,
scale_to_prefill=(
......@@ -201,6 +216,11 @@ class DiagnosticsRecorder:
This method has no side effects (no file I/O, no snapshot clearing).
"""
# TODO: link x-axes across all subplots (e.g. ``fig.update_xaxes(
# matches="x")`` or shared_xaxes=True in make_subplots) so zooming
# into a time range on one chart also zooms the others. Currently
# a user has to zoom each subplot independently to narrow down on a
# specific time window.
ts = [s.timestamp_s for s in snaps]
labels = [
datetime.fromtimestamp(t, tz=timezone.utc).strftime("%H:%M:%S") for t in ts
......@@ -253,6 +273,32 @@ class DiagnosticsRecorder:
row=1,
col=1,
)
tp_lower_p = _vals("throughput_lower_bound_prefill")
if any(v is not None and v > 1 for v in tp_lower_p):
fig.add_trace(
go.Scatter(
x=labels,
y=tp_lower_p,
name="Prefill TP Lower Bound",
mode="lines",
line=dict(dash="dash", color="darkblue"),
),
row=1,
col=1,
)
tp_lower_d = _vals("throughput_lower_bound_decode")
if any(v is not None and v > 1 for v in tp_lower_d):
fig.add_trace(
go.Scatter(
x=labels,
y=tp_lower_d,
name="Decode TP Lower Bound",
mode="lines",
line=dict(dash="dash", color="red"),
),
row=1,
col=1,
)
# 1b. Request rate
fig.add_trace(
......@@ -272,6 +318,7 @@ class DiagnosticsRecorder:
name="Predicted RPS",
mode="lines",
line=dict(dash="dot"),
connectgaps=True,
),
row=1,
col=2,
......@@ -396,39 +443,35 @@ class DiagnosticsRecorder:
col=1,
)
# 4b. Decode engine load (queued + inflight, one line each per engine)
# 4b. Decode engine load (queued + inflight combined per engine)
for eid in sorted(decode_engine_ids):
y_queued = []
y_inflight = []
y_total = []
for s in snaps:
q, f_ = None, None
val = None
for e in s.decode_engines:
if f"{e.worker_id}:dp{e.dp_rank}" == eid:
q = e.queued_decode_kv_tokens
f_ = e.inflight_decode_kv_tokens
val = e.queued_decode_kv_tokens + e.inflight_decode_kv_tokens
break
y_queued.append(q)
y_inflight.append(f_)
y_total.append(val)
fig.add_trace(
go.Scatter(
x=labels,
y=y_queued,
name=f"D {eid} queued",
y=y_total,
name=f"D {eid} total KV",
mode="lines+markers",
showlegend=False,
),
row=4,
col=2,
)
fig.add_trace(
go.Scatter(
x=labels,
y=y_inflight,
name=f"D {eid} inflight",
mode="lines",
line=dict(dash="dot"),
showlegend=False,
),
# KV capacity line (set by adapter if available)
if self.max_kv_tokens is not None and self.max_kv_tokens > 0:
fig.add_hline(
y=self.max_kv_tokens,
line_dash="dash",
line_color="red",
annotation_text=f"KV Capacity ({self.max_kv_tokens:,})",
row=4,
col=2,
)
......@@ -442,6 +485,7 @@ class DiagnosticsRecorder:
y=_vals("engine_rps_prefill"),
name="Prefill Engine RPS",
mode="lines+markers",
connectgaps=True,
),
row=5,
col=1,
......@@ -452,6 +496,7 @@ class DiagnosticsRecorder:
y=_vals("engine_rps_decode"),
name="Decode Engine RPS",
mode="lines+markers",
connectgaps=True,
),
row=5,
col=1,
......@@ -475,6 +520,7 @@ class DiagnosticsRecorder:
name="Predicted ISL",
mode="lines",
line=dict(dash="dot"),
connectgaps=True,
),
row=5,
col=2,
......@@ -496,14 +542,25 @@ class DiagnosticsRecorder:
name="Predicted OSL",
mode="lines",
line=dict(dash="dot"),
connectgaps=True,
),
row=5,
col=2,
)
# -- Row 6: Decision timelines -----------------------------------
#
# Layout adapts to scaling mode:
# - Disagg (per-component reasons populated): two tracks per
# subplot, prefill on y=2, decode on y=1, with "prefill" /
# "decode" y-axis labels.
# - Agg / easy mode (only aggregate reason populated): single
# track at y=1 with "Load Decision" / "Throughput Decision"
# labels.
# We detect the mode by whether any snapshot has a per-component
# reason set; switching mode mid-run would produce a mixed chart,
# but that doesn't happen because mode is fixed at planner init.
load_reasons = _vals("load_decision_reason")
_LOAD_COLORS = {
"scale_up": "green",
"scale_down": "blue",
......@@ -515,25 +572,6 @@ class DiagnosticsRecorder:
"worker_count_mismatch": "red",
"insufficient_data": "pink",
}
fig.add_trace(
go.Scatter(
x=labels,
y=[1] * len(labels),
mode="markers",
marker=dict(
color=[_LOAD_COLORS.get(r or "", "gray") for r in load_reasons],
size=10,
symbol="square",
),
text=load_reasons,
name="Load Decision",
hoverinfo="text+x",
),
row=6,
col=1,
)
tp_reasons = _vals("throughput_decision_reason")
_TP_COLORS = {
"scale": "green",
"set_lower_bound": "blue",
......@@ -541,32 +579,148 @@ class DiagnosticsRecorder:
"no_traffic_data": "yellow",
"predict_failed": "red",
"model_not_ready": "orange",
"partner_not_ready": "pink",
}
# Detect disagg mode: if any per-component reason is populated,
# plot two horizontal tracks (prefill at y=2, decode at y=1);
# otherwise plot a single aggregate track at y=1.
has_per_component_load = any(
s.load_decision_reason_prefill is not None
or s.load_decision_reason_decode is not None
for s in snaps
)
has_per_component_tp = any(
s.throughput_decision_reason_prefill is not None
or s.throughput_decision_reason_decode is not None
for s in snaps
)
def _add_decision_track(
field_name: str,
y_value: int,
label: str,
colors: dict,
symbol: str,
row: int,
col: int,
) -> None:
reasons = _vals(field_name)
fig.add_trace(
go.Scatter(
x=labels,
y=[1] * len(labels),
y=[y_value] * len(labels),
mode="markers",
marker=dict(
color=[_TP_COLORS.get(r or "", "gray") for r in tp_reasons],
color=[colors.get(r or "", "gray") for r in reasons],
size=10,
symbol="diamond",
symbol=symbol,
),
text=tp_reasons,
name="Throughput Decision",
text=reasons,
name=label,
hoverinfo="text+x",
showlegend=False,
),
row=row,
col=col,
)
if has_per_component_load:
_add_decision_track(
"load_decision_reason_prefill",
2,
"Load (prefill)",
_LOAD_COLORS,
"square",
6,
1,
)
_add_decision_track(
"load_decision_reason_decode",
1,
"Load (decode)",
_LOAD_COLORS,
"square",
6,
1,
)
fig.update_yaxes(
tickmode="array",
tickvals=[1, 2],
ticktext=["decode", "prefill"],
range=[0.5, 2.5],
row=6,
col=1,
)
else:
_add_decision_track(
"load_decision_reason",
1,
"Load Decision",
_LOAD_COLORS,
"square",
6,
1,
)
if has_per_component_tp:
_add_decision_track(
"throughput_decision_reason_prefill",
2,
"TP (prefill)",
_TP_COLORS,
"diamond",
6,
2,
)
_add_decision_track(
"throughput_decision_reason_decode",
1,
"TP (decode)",
_TP_COLORS,
"diamond",
6,
2,
)
fig.update_yaxes(
tickmode="array",
tickvals=[1, 2],
ticktext=["decode", "prefill"],
range=[0.5, 2.5],
row=6,
col=2,
)
else:
_add_decision_track(
"throughput_decision_reason",
1,
"Throughput Decision",
_TP_COLORS,
"diamond",
6,
2,
)
# -- Layout -------------------------------------------------------
num_scaling_events = sum(
1
for s in snaps
if s.scale_to_prefill is not None or s.scale_to_decode is not None
)
# Count actual replica transitions, not just ticks where a decision
# was recorded: two consecutive ticks with scale_to=5 aren't two
# scaling events.
num_scaling_events = 0
prev_p: Optional[int] = None
prev_d: Optional[int] = None
for s in snaps:
cur_p = s.num_prefill_replicas
cur_d = s.num_decode_replicas
if prev_p is not None and cur_p is not None and cur_p != prev_p:
num_scaling_events += 1
if prev_d is not None and cur_d is not None and cur_d != prev_d:
num_scaling_events += 1
if cur_p is not None:
prev_p = cur_p
if cur_d is not None:
prev_d = cur_d
t0 = datetime.fromtimestamp(ts[0], tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S UTC"
)
......@@ -576,7 +730,7 @@ class DiagnosticsRecorder:
summary = (
f"<b>Planner Diagnostics Report</b><br>"
f"Time range: {t0}{t1} ({len(snaps)} ticks)<br>"
f"Scaling events: {num_scaling_events} | "
f"Replica transitions: {num_scaling_events} | "
f"GPU hours: {snaps[-1].gpu_hours:.2f}<br>"
f"SLA targets: TTFT={self.config.ttft:.0f}ms, ITL={self.config.itl:.0f}ms"
)
......
......@@ -137,7 +137,14 @@ class ReplayPlannerAdapter:
self._scaling_target_decode: Optional[int] = None
# Diagnostics recorder for HTML report generation
self._recorder = DiagnosticsRecorder(config=planner_config)
decode_max_kv = (
capabilities.decode.max_kv_tokens
if capabilities and capabilities.decode
else None
)
self._recorder = DiagnosticsRecorder(
config=planner_config, max_kv_tokens=decode_max_kv
)
self._cumulative_gpu_hours: float = 0.0
self._last_tick_s: float = 0.0
self._last_traffic: Metrics = Metrics()
......@@ -218,15 +225,6 @@ class ReplayPlannerAdapter:
self._cumulative_gpu_hours += (num_p * gpu_p + num_d * gpu_d) * dt_h
self._last_tick_s = now_s
# Build observed Metrics from traffic in tick_input
if tick_input.traffic is not None:
t = tick_input.traffic
self._last_traffic = Metrics(
num_req=t.num_req,
isl=t.isl,
osl=t.osl,
)
self._recorder.record(
tick_input,
effects,
......@@ -422,12 +420,24 @@ class ReplayPlannerAdapter:
t = self._bridge.drain_traffic()
duration_s = t.get("duration_s", 0.0)
if duration_s > 0:
num_req = float(t.get("num_req", 0))
traffic = TrafficObservation(
duration_s=duration_s,
num_req=float(t.get("num_req", 0)),
num_req=num_req,
isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0),
)
# Stash observed TTFT/ITL for the diagnostics recorder.
# When num_req == 0, the Rust accumulator returns 0 as a
# placeholder; only record latency values when we actually
# observed requests in this window.
self._last_traffic = Metrics(
ttft=t.get("avg_ttft_ms") if num_req > 0 else None,
itl=t.get("avg_itl_ms") if num_req > 0 else None,
num_req=traffic.num_req,
isl=traffic.isl,
osl=traffic.osl,
)
return TickInput(
now_s=now_s,
......
......@@ -1345,7 +1345,18 @@ impl PlannerReplayBridge {
/// Drain accumulated traffic metrics since the last drain.
///
/// Returns a dict with `duration_s`, `num_req`, `avg_isl`, `avg_osl`.
/// Returns a dict with:
/// - `duration_s` (f64): window length in seconds
/// - `num_req` (usize): completed requests in the window
/// - `avg_isl` (f64): mean input sequence length (tokens)
/// - `avg_osl` (f64): mean output sequence length (tokens)
/// - `avg_ttft_ms` (f64): mean time-to-first-token in milliseconds,
/// averaged only over requests that reported
/// a TTFT sample (0.0 when no samples)
/// - `avg_itl_ms` (f64): mean inter-token latency in milliseconds,
/// averaged only over requests that generated
/// at least one token gap (0.0 when no samples)
///
/// Call this only on throughput-scaling ticks so the observation window
/// covers the full `throughput_adjustment_interval`.
fn drain_traffic(&mut self, py: Python<'_>) -> PyResult<PyObject> {
......@@ -1354,13 +1365,15 @@ impl PlannerReplayBridge {
.as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
let (duration_s, num_req, avg_isl, avg_osl) = handle.drain_traffic();
let stats = handle.drain_traffic();
let result = json!({
"duration_s": duration_s,
"num_req": num_req,
"avg_isl": avg_isl,
"avg_osl": avg_osl,
"duration_s": stats.duration_s,
"num_req": stats.num_req,
"avg_isl": stats.avg_isl,
"avg_osl": stats.avg_osl,
"avg_ttft_ms": stats.avg_ttft_ms,
"avg_itl_ms": stats.avg_itl_ms,
});
pythonize(py, &result)
......
......@@ -18,6 +18,10 @@ if TYPE_CHECKING:
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
ScheduledRequestMetrics,
)
from dynamo.llm import AicPerfConfig, KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
from dynamo.replay.reporting import format_report_table, write_report_json
......@@ -123,6 +127,76 @@ def _engine_caps(args: MockEngineArgs) -> EngineCapabilities:
)
def _generate_aic_prefill_fpms(
aic_session,
engine_args: MockEngineArgs,
granularity: int = 8,
) -> list[ForwardPassMetrics]:
"""Generate prefill benchmark FPMs using AIC predictions.
Sweeps ISL at batch_size=1 within the engine's per-pass token budget
(``max_num_batched_tokens``); a single forward pass physically cannot
process more than that, so the regression shouldn't see larger sums.
For longer ISL, callers use chunked TTFT estimation.
"""
prefill_max = engine_args.max_num_batched_tokens or 8192
prefill_step = max(1, (prefill_max - 100) // granularity)
prefill_fpms: list[ForwardPassMetrics] = []
for isl in range(100, prefill_max + 1, prefill_step):
ttft_ms = aic_session.predict_prefill(1, isl, 0)
if ttft_ms > 0:
prefill_fpms.append(
ForwardPassMetrics(
wall_time=ttft_ms / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_prefill_requests=1,
sum_prefill_tokens=isl,
),
)
)
return prefill_fpms
def _generate_aic_decode_fpms(
aic_session,
engine_args: MockEngineArgs,
granularity: int = 8,
) -> list[ForwardPassMetrics]:
"""Generate decode benchmark FPMs using AIC predictions.
Sweeps (batch_size x context_length). ``granularity`` controls the
number of sample points per axis; the batch-size ceiling comes from
the engine's ``max_num_seqs`` so the regression sees realistic
concurrency, not an artificial cap at the sweep density.
"""
max_kv_tokens = engine_args.num_gpu_blocks * engine_args.block_size
if max_kv_tokens <= 0:
max_kv_tokens = 16384 * 16
decode_fpms: list[ForwardPassMetrics] = []
ctx_lengths = [500, 2000, 4000, 8000]
bs_max = engine_args.max_num_seqs or 256
bs_step = max(1, bs_max // granularity)
for ctx_len in ctx_lengths:
for bs in range(1, bs_max + 1, bs_step):
sum_kv = bs * ctx_len
if sum_kv > max_kv_tokens:
break
itl_ms = aic_session.predict_decode(bs, ctx_len, 2)
if itl_ms > 0:
decode_fpms.append(
ForwardPassMetrics(
wall_time=itl_ms / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_decode_requests=bs,
sum_decode_kv_tokens=sum_kv,
),
)
)
return decode_fpms
def _run_planner_replay(
trace_file: str,
extra_engine_args: MockEngineArgs | None,
......@@ -136,6 +210,7 @@ def _run_planner_replay(
arrival_speedup_ratio: float,
trace_block_size: int,
planner_config_arg: str,
benchmark_granularity: int = 8,
):
"""Run an offline replay with planner-in-the-loop (agg or disagg).
......@@ -200,6 +275,107 @@ def _run_planner_replay(
bridge=bridge,
capabilities=capabilities,
)
# Bootstrap regression models from mocker's perf model.
# AIC provides accurate batch-size-aware timing that works with the
# planner's linear regression. The default polynomial model cannot
# feed the throughput regression (its decode formula is quadratic in
# utilization ratio, causing negative regression coefficients).
if not adapter._sm._is_easy:
ref_args = extra_engine_args or prefill_engine_args or MockEngineArgs()
aic_backend = ref_args.aic_backend
if (
aic_backend is None
or ref_args.aic_system is None
or ref_args.aic_model_path is None
):
sys.stderr.write(
"Note: throughput-based scaling regression requires AIC perf model "
"(set aic_backend/aic_system/aic_model_path in --extra-engine-args). "
"Falling back to load-based scaling only.\n"
)
else:
# Create AIC session -- narrow to the concrete exception types
# AIC/PyO3 can raise so we degrade gracefully on missing
# dependencies or bad config, but don't swallow unrelated bugs
# (AttributeError, KeyboardInterrupt, etc.) introduced by
# refactors.
try:
from dynamo._internal.aic import create_session
aic_session = create_session(
backend_name=aic_backend,
system=ref_args.aic_system,
model_path=ref_args.aic_model_path,
tp_size=ref_args.aic_tp_size or 1,
backend_version=ref_args.aic_backend_version,
)
except (
ImportError,
RuntimeError,
ValueError,
KeyError,
FileNotFoundError,
) as e:
sys.stderr.write(
f"Warning: AIC session creation failed ({e}); "
"throughput regression will not be bootstrapped.\n"
)
aic_session = None
# Generate benchmark FPMs and load into regression. Disagg
# prefill and decode engines typically have different
# max_num_seqs and KV cache sizes, so each sweep uses its own
# engine args. Agg has a single engine so uses one set. AIC's
# predict_* can raise on unsupported model/system combos or
# numerical edge cases; log and fall back in those cases.
if aic_session is not None:
p_args = (
extra_engine_args
if planner_config.mode == "agg"
else prefill_engine_args
) or ref_args
d_args = (
extra_engine_args
if planner_config.mode == "agg"
else decode_engine_args
) or ref_args
try:
prefill_fpms = _generate_aic_prefill_fpms(
aic_session, p_args, benchmark_granularity
)
decode_fpms = _generate_aic_decode_fpms(
aic_session, d_args, benchmark_granularity
)
except (RuntimeError, ValueError, KeyError, ArithmeticError) as e:
sys.stderr.write(
f"Warning: AIC benchmark generation failed ({e}); "
"throughput regression will not be bootstrapped.\n"
)
prefill_fpms, decode_fpms = [], []
if planner_config.mode == "agg":
# Agg regression fits on (sum_prefill_tokens, sum_decode_kv_tokens);
# combine prefill-only and decode-only points so both features
# have variance.
agg_fpms = prefill_fpms + decode_fpms
if agg_fpms:
adapter._sm.load_benchmark_fpms(agg_fpms=agg_fpms)
else:
sys.stderr.write(
"Warning: AIC produced no agg benchmark FPMs\n"
)
else:
if prefill_fpms and decode_fpms:
adapter._sm.load_benchmark_fpms(
prefill_fpms=prefill_fpms, decode_fpms=decode_fpms
)
else:
sys.stderr.write(
f"Warning: AIC produced empty benchmark FPMs "
f"(prefill={len(prefill_fpms)}, decode={len(decode_fpms)})\n"
)
return adapter.run()
......@@ -256,6 +432,12 @@ def main(argv: Sequence[str] | None = None) -> int:
"--planner-config",
help="path to planner config YAML/JSON or inline JSON; enables planner-in-the-loop replay (offline agg only)",
)
parser.add_argument(
"--benchmark-granularity",
type=int,
default=8,
help="number of sweep points for synthetic perf model benchmark (default: 8, matching profiler)",
)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None
......@@ -311,6 +493,7 @@ def main(argv: Sequence[str] | None = None) -> int:
arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size,
planner_config_arg=args.planner_config,
benchmark_granularity=args.benchmark_granularity,
)
report = planner_report.trace_report
if planner_report.scaling_events:
......
......@@ -256,6 +256,15 @@ impl TraceCollector {
}
}
/// Return (ttft_ms, mean_itl_ms) for a completed request, if available.
pub(crate) fn request_latencies(&self, uuid: Uuid) -> Option<(f64, f64)> {
let stats = self.requests.get(&uuid)?;
let first_token_ms = stats.first_token_ms()?;
let ttft_ms = (first_token_ms - stats.arrival_time_ms).max(0.0);
let mean_itl_ms = stats.mean_tpot_ms().unwrap_or(0.0);
Some((ttft_ms, mean_itl_ms))
}
pub(crate) fn finish(self) -> TraceSimulationReport {
let requests = self.requests;
let request_count = requests.len();
......
......@@ -17,7 +17,7 @@ use super::state::OfflineWorkerSnapshot;
use super::{
components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
ScheduledWorkerCompletion, TrafficAccumulator, TrafficStats, WorkerAdmission,
},
state::AggRequestState,
};
......@@ -381,8 +381,12 @@ impl AggRuntime {
let removed_state = self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?;
self.traffic
.on_request(removed_state.input_tokens, removed_state.output_tokens);
let latencies = self.collector.request_latencies(signal.uuid);
self.traffic.on_request(
removed_state.input_tokens,
removed_state.output_tokens,
latencies,
);
self.admission
.on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed();
......@@ -603,7 +607,7 @@ impl AggRuntime {
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
pub(in crate::replay) fn drain_traffic(&mut self) -> TrafficStats {
self.traffic.drain(self.now_ms)
}
......
......@@ -12,6 +12,7 @@ pub(crate) use router::OfflineReplayRouter;
#[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay) use types::ReplayMode;
pub use types::TrafficStats;
pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ScheduledWorkerCompletion, TrafficAccumulator,
};
......
......@@ -65,14 +65,41 @@ pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
}
/// Accumulated traffic statistics returned by [`TrafficAccumulator::drain`].
///
/// IMPORTANT: When fields here are added or renamed, update the PyO3
/// binding in ``lib/bindings/python/rust/llm/replay.rs`` (drain_traffic
/// method) so the exported JSON dict matches. The Python adapter in
/// ``replay_adapter.py`` reads these keys by name.
#[derive(Debug, Clone)]
pub struct TrafficStats {
pub duration_s: f64,
pub num_req: usize,
pub avg_isl: f64,
pub avg_osl: f64,
pub avg_ttft_ms: f64,
pub avg_itl_ms: f64,
}
/// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window).
///
/// Latency samples are tracked independently of request counts: a request
/// only contributes to ``total_ttft_ms`` / ``ttft_count`` if a positive TTFT
/// was recorded, and similarly for ITL. This means ``avg_ttft_ms`` and
/// ``avg_itl_ms`` reflect only requests that actually produced the sample,
/// rather than silently underestimating when some requests lack latency
/// data (e.g. requests that fail before emitting a token).
#[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64,
num_req: usize,
total_isl: usize,
total_osl: usize,
total_ttft_ms: f64,
total_itl_ms: f64,
ttft_count: usize,
itl_count: usize,
}
impl TrafficAccumulator {
......@@ -82,23 +109,37 @@ impl TrafficAccumulator {
num_req: 0,
total_isl: 0,
total_osl: 0,
total_ttft_ms: 0.0,
total_itl_ms: 0.0,
ttft_count: 0,
itl_count: 0,
}
}
/// Record one admitted request.
/// Record one completed request with optional latency data.
pub(in crate::replay::offline) fn on_request(
&mut self,
input_tokens: usize,
output_tokens: usize,
latencies: Option<(f64, f64)>,
) {
self.num_req += 1;
self.total_isl += input_tokens;
self.total_osl += output_tokens;
if let Some((ttft_ms, mean_itl_ms)) = latencies {
if ttft_ms > 0.0 {
self.total_ttft_ms += ttft_ms;
self.ttft_count += 1;
}
if mean_itl_ms > 0.0 {
self.total_itl_ms += mean_itl_ms;
self.itl_count += 1;
}
}
}
/// Drain the accumulator at the given simulated time, returning
/// (duration_s, num_req, avg_isl, avg_osl) and resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> (f64, usize, f64, f64) {
/// Drain the accumulator at the given simulated time, resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> TrafficStats {
let duration_s = (now_ms - self.window_start_ms) / 1000.0;
let num_req = self.num_req;
let avg_isl = if num_req > 0 {
......@@ -111,10 +152,31 @@ impl TrafficAccumulator {
} else {
0.0
};
let avg_ttft_ms = if self.ttft_count > 0 {
self.total_ttft_ms / self.ttft_count as f64
} else {
0.0
};
let avg_itl_ms = if self.itl_count > 0 {
self.total_itl_ms / self.itl_count as f64
} else {
0.0
};
self.window_start_ms = now_ms;
self.num_req = 0;
self.total_isl = 0;
self.total_osl = 0;
(duration_s, num_req, avg_isl, avg_osl)
self.total_ttft_ms = 0.0;
self.total_itl_ms = 0.0;
self.ttft_count = 0;
self.itl_count = 0;
TrafficStats {
duration_s,
num_req,
avg_isl,
avg_osl,
avg_ttft_ms,
avg_itl_ms,
}
}
}
......@@ -11,7 +11,7 @@ use uuid::Uuid;
pub(super) use super::components::ReplayMode;
use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission,
ScheduledWorkerCompletion, TrafficAccumulator, TrafficStats, WorkerAdmission,
};
use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress;
......@@ -519,7 +519,9 @@ impl DisaggRuntime {
let original = state.original_request()?;
let input_tokens = original.tokens.len();
let output_tokens = original.max_output_tokens;
self.traffic.on_request(input_tokens, output_tokens);
let latencies = self.collector.request_latencies(signal.uuid);
self.traffic
.on_request(input_tokens, output_tokens, latencies);
self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)]
{
......@@ -831,7 +833,7 @@ impl DisaggRuntime {
}
/// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
pub(in crate::replay) fn drain_traffic(&mut self) -> TrafficStats {
self.traffic.drain(self.now_ms)
}
......
......@@ -15,7 +15,7 @@ use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;
use super::offline::agg::AggRuntime;
use super::offline::components::ReplayMode;
use super::offline::components::{ReplayMode, TrafficStats};
use super::offline::disagg::DisaggRuntime;
use super::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport,
......@@ -162,10 +162,9 @@ impl PlannerReplayHandle {
/// Drain accumulated traffic metrics since the last drain.
///
/// Returns `(duration_s, num_req, avg_isl, avg_osl)`. Call this only on
/// throughput-scaling ticks so the window covers the full
/// Call this only on throughput-scaling ticks so the window covers the full
/// `throughput_adjustment_interval`, not just the gap between load ticks.
pub fn drain_traffic(&mut self) -> (f64, usize, f64, f64) {
pub fn drain_traffic(&mut self) -> TrafficStats {
match &mut self.runtime {
RuntimeKind::Agg(rt) => rt.drain_traffic(),
RuntimeKind::Disagg(rt) => rt.drain_traffic(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment