Unverified Commit c388483a authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat(planner/replay): KV reuse awareness in load + throughput scaling (#8314)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent db124db0
......@@ -414,7 +414,7 @@ class NativePlannerBase:
return num_p, num_d, True
async def _collect_traffic(self) -> Optional[TrafficObservation]:
"""Pull traffic metrics from Prometheus."""
"""Pull traffic metrics from Prometheus over the throughput interval."""
num_p, num_d, _ = await self._get_worker_counts_raw()
if self.prometheus_port != 0:
......@@ -458,9 +458,14 @@ class NativePlannerBase:
m.osl = self.prometheus_traffic_client.get_avg_output_sequence_tokens(
interval_str, self.model_name
)
m.kv_hit_rate = self.prometheus_traffic_client.get_avg_kv_hit_rate(
interval_str, self.model_name
)
hit_rate_str = f"{m.kv_hit_rate:.3f}" if m.kv_hit_rate is not None else "n/a"
logger.info(
f"Observed num_req: {m.num_req:.2f} isl: {m.isl:.2f} osl: {m.osl:.2f}"
f"Observed num_req: {m.num_req:.2f} isl: {m.isl:.2f} osl: {m.osl:.2f} "
f"kv_hit_rate: {hit_rate_str}"
)
if self.prometheus_port != 0:
......@@ -483,6 +488,42 @@ class NativePlannerBase:
num_req=m.num_req,
isl=m.isl,
osl=m.osl,
kv_hit_rate=m.kv_hit_rate,
)
async def _collect_kv_hit_rate_observation(
self, duration_s: float
) -> Optional[TrafficObservation]:
"""Pull only the KV hit rate from Prometheus over ``duration_s``.
Used in load-only deployments: the load tick only needs the hit rate
to discount prefill work, so we skip the five other (unused) traffic
queries to keep the per-load-tick scrape cheap.
Returns ``None`` when the router metric is unavailable (e.g.
Prometheus source is "frontend"); the state machine treats that as
a no-discount fallback.
"""
assert self.model_name is not None
if duration_s <= 0:
return None
interval_str = f"{int(duration_s)}s"
hit_rate = self.prometheus_traffic_client.get_avg_kv_hit_rate(
interval_str, self.model_name
)
# Mirror the observed value into Metrics so the diagnostics recorder
# sees the up-to-date hit rate even on load-only ticks.
self._last_metrics.kv_hit_rate = hit_rate
hit_rate_str = f"{hit_rate:.3f}" if hit_rate is not None else "n/a"
logger.info(f"Observed kv_hit_rate over {interval_str}: {hit_rate_str}")
if hit_rate is None:
return None
return TrafficObservation(
duration_s=duration_s,
num_req=0.0,
isl=0.0,
osl=0.0,
kv_hit_rate=hit_rate,
)
def _collect_fpm(self) -> FpmObservations:
......@@ -560,7 +601,17 @@ class NativePlannerBase:
fpm_obs = None
if tick.need_traffic_metrics:
# Throughput ticks pull the full traffic snapshot over the
# throughput interval. Load-only deployments instead piggyback
# a cheap kv-hit-rate-only scrape (over the load interval) on
# each load tick so the planner can still discount prefill work
# by recent prefix reuse.
if tick.run_throughput_scaling:
traffic = await self._collect_traffic()
else:
traffic = await self._collect_kv_hit_rate_observation(
tick.traffic_metrics_duration_s
)
if tick.need_worker_states:
worker_counts = await self._collect_worker_counts()
if tick.need_worker_fpm:
......
......@@ -368,11 +368,13 @@ class LoadScalingMixin:
self._diag_load_reason = "insufficient_data"
return None
kv_hit_rate = self._last_kv_hit_rate
estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self._prefill_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
kv_hit_rate=kv_hit_rate,
)
if est is not None:
est_ms = est * 1000
......@@ -380,7 +382,8 @@ class LoadScalingMixin:
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued={fpm.queued_requests.sum_prefill_tokens}, "
f"avg_isl={self._prefill_regression.avg_isl:.1f})"
f"avg_isl={self._prefill_regression.avg_isl:.1f}, "
f"kv_hit_rate={kv_hit_rate if kv_hit_rate is not None else 'n/a'})"
)
if estimates:
......@@ -432,12 +435,14 @@ class LoadScalingMixin:
num_workers: int,
max_tokens: int,
) -> Optional[int]:
kv_hit_rate = self._last_kv_hit_rate
estimates: list[float] = []
for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
kv_hit_rate=kv_hit_rate,
)
if est is not None:
estimates.append(est * 1000)
......
......@@ -13,7 +13,11 @@ from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage
from dynamo.planner.core.perf_model.base import (
_BaseRegressionModel,
_clamp_kv_hit_rate,
_MovingAverage,
)
logger = logging.getLogger(__name__)
......@@ -77,15 +81,22 @@ class AggRegressionModel(_BaseRegressionModel):
queued_prefill_tokens: int,
max_num_batched_tokens: int,
current_decode_kv: int,
kv_hit_rate: Optional[float] = None,
) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode.
``kv_hit_rate`` (0.0-1.0) discounts the aggregate work ahead --
both the queue backlog and the hypothetical next request's ISL --
because a new arrival will benefit from the same prefix-cache hit
rate as the current workload. See ``PrefillRegressionModel``.
Returns estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
total_tokens = (queued_prefill_tokens + self._avg_isl.value) * scale
if total_tokens <= 0:
return 0.0
......@@ -121,6 +132,7 @@ class AggRegressionModel(_BaseRegressionModel):
itl_sla: float,
max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None,
kv_hit_rate: Optional[float] = None,
) -> tuple[float, float, float]:
"""Find the maximum agg engine request rate under both SLA targets.
......@@ -131,6 +143,11 @@ class AggRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
``kv_hit_rate`` discounts only the prefill portion of each
iteration; decode KV residency uses the full ISL because cache
hits reduce prefill compute but do not shrink the KV footprint
used during decode.
The upper bound for the batch-size sweep is the smallest of:
1. KV cache capacity: ``max_kv_tokens / (isl + osl/2)``
2. ``max_num_seqs`` (engine concurrency limit)
......@@ -162,6 +179,9 @@ class AggRegressionModel(_BaseRegressionModel):
):
return (0.0, 0.0, 0.0)
prefill_scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
effective_isl = isl * prefill_scale
avg_ctx = isl + osl / 2.0
# KV cache cap
......@@ -174,14 +194,17 @@ class AggRegressionModel(_BaseRegressionModel):
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
# Prefill/decode balance cap via binary search within [1, min(kv_cap, seq_cap)]
# For each candidate x, check: isl / (max_num_batched_tokens - x) <= osl
# For each candidate x, check: effective_isl / (max_num_batched_tokens - x) <= osl
# Uses ``effective_isl`` (post-cache) because cache reuse shrinks the
# prefill tokens each new request consumes from the per-iteration
# budget, raising the admissible batch size.
hard_cap = min(kv_cap, seq_cap, max_num_batched_tokens - 1)
def _prefill_balanced(x: int) -> bool:
prefill_budget = max_num_batched_tokens - x
if prefill_budget <= 0:
return False
return isl / prefill_budget <= osl
return effective_isl / prefill_budget <= osl
lo, hi = 1, max(1, hard_cap)
while lo < hi:
......@@ -198,14 +221,27 @@ class AggRegressionModel(_BaseRegressionModel):
for bs in range(1, max_bs + 1):
decode_kv = bs * avg_ctx
prefill_per_iter = min(bs * isl / max(1.0, osl), max_num_batched_tokens)
# Discounted prefill per iter feeds the wall-time regression: the
# engine actually computes ``effective_isl`` tokens per request
# because the cached prefix is skipped.
prefill_per_iter = min(
bs * effective_isl / max(1.0, osl), max_num_batched_tokens
)
wt = self._predict_2d(prefill_per_iter, decode_kv)
itl_ms = wt * 1000.0
# ``estimate_next_ttft`` applies the same discount internally to
# both the queued portion and the avg_isl portion. To keep the
# discount uniform, we pass the *raw* prefill_per_iter as the
# queued contribution and forward ``kv_hit_rate`` so the
# function's own ``(1 - clamp(kv_hit_rate))`` factor scales
# both sides consistently.
raw_prefill_per_iter = min(bs * isl / max(1.0, osl), max_num_batched_tokens)
est_ttft = self.estimate_next_ttft(
queued_prefill_tokens=int(prefill_per_iter),
queued_prefill_tokens=int(raw_prefill_per_iter),
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=int(decode_kv),
kv_hit_rate=kv_hit_rate,
)
ttft_ms = est_ttft * 1000.0 if est_ttft is not None else 0.0
......
......@@ -11,7 +11,7 @@ decode, and agg perf model subclasses.
import logging
import math
from collections import defaultdict, deque
from typing import Union
from typing import Optional, Union
import numpy as np
from sklearn.linear_model import LinearRegression
......@@ -20,6 +20,22 @@ from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
# Upper bound on the applied KV hit rate discount. A full 1.0 reading would
# zero out queued/avg prefill tokens and could mask a genuine backlog; cap
# at 0.95 so the planner always sees *some* work ahead.
_MAX_KV_HIT_RATE_DISCOUNT = 0.95
def _clamp_kv_hit_rate(kv_hit_rate: Optional[float]) -> float:
"""Clamp a raw hit rate into the usable discount range.
Returns 0.0 for ``None`` / NaN (no discount, preserves pre-change
behavior), otherwise clamps into ``[0.0, _MAX_KV_HIT_RATE_DISCOUNT]``.
"""
if kv_hit_rate is None or math.isnan(kv_hit_rate):
return 0.0
return max(0.0, min(_MAX_KV_HIT_RATE_DISCOUNT, float(kv_hit_rate)))
class _MovingAverage:
"""Fixed-window moving average that skips leading zeros.
......
......@@ -13,7 +13,11 @@ from typing import Optional
import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage
from dynamo.planner.core.perf_model.base import (
_BaseRegressionModel,
_clamp_kv_hit_rate,
_MovingAverage,
)
logger = logging.getLogger(__name__)
......@@ -58,15 +62,23 @@ class PrefillRegressionModel(_BaseRegressionModel):
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
kv_hit_rate: Optional[float] = None,
) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request.
``kv_hit_rate`` (0.0-1.0) discounts the aggregate work ahead --
both the queue backlog and the hypothetical next request's ISL --
because a new arrival will benefit from the same prefix-cache hit
rate as the current workload. The regression features themselves
(per-iter chunk sizes) remain unchanged, so no double-counting.
Returns estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
total_tokens = (queued_prefill_tokens + self._avg_isl.value) * scale
if total_tokens <= 0:
return 0.0
......
......@@ -94,6 +94,9 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._num_req_predictor = predictor_cls(config)
self._isl_predictor = predictor_cls(config)
self._osl_predictor = predictor_cls(config)
# KV hit rate has no good offline-trace proxy, so it is NOT warmed
# via ``warm_load_predictors``; it learns only from live observations.
self._kv_hit_rate_predictor = predictor_cls(config)
self._num_p_workers: int = 0
self._num_d_workers: int = 0
......@@ -103,6 +106,12 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._throughput_lower_bound_p: int = 1
self._throughput_lower_bound_d: int = 1
# Most recent observed KV hit rate from the router. Used by load-scaling
# to discount queued/avg prefill tokens in ``estimate_next_ttft``. Sticky
# across ticks because load-scaling and throughput-scaling cadences
# may differ. ``None`` means "no observation yet" -> no discount.
self._last_kv_hit_rate: Optional[float] = None
self._next_load_s: float = float("inf")
self._next_throughput_s: float = float("inf")
......@@ -112,6 +121,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_predicted_num_req: Optional[float] = None
self._diag_predicted_isl: Optional[float] = None
self._diag_predicted_osl: Optional[float] = None
self._diag_predicted_kv_hit_rate: Optional[float] = None
self._diag_engine_rps_prefill: Optional[float] = None
self._diag_engine_rps_decode: Optional[float] = None
self._diag_load_reason: Optional[str] = None
......@@ -193,6 +203,11 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
)
if tick.run_load_scaling:
# In load-only deployments the kv-hit-rate scrape rides on the
# load tick, so consume the traffic observation here. In mixed
# mode the throughput branch above already handled it.
if not tick.run_throughput_scaling and tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
if tick_input.fpm_observations is not None:
if not self._is_easy:
self._observe_fpm(tick_input.fpm_observations)
......@@ -216,6 +231,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_predicted_num_req = None
self._diag_predicted_isl = None
self._diag_predicted_osl = None
self._diag_predicted_kv_hit_rate = None
self._diag_engine_rps_prefill = None
self._diag_engine_rps_decode = None
self._diag_load_reason = None
......@@ -232,6 +248,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
predicted_num_req=self._diag_predicted_num_req,
predicted_isl=self._diag_predicted_isl,
predicted_osl=self._diag_predicted_osl,
predicted_kv_hit_rate=self._diag_predicted_kv_hit_rate,
engine_rps_prefill=self._diag_engine_rps_prefill,
engine_rps_decode=self._diag_engine_rps_decode,
throughput_lower_bound_prefill=self._throughput_lower_bound_p,
......@@ -255,16 +272,27 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
at_s = min(self._next_load_s, self._next_throughput_s)
is_load = self._next_load_s <= at_s + self._MERGE_TOLERANCE_S
is_throughput = self._next_throughput_s <= at_s + self._MERGE_TOLERANCE_S
# Throughput ticks scrape full traffic over the throughput interval.
# In load-only deployments (no throughput tick ever fires) load ticks
# carry a kv-hit-rate-only scrape over the load interval so the
# planner can still discount prefill work by recent prefix reuse.
if is_throughput:
need_traffic = True
traffic_duration_s = float(self._config.throughput_adjustment_interval)
elif is_load and not self._config.enable_throughput_scaling:
need_traffic = True
traffic_duration_s = float(self._config.load_adjustment_interval)
else:
need_traffic = False
traffic_duration_s = 0.0
return ScheduledTick(
at_s=at_s,
run_load_scaling=is_load,
run_throughput_scaling=is_throughput,
need_worker_states=True,
need_worker_fpm=is_load,
need_traffic_metrics=is_throughput,
traffic_metrics_duration_s=(
self._config.throughput_adjustment_interval if is_throughput else 0.0
),
need_traffic_metrics=need_traffic,
traffic_metrics_duration_s=traffic_duration_s,
)
# ------------------------------------------------------------------
......@@ -312,9 +340,25 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
logger.info(f"FPM load stats: {len(obs.decode)} decode engines observed")
def _observe_traffic(self, traffic: TrafficObservation) -> None:
# Throughput-scaling predictors only have a downstream consumer when
# throughput scaling is enabled. In load-only mode the traffic scrape
# is a kv-hit-rate-only path and num_req/isl/osl arrive as zero
# placeholders, so feeding the predictors would just pollute them.
if self._config.enable_throughput_scaling:
self._num_req_predictor.add_data_point(traffic.num_req)
self._isl_predictor.add_data_point(traffic.isl)
self._osl_predictor.add_data_point(traffic.osl)
if traffic.kv_hit_rate is not None and not math.isnan(traffic.kv_hit_rate):
if self._config.enable_throughput_scaling:
# Mixed mode: feed the predictor; ``_last_kv_hit_rate`` will be
# overwritten with the predicted value inside
# ``_advance_throughput`` so load scaling consumes the smoothed
# forecast (not the raw per-window observation).
self._kv_hit_rate_predictor.add_data_point(traffic.kv_hit_rate)
else:
# Load-only mode: there is no predictor path, the load tick
# consumes the freshly observed average directly.
self._last_kv_hit_rate = traffic.kv_hit_rate
# ------------------------------------------------------------------
# Budget
......
......@@ -14,6 +14,7 @@ import logging
import math
from typing import Optional
from dynamo.planner.core.perf_model.base import _clamp_kv_hit_rate
from dynamo.planner.core.types import ScalingDecision, TrafficObservation
logger = logging.getLogger(__name__)
......@@ -26,11 +27,14 @@ class ThroughputScalingMixin:
_diag_predicted_num_req: Optional[float]
_diag_predicted_isl: Optional[float]
_diag_predicted_osl: Optional[float]
_diag_predicted_kv_hit_rate: Optional[float]
_diag_engine_rps_prefill: Optional[float]
_diag_engine_rps_decode: Optional[float]
_diag_throughput_reason: Optional[str]
_diag_throughput_reason_prefill: Optional[str]
_diag_throughput_reason_decode: Optional[str]
# Sticky value consumed by the load-scaling path between throughput ticks.
_last_kv_hit_rate: Optional[float]
def _advance_throughput(
self, traffic: TrafficObservation
......@@ -48,13 +52,27 @@ class ThroughputScalingMixin:
self._diag_throughput_reason = "no_traffic_data"
return None
demand_rps = next_num_req / traffic.duration_s
predicted_hit_rate = self._predict_kv_hit_rate()
# Promote the predicted value to the sticky field so subsequent
# load-scaling ticks (between throughput ticks) discount prefill work
# using the smoothed forecast rather than the raw last-window
# observation.
if predicted_hit_rate is not None and not math.isnan(predicted_hit_rate):
self._last_kv_hit_rate = predicted_hit_rate
mode = self._config.mode
if mode == "agg":
return self._throughput_agg(demand_rps, next_isl, next_osl)
return self._throughput_agg(
demand_rps, next_isl, next_osl, predicted_hit_rate
)
if mode == "disagg":
return self._throughput_disagg(demand_rps, next_isl, next_osl)
return self._throughput_single(demand_rps, next_isl, next_osl, mode)
return self._throughput_disagg(
demand_rps, next_isl, next_osl, predicted_hit_rate
)
return self._throughput_single(
demand_rps, next_isl, next_osl, mode, predicted_hit_rate
)
def _predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try:
......@@ -73,11 +91,33 @@ class ThroughputScalingMixin:
self._diag_throughput_reason = "predict_failed"
return None, None, None
def _predict_kv_hit_rate(self) -> Optional[float]:
"""Predict next-interval KV hit rate.
Returns ``None`` if the predictor isn't ready (cold start: no live
observations yet, no trace-based warmup) -- the caller treats that
as a 0.0 discount, preserving pre-change throughput-scaling behavior.
"""
try:
predicted = self._kv_hit_rate_predictor.predict_next()
except Exception as e:
logger.warning(f"Failed to predict kv_hit_rate: {e}")
self._diag_predicted_kv_hit_rate = None
return None
self._diag_predicted_kv_hit_rate = predicted
logger.info(f"Predicted kv_hit_rate={predicted:.3f}")
return predicted
def _throughput_single(
self, demand_rps: float, isl: float, osl: float, component: str
self,
demand_rps: float,
isl: float,
osl: float,
component: str,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]:
desired = (
self._compute_prefill_replicas(demand_rps, isl, osl)
self._compute_prefill_replicas(demand_rps, isl, osl, kv_hit_rate)
if component == "prefill"
else self._compute_decode_replicas(demand_rps, isl, osl)
)
......@@ -102,9 +142,13 @@ class ThroughputScalingMixin:
)
def _throughput_disagg(
self, demand_rps: float, isl: float, osl: float
self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl)
num_p = self._compute_prefill_replicas(demand_rps, isl, osl, kv_hit_rate)
num_d = self._compute_decode_replicas(demand_rps, isl, osl)
# _compute_* sets _diag_throughput_reason = "model_not_ready" when
# the regression isn't fit yet. If one side is not ready, the other
......@@ -136,7 +180,11 @@ class ThroughputScalingMixin:
return ScalingDecision(num_prefill=num_p, num_decode=num_d)
def _throughput_agg(
self, demand_rps: float, isl: float, osl: float
self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]:
d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None
......@@ -159,6 +207,7 @@ class ThroughputScalingMixin:
itl_sla=self._config.itl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None,
kv_hit_rate=kv_hit_rate,
)
if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling")
......@@ -188,12 +237,20 @@ class ThroughputScalingMixin:
return ScalingDecision(num_decode=desired)
def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float
self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[int]:
# Prefix cache reuse shrinks the *compute* work of prefill but not
# decode KV residency, so we discount only the ISL fed into the
# prefill regression.
effective_isl = isl * (1.0 - _clamp_kv_hit_rate(kv_hit_rate))
p_caps = self._capabilities.prefill
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft,
isl=isl,
isl=effective_isl,
max_num_batched_tokens=p_caps.max_num_batched_tokens if p_caps else None,
)
if engine_rps <= 0:
......@@ -209,7 +266,9 @@ class ThroughputScalingMixin:
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info(
f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_ttft={ttft_ms:.1f}ms"
f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, "
f"est_ttft={ttft_ms:.1f}ms, isl_raw={isl:.1f}, "
f"isl_effective={effective_isl:.1f}"
)
return result
......
......@@ -48,6 +48,7 @@ class TrafficObservation:
num_req: float
isl: float
osl: float
kv_hit_rate: Optional[float] = None
@dataclass
......@@ -107,6 +108,7 @@ class TickDiagnostics:
predicted_num_req: Optional[float] = None
predicted_isl: Optional[float] = None
predicted_osl: Optional[float] = None
predicted_kv_hit_rate: Optional[float] = None
# Throughput-scaling: single-engine capacity under SLA (req/s)
engine_rps_prefill: Optional[float] = None
......
......@@ -55,6 +55,7 @@ class TickSnapshot:
observed_request_duration_seconds: Optional[float] = None
observed_input_sequence_tokens: Optional[float] = None
observed_output_sequence_tokens: Optional[float] = None
observed_kv_hit_rate: Optional[float] = None
# Diagnostics from state machine
estimated_ttft_ms: Optional[float] = None
......@@ -62,6 +63,7 @@ class TickSnapshot:
predicted_requests_per_second: Optional[float] = None
predicted_input_sequence_tokens: Optional[float] = None
predicted_output_sequence_tokens: Optional[float] = None
predicted_kv_hit_rate: Optional[float] = None
engine_rps_prefill: Optional[float] = None
engine_rps_decode: Optional[float] = None
load_decision_reason: Optional[str] = None
......@@ -173,6 +175,7 @@ class DiagnosticsRecorder:
observed_request_duration_seconds=observed.request_duration,
observed_input_sequence_tokens=observed.isl,
observed_output_sequence_tokens=observed.osl,
observed_kv_hit_rate=observed.kv_hit_rate,
estimated_ttft_ms=diag.estimated_ttft_ms,
estimated_itl_ms=diag.estimated_itl_ms,
predicted_requests_per_second=(
......@@ -182,6 +185,7 @@ class DiagnosticsRecorder:
),
predicted_input_sequence_tokens=diag.predicted_isl,
predicted_output_sequence_tokens=diag.predicted_osl,
predicted_kv_hit_rate=diag.predicted_kv_hit_rate,
engine_rps_prefill=diag.engine_rps_prefill,
engine_rps_decode=diag.engine_rps_decode,
load_decision_reason=diag.load_decision_reason,
......
......@@ -42,6 +42,7 @@ class Metrics:
request_duration: Optional[float] = None
p_load: Optional[float] = None
d_load: Optional[float] = None
kv_hit_rate: Optional[float] = None
def is_valid(self) -> bool:
"""Check if all required metrics are valid (not None and not NaN)."""
......@@ -301,6 +302,47 @@ class PrometheusAPIClient:
model_name,
)
def get_avg_kv_hit_rate(self, interval: str, model_name: str) -> Optional[float]:
"""Average predicted KV cache hit rate (0.0-1.0) from the router.
Only available when metrics_source == "router" (the histogram lives on
the LocalRouter component). In disagg deployments the scrape is
namespace-filtered, so if the planner's ``dynamo_namespace`` matches
the prefill pool, the returned value pools only prefill-router
observations.
Returns ``None`` (not ``0.0``) on missing data — Prometheus scrape
gaps must not be confused with a real "no reuse" signal: the state
machine treats a real ``0.0`` as a valid observation and would
otherwise drag the predictor / sticky value down toward zero on
every scrape failure. The caller's ``_clamp_kv_hit_rate(None)``
falls back to no-discount behavior, which is the safe choice.
"""
if self.metrics_source != "router":
return None
full_metric_name = (
f"{prometheus_names.name_prefix.COMPONENT}_"
f"{prometheus_names.router.KV_HIT_RATE}"
)
try:
ns = self.dynamo_namespace.replace("-", "_")
ns_filter = f'{prometheus_names.labels.NAMESPACE}="{ns}"'
query = (
f"sum(increase({full_metric_name}_sum{{{ns_filter}}}[{interval}])) / "
f"sum(increase({full_metric_name}_count{{{ns_filter}}}[{interval}]))"
)
result = self.prom.custom_query(query=query)
if not result:
logger.info(
f"No prometheus data for {full_metric_name}, returning None"
)
return None
value = float(result[0]["value"][1])
return None if math.isnan(value) else value
except Exception as e:
logger.warning(f"Error getting avg kv hit rate: {e}")
return None
def warn_if_router_not_scraped(self) -> None:
"""Warn if Prometheus is not scraping any dynamo_component_router_* series.
......
......@@ -421,11 +421,16 @@ class ReplayPlannerAdapter:
duration_s = t.get("duration_s", 0.0)
if duration_s > 0:
num_req = float(t.get("num_req", 0))
# The mocker publishes avg_kv_hit_rate as 0.0 when the
# window had no admissions with non-zero ISL blocks;
# pass it through as-is so the state machine can decide
# whether to feed its predictor.
traffic = TrafficObservation(
duration_s=duration_s,
num_req=num_req,
isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0),
kv_hit_rate=t.get("avg_kv_hit_rate"),
)
# Stash observed TTFT/ITL for the diagnostics recorder.
# When num_req == 0, the Rust accumulator returns 0 as a
......@@ -437,6 +442,7 @@ class ReplayPlannerAdapter:
num_req=traffic.num_req,
isl=traffic.isl,
osl=traffic.osl,
kv_hit_rate=traffic.kv_hit_rate,
)
return TickInput(
......
......@@ -185,6 +185,109 @@ class TestPrefillRegressionModel:
)
assert est is not None
def test_kv_hit_rate_none_equals_zero(self):
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
for tokens in [500, 1000, 1500, 2000, 2500]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens + 0.002,
)
model.add_observation(fpm)
none_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
kv_hit_rate=None,
)
zero_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
kv_hit_rate=0.0,
)
assert none_est == zero_est
def test_kv_hit_rate_discounts_queued_and_avg_isl(self):
"""A hit rate of 0.5 should halve the simulated work, roughly halving TTFT."""
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
# Fit on several points so the regression is stable and ~linear in tokens.
for tokens in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
model.add_observation(fpm)
max_batched = 100_000 # single-iteration regime, no chunking rounding
est_full = model.estimate_next_ttft(
queued_prefill_tokens=4000,
max_num_batched_tokens=max_batched,
kv_hit_rate=0.0,
)
est_half = model.estimate_next_ttft(
queued_prefill_tokens=4000,
max_num_batched_tokens=max_batched,
kv_hit_rate=0.5,
)
assert est_full is not None and est_half is not None
# With a ~linear regression and no chunking rounding, 0.5 discount
# should produce roughly half the TTFT (within 20% tolerance for
# linearly-fitted intercept noise).
assert est_half < est_full
assert est_half / est_full < 0.75
def test_kv_hit_rate_clamped(self):
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
for tokens in [500, 1000, 1500, 2000, 2500]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
model.add_observation(fpm)
# kv_hit_rate > 1.0 should clamp to 0.95 (not 1.0) so queued/avg don't
# fully zero out.
est_above = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=1.5,
)
est_cap = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=0.95,
)
assert est_above == est_cap
# Negative values clamp to 0.0 (no discount).
est_negative = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=-0.3,
)
est_zero = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=0.0,
)
assert est_negative == est_zero
# NaN falls back to 0.0.
est_nan = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=float("nan"),
)
assert est_nan == est_zero
# ── Bucketed retirement tests ─────────────────────────────────────────
......@@ -447,3 +550,103 @@ class TestAggRegressionModel:
itl_sla=50.0,
)
assert thpt == 0.0
def test_agg_kv_hit_rate_none_equals_zero(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
none_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
current_decode_kv=1000,
kv_hit_rate=None,
)
zero_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
current_decode_kv=1000,
kv_hit_rate=0.0,
)
assert none_est == zero_est
def test_agg_kv_hit_rate_discounts_prefill(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
est_full = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=100_000,
current_decode_kv=1000,
kv_hit_rate=0.0,
)
est_half = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=100_000,
current_decode_kv=1000,
kv_hit_rate=0.5,
)
assert est_full is not None and est_half is not None
assert est_half < est_full
def test_agg_find_best_engine_rps_hit_rate_increases_throughput(self):
"""find_best_engine_agg_rps should discount only prefill work,
leaving decode KV at full context; higher hit rate should yield
greater-or-equal engine rps."""
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
rps_base, _, _ = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
max_num_batched_tokens=4096,
ttft_sla=500.0,
itl_sla=50.0,
kv_hit_rate=0.0,
)
rps_hit, _, _ = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
max_num_batched_tokens=4096,
ttft_sla=500.0,
itl_sla=50.0,
kv_hit_rate=0.6,
)
assert rps_hit >= rps_base
def test_agg_find_best_engine_rps_uniform_discount_in_ttft_estimate(self):
"""``find_best_engine_agg_rps`` must apply the kv_hit_rate discount
uniformly to BOTH the per-iter prefill and the avg_isl portion of
the TTFT simulation. Regression for the bug where the function
passed already-discounted prefill_per_iter to estimate_next_ttft
without forwarding kv_hit_rate, leaving avg_isl at full size and
inflating the predicted TTFT (= over-provisioning replicas)."""
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
# With a permissive ITL/TTFT SLA, the only difference in engine_rps
# at high hit rate vs zero hit rate should come from the prefill
# discount. If the bug recurs the high-hit-rate path will under-
# estimate capacity (smaller batch sweep) and produce strictly less
# rps growth than the discount factor warrants.
rps_zero, _, _ = model.find_best_engine_agg_rps(
isl=4000.0,
osl=200.0,
max_num_batched_tokens=8192,
ttft_sla=10_000.0,
itl_sla=10_000.0,
kv_hit_rate=0.0,
)
rps_high, _, _ = model.find_best_engine_agg_rps(
isl=4000.0,
osl=200.0,
max_num_batched_tokens=8192,
ttft_sla=10_000.0,
itl_sla=10_000.0,
kv_hit_rate=0.8,
)
# Strictly greater capacity at 80% hit rate (not just >=).
assert rps_high > rps_zero
......@@ -325,6 +325,29 @@ class TestPrometheusAPIClientRouterSource:
expected_metric = f"{prometheus_names.name_prefix.COMPONENT}_{prometheus_names.router.OUTPUT_SEQUENCE_TOKENS}"
assert expected_metric in call_args
def test_get_avg_kv_hit_rate_dispatches_to_router_histogram(self, router_client):
"""get_avg_kv_hit_rate with router source queries dynamo_component_router_kv_hit_rate."""
# Return a plausible 0.0-1.0 ratio rather than the default 42.0 fixture.
router_client.prom.custom_query.return_value = [{"value": [0, "0.35"]}]
result = router_client.get_avg_kv_hit_rate("60s", "mymodel")
assert result == 0.35
call_args = str(router_client.prom.custom_query.call_args)
expected_metric = f"{prometheus_names.name_prefix.COMPONENT}_{prometheus_names.router.KV_HIT_RATE}"
assert expected_metric in call_args
def test_get_avg_kv_hit_rate_returns_none_for_frontend_source(self):
"""Frontend source doesn't publish an aggregate kv_hit_rate, so the
client should short-circuit to None rather than issue a PromQL query."""
client = PrometheusAPIClient(
"http://localhost:9090", "test-fe-namespace", metrics_source="frontend"
)
client.prom = MagicMock()
client.prom.custom_query.return_value = [{"value": [0, "42.0"]}]
result = client.get_avg_kv_hit_rate("60s", "mymodel")
assert result is None
client.prom.custom_query.assert_not_called()
def test_get_avg_request_count_uses_router_requests_total(self, router_client):
"""get_avg_request_count with router source queries dynamo_component_router_requests_total."""
result = router_client.get_avg_request_count("60s", "mymodel")
......
......@@ -176,7 +176,10 @@ class TestInitialTick:
tick = core.initial_tick(start_s=0.0)
assert tick.at_s == 5.0
assert tick.need_worker_fpm
assert not tick.need_traffic_metrics
# Load-only mode rides a kv-hit-rate scrape on the load tick so the
# planner can discount prefill work by recent prefix reuse.
assert tick.need_traffic_metrics
assert tick.traffic_metrics_duration_s == 5.0
def test_throughput_only(self):
core = _make_core(enable_load_scaling=False)
......@@ -426,6 +429,221 @@ class TestThroughputScaling:
assert effects.next_tick.at_s == 120.0
class TestKvHitRatePlumbing:
def test_load_only_observe_traffic_updates_last_kv_hit_rate(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(
duration_s=5, num_req=100, isl=1000, osl=150, kv_hit_rate=0.3
)
)
assert core._last_kv_hit_rate == 0.3
def test_load_only_skips_throughput_predictor_feeds(self):
"""In load-only mode the throughput predictors have no consumer; we
must not pollute their buffers with placeholder zeros."""
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.4)
)
assert core._num_req_predictor.data_buffer == []
assert core._isl_predictor.data_buffer == []
assert core._osl_predictor.data_buffer == []
# kv predictor also untouched in load-only mode (no prediction needed)
assert core._kv_hit_rate_predictor.data_buffer == []
def test_load_only_none_kv_hit_rate_leaves_last_value_unchanged(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.42)
)
# Subsequent observation without a hit rate (scrape failure / frontend
# source) must not clobber the sticky value -- the planner keeps
# using the most recent valid reading.
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=None)
)
assert core._last_kv_hit_rate == 0.42
def test_load_only_nan_kv_hit_rate_is_ignored(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.5)
)
core._observe_traffic(
TrafficObservation(
duration_s=5,
num_req=0,
isl=0,
osl=0,
kv_hit_rate=float("nan"),
)
)
assert core._last_kv_hit_rate == 0.5
def test_mixed_mode_observe_traffic_feeds_predictor_only(self):
"""In mixed mode the raw observation feeds the predictor; the sticky
``_last_kv_hit_rate`` is *not* updated until ``_advance_throughput``
promotes the predicted value to it."""
core = _make_core() # both load + throughput scaling enabled
assert core._last_kv_hit_rate is None
core._observe_traffic(
TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.3
)
)
# Predictor saw the observation
assert len(core._kv_hit_rate_predictor.data_buffer) == 1
# Sticky value is *not* set from the raw observation in mixed mode
assert core._last_kv_hit_rate is None
def test_mixed_mode_advance_throughput_promotes_predicted_value(self):
"""After a throughput tick fires, ``_last_kv_hit_rate`` should hold
the predicted value (used by all subsequent load ticks until the
next throughput tick)."""
core = _make_core(
mode="prefill", enable_load_scaling=True, enable_throughput_scaling=True
)
_train_prefill_regression(core)
# ConstantPredictor returns the last observed value once min_data_points=1.
# Feed a known value and run a throughput tick.
traffic = TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.6
)
tick_input = TickInput(
now_s=60.0,
traffic=traffic,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
core.on_tick(_tick_for(tick_input), tick_input)
# Constant predictor returns 0.6, which is then promoted to sticky
assert core._last_kv_hit_rate == pytest.approx(0.6)
def test_load_only_scheduler_sets_need_traffic_on_load_tick(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=False,
load_adjustment_interval=7,
)
tick = core.initial_tick(start_s=0.0)
# Load-only mode: the load tick should request a kv-hit-rate scrape
# over the load interval.
assert tick.run_load_scaling
assert not tick.run_throughput_scaling
assert tick.need_traffic_metrics
assert tick.traffic_metrics_duration_s == 7.0
def test_throughput_enabled_scheduler_skips_traffic_on_pure_load_tick(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=True,
load_adjustment_interval=5,
throughput_adjustment_interval=60,
)
tick = core.initial_tick(start_s=0.0)
# First tick is a pure load tick (5s < 60s); traffic scrape is reserved
# for the throughput tick when both modes are enabled.
assert tick.run_load_scaling
assert not tick.run_throughput_scaling
assert not tick.need_traffic_metrics
def test_load_only_load_tick_consumes_traffic(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=False,
)
tick_input = TickInput(
now_s=5.0,
traffic=TrafficObservation(
duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.7
),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
core.on_tick(_tick_for(tick_input), tick_input)
assert core._last_kv_hit_rate == 0.7
def test_warm_load_predictors_skips_kv_hit_rate(self):
"""kv_hit_rate has no good offline-trace proxy, so it must not
receive warmup data (only live observations feed it)."""
core = _make_core()
observations = [
TrafficObservation(
duration_s=60, num_req=50 * i, isl=1000, osl=150, kv_hit_rate=0.1 * i
)
for i in range(1, 4)
]
core.warm_load_predictors(observations)
# Other predictors accumulated their respective series
assert len(core._num_req_predictor.data_buffer) == 3
assert len(core._isl_predictor.data_buffer) == 3
assert len(core._osl_predictor.data_buffer) == 3
# kv_hit_rate predictor stayed cold
assert core._kv_hit_rate_predictor.data_buffer == []
def test_throughput_diagnostics_include_predicted_kv_hit_rate(self):
core = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core)
core._observe_traffic(
TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.4
)
)
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.4
),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
# ConstantPredictor predicts the last value it saw
assert effects.diagnostics.predicted_kv_hit_rate == 0.4
def test_high_predicted_hit_rate_reduces_prefill_replicas(self):
"""With the same demand + regression, a high predicted hit rate
should yield fewer (or at worst equal) prefill replicas than no
reuse."""
core_base = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core_base)
core_hit = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core_hit)
# Feed several observations so the (constant) predictor locks in.
traffic_base = TrafficObservation(
duration_s=60, num_req=500, isl=4000, osl=150, kv_hit_rate=0.0
)
traffic_hit = TrafficObservation(
duration_s=60, num_req=500, isl=4000, osl=150, kv_hit_rate=0.8
)
core_base._observe_traffic(traffic_base)
core_hit._observe_traffic(traffic_hit)
tick_base = TickInput(
now_s=60.0,
traffic=traffic_base,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
tick_hit = TickInput(
now_s=60.0,
traffic=traffic_hit,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects_base = core_base.on_tick(_tick_for(tick_base), tick_base)
effects_hit = core_hit.on_tick(_tick_for(tick_hit), tick_hit)
assert effects_base.scale_to is not None
assert effects_hit.scale_to is not None
assert effects_hit.scale_to.num_prefill <= effects_base.scale_to.num_prefill
# ── FPM reconciliation ───────────────────────────────────────────────
......
......@@ -1356,6 +1356,13 @@ impl PlannerReplayBridge {
/// - `avg_itl_ms` (f64): mean inter-token latency in milliseconds,
/// averaged only over requests that generated
/// at least one token gap (0.0 when no samples)
/// - `avg_kv_hit_rate` (f64): arithmetic mean of per-request
/// ``overlap_blocks / isl_blocks`` ratios
/// across router admissions in the window
/// (one sample per request, not weighted
/// by ISL), matching the real router's
/// `dynamo_component_router_kv_hit_rate`
/// histogram semantics
///
/// Call this only on throughput-scaling ticks so the observation window
/// covers the full `throughput_adjustment_interval`.
......@@ -1374,6 +1381,7 @@ impl PlannerReplayBridge {
"avg_osl": stats.avg_osl,
"avg_ttft_ms": stats.avg_ttft_ms,
"avg_itl_ms": stats.avg_itl_ms,
"avg_kv_hit_rate": stats.avg_kv_hit_rate,
});
pythonize(py, &result)
......
......@@ -317,6 +317,8 @@ class router:
INPUT_SEQUENCE_TOKENS = "router_input_sequence_tokens"
# Output sequence length in tokens observed at the router
OUTPUT_SEQUENCE_TOKENS = "router_output_sequence_tokens"
# Predicted KV cache hit rate at routing time (0.0-1.0)
KV_HIT_RATE = "router_kv_hit_rate"
class router_request:
......@@ -379,22 +381,6 @@ class tokio_perf:
class transport:
"""Transport-specific metrics (TCP / NATS)"""
# NOTE: Nested classes added manually because the codegen does not yet
# handle Rust submodules (see TODO in prometheus_parser.rs).
# Re-running gen-python-prometheus-names will overwrite this file and
# lose these classes until the codegen is updated.
class tcp:
POOL_ACTIVE = "tcp_pool_active"
POOL_IDLE = "tcp_pool_idle"
BYTES_SENT_TOTAL = "tcp_bytes_sent_total"
BYTES_RECEIVED_TOTAL = "tcp_bytes_received_total"
ERRORS_TOTAL = "tcp_errors_total"
SERVER_QUEUE_DEPTH = "tcp_server_queue_depth"
class nats:
ERRORS_TOTAL = "nats_errors_total"
class trtllm_additional:
"""Additional TRT-LLM worker metrics beyond what the engine natively provides."""
......
......@@ -257,7 +257,14 @@ impl AggRuntime {
&mut self,
admissions: Vec<WorkerAdmission>,
) -> anyhow::Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
for WorkerAdmission {
uuid,
worker_idx,
overlap_blocks,
isl_blocks,
} in admissions
{
self.traffic.on_admission(overlap_blocks, isl_blocks);
let request = self
.requests
.get_mut(&uuid)
......
......@@ -35,6 +35,16 @@ use crate::replay::router_shared::{
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
/// Internal result of a successful ``admit_request`` call: the chosen
/// worker plus the router's view of prefix-cache overlap, so callers can
/// forward the overlap stats to the traffic accumulator.
#[derive(Debug, Clone, Copy)]
struct AdmitOutcome {
worker_idx: usize,
overlap_blocks: u32,
isl_blocks: u32,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflinePendingRequestSnapshot {
......@@ -252,12 +262,16 @@ impl OfflineReplayRouter {
return Ok(RouterEffects::default());
}
let uuid = request
.uuid
.expect("offline replay requests must have UUIDs before router submission");
let outcome = self.admit_request(pending, decay_now)?;
Ok(RouterEffects {
admissions: vec![WorkerAdmission {
uuid: request
.uuid
.expect("offline replay requests must have UUIDs before router submission"),
worker_idx: self.admit_request(pending, decay_now)?,
uuid,
worker_idx: outcome.worker_idx,
overlap_blocks: outcome.overlap_blocks,
isl_blocks: outcome.isl_blocks,
}],
})
}
......@@ -279,11 +293,7 @@ impl OfflineReplayRouter {
.mark_prefill_completed(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?;
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
admissions: self.drain_pending(decay_now)?,
})
}
......@@ -297,11 +307,7 @@ impl OfflineReplayRouter {
.free(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?;
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
admissions: self.drain_pending(decay_now)?,
})
}
......@@ -310,11 +316,7 @@ impl OfflineReplayRouter {
pub(crate) fn try_drain_pending(&mut self, now_ms: f64) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms);
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
admissions: self.drain_pending(decay_now)?,
})
}
......@@ -359,11 +361,7 @@ impl OfflineReplayRouter {
}
let decay_now = self.decay_now(now_ms);
Ok(RouterEffects {
admissions: self
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
admissions: self.drain_pending(decay_now)?,
})
}
......@@ -486,7 +484,11 @@ impl OfflineReplayRouter {
})
}
fn admit_request(&mut self, request: PendingRequest, decay_now: Instant) -> Result<usize> {
fn admit_request(
&mut self,
request: PendingRequest,
decay_now: Instant,
) -> Result<AdmitOutcome> {
let (decode_blocks, prefill_tokens) = self
.slots
.potential_blocks_and_tokens_with_prefill_tracking(
......@@ -511,6 +513,10 @@ impl OfflineReplayRouter {
request.track_prefill_tokens,
);
let isl_blocks = u32::try_from(request.isl_tokens.div_ceil(self.block_size as usize))
.unwrap_or(u32::MAX);
let overlap_blocks = selection.overlap_blocks;
self.slots
.add_request(
SequenceRequest {
......@@ -526,10 +532,14 @@ impl OfflineReplayRouter {
)
.map_err(anyhow::Error::from)?;
Ok(worker_idx)
Ok(AdmitOutcome {
worker_idx,
overlap_blocks,
isl_blocks,
})
}
fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<(Uuid, usize)>> {
fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<WorkerAdmission>> {
let Some(threshold) = self.queue_threshold else {
return Ok(Vec::new());
};
......@@ -540,8 +550,13 @@ impl OfflineReplayRouter {
break;
};
let uuid = request.uuid;
let worker_idx = self.admit_request(request, decay_now)?;
admissions.push((uuid, worker_idx));
let outcome = self.admit_request(request, decay_now)?;
admissions.push(WorkerAdmission {
uuid,
worker_idx: outcome.worker_idx,
overlap_blocks: outcome.overlap_blocks,
isl_blocks: outcome.isl_blocks,
});
}
Ok(admissions)
......@@ -784,6 +799,8 @@ mod tests {
vec![WorkerAdmission {
uuid: Uuid::from_u128(1),
worker_idx: 3,
overlap_blocks: 0,
isl_blocks: 1,
}]
);
}
......@@ -836,6 +853,8 @@ mod tests {
vec![WorkerAdmission {
uuid: Uuid::from_u128(2),
worker_idx: 1,
overlap_blocks: 0,
isl_blocks: 1,
}]
);
assert_eq!(router.pending_count(), 0);
......
......@@ -25,6 +25,13 @@ pub(in crate::replay::offline) enum EnginePassMode {
pub(crate) struct WorkerAdmission {
pub(crate) uuid: Uuid,
pub(crate) worker_idx: usize,
/// Number of blocks the router matched against the prefix cache at
/// admission time. Used by the traffic accumulator to derive an
/// average KV hit rate for the planner.
pub(crate) overlap_blocks: u32,
/// Total ISL expressed in blocks (ceil(isl_tokens / block_size)),
/// paired with ``overlap_blocks`` for the hit-rate ratio.
pub(crate) isl_blocks: u32,
}
#[derive(Debug)]
......@@ -79,10 +86,21 @@ pub struct TrafficStats {
pub avg_osl: f64,
pub avg_ttft_ms: f64,
pub avg_itl_ms: f64,
/// Mean prefix-cache hit rate (0.0-1.0) across router admissions in
/// the window, computed as ``mean(overlap_blocks / isl_blocks)`` over
/// admitted requests (i.e. the arithmetic mean of per-request
/// ratios). Matches the semantics of the real router's
/// ``dynamo_component_router_kv_hit_rate`` Prometheus histogram,
/// which observes one ``overlap/isl`` sample per request; the
/// PromQL query ``sum(increase(_sum)) / sum(increase(_count))``
/// returns the arithmetic mean of those samples, independent of
/// per-request ISL size.
pub avg_kv_hit_rate: f64,
}
/// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window).
/// `TrafficObservation` (num_req, avg ISL, avg OSL, avg latencies, avg
/// KV hit rate over a window).
///
/// Latency samples are tracked independently of request counts: a request
/// only contributes to ``total_ttft_ms`` / ``ttft_count`` if a positive TTFT
......@@ -90,6 +108,12 @@ pub struct TrafficStats {
/// ``avg_itl_ms`` reflect only requests that actually produced the sample,
/// rather than silently underestimating when some requests lack latency
/// data (e.g. requests that fail before emitting a token).
///
/// KV hit-rate observations come from the router at admission time (not
/// completion) and are recorded as per-request ratios, matching the real
/// router's per-request histogram: each admission contributes one
/// ``overlap_blocks / isl_blocks`` sample to the running mean, so large
/// requests don't get weighted more heavily than small ones.
#[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64,
......@@ -100,6 +124,11 @@ pub(in crate::replay::offline) struct TrafficAccumulator {
total_itl_ms: f64,
ttft_count: usize,
itl_count: usize,
/// Running sum of per-request hit-rate ratios (``overlap / isl``);
/// divided by ``hit_rate_count`` at drain time to give the mean.
total_hit_rate: f64,
/// Number of admissions with non-zero ISL blocks in the current window.
hit_rate_count: usize,
}
impl TrafficAccumulator {
......@@ -113,6 +142,8 @@ impl TrafficAccumulator {
total_itl_ms: 0.0,
ttft_count: 0,
itl_count: 0,
total_hit_rate: 0.0,
hit_rate_count: 0,
}
}
......@@ -138,6 +169,26 @@ impl TrafficAccumulator {
}
}
/// Record one router admission's prefix-cache overlap as a
/// per-request ratio. Called at admission time (not completion) so
/// the mean hit rate reflects the router's view at routing decision
/// — matching the real router's per-request histogram, where each
/// request contributes exactly one ``overlap/isl`` sample.
/// Admissions with ``isl_blocks == 0`` are skipped (no meaningful
/// ratio), mirroring ``RequestTracker::kv_hit_rate()`` returning
/// ``None`` in that case.
pub(in crate::replay::offline) fn on_admission(
&mut self,
overlap_blocks: u32,
isl_blocks: u32,
) {
if isl_blocks == 0 {
return;
}
self.total_hit_rate += f64::from(overlap_blocks) / f64::from(isl_blocks);
self.hit_rate_count += 1;
}
/// Drain the accumulator at the given simulated time, resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> TrafficStats {
let duration_s = (now_ms - self.window_start_ms) / 1000.0;
......@@ -162,6 +213,11 @@ impl TrafficAccumulator {
} else {
0.0
};
let avg_kv_hit_rate = if self.hit_rate_count > 0 {
self.total_hit_rate / self.hit_rate_count as f64
} else {
0.0
};
self.window_start_ms = now_ms;
self.num_req = 0;
self.total_isl = 0;
......@@ -170,6 +226,8 @@ impl TrafficAccumulator {
self.total_itl_ms = 0.0;
self.ttft_count = 0;
self.itl_count = 0;
self.total_hit_rate = 0.0;
self.hit_rate_count = 0;
TrafficStats {
duration_s,
num_req,
......@@ -177,6 +235,64 @@ impl TrafficAccumulator {
avg_osl,
avg_ttft_ms,
avg_itl_ms,
avg_kv_hit_rate,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn traffic_accumulator_drain_with_no_admissions_reports_zero_hit_rate() {
let mut acc = TrafficAccumulator::new();
acc.on_request(100, 50, None);
let stats = acc.drain(1_000.0);
assert_eq!(stats.num_req, 1);
assert!((stats.avg_isl - 100.0).abs() < 1e-9);
assert!((stats.avg_osl - 50.0).abs() < 1e-9);
assert_eq!(stats.avg_kv_hit_rate, 0.0);
}
#[test]
fn traffic_accumulator_hit_rate_is_mean_of_per_request_ratios() {
let mut acc = TrafficAccumulator::new();
// Small request: mostly hit. Big request: no hit.
acc.on_admission(3, 4); // per-request ratio: 0.75
acc.on_admission(0, 12); // per-request ratio: 0.0
acc.on_request(256, 32, None);
acc.on_request(768, 32, None);
let stats = acc.drain(1_000.0);
assert_eq!(stats.num_req, 2);
// Per-request mean matches the real router's Prometheus histogram:
// (0.75 + 0.0) / 2 = 0.375. Every request contributes one sample
// regardless of ISL size, so large requests don't dominate.
assert!((stats.avg_kv_hit_rate - 0.375).abs() < 1e-9);
}
#[test]
fn traffic_accumulator_skips_admissions_with_zero_isl_blocks() {
let mut acc = TrafficAccumulator::new();
acc.on_admission(0, 0); // skipped -- no meaningful ratio
acc.on_admission(2, 4); // ratio = 0.5
let stats = acc.drain(1_000.0);
// Only the non-zero-ISL sample counts toward the mean.
assert!((stats.avg_kv_hit_rate - 0.5).abs() < 1e-9);
}
#[test]
fn traffic_accumulator_resets_counters_on_drain() {
let mut acc = TrafficAccumulator::new();
acc.on_admission(5, 10);
acc.on_request(100, 50, None);
let _ = acc.drain(1_000.0);
// Second drain on the same accumulator should see no state carried over.
let stats = acc.drain(2_000.0);
assert!((stats.duration_s - 1.0).abs() < 1e-9);
assert_eq!(stats.num_req, 0);
assert_eq!(stats.avg_isl, 0.0);
assert_eq!(stats.avg_osl, 0.0);
assert_eq!(stats.avg_kv_hit_rate, 0.0);
}
}
......@@ -303,7 +303,14 @@ impl DisaggRuntime {
/// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
for WorkerAdmission {
uuid,
worker_idx,
overlap_blocks,
isl_blocks,
} in admissions
{
self.traffic.on_admission(overlap_blocks, isl_blocks);
if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill {
bail!("offline disagg replay expected queued prefill request for {uuid}");
}
......@@ -313,8 +320,16 @@ impl DisaggRuntime {
}
/// Turn decode router admissions into concrete worker dispatches.
///
/// Note: only the prefill router's admissions are fed to
/// ``traffic.on_admission``; decode-router admissions reflect the
/// same requests re-routing after prefill completes and would double
/// count overlap observations.
fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions {
for WorkerAdmission {
uuid, worker_idx, ..
} in admissions
{
if self.state(uuid)?.phase != DisaggPhase::QueuedDecode {
bail!("offline disagg replay expected queued decode request for {uuid}");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment