Unverified Commit 0c7068dd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix(replay/planner): planner-in-the-loop replay diagnostics + planner scaling logics fixes (#8280)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 4e7c6afd
...@@ -41,6 +41,8 @@ class LoadScalingMixin: ...@@ -41,6 +41,8 @@ class LoadScalingMixin:
_diag_estimated_ttft_ms: Optional[float] _diag_estimated_ttft_ms: Optional[float]
_diag_estimated_itl_ms: Optional[float] _diag_estimated_itl_ms: Optional[float]
_diag_load_reason: Optional[str] _diag_load_reason: Optional[str]
_diag_load_reason_prefill: Optional[str]
_diag_load_reason_decode: Optional[str]
def _advance_load(self, obs: FpmObservations) -> Optional[ScalingDecision]: def _advance_load(self, obs: FpmObservations) -> Optional[ScalingDecision]:
if not self._config.enable_load_scaling: if not self._config.enable_load_scaling:
...@@ -122,16 +124,22 @@ class LoadScalingMixin: ...@@ -122,16 +124,22 @@ class LoadScalingMixin:
if not p_stats and not d_stats: if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping") logger.warning("No FPM data for either prefill or decode, skipping")
self._diag_load_reason = "no_fpm_data" self._diag_load_reason = "no_fpm_data"
self._diag_load_reason_prefill = "no_fpm_data"
self._diag_load_reason_decode = "no_fpm_data"
return None return None
if p_stats and not self._reconcile_fpm_worker_count( if p_stats and not self._reconcile_fpm_worker_count(
p_stats, self._num_p_workers, "prefill" p_stats, self._num_p_workers, "prefill"
): ):
self._diag_load_reason = "worker_count_mismatch" self._diag_load_reason = "worker_count_mismatch"
self._diag_load_reason_prefill = "worker_count_mismatch"
self._diag_load_reason_decode = "worker_count_mismatch"
return None return None
if d_stats and not self._reconcile_fpm_worker_count( if d_stats and not self._reconcile_fpm_worker_count(
d_stats, self._num_d_workers, "decode" d_stats, self._num_d_workers, "decode"
): ):
self._diag_load_reason = "worker_count_mismatch" self._diag_load_reason = "worker_count_mismatch"
self._diag_load_reason_prefill = "worker_count_mismatch"
self._diag_load_reason_decode = "worker_count_mismatch"
return None return None
easy = self._config.optimization_target != "sla" easy = self._config.optimization_target != "sla"
...@@ -157,30 +165,61 @@ class LoadScalingMixin: ...@@ -157,30 +165,61 @@ class LoadScalingMixin:
final_p = p_desired if p_desired is not None else self._num_p_workers final_p = p_desired if p_desired is not None else self._num_p_workers
final_d = d_desired if d_desired is not None else self._num_d_workers final_d = d_desired if d_desired is not None else self._num_d_workers
if final_p == self._num_p_workers and final_d == self._num_d_workers: # Enforce bounds first so "no change" comparison is against the
logger.info("Load-based scaling: no scaling needed") # post-floor target, not the raw load decision. Otherwise a load
self._diag_load_reason = "no_change" # decision of "no change" would skip the floor and let replicas
return None # stay below a throughput-scaling lower bound that was raised on
# a previous (or same) tick.
original_p, original_d = final_p, final_d original_p, original_d = final_p, final_d
# Apply throughput floor first and track the post-floor value so we
# can attribute later lifts to their real source -- throughput
# capping is a distinct diagnostic from min_endpoint / global-budget
# lifts, which should not be labelled "scale_down_capped_by_throughput".
if self._config.enable_throughput_scaling: if self._config.enable_throughput_scaling:
final_p = max(final_p, self._throughput_lower_bound_p) final_p = max(final_p, self._throughput_lower_bound_p)
final_d = max(final_d, self._throughput_lower_bound_d) final_d = max(final_d, self._throughput_lower_bound_d)
post_floor_p, post_floor_d = final_p, final_d
final_p = max(final_p, self._config.min_endpoint) final_p = max(final_p, self._config.min_endpoint)
final_d = max(final_d, self._config.min_endpoint) final_d = max(final_d, self._config.min_endpoint)
final_p, final_d = self._apply_global_budget(final_p, final_d) final_p, final_d = self._apply_global_budget(final_p, final_d)
if (final_p > original_p or final_d > original_d) and ( # Per-component reasons
original_p < self._num_p_workers or original_d < self._num_d_workers def _reason(final: int, original: int, post_floor: int, current: int) -> str:
): # Only classify as throughput-capped when the throughput floor
self._diag_load_reason = "scale_down_capped_by_throughput" # itself lifted the load decision; later min_endpoint / budget
elif final_p > self._num_p_workers or final_d > self._num_d_workers: # adjustments don't count.
self._diag_load_reason = "scale_up" floor_capped = post_floor > original and original < current
elif final_p < self._num_p_workers or final_d < self._num_d_workers: if final > current:
self._diag_load_reason = "scale_down" return "scale_up"
else: if final < current:
self._diag_load_reason = "no_change" return (
"scale_down_capped_by_throughput" if floor_capped else "scale_down"
)
return "scale_down_capped_by_throughput" if floor_capped else "no_change"
self._diag_load_reason_prefill = _reason(
final_p, original_p, post_floor_p, self._num_p_workers
)
self._diag_load_reason_decode = _reason(
final_d, original_d, post_floor_d, self._num_d_workers
)
# Aggregate reason: prioritise "most interesting" across components.
_PRIORITY = {
"scale_up": 4,
"scale_down_capped_by_throughput": 3,
"scale_down": 2,
"no_change": 1,
}
self._diag_load_reason = max(
(self._diag_load_reason_prefill, self._diag_load_reason_decode),
key=lambda r: _PRIORITY.get(r or "", 0),
)
if final_p == self._num_p_workers and final_d == self._num_d_workers:
logger.info("Load-based scaling: no scaling needed")
return None
logger.info( logger.info(
f"Load-based disagg scaling: prefill {self._num_p_workers}->{final_p}, " f"Load-based disagg scaling: prefill {self._num_p_workers}->{final_p}, "
...@@ -270,9 +309,10 @@ class LoadScalingMixin: ...@@ -270,9 +309,10 @@ class LoadScalingMixin:
): ):
desired = max(p_desired, d_desired) desired = max(p_desired, d_desired)
else: else:
logger.info("Agg scaling: no scaling needed") # Load scaling sees "no change" -- but the throughput floor may
self._diag_load_reason = "no_change" # still require scaling up, so keep processing rather than
return None # returning early.
desired = num_workers
original_desired = desired original_desired = desired
desired = max(desired, self._config.min_endpoint) desired = max(desired, self._config.min_endpoint)
...@@ -280,15 +320,23 @@ class LoadScalingMixin: ...@@ -280,15 +320,23 @@ class LoadScalingMixin:
desired = max(desired, self._throughput_lower_bound_d) desired = max(desired, self._throughput_lower_bound_d)
desired = self._apply_single_budget(desired, "decode") desired = self._apply_single_budget(desired, "decode")
# Preserve "load wanted to scale down but floor lifted it" as a
# distinct diagnostic reason even when the net result is no change.
floor_capped = desired > original_desired and original_desired < num_workers
if desired == num_workers:
logger.info("Agg scaling: no scaling needed")
self._diag_load_reason = (
"scale_down_capped_by_throughput" if floor_capped else "no_change"
)
return None
if desired < num_workers: if desired < num_workers:
if desired > original_desired: self._diag_load_reason = (
self._diag_load_reason = "scale_down_capped_by_throughput" "scale_down_capped_by_throughput" if floor_capped else "scale_down"
else: )
self._diag_load_reason = "scale_down" else: # desired > num_workers (equality returned above)
elif desired > num_workers:
self._diag_load_reason = "scale_up" self._diag_load_reason = "scale_up"
else:
self._diag_load_reason = "no_change"
logger.info(f"Agg load-based scaling: {num_workers} -> {desired}") logger.info(f"Agg load-based scaling: {num_workers} -> {desired}")
return ScalingDecision(num_decode=desired) return ScalingDecision(num_decode=desired)
......
...@@ -119,6 +119,8 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -119,6 +119,8 @@ class AggRegressionModel(_BaseRegressionModel):
max_num_batched_tokens: int, max_num_batched_tokens: int,
ttft_sla: float, ttft_sla: float,
itl_sla: float, itl_sla: float,
max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None,
) -> tuple[float, float, float]: ) -> tuple[float, float, float]:
"""Find the maximum agg engine request rate under both SLA targets. """Find the maximum agg engine request rate under both SLA targets.
...@@ -129,19 +131,28 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -129,19 +131,28 @@ class AggRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law: Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``. ``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
Args: The upper bound for the batch-size sweep is the smallest of:
isl: average input sequence length (tokens). 1. KV cache capacity: ``max_kv_tokens / (isl + osl/2)``
osl: average output sequence length (tokens). 2. ``max_num_seqs`` (engine concurrency limit)
max_num_batched_tokens: per-iteration token budget. 3. The prefill/decode rate-balance point (steady state). For a
ttft_sla: TTFT target in milliseconds. batch of size ``x``:
itl_sla: ITL target in milliseconds. - Decode egress rate: ``x / osl`` requests finish per iter
(x concurrent streams, each taking osl decode iters).
Returns: - Prefill admission rate: ``(max_num_batched_tokens - x) / isl``
(engine_rps, actual_ttft_ms, actual_itl_ms) -- 0 rps requests admitted per iter (the budget left after decode
signals an error (model not fitted or invalid input); takes one slot per in-flight request, divided by isl tokens
positive rps is the best achievable rate with the per new request).
predicted TTFT/ITL. If SLAs are violated, a warning Steady state requires admission >= egress:
is logged but the rate is still returned. ``(max_num_batched_tokens - x) / isl >= x / osl``,
which simplifies to
``isl / (max_num_batched_tokens - x) <= osl``
(the check implemented below), or equivalently
``x <= osl * max_num_batched_tokens / (isl + osl)``.
Above this, prefill becomes the bottleneck and TTFT grows
unbounded.
The caller guarantees ``osl > 0`` and ``max_num_batched_tokens > 0``
via the early-return validation above.
""" """
if ( if (
not self._ensure_fitted() not self._ensure_fitted()
...@@ -152,7 +163,34 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -152,7 +163,34 @@ class AggRegressionModel(_BaseRegressionModel):
return (0.0, 0.0, 0.0) return (0.0, 0.0, 0.0)
avg_ctx = isl + osl / 2.0 avg_ctx = isl + osl / 2.0
max_bs = max(1, int(max_num_batched_tokens / max(1, avg_ctx))) * 2
# KV cache cap
kv_cap = (
max(1, int(max_kv_tokens / max(1.0, avg_ctx)))
if max_kv_tokens and max_kv_tokens > 0
else 1024 # large fallback when capability not known
)
# Concurrency cap
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
# Prefill/decode balance cap via binary search within [1, min(kv_cap, seq_cap)]
# For each candidate x, check: isl / (max_num_batched_tokens - x) <= osl
hard_cap = min(kv_cap, seq_cap, max_num_batched_tokens - 1)
def _prefill_balanced(x: int) -> bool:
prefill_budget = max_num_batched_tokens - x
if prefill_budget <= 0:
return False
return isl / prefill_budget <= osl
lo, hi = 1, max(1, hard_cap)
while lo < hi:
mid = (lo + hi + 1) // 2
if _prefill_balanced(mid):
lo = mid
else:
hi = mid - 1
max_bs = lo
best_rps = 0.0 best_rps = 0.0
best_ttft_ms = 0.0 best_ttft_ms = 0.0
......
...@@ -186,6 +186,28 @@ class _BaseRegressionModel: ...@@ -186,6 +186,28 @@ class _BaseRegressionModel:
def _gather_observations(self) -> list[tuple[Union[float, list[float]], float]]: def _gather_observations(self) -> list[tuple[Union[float, list[float]], float]]:
return [obs for bucket in self._buckets.values() for obs in bucket] return [obs for bucket in self._buckets.values() for obs in bucket]
# Feature indices whose coefficients are allowed to go slightly negative
# under noisy fits without rejecting the whole model. Used when a feature
# has weak signal and can flip sign under multicollinearity without
# violating physical monotonicity (e.g. num_decode_requests vs
# sum_decode_kv_tokens: total KV dominates wall time, so a small negative
# batch-size coefficient is numerical noise, not "more requests → less
# work"). Subclasses override via ``_relaxable_feature_indices``.
#
# Most features should stay non-negative: both AggRegressionModel and
# PrefillRegressionModel operate on token counts (sum_prefill_tokens,
# sum_decode_kv_tokens) that directly drive GPU compute and therefore
# must have positive coefficients. Only DecodeRegressionModel relaxes
# index 0 (num_decode_requests), which is a weaker secondary feature.
_relaxable_feature_indices: frozenset[int] = frozenset()
# Coefficients within this band of 0 are treated as numerical noise
# when a feature is marked relaxable. Anything more negative than this
# implies the regression is learning an inverted relationship and is
# rejected (or clipped, for relaxable features) so the planner does not
# scale on physically impossible predictions.
_RELAXABLE_NEG_TOLERANCE = 1e-6
def _fit(self) -> bool: def _fit(self) -> bool:
observations = self._gather_observations() observations = self._gather_observations()
if len(observations) < self.min_observations: if len(observations) < self.min_observations:
...@@ -199,13 +221,37 @@ class _BaseRegressionModel: ...@@ -199,13 +221,37 @@ class _BaseRegressionModel:
# Negative coefficients mean "more load → less compute time", which # Negative coefficients mean "more load → less compute time", which
# is physically impossible. Reject the fit so callers see the model # is physically impossible. Reject the fit so callers see the model
# as not ready rather than making inverted scaling decisions. # as not ready rather than making inverted scaling decisions.
if np.any(self._model.coef_ < 0): # Exception: features in ``_relaxable_feature_indices`` may go
# slightly negative due to multicollinearity / noise; accept tiny
# (within-tolerance) negatives as-is, and clamp larger relaxable
# negatives to 0 so predictions remain monotone in that feature.
coef = self._model.coef_
neg_mask = coef < 0
if np.any(neg_mask):
non_relaxable_negs = [
i
for i in range(len(coef))
if neg_mask[i] and i not in self._relaxable_feature_indices
]
if non_relaxable_negs:
logger.warning( logger.warning(
f"Regression produced negative coefficients {self._model.coef_.tolist()}, " f"Regression produced negative coefficients {coef.tolist()}, "
"model rejected — scaling will be skipped until more data arrives" "model rejected — scaling will be skipped until more data arrives"
) )
self._is_fitted = False self._is_fitted = False
return False return False
# Any negatives remaining here are on relaxable features. Clamp
# those that exceed the noise tolerance so the model never
# predicts lower wall time for higher values of that feature.
large_negs = neg_mask & (coef < -self._RELAXABLE_NEG_TOLERANCE)
if np.any(large_negs):
logger.debug(
"Clamped large negative relaxable coefficients at indices "
"%s from %s to 0",
[i for i in range(len(coef)) if large_negs[i]],
coef.tolist(),
)
coef[large_negs] = 0.0
self._is_fitted = True self._is_fitted = True
return True return True
......
...@@ -18,7 +18,20 @@ logger = logging.getLogger(__name__) ...@@ -18,7 +18,20 @@ logger = logging.getLogger(__name__)
class DecodeRegressionModel(_BaseRegressionModel): class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from decode batch composition.""" """Predict per-iteration wall time from decode batch composition.
Features: ``[num_decode_requests, sum_decode_kv_tokens]``. The
``sum_decode_kv_tokens`` feature dominates wall time via attention
compute, while ``num_decode_requests`` has a weaker secondary effect
from linear-layer work. Under multicollinearity (both features scale
with batch size), the ``num_decode_requests`` coefficient can flip
sign under noisy fits; we accept the small negative value since
``sum_decode_kv_tokens`` keeps the overall prediction monotone.
"""
# num_decode_requests (index 0) is relaxable; sum_decode_kv_tokens (index 1)
# must remain non-negative.
_relaxable_feature_indices = frozenset({0})
def __init__( def __init__(
self, self,
...@@ -69,7 +82,12 @@ class DecodeRegressionModel(_BaseRegressionModel): ...@@ -69,7 +82,12 @@ class DecodeRegressionModel(_BaseRegressionModel):
return self._predict_2d(num_req, total_kv) return self._predict_2d(num_req, total_kv)
def find_best_engine_decode_rps( def find_best_engine_decode_rps(
self, itl: float, context_length: float, osl: float self,
itl: float,
context_length: float,
osl: float,
max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None,
) -> tuple[float, float]: ) -> tuple[float, float]:
"""Find the maximum decode engine request rate within an ITL target. """Find the maximum decode engine request rate within an ITL target.
...@@ -81,6 +99,12 @@ class DecodeRegressionModel(_BaseRegressionModel): ...@@ -81,6 +99,12 @@ class DecodeRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law: Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``. ``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
The upper bound of the sweep is the smallest of:
- ``max_kv_tokens / context_length`` -- KV cache capacity
- ``max_num_seqs`` -- engine concurrency limit
Falls back to ``_max_observed_kv / context_length`` (or 256) if
neither capability is provided.
Returns: Returns:
(engine_rps, actual_itl_ms) -- 0 rps signals an error (engine_rps, actual_itl_ms) -- 0 rps signals an error
(model not fitted or invalid input); positive rps is (model not fitted or invalid input); positive rps is
...@@ -89,11 +113,14 @@ class DecodeRegressionModel(_BaseRegressionModel): ...@@ -89,11 +113,14 @@ class DecodeRegressionModel(_BaseRegressionModel):
if not self._ensure_fitted() or context_length <= 0 or osl <= 0 or itl <= 0: if not self._ensure_fitted() or context_length <= 0 or osl <= 0 or itl <= 0:
return (0.0, 0.0) return (0.0, 0.0)
max_batch = ( if max_kv_tokens and max_kv_tokens > 0:
max(1, int(self._max_observed_kv / context_length)) kv_cap = max(1, int(max_kv_tokens / context_length))
if self._max_observed_kv > 0 elif self._max_observed_kv > 0:
else 256 kv_cap = max(1, int(self._max_observed_kv / context_length))
) else:
kv_cap = 256
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
max_batch = max(1, min(kv_cap, seq_cap))
lo, hi = 1, max_batch lo, hi = 1, max_batch
best_bs, best_wt = 1, self._predict_2d(1, context_length) best_bs, best_wt = 1, self._predict_2d(1, context_length)
......
...@@ -80,14 +80,28 @@ class PrefillRegressionModel(_BaseRegressionModel): ...@@ -80,14 +80,28 @@ class PrefillRegressionModel(_BaseRegressionModel):
return total_time return total_time
def find_best_engine_prefill_rps( def find_best_engine_prefill_rps(
self, ttft_sla: float, isl: float self,
ttft_sla: float,
isl: float,
max_num_batched_tokens: Optional[int] = None,
) -> tuple[float, float]: ) -> tuple[float, float]:
"""Find prefill engine request rate under a TTFT target. """Find prefill engine request rate under a TTFT target.
Predicts wall_time for a single prefill at the given ISL. Predicts wall_time for a single prefill at the given ISL and
If the predicted TTFT exceeds the SLA, logs a warning but derives engine_rps = 1 / wt. This formula assumes the regression
still returns the best achievable rate so the caller can scales roughly linearly with sum_prefill_tokens: under that
scale based on load matching. assumption batching multiple prefills (each with ISL tokens) gives
the same engine_rps as one-request-at-a-time, because a batch of
B requests has wt ≈ k·B·ISL, so rate = B/wt = 1/(k·ISL).
If ISL exceeds max_num_batched_tokens, a single request must be
chunked across multiple forward passes. We compute wall_time as
ceil(ISL / MBT) * wt(MBT-sized chunk) to stay within the model's
training domain.
If the predicted TTFT exceeds the SLA, logs a warning but still
returns the best achievable rate so the caller can scale based
on load matching.
Returns: Returns:
(engine_rps, actual_ttft_ms) -- 0 rps signals an error (engine_rps, actual_ttft_ms) -- 0 rps signals an error
...@@ -96,7 +110,28 @@ class PrefillRegressionModel(_BaseRegressionModel): ...@@ -96,7 +110,28 @@ class PrefillRegressionModel(_BaseRegressionModel):
""" """
if not self._ensure_fitted() or isl <= 0: if not self._ensure_fitted() or isl <= 0:
return (0.0, 0.0) return (0.0, 0.0)
# Chunk long prefills so we stay within the regression's training
# domain: a single forward pass never processes more than
# max_num_batched_tokens tokens. At the boundary isl ==
# max_num_batched_tokens, the `else` branch handles it as a single
# pass (no chunking needed); strict `>` inequality is deliberate.
if (
max_num_batched_tokens
and max_num_batched_tokens > 0
and isl > max_num_batched_tokens
):
num_chunks = math.ceil(isl / max_num_batched_tokens)
# remainder is the size of the final (possibly partial) chunk.
# Invariant: remainder ∈ (0, max_num_batched_tokens] by
# construction of num_chunks via math.ceil.
remainder = isl - (num_chunks - 1) * max_num_batched_tokens
wt = (num_chunks - 1) * self._predict_wall_time(
float(max_num_batched_tokens)
) + self._predict_wall_time(remainder)
else:
wt = self._predict_wall_time(isl) wt = self._predict_wall_time(isl)
actual_ttft_ms = wt * 1000.0 actual_ttft_ms = wt * 1000.0
engine_rps = 1.0 / wt engine_rps = 1.0 / wt
if actual_ttft_ms > ttft_sla: if actual_ttft_ms > ttft_sla:
......
...@@ -116,6 +116,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -116,6 +116,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_engine_rps_decode: Optional[float] = None self._diag_engine_rps_decode: Optional[float] = None
self._diag_load_reason: Optional[str] = None self._diag_load_reason: Optional[str] = None
self._diag_throughput_reason: Optional[str] = None self._diag_throughput_reason: Optional[str] = None
self._diag_load_reason_prefill: Optional[str] = None
self._diag_load_reason_decode: Optional[str] = None
self._diag_throughput_reason_prefill: Optional[str] = None
self._diag_throughput_reason_decode: Optional[str] = None
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API # Public API
...@@ -170,6 +174,24 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -170,6 +174,24 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
if tick_input.worker_counts is not None: if tick_input.worker_counts is not None:
self._update_inventory(tick_input.worker_counts) self._update_inventory(tick_input.worker_counts)
# Run throughput scaling first so any updated lower bound is visible
# to the load scaling pass on a combined tick. Otherwise load scaling
# reads the stale bound, potentially deciding to scale below the new
# floor set in this same tick.
#
# We always advance _next_throughput_s on a throughput tick, even if
# no traffic was available, so the planner keeps the throughput
# cadence stable rather than re-firing back-to-back ticks whenever
# traffic is temporarily absent.
throughput_decision = None
if tick.run_throughput_scaling:
if tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
throughput_decision = self._advance_throughput(tick_input.traffic)
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
if tick.run_load_scaling: if tick.run_load_scaling:
if tick_input.fpm_observations is not None: if tick_input.fpm_observations is not None:
if not self._is_easy: if not self._is_easy:
...@@ -179,16 +201,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -179,16 +201,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
effects.scale_to = load_decision effects.scale_to = load_decision
self._next_load_s = tick_input.now_s + self._config.load_adjustment_interval self._next_load_s = tick_input.now_s + self._config.load_adjustment_interval
if tick.run_throughput_scaling: # Load scaling has precedence when it produced a decision; otherwise
if tick_input.traffic is not None: # fall back to the throughput-scaling decision.
self._observe_traffic(tick_input.traffic) if effects.scale_to is None and throughput_decision is not None:
throughput_decision = self._advance_throughput(tick_input.traffic)
if throughput_decision is not None:
if effects.scale_to is None:
effects.scale_to = throughput_decision effects.scale_to = throughput_decision
self._next_throughput_s = (
tick_input.now_s + self._config.throughput_adjustment_interval
)
effects.diagnostics = self._build_diagnostics() effects.diagnostics = self._build_diagnostics()
effects.next_tick = self._next_scheduled_tick() effects.next_tick = self._next_scheduled_tick()
...@@ -204,6 +220,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -204,6 +220,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_engine_rps_decode = None self._diag_engine_rps_decode = None
self._diag_load_reason = None self._diag_load_reason = None
self._diag_throughput_reason = None self._diag_throughput_reason = None
self._diag_load_reason_prefill = None
self._diag_load_reason_decode = None
self._diag_throughput_reason_prefill = None
self._diag_throughput_reason_decode = None
def _build_diagnostics(self) -> TickDiagnostics: def _build_diagnostics(self) -> TickDiagnostics:
return TickDiagnostics( return TickDiagnostics(
...@@ -214,8 +234,14 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -214,8 +234,14 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
predicted_osl=self._diag_predicted_osl, predicted_osl=self._diag_predicted_osl,
engine_rps_prefill=self._diag_engine_rps_prefill, engine_rps_prefill=self._diag_engine_rps_prefill,
engine_rps_decode=self._diag_engine_rps_decode, engine_rps_decode=self._diag_engine_rps_decode,
throughput_lower_bound_prefill=self._throughput_lower_bound_p,
throughput_lower_bound_decode=self._throughput_lower_bound_d,
load_decision_reason=self._diag_load_reason, load_decision_reason=self._diag_load_reason,
throughput_decision_reason=self._diag_throughput_reason, throughput_decision_reason=self._diag_throughput_reason,
load_decision_reason_prefill=self._diag_load_reason_prefill,
load_decision_reason_decode=self._diag_load_reason_decode,
throughput_decision_reason_prefill=self._diag_throughput_reason_prefill,
throughput_decision_reason_decode=self._diag_throughput_reason_decode,
) )
# ------------------------------------------------------------------ # ------------------------------------------------------------------
......
...@@ -29,6 +29,8 @@ class ThroughputScalingMixin: ...@@ -29,6 +29,8 @@ class ThroughputScalingMixin:
_diag_engine_rps_prefill: Optional[float] _diag_engine_rps_prefill: Optional[float]
_diag_engine_rps_decode: Optional[float] _diag_engine_rps_decode: Optional[float]
_diag_throughput_reason: Optional[str] _diag_throughput_reason: Optional[str]
_diag_throughput_reason_prefill: Optional[str]
_diag_throughput_reason_decode: Optional[str]
def _advance_throughput( def _advance_throughput(
self, traffic: TrafficObservation self, traffic: TrafficObservation
...@@ -104,9 +106,24 @@ class ThroughputScalingMixin: ...@@ -104,9 +106,24 @@ class ThroughputScalingMixin:
) -> Optional[ScalingDecision]: ) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl) num_p = self._compute_prefill_replicas(demand_rps, isl, osl)
num_d = self._compute_decode_replicas(demand_rps, isl, osl) num_d = self._compute_decode_replicas(demand_rps, isl, osl)
# _compute_* sets _diag_throughput_reason = "model_not_ready" when
# the regression isn't fit yet. If one side is not ready, the other
# side's computation was still valid but its decision is blocked,
# so we label it "partner_not_ready" to keep per-component
# diagnostics consistent with the aggregate reason.
if num_p is None or num_d is None: if num_p is None or num_d is None:
self._diag_throughput_reason_prefill = (
"model_not_ready" if num_p is None else "partner_not_ready"
)
self._diag_throughput_reason_decode = (
"model_not_ready" if num_d is None else "partner_not_ready"
)
return None return None
reason = "set_lower_bound" if self._config.enable_load_scaling else "scale"
self._diag_throughput_reason_prefill = reason
self._diag_throughput_reason_decode = reason
if self._config.enable_load_scaling: if self._config.enable_load_scaling:
self._throughput_lower_bound_p = num_p self._throughput_lower_bound_p = num_p
self._throughput_lower_bound_d = num_d self._throughput_lower_bound_d = num_d
...@@ -140,6 +157,8 @@ class ThroughputScalingMixin: ...@@ -140,6 +157,8 @@ class ThroughputScalingMixin:
max_num_batched_tokens=max_tokens, max_num_batched_tokens=max_tokens,
ttft_sla=self._config.ttft, ttft_sla=self._config.ttft,
itl_sla=self._config.itl, itl_sla=self._config.itl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None,
) )
if engine_rps <= 0: if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling") logger.warning("Agg perf model not ready, skipping throughput scaling")
...@@ -171,8 +190,11 @@ class ThroughputScalingMixin: ...@@ -171,8 +190,11 @@ class ThroughputScalingMixin:
def _compute_prefill_replicas( def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float self, demand_rps: float, isl: float, osl: float
) -> Optional[int]: ) -> Optional[int]:
p_caps = self._capabilities.prefill
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps( engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft, isl=isl ttft_sla=self._config.ttft,
isl=isl,
max_num_batched_tokens=p_caps.max_num_batched_tokens if p_caps else None,
) )
if engine_rps <= 0: if engine_rps <= 0:
logger.warning("Prefill perf model not ready, skipping throughput scaling") logger.warning("Prefill perf model not ready, skipping throughput scaling")
...@@ -194,10 +216,13 @@ class ThroughputScalingMixin: ...@@ -194,10 +216,13 @@ class ThroughputScalingMixin:
def _compute_decode_replicas( def _compute_decode_replicas(
self, demand_rps: float, isl: float, osl: float self, demand_rps: float, isl: float, osl: float
) -> Optional[int]: ) -> Optional[int]:
d_caps = self._capabilities.decode
engine_rps, itl_ms = self._decode_regression.find_best_engine_decode_rps( engine_rps, itl_ms = self._decode_regression.find_best_engine_decode_rps(
itl=self._config.itl, itl=self._config.itl,
context_length=isl + osl / 2, context_length=isl + osl / 2,
osl=osl, osl=osl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None,
) )
if engine_rps <= 0: if engine_rps <= 0:
logger.warning("Decode perf model not ready, skipping throughput scaling") logger.warning("Decode perf model not ready, skipping throughput scaling")
......
...@@ -112,9 +112,20 @@ class TickDiagnostics: ...@@ -112,9 +112,20 @@ class TickDiagnostics:
engine_rps_prefill: Optional[float] = None engine_rps_prefill: Optional[float] = None
engine_rps_decode: Optional[float] = None engine_rps_decode: Optional[float] = None
# Throughput-scaling: lower bound on replicas
throughput_lower_bound_prefill: Optional[int] = None
throughput_lower_bound_decode: Optional[int] = None
# Scaling decision reasons (set by the mixin that ran) # Scaling decision reasons (set by the mixin that ran)
# Aggregate reasons (agg mode, or combined disagg).
load_decision_reason: Optional[str] = None load_decision_reason: Optional[str] = None
throughput_decision_reason: Optional[str] = None throughput_decision_reason: Optional[str] = None
# Per-component reasons (populated in disagg mode for separate
# prefill / decode decision timelines).
load_decision_reason_prefill: Optional[str] = None
load_decision_reason_decode: Optional[str] = None
throughput_decision_reason_prefill: Optional[str] = None
throughput_decision_reason_decode: Optional[str] = None
@dataclass @dataclass
......
...@@ -66,11 +66,19 @@ class TickSnapshot: ...@@ -66,11 +66,19 @@ class TickSnapshot:
engine_rps_decode: Optional[float] = None engine_rps_decode: Optional[float] = None
load_decision_reason: Optional[str] = None load_decision_reason: Optional[str] = None
throughput_decision_reason: Optional[str] = None throughput_decision_reason: Optional[str] = None
load_decision_reason_prefill: Optional[str] = None
load_decision_reason_decode: Optional[str] = None
throughput_decision_reason_prefill: Optional[str] = None
throughput_decision_reason_decode: Optional[str] = None
# Per-engine FPM queue depths # Per-engine FPM queue depths
prefill_engines: list[PerEngineFpm] = field(default_factory=list) prefill_engines: list[PerEngineFpm] = field(default_factory=list)
decode_engines: list[PerEngineFpm] = field(default_factory=list) decode_engines: list[PerEngineFpm] = field(default_factory=list)
# Throughput lower bound
throughput_lower_bound_prefill: Optional[int] = None
throughput_lower_bound_decode: Optional[int] = None
# Scaling decision # Scaling decision
scale_to_prefill: Optional[int] = None scale_to_prefill: Optional[int] = None
scale_to_decode: Optional[int] = None scale_to_decode: Optional[int] = None
...@@ -88,6 +96,7 @@ class DiagnosticsRecorder: ...@@ -88,6 +96,7 @@ class DiagnosticsRecorder:
""" """
config: PlannerConfig config: PlannerConfig
max_kv_tokens: Optional[int] = None
_snapshots: list[TickSnapshot] = field(default_factory=list) _snapshots: list[TickSnapshot] = field(default_factory=list)
_last_report_s: float = 0.0 _last_report_s: float = 0.0
_report_count: int = 0 _report_count: int = 0
...@@ -177,6 +186,12 @@ class DiagnosticsRecorder: ...@@ -177,6 +186,12 @@ class DiagnosticsRecorder:
engine_rps_decode=diag.engine_rps_decode, engine_rps_decode=diag.engine_rps_decode,
load_decision_reason=diag.load_decision_reason, load_decision_reason=diag.load_decision_reason,
throughput_decision_reason=diag.throughput_decision_reason, throughput_decision_reason=diag.throughput_decision_reason,
load_decision_reason_prefill=diag.load_decision_reason_prefill,
load_decision_reason_decode=diag.load_decision_reason_decode,
throughput_decision_reason_prefill=diag.throughput_decision_reason_prefill,
throughput_decision_reason_decode=diag.throughput_decision_reason_decode,
throughput_lower_bound_prefill=diag.throughput_lower_bound_prefill,
throughput_lower_bound_decode=diag.throughput_lower_bound_decode,
prefill_engines=prefill_engines, prefill_engines=prefill_engines,
decode_engines=decode_engines, decode_engines=decode_engines,
scale_to_prefill=( scale_to_prefill=(
...@@ -201,6 +216,11 @@ class DiagnosticsRecorder: ...@@ -201,6 +216,11 @@ class DiagnosticsRecorder:
This method has no side effects (no file I/O, no snapshot clearing). This method has no side effects (no file I/O, no snapshot clearing).
""" """
# TODO: link x-axes across all subplots (e.g. ``fig.update_xaxes(
# matches="x")`` or shared_xaxes=True in make_subplots) so zooming
# into a time range on one chart also zooms the others. Currently
# a user has to zoom each subplot independently to narrow down on a
# specific time window.
ts = [s.timestamp_s for s in snaps] ts = [s.timestamp_s for s in snaps]
labels = [ labels = [
datetime.fromtimestamp(t, tz=timezone.utc).strftime("%H:%M:%S") for t in ts datetime.fromtimestamp(t, tz=timezone.utc).strftime("%H:%M:%S") for t in ts
...@@ -253,6 +273,32 @@ class DiagnosticsRecorder: ...@@ -253,6 +273,32 @@ class DiagnosticsRecorder:
row=1, row=1,
col=1, col=1,
) )
tp_lower_p = _vals("throughput_lower_bound_prefill")
if any(v is not None and v > 1 for v in tp_lower_p):
fig.add_trace(
go.Scatter(
x=labels,
y=tp_lower_p,
name="Prefill TP Lower Bound",
mode="lines",
line=dict(dash="dash", color="darkblue"),
),
row=1,
col=1,
)
tp_lower_d = _vals("throughput_lower_bound_decode")
if any(v is not None and v > 1 for v in tp_lower_d):
fig.add_trace(
go.Scatter(
x=labels,
y=tp_lower_d,
name="Decode TP Lower Bound",
mode="lines",
line=dict(dash="dash", color="red"),
),
row=1,
col=1,
)
# 1b. Request rate # 1b. Request rate
fig.add_trace( fig.add_trace(
...@@ -272,6 +318,7 @@ class DiagnosticsRecorder: ...@@ -272,6 +318,7 @@ class DiagnosticsRecorder:
name="Predicted RPS", name="Predicted RPS",
mode="lines", mode="lines",
line=dict(dash="dot"), line=dict(dash="dot"),
connectgaps=True,
), ),
row=1, row=1,
col=2, col=2,
...@@ -396,39 +443,35 @@ class DiagnosticsRecorder: ...@@ -396,39 +443,35 @@ class DiagnosticsRecorder:
col=1, col=1,
) )
# 4b. Decode engine load (queued + inflight, one line each per engine) # 4b. Decode engine load (queued + inflight combined per engine)
for eid in sorted(decode_engine_ids): for eid in sorted(decode_engine_ids):
y_queued = [] y_total = []
y_inflight = []
for s in snaps: for s in snaps:
q, f_ = None, None val = None
for e in s.decode_engines: for e in s.decode_engines:
if f"{e.worker_id}:dp{e.dp_rank}" == eid: if f"{e.worker_id}:dp{e.dp_rank}" == eid:
q = e.queued_decode_kv_tokens val = e.queued_decode_kv_tokens + e.inflight_decode_kv_tokens
f_ = e.inflight_decode_kv_tokens
break break
y_queued.append(q) y_total.append(val)
y_inflight.append(f_)
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=labels, x=labels,
y=y_queued, y=y_total,
name=f"D {eid} queued", name=f"D {eid} total KV",
mode="lines+markers", mode="lines+markers",
showlegend=False, showlegend=False,
), ),
row=4, row=4,
col=2, col=2,
) )
fig.add_trace(
go.Scatter( # KV capacity line (set by adapter if available)
x=labels, if self.max_kv_tokens is not None and self.max_kv_tokens > 0:
y=y_inflight, fig.add_hline(
name=f"D {eid} inflight", y=self.max_kv_tokens,
mode="lines", line_dash="dash",
line=dict(dash="dot"), line_color="red",
showlegend=False, annotation_text=f"KV Capacity ({self.max_kv_tokens:,})",
),
row=4, row=4,
col=2, col=2,
) )
...@@ -442,6 +485,7 @@ class DiagnosticsRecorder: ...@@ -442,6 +485,7 @@ class DiagnosticsRecorder:
y=_vals("engine_rps_prefill"), y=_vals("engine_rps_prefill"),
name="Prefill Engine RPS", name="Prefill Engine RPS",
mode="lines+markers", mode="lines+markers",
connectgaps=True,
), ),
row=5, row=5,
col=1, col=1,
...@@ -452,6 +496,7 @@ class DiagnosticsRecorder: ...@@ -452,6 +496,7 @@ class DiagnosticsRecorder:
y=_vals("engine_rps_decode"), y=_vals("engine_rps_decode"),
name="Decode Engine RPS", name="Decode Engine RPS",
mode="lines+markers", mode="lines+markers",
connectgaps=True,
), ),
row=5, row=5,
col=1, col=1,
...@@ -475,6 +520,7 @@ class DiagnosticsRecorder: ...@@ -475,6 +520,7 @@ class DiagnosticsRecorder:
name="Predicted ISL", name="Predicted ISL",
mode="lines", mode="lines",
line=dict(dash="dot"), line=dict(dash="dot"),
connectgaps=True,
), ),
row=5, row=5,
col=2, col=2,
...@@ -496,14 +542,25 @@ class DiagnosticsRecorder: ...@@ -496,14 +542,25 @@ class DiagnosticsRecorder:
name="Predicted OSL", name="Predicted OSL",
mode="lines", mode="lines",
line=dict(dash="dot"), line=dict(dash="dot"),
connectgaps=True,
), ),
row=5, row=5,
col=2, col=2,
) )
# -- Row 6: Decision timelines ----------------------------------- # -- Row 6: Decision timelines -----------------------------------
#
# Layout adapts to scaling mode:
# - Disagg (per-component reasons populated): two tracks per
# subplot, prefill on y=2, decode on y=1, with "prefill" /
# "decode" y-axis labels.
# - Agg / easy mode (only aggregate reason populated): single
# track at y=1 with "Load Decision" / "Throughput Decision"
# labels.
# We detect the mode by whether any snapshot has a per-component
# reason set; switching mode mid-run would produce a mixed chart,
# but that doesn't happen because mode is fixed at planner init.
load_reasons = _vals("load_decision_reason")
_LOAD_COLORS = { _LOAD_COLORS = {
"scale_up": "green", "scale_up": "green",
"scale_down": "blue", "scale_down": "blue",
...@@ -515,25 +572,6 @@ class DiagnosticsRecorder: ...@@ -515,25 +572,6 @@ class DiagnosticsRecorder:
"worker_count_mismatch": "red", "worker_count_mismatch": "red",
"insufficient_data": "pink", "insufficient_data": "pink",
} }
fig.add_trace(
go.Scatter(
x=labels,
y=[1] * len(labels),
mode="markers",
marker=dict(
color=[_LOAD_COLORS.get(r or "", "gray") for r in load_reasons],
size=10,
symbol="square",
),
text=load_reasons,
name="Load Decision",
hoverinfo="text+x",
),
row=6,
col=1,
)
tp_reasons = _vals("throughput_decision_reason")
_TP_COLORS = { _TP_COLORS = {
"scale": "green", "scale": "green",
"set_lower_bound": "blue", "set_lower_bound": "blue",
...@@ -541,32 +579,148 @@ class DiagnosticsRecorder: ...@@ -541,32 +579,148 @@ class DiagnosticsRecorder:
"no_traffic_data": "yellow", "no_traffic_data": "yellow",
"predict_failed": "red", "predict_failed": "red",
"model_not_ready": "orange", "model_not_ready": "orange",
"partner_not_ready": "pink",
} }
# Detect disagg mode: if any per-component reason is populated,
# plot two horizontal tracks (prefill at y=2, decode at y=1);
# otherwise plot a single aggregate track at y=1.
has_per_component_load = any(
s.load_decision_reason_prefill is not None
or s.load_decision_reason_decode is not None
for s in snaps
)
has_per_component_tp = any(
s.throughput_decision_reason_prefill is not None
or s.throughput_decision_reason_decode is not None
for s in snaps
)
def _add_decision_track(
field_name: str,
y_value: int,
label: str,
colors: dict,
symbol: str,
row: int,
col: int,
) -> None:
reasons = _vals(field_name)
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=labels, x=labels,
y=[1] * len(labels), y=[y_value] * len(labels),
mode="markers", mode="markers",
marker=dict( marker=dict(
color=[_TP_COLORS.get(r or "", "gray") for r in tp_reasons], color=[colors.get(r or "", "gray") for r in reasons],
size=10, size=10,
symbol="diamond", symbol=symbol,
), ),
text=tp_reasons, text=reasons,
name="Throughput Decision", name=label,
hoverinfo="text+x", hoverinfo="text+x",
showlegend=False,
), ),
row=row,
col=col,
)
if has_per_component_load:
_add_decision_track(
"load_decision_reason_prefill",
2,
"Load (prefill)",
_LOAD_COLORS,
"square",
6,
1,
)
_add_decision_track(
"load_decision_reason_decode",
1,
"Load (decode)",
_LOAD_COLORS,
"square",
6,
1,
)
fig.update_yaxes(
tickmode="array",
tickvals=[1, 2],
ticktext=["decode", "prefill"],
range=[0.5, 2.5],
row=6,
col=1,
)
else:
_add_decision_track(
"load_decision_reason",
1,
"Load Decision",
_LOAD_COLORS,
"square",
6,
1,
)
if has_per_component_tp:
_add_decision_track(
"throughput_decision_reason_prefill",
2,
"TP (prefill)",
_TP_COLORS,
"diamond",
6,
2,
)
_add_decision_track(
"throughput_decision_reason_decode",
1,
"TP (decode)",
_TP_COLORS,
"diamond",
6,
2,
)
fig.update_yaxes(
tickmode="array",
tickvals=[1, 2],
ticktext=["decode", "prefill"],
range=[0.5, 2.5],
row=6, row=6,
col=2, col=2,
) )
else:
_add_decision_track(
"throughput_decision_reason",
1,
"Throughput Decision",
_TP_COLORS,
"diamond",
6,
2,
)
# -- Layout ------------------------------------------------------- # -- Layout -------------------------------------------------------
num_scaling_events = sum( # Count actual replica transitions, not just ticks where a decision
1 # was recorded: two consecutive ticks with scale_to=5 aren't two
for s in snaps # scaling events.
if s.scale_to_prefill is not None or s.scale_to_decode is not None num_scaling_events = 0
) prev_p: Optional[int] = None
prev_d: Optional[int] = None
for s in snaps:
cur_p = s.num_prefill_replicas
cur_d = s.num_decode_replicas
if prev_p is not None and cur_p is not None and cur_p != prev_p:
num_scaling_events += 1
if prev_d is not None and cur_d is not None and cur_d != prev_d:
num_scaling_events += 1
if cur_p is not None:
prev_p = cur_p
if cur_d is not None:
prev_d = cur_d
t0 = datetime.fromtimestamp(ts[0], tz=timezone.utc).strftime( t0 = datetime.fromtimestamp(ts[0], tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S UTC" "%Y-%m-%d %H:%M:%S UTC"
) )
...@@ -576,7 +730,7 @@ class DiagnosticsRecorder: ...@@ -576,7 +730,7 @@ class DiagnosticsRecorder:
summary = ( summary = (
f"<b>Planner Diagnostics Report</b><br>" f"<b>Planner Diagnostics Report</b><br>"
f"Time range: {t0}{t1} ({len(snaps)} ticks)<br>" f"Time range: {t0}{t1} ({len(snaps)} ticks)<br>"
f"Scaling events: {num_scaling_events} | " f"Replica transitions: {num_scaling_events} | "
f"GPU hours: {snaps[-1].gpu_hours:.2f}<br>" f"GPU hours: {snaps[-1].gpu_hours:.2f}<br>"
f"SLA targets: TTFT={self.config.ttft:.0f}ms, ITL={self.config.itl:.0f}ms" f"SLA targets: TTFT={self.config.ttft:.0f}ms, ITL={self.config.itl:.0f}ms"
) )
......
...@@ -137,7 +137,14 @@ class ReplayPlannerAdapter: ...@@ -137,7 +137,14 @@ class ReplayPlannerAdapter:
self._scaling_target_decode: Optional[int] = None self._scaling_target_decode: Optional[int] = None
# Diagnostics recorder for HTML report generation # Diagnostics recorder for HTML report generation
self._recorder = DiagnosticsRecorder(config=planner_config) decode_max_kv = (
capabilities.decode.max_kv_tokens
if capabilities and capabilities.decode
else None
)
self._recorder = DiagnosticsRecorder(
config=planner_config, max_kv_tokens=decode_max_kv
)
self._cumulative_gpu_hours: float = 0.0 self._cumulative_gpu_hours: float = 0.0
self._last_tick_s: float = 0.0 self._last_tick_s: float = 0.0
self._last_traffic: Metrics = Metrics() self._last_traffic: Metrics = Metrics()
...@@ -218,15 +225,6 @@ class ReplayPlannerAdapter: ...@@ -218,15 +225,6 @@ class ReplayPlannerAdapter:
self._cumulative_gpu_hours += (num_p * gpu_p + num_d * gpu_d) * dt_h self._cumulative_gpu_hours += (num_p * gpu_p + num_d * gpu_d) * dt_h
self._last_tick_s = now_s self._last_tick_s = now_s
# Build observed Metrics from traffic in tick_input
if tick_input.traffic is not None:
t = tick_input.traffic
self._last_traffic = Metrics(
num_req=t.num_req,
isl=t.isl,
osl=t.osl,
)
self._recorder.record( self._recorder.record(
tick_input, tick_input,
effects, effects,
...@@ -422,12 +420,24 @@ class ReplayPlannerAdapter: ...@@ -422,12 +420,24 @@ class ReplayPlannerAdapter:
t = self._bridge.drain_traffic() t = self._bridge.drain_traffic()
duration_s = t.get("duration_s", 0.0) duration_s = t.get("duration_s", 0.0)
if duration_s > 0: if duration_s > 0:
num_req = float(t.get("num_req", 0))
traffic = TrafficObservation( traffic = TrafficObservation(
duration_s=duration_s, duration_s=duration_s,
num_req=float(t.get("num_req", 0)), num_req=num_req,
isl=t.get("avg_isl", 0.0), isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0), osl=t.get("avg_osl", 0.0),
) )
# Stash observed TTFT/ITL for the diagnostics recorder.
# When num_req == 0, the Rust accumulator returns 0 as a
# placeholder; only record latency values when we actually
# observed requests in this window.
self._last_traffic = Metrics(
ttft=t.get("avg_ttft_ms") if num_req > 0 else None,
itl=t.get("avg_itl_ms") if num_req > 0 else None,
num_req=traffic.num_req,
isl=traffic.isl,
osl=traffic.osl,
)
return TickInput( return TickInput(
now_s=now_s, now_s=now_s,
......
...@@ -1345,7 +1345,18 @@ impl PlannerReplayBridge { ...@@ -1345,7 +1345,18 @@ impl PlannerReplayBridge {
/// Drain accumulated traffic metrics since the last drain. /// Drain accumulated traffic metrics since the last drain.
/// ///
/// Returns a dict with `duration_s`, `num_req`, `avg_isl`, `avg_osl`. /// Returns a dict with:
/// - `duration_s` (f64): window length in seconds
/// - `num_req` (usize): completed requests in the window
/// - `avg_isl` (f64): mean input sequence length (tokens)
/// - `avg_osl` (f64): mean output sequence length (tokens)
/// - `avg_ttft_ms` (f64): mean time-to-first-token in milliseconds,
/// averaged only over requests that reported
/// a TTFT sample (0.0 when no samples)
/// - `avg_itl_ms` (f64): mean inter-token latency in milliseconds,
/// averaged only over requests that generated
/// at least one token gap (0.0 when no samples)
///
/// Call this only on throughput-scaling ticks so the observation window /// Call this only on throughput-scaling ticks so the observation window
/// covers the full `throughput_adjustment_interval`. /// covers the full `throughput_adjustment_interval`.
fn drain_traffic(&mut self, py: Python<'_>) -> PyResult<PyObject> { fn drain_traffic(&mut self, py: Python<'_>) -> PyResult<PyObject> {
...@@ -1354,13 +1365,15 @@ impl PlannerReplayBridge { ...@@ -1354,13 +1365,15 @@ impl PlannerReplayBridge {
.as_mut() .as_mut()
.ok_or_else(|| PyException::new_err("bridge has been finalized"))?; .ok_or_else(|| PyException::new_err("bridge has been finalized"))?;
let (duration_s, num_req, avg_isl, avg_osl) = handle.drain_traffic(); let stats = handle.drain_traffic();
let result = json!({ let result = json!({
"duration_s": duration_s, "duration_s": stats.duration_s,
"num_req": num_req, "num_req": stats.num_req,
"avg_isl": avg_isl, "avg_isl": stats.avg_isl,
"avg_osl": avg_osl, "avg_osl": stats.avg_osl,
"avg_ttft_ms": stats.avg_ttft_ms,
"avg_itl_ms": stats.avg_itl_ms,
}); });
pythonize(py, &result) pythonize(py, &result)
......
...@@ -18,6 +18,10 @@ if TYPE_CHECKING: ...@@ -18,6 +18,10 @@ if TYPE_CHECKING:
os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1") os.environ.setdefault("DYNAMO_SKIP_PYTHON_LOG_INIT", "1")
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
ScheduledRequestMetrics,
)
from dynamo.llm import AicPerfConfig, KvRouterConfig, MockEngineArgs from dynamo.llm import AicPerfConfig, KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
from dynamo.replay.reporting import format_report_table, write_report_json from dynamo.replay.reporting import format_report_table, write_report_json
...@@ -123,6 +127,76 @@ def _engine_caps(args: MockEngineArgs) -> EngineCapabilities: ...@@ -123,6 +127,76 @@ def _engine_caps(args: MockEngineArgs) -> EngineCapabilities:
) )
def _generate_aic_prefill_fpms(
aic_session,
engine_args: MockEngineArgs,
granularity: int = 8,
) -> list[ForwardPassMetrics]:
"""Generate prefill benchmark FPMs using AIC predictions.
Sweeps ISL at batch_size=1 within the engine's per-pass token budget
(``max_num_batched_tokens``); a single forward pass physically cannot
process more than that, so the regression shouldn't see larger sums.
For longer ISL, callers use chunked TTFT estimation.
"""
prefill_max = engine_args.max_num_batched_tokens or 8192
prefill_step = max(1, (prefill_max - 100) // granularity)
prefill_fpms: list[ForwardPassMetrics] = []
for isl in range(100, prefill_max + 1, prefill_step):
ttft_ms = aic_session.predict_prefill(1, isl, 0)
if ttft_ms > 0:
prefill_fpms.append(
ForwardPassMetrics(
wall_time=ttft_ms / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_prefill_requests=1,
sum_prefill_tokens=isl,
),
)
)
return prefill_fpms
def _generate_aic_decode_fpms(
aic_session,
engine_args: MockEngineArgs,
granularity: int = 8,
) -> list[ForwardPassMetrics]:
"""Generate decode benchmark FPMs using AIC predictions.
Sweeps (batch_size x context_length). ``granularity`` controls the
number of sample points per axis; the batch-size ceiling comes from
the engine's ``max_num_seqs`` so the regression sees realistic
concurrency, not an artificial cap at the sweep density.
"""
max_kv_tokens = engine_args.num_gpu_blocks * engine_args.block_size
if max_kv_tokens <= 0:
max_kv_tokens = 16384 * 16
decode_fpms: list[ForwardPassMetrics] = []
ctx_lengths = [500, 2000, 4000, 8000]
bs_max = engine_args.max_num_seqs or 256
bs_step = max(1, bs_max // granularity)
for ctx_len in ctx_lengths:
for bs in range(1, bs_max + 1, bs_step):
sum_kv = bs * ctx_len
if sum_kv > max_kv_tokens:
break
itl_ms = aic_session.predict_decode(bs, ctx_len, 2)
if itl_ms > 0:
decode_fpms.append(
ForwardPassMetrics(
wall_time=itl_ms / 1000.0,
scheduled_requests=ScheduledRequestMetrics(
num_decode_requests=bs,
sum_decode_kv_tokens=sum_kv,
),
)
)
return decode_fpms
def _run_planner_replay( def _run_planner_replay(
trace_file: str, trace_file: str,
extra_engine_args: MockEngineArgs | None, extra_engine_args: MockEngineArgs | None,
...@@ -136,6 +210,7 @@ def _run_planner_replay( ...@@ -136,6 +210,7 @@ def _run_planner_replay(
arrival_speedup_ratio: float, arrival_speedup_ratio: float,
trace_block_size: int, trace_block_size: int,
planner_config_arg: str, planner_config_arg: str,
benchmark_granularity: int = 8,
): ):
"""Run an offline replay with planner-in-the-loop (agg or disagg). """Run an offline replay with planner-in-the-loop (agg or disagg).
...@@ -200,6 +275,107 @@ def _run_planner_replay( ...@@ -200,6 +275,107 @@ def _run_planner_replay(
bridge=bridge, bridge=bridge,
capabilities=capabilities, capabilities=capabilities,
) )
# Bootstrap regression models from mocker's perf model.
# AIC provides accurate batch-size-aware timing that works with the
# planner's linear regression. The default polynomial model cannot
# feed the throughput regression (its decode formula is quadratic in
# utilization ratio, causing negative regression coefficients).
if not adapter._sm._is_easy:
ref_args = extra_engine_args or prefill_engine_args or MockEngineArgs()
aic_backend = ref_args.aic_backend
if (
aic_backend is None
or ref_args.aic_system is None
or ref_args.aic_model_path is None
):
sys.stderr.write(
"Note: throughput-based scaling regression requires AIC perf model "
"(set aic_backend/aic_system/aic_model_path in --extra-engine-args). "
"Falling back to load-based scaling only.\n"
)
else:
# Create AIC session -- narrow to the concrete exception types
# AIC/PyO3 can raise so we degrade gracefully on missing
# dependencies or bad config, but don't swallow unrelated bugs
# (AttributeError, KeyboardInterrupt, etc.) introduced by
# refactors.
try:
from dynamo._internal.aic import create_session
aic_session = create_session(
backend_name=aic_backend,
system=ref_args.aic_system,
model_path=ref_args.aic_model_path,
tp_size=ref_args.aic_tp_size or 1,
backend_version=ref_args.aic_backend_version,
)
except (
ImportError,
RuntimeError,
ValueError,
KeyError,
FileNotFoundError,
) as e:
sys.stderr.write(
f"Warning: AIC session creation failed ({e}); "
"throughput regression will not be bootstrapped.\n"
)
aic_session = None
# Generate benchmark FPMs and load into regression. Disagg
# prefill and decode engines typically have different
# max_num_seqs and KV cache sizes, so each sweep uses its own
# engine args. Agg has a single engine so uses one set. AIC's
# predict_* can raise on unsupported model/system combos or
# numerical edge cases; log and fall back in those cases.
if aic_session is not None:
p_args = (
extra_engine_args
if planner_config.mode == "agg"
else prefill_engine_args
) or ref_args
d_args = (
extra_engine_args
if planner_config.mode == "agg"
else decode_engine_args
) or ref_args
try:
prefill_fpms = _generate_aic_prefill_fpms(
aic_session, p_args, benchmark_granularity
)
decode_fpms = _generate_aic_decode_fpms(
aic_session, d_args, benchmark_granularity
)
except (RuntimeError, ValueError, KeyError, ArithmeticError) as e:
sys.stderr.write(
f"Warning: AIC benchmark generation failed ({e}); "
"throughput regression will not be bootstrapped.\n"
)
prefill_fpms, decode_fpms = [], []
if planner_config.mode == "agg":
# Agg regression fits on (sum_prefill_tokens, sum_decode_kv_tokens);
# combine prefill-only and decode-only points so both features
# have variance.
agg_fpms = prefill_fpms + decode_fpms
if agg_fpms:
adapter._sm.load_benchmark_fpms(agg_fpms=agg_fpms)
else:
sys.stderr.write(
"Warning: AIC produced no agg benchmark FPMs\n"
)
else:
if prefill_fpms and decode_fpms:
adapter._sm.load_benchmark_fpms(
prefill_fpms=prefill_fpms, decode_fpms=decode_fpms
)
else:
sys.stderr.write(
f"Warning: AIC produced empty benchmark FPMs "
f"(prefill={len(prefill_fpms)}, decode={len(decode_fpms)})\n"
)
return adapter.run() return adapter.run()
...@@ -256,6 +432,12 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -256,6 +432,12 @@ def main(argv: Sequence[str] | None = None) -> int:
"--planner-config", "--planner-config",
help="path to planner config YAML/JSON or inline JSON; enables planner-in-the-loop replay (offline agg only)", help="path to planner config YAML/JSON or inline JSON; enables planner-in-the-loop replay (offline agg only)",
) )
parser.add_argument(
"--benchmark-granularity",
type=int,
default=8,
help="number of sweep points for synthetic perf model benchmark (default: 8, matching profiler)",
)
args = parser.parse_args(list(sys.argv[1:] if argv is None else argv)) args = parser.parse_args(list(sys.argv[1:] if argv is None else argv))
using_trace_file = args.trace_file is not None using_trace_file = args.trace_file is not None
...@@ -311,6 +493,7 @@ def main(argv: Sequence[str] | None = None) -> int: ...@@ -311,6 +493,7 @@ def main(argv: Sequence[str] | None = None) -> int:
arrival_speedup_ratio=args.arrival_speedup_ratio, arrival_speedup_ratio=args.arrival_speedup_ratio,
trace_block_size=args.trace_block_size, trace_block_size=args.trace_block_size,
planner_config_arg=args.planner_config, planner_config_arg=args.planner_config,
benchmark_granularity=args.benchmark_granularity,
) )
report = planner_report.trace_report report = planner_report.trace_report
if planner_report.scaling_events: if planner_report.scaling_events:
......
...@@ -256,6 +256,15 @@ impl TraceCollector { ...@@ -256,6 +256,15 @@ impl TraceCollector {
} }
} }
/// Return (ttft_ms, mean_itl_ms) for a completed request, if available.
pub(crate) fn request_latencies(&self, uuid: Uuid) -> Option<(f64, f64)> {
let stats = self.requests.get(&uuid)?;
let first_token_ms = stats.first_token_ms()?;
let ttft_ms = (first_token_ms - stats.arrival_time_ms).max(0.0);
let mean_itl_ms = stats.mean_tpot_ms().unwrap_or(0.0);
Some((ttft_ms, mean_itl_ms))
}
pub(crate) fn finish(self) -> TraceSimulationReport { pub(crate) fn finish(self) -> TraceSimulationReport {
let requests = self.requests; let requests = self.requests;
let request_count = requests.len(); let request_count = requests.len();
......
...@@ -17,7 +17,7 @@ use super::state::OfflineWorkerSnapshot; ...@@ -17,7 +17,7 @@ use super::state::OfflineWorkerSnapshot;
use super::{ use super::{
components::{ components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter, AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission, ScheduledWorkerCompletion, TrafficAccumulator, TrafficStats, WorkerAdmission,
}, },
state::AggRequestState, state::AggRequestState,
}; };
...@@ -381,8 +381,12 @@ impl AggRuntime { ...@@ -381,8 +381,12 @@ impl AggRuntime {
let removed_state = self.requests.remove(&signal.uuid).ok_or_else(|| { let removed_state = self.requests.remove(&signal.uuid).ok_or_else(|| {
anyhow::anyhow!("offline replay missing request state for {}", signal.uuid) anyhow::anyhow!("offline replay missing request state for {}", signal.uuid)
})?; })?;
self.traffic let latencies = self.collector.request_latencies(signal.uuid);
.on_request(removed_state.input_tokens, removed_state.output_tokens); self.traffic.on_request(
removed_state.input_tokens,
removed_state.output_tokens,
latencies,
);
self.admission self.admission
.on_request_completed(signal.uuid, self.now_ms)?; .on_request_completed(signal.uuid, self.now_ms)?;
self.progress.inc_completed(); self.progress.inc_completed();
...@@ -603,7 +607,7 @@ impl AggRuntime { ...@@ -603,7 +607,7 @@ impl AggRuntime {
} }
/// Drain accumulated traffic stats since the last drain. /// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) { pub(in crate::replay) fn drain_traffic(&mut self) -> TrafficStats {
self.traffic.drain(self.now_ms) self.traffic.drain(self.now_ms)
} }
......
...@@ -12,6 +12,7 @@ pub(crate) use router::OfflineReplayRouter; ...@@ -12,6 +12,7 @@ pub(crate) use router::OfflineReplayRouter;
#[cfg(test)] #[cfg(test)]
pub(crate) use router::OfflineRouterSnapshot; pub(crate) use router::OfflineRouterSnapshot;
pub(in crate::replay) use types::ReplayMode; pub(in crate::replay) use types::ReplayMode;
pub use types::TrafficStats;
pub(in crate::replay::offline) use types::{ pub(in crate::replay::offline) use types::{
EngineEffects, EnginePassMode, ReadyArrival, ScheduledWorkerCompletion, TrafficAccumulator, EngineEffects, EnginePassMode, ReadyArrival, ScheduledWorkerCompletion, TrafficAccumulator,
}; };
......
...@@ -65,14 +65,41 @@ pub(in crate::replay::offline) struct ReadyArrival { ...@@ -65,14 +65,41 @@ pub(in crate::replay::offline) struct ReadyArrival {
pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>, pub(in crate::replay::offline) replay_hashes: Option<ReplayRequestHashes>,
} }
/// Accumulated traffic statistics returned by [`TrafficAccumulator::drain`].
///
/// IMPORTANT: When fields here are added or renamed, update the PyO3
/// binding in ``lib/bindings/python/rust/llm/replay.rs`` (drain_traffic
/// method) so the exported JSON dict matches. The Python adapter in
/// ``replay_adapter.py`` reads these keys by name.
#[derive(Debug, Clone)]
pub struct TrafficStats {
pub duration_s: f64,
pub num_req: usize,
pub avg_isl: f64,
pub avg_osl: f64,
pub avg_ttft_ms: f64,
pub avg_itl_ms: f64,
}
/// Accumulates traffic statistics between planner ticks for deriving /// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window). /// `TrafficObservation` (num_req, avg ISL, avg OSL over a window).
///
/// Latency samples are tracked independently of request counts: a request
/// only contributes to ``total_ttft_ms`` / ``ttft_count`` if a positive TTFT
/// was recorded, and similarly for ITL. This means ``avg_ttft_ms`` and
/// ``avg_itl_ms`` reflect only requests that actually produced the sample,
/// rather than silently underestimating when some requests lack latency
/// data (e.g. requests that fail before emitting a token).
#[derive(Debug)] #[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator { pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64, window_start_ms: f64,
num_req: usize, num_req: usize,
total_isl: usize, total_isl: usize,
total_osl: usize, total_osl: usize,
total_ttft_ms: f64,
total_itl_ms: f64,
ttft_count: usize,
itl_count: usize,
} }
impl TrafficAccumulator { impl TrafficAccumulator {
...@@ -82,23 +109,37 @@ impl TrafficAccumulator { ...@@ -82,23 +109,37 @@ impl TrafficAccumulator {
num_req: 0, num_req: 0,
total_isl: 0, total_isl: 0,
total_osl: 0, total_osl: 0,
total_ttft_ms: 0.0,
total_itl_ms: 0.0,
ttft_count: 0,
itl_count: 0,
} }
} }
/// Record one admitted request. /// Record one completed request with optional latency data.
pub(in crate::replay::offline) fn on_request( pub(in crate::replay::offline) fn on_request(
&mut self, &mut self,
input_tokens: usize, input_tokens: usize,
output_tokens: usize, output_tokens: usize,
latencies: Option<(f64, f64)>,
) { ) {
self.num_req += 1; self.num_req += 1;
self.total_isl += input_tokens; self.total_isl += input_tokens;
self.total_osl += output_tokens; self.total_osl += output_tokens;
if let Some((ttft_ms, mean_itl_ms)) = latencies {
if ttft_ms > 0.0 {
self.total_ttft_ms += ttft_ms;
self.ttft_count += 1;
}
if mean_itl_ms > 0.0 {
self.total_itl_ms += mean_itl_ms;
self.itl_count += 1;
}
}
} }
/// Drain the accumulator at the given simulated time, returning /// Drain the accumulator at the given simulated time, resetting counters.
/// (duration_s, num_req, avg_isl, avg_osl) and resetting counters. pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> TrafficStats {
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> (f64, usize, f64, f64) {
let duration_s = (now_ms - self.window_start_ms) / 1000.0; let duration_s = (now_ms - self.window_start_ms) / 1000.0;
let num_req = self.num_req; let num_req = self.num_req;
let avg_isl = if num_req > 0 { let avg_isl = if num_req > 0 {
...@@ -111,10 +152,31 @@ impl TrafficAccumulator { ...@@ -111,10 +152,31 @@ impl TrafficAccumulator {
} else { } else {
0.0 0.0
}; };
let avg_ttft_ms = if self.ttft_count > 0 {
self.total_ttft_ms / self.ttft_count as f64
} else {
0.0
};
let avg_itl_ms = if self.itl_count > 0 {
self.total_itl_ms / self.itl_count as f64
} else {
0.0
};
self.window_start_ms = now_ms; self.window_start_ms = now_ms;
self.num_req = 0; self.num_req = 0;
self.total_isl = 0; self.total_isl = 0;
self.total_osl = 0; self.total_osl = 0;
(duration_s, num_req, avg_isl, avg_osl) self.total_ttft_ms = 0.0;
self.total_itl_ms = 0.0;
self.ttft_count = 0;
self.itl_count = 0;
TrafficStats {
duration_s,
num_req,
avg_isl,
avg_osl,
avg_ttft_ms,
avg_itl_ms,
}
} }
} }
...@@ -11,7 +11,7 @@ use uuid::Uuid; ...@@ -11,7 +11,7 @@ use uuid::Uuid;
pub(super) use super::components::ReplayMode; pub(super) use super::components::ReplayMode;
use super::components::{ use super::components::{
AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter, AdmissionQueue, EngineComponent, EngineEffects, EnginePassMode, OfflineReplayRouter,
ScheduledWorkerCompletion, TrafficAccumulator, WorkerAdmission, ScheduledWorkerCompletion, TrafficAccumulator, TrafficStats, WorkerAdmission,
}; };
use super::events::{SimulationEvent, SimulationWorkerStage}; use super::events::{SimulationEvent, SimulationWorkerStage};
use super::progress::ReplayProgress; use super::progress::ReplayProgress;
...@@ -519,7 +519,9 @@ impl DisaggRuntime { ...@@ -519,7 +519,9 @@ impl DisaggRuntime {
let original = state.original_request()?; let original = state.original_request()?;
let input_tokens = original.tokens.len(); let input_tokens = original.tokens.len();
let output_tokens = original.max_output_tokens; let output_tokens = original.max_output_tokens;
self.traffic.on_request(input_tokens, output_tokens); let latencies = self.collector.request_latencies(signal.uuid);
self.traffic
.on_request(input_tokens, output_tokens, latencies);
self.state_mut(signal.uuid)?.mark_done(); self.state_mut(signal.uuid)?.mark_done();
#[cfg(test)] #[cfg(test)]
{ {
...@@ -831,7 +833,7 @@ impl DisaggRuntime { ...@@ -831,7 +833,7 @@ impl DisaggRuntime {
} }
/// Drain accumulated traffic stats since the last drain. /// Drain accumulated traffic stats since the last drain.
pub(in crate::replay) fn drain_traffic(&mut self) -> (f64, usize, f64, f64) { pub(in crate::replay) fn drain_traffic(&mut self) -> TrafficStats {
self.traffic.drain(self.now_ms) self.traffic.drain(self.now_ms)
} }
......
...@@ -15,7 +15,7 @@ use anyhow::Result; ...@@ -15,7 +15,7 @@ use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig; use dynamo_kv_router::config::KvRouterConfig;
use super::offline::agg::AggRuntime; use super::offline::agg::AggRuntime;
use super::offline::components::ReplayMode; use super::offline::components::{ReplayMode, TrafficStats};
use super::offline::disagg::DisaggRuntime; use super::offline::disagg::DisaggRuntime;
use super::{ use super::{
OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport, OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport,
...@@ -162,10 +162,9 @@ impl PlannerReplayHandle { ...@@ -162,10 +162,9 @@ impl PlannerReplayHandle {
/// Drain accumulated traffic metrics since the last drain. /// Drain accumulated traffic metrics since the last drain.
/// ///
/// Returns `(duration_s, num_req, avg_isl, avg_osl)`. Call this only on /// Call this only on throughput-scaling ticks so the window covers the full
/// throughput-scaling ticks so the window covers the full
/// `throughput_adjustment_interval`, not just the gap between load ticks. /// `throughput_adjustment_interval`, not just the gap between load ticks.
pub fn drain_traffic(&mut self) -> (f64, usize, f64, f64) { pub fn drain_traffic(&mut self) -> TrafficStats {
match &mut self.runtime { match &mut self.runtime {
RuntimeKind::Agg(rt) => rt.drain_traffic(), RuntimeKind::Agg(rt) => rt.drain_traffic(),
RuntimeKind::Disagg(rt) => rt.drain_traffic(), RuntimeKind::Disagg(rt) => rt.drain_traffic(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment