"lib/runtime/src/vscode:/vscode.git/clone" did not exist on "31b78e96374545efd452cbea7f06ecc3d7a281f4"
Unverified Commit c388483a authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat(planner/replay): KV reuse awareness in load + throughput scaling (#8314)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent db124db0
...@@ -414,7 +414,7 @@ class NativePlannerBase: ...@@ -414,7 +414,7 @@ class NativePlannerBase:
return num_p, num_d, True return num_p, num_d, True
async def _collect_traffic(self) -> Optional[TrafficObservation]: async def _collect_traffic(self) -> Optional[TrafficObservation]:
"""Pull traffic metrics from Prometheus.""" """Pull traffic metrics from Prometheus over the throughput interval."""
num_p, num_d, _ = await self._get_worker_counts_raw() num_p, num_d, _ = await self._get_worker_counts_raw()
if self.prometheus_port != 0: if self.prometheus_port != 0:
...@@ -458,9 +458,14 @@ class NativePlannerBase: ...@@ -458,9 +458,14 @@ class NativePlannerBase:
m.osl = self.prometheus_traffic_client.get_avg_output_sequence_tokens( m.osl = self.prometheus_traffic_client.get_avg_output_sequence_tokens(
interval_str, self.model_name interval_str, self.model_name
) )
m.kv_hit_rate = self.prometheus_traffic_client.get_avg_kv_hit_rate(
interval_str, self.model_name
)
hit_rate_str = f"{m.kv_hit_rate:.3f}" if m.kv_hit_rate is not None else "n/a"
logger.info( logger.info(
f"Observed num_req: {m.num_req:.2f} isl: {m.isl:.2f} osl: {m.osl:.2f}" f"Observed num_req: {m.num_req:.2f} isl: {m.isl:.2f} osl: {m.osl:.2f} "
f"kv_hit_rate: {hit_rate_str}"
) )
if self.prometheus_port != 0: if self.prometheus_port != 0:
...@@ -483,6 +488,42 @@ class NativePlannerBase: ...@@ -483,6 +488,42 @@ class NativePlannerBase:
num_req=m.num_req, num_req=m.num_req,
isl=m.isl, isl=m.isl,
osl=m.osl, osl=m.osl,
kv_hit_rate=m.kv_hit_rate,
)
async def _collect_kv_hit_rate_observation(
self, duration_s: float
) -> Optional[TrafficObservation]:
"""Pull only the KV hit rate from Prometheus over ``duration_s``.
Used in load-only deployments: the load tick only needs the hit rate
to discount prefill work, so we skip the five other (unused) traffic
queries to keep the per-load-tick scrape cheap.
Returns ``None`` when the router metric is unavailable (e.g.
Prometheus source is "frontend"); the state machine treats that as
a no-discount fallback.
"""
assert self.model_name is not None
if duration_s <= 0:
return None
interval_str = f"{int(duration_s)}s"
hit_rate = self.prometheus_traffic_client.get_avg_kv_hit_rate(
interval_str, self.model_name
)
# Mirror the observed value into Metrics so the diagnostics recorder
# sees the up-to-date hit rate even on load-only ticks.
self._last_metrics.kv_hit_rate = hit_rate
hit_rate_str = f"{hit_rate:.3f}" if hit_rate is not None else "n/a"
logger.info(f"Observed kv_hit_rate over {interval_str}: {hit_rate_str}")
if hit_rate is None:
return None
return TrafficObservation(
duration_s=duration_s,
num_req=0.0,
isl=0.0,
osl=0.0,
kv_hit_rate=hit_rate,
) )
def _collect_fpm(self) -> FpmObservations: def _collect_fpm(self) -> FpmObservations:
...@@ -560,7 +601,17 @@ class NativePlannerBase: ...@@ -560,7 +601,17 @@ class NativePlannerBase:
fpm_obs = None fpm_obs = None
if tick.need_traffic_metrics: if tick.need_traffic_metrics:
# Throughput ticks pull the full traffic snapshot over the
# throughput interval. Load-only deployments instead piggyback
# a cheap kv-hit-rate-only scrape (over the load interval) on
# each load tick so the planner can still discount prefill work
# by recent prefix reuse.
if tick.run_throughput_scaling:
traffic = await self._collect_traffic() traffic = await self._collect_traffic()
else:
traffic = await self._collect_kv_hit_rate_observation(
tick.traffic_metrics_duration_s
)
if tick.need_worker_states: if tick.need_worker_states:
worker_counts = await self._collect_worker_counts() worker_counts = await self._collect_worker_counts()
if tick.need_worker_fpm: if tick.need_worker_fpm:
......
...@@ -368,11 +368,13 @@ class LoadScalingMixin: ...@@ -368,11 +368,13 @@ class LoadScalingMixin:
self._diag_load_reason = "insufficient_data" self._diag_load_reason = "insufficient_data"
return None return None
kv_hit_rate = self._last_kv_hit_rate
estimates: list[float] = [] estimates: list[float] = []
for (wid, dp), fpm in fpm_stats.items(): for (wid, dp), fpm in fpm_stats.items():
est = self._prefill_regression.estimate_next_ttft( est = self._prefill_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens, queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens, max_num_batched_tokens=max_tokens,
kv_hit_rate=kv_hit_rate,
) )
if est is not None: if est is not None:
est_ms = est * 1000 est_ms = est * 1000
...@@ -380,7 +382,8 @@ class LoadScalingMixin: ...@@ -380,7 +382,8 @@ class LoadScalingMixin:
logger.info( logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms " f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued={fpm.queued_requests.sum_prefill_tokens}, " f"(queued={fpm.queued_requests.sum_prefill_tokens}, "
f"avg_isl={self._prefill_regression.avg_isl:.1f})" f"avg_isl={self._prefill_regression.avg_isl:.1f}, "
f"kv_hit_rate={kv_hit_rate if kv_hit_rate is not None else 'n/a'})"
) )
if estimates: if estimates:
...@@ -432,12 +435,14 @@ class LoadScalingMixin: ...@@ -432,12 +435,14 @@ class LoadScalingMixin:
num_workers: int, num_workers: int,
max_tokens: int, max_tokens: int,
) -> Optional[int]: ) -> Optional[int]:
kv_hit_rate = self._last_kv_hit_rate
estimates: list[float] = [] estimates: list[float] = []
for fpm in fpm_stats.values(): for fpm in fpm_stats.values():
est = self._agg_regression.estimate_next_ttft( est = self._agg_regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens, queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_tokens, max_num_batched_tokens=max_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens, current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
kv_hit_rate=kv_hit_rate,
) )
if est is not None: if est is not None:
estimates.append(est * 1000) estimates.append(est * 1000)
......
...@@ -13,7 +13,11 @@ from typing import Optional ...@@ -13,7 +13,11 @@ from typing import Optional
import numpy as np import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage from dynamo.planner.core.perf_model.base import (
_BaseRegressionModel,
_clamp_kv_hit_rate,
_MovingAverage,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -77,15 +81,22 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -77,15 +81,22 @@ class AggRegressionModel(_BaseRegressionModel):
queued_prefill_tokens: int, queued_prefill_tokens: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
current_decode_kv: int, current_decode_kv: int,
kv_hit_rate: Optional[float] = None,
) -> Optional[float]: ) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode. """Simulate prefill scheduling with piggybacked decode.
``kv_hit_rate`` (0.0-1.0) discounts the aggregate work ahead --
both the queue backlog and the hypothetical next request's ISL --
because a new arrival will benefit from the same prefix-cache hit
rate as the current workload. See ``PrefillRegressionModel``.
Returns estimated TTFT in seconds, or None if the model is not ready. Returns estimated TTFT in seconds, or None if the model is not ready.
""" """
if not self._ensure_fitted() or max_num_batched_tokens <= 0: if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None return None
total_tokens = queued_prefill_tokens + self._avg_isl.value scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
total_tokens = (queued_prefill_tokens + self._avg_isl.value) * scale
if total_tokens <= 0: if total_tokens <= 0:
return 0.0 return 0.0
...@@ -121,6 +132,7 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -121,6 +132,7 @@ class AggRegressionModel(_BaseRegressionModel):
itl_sla: float, itl_sla: float,
max_kv_tokens: Optional[int] = None, max_kv_tokens: Optional[int] = None,
max_num_seqs: Optional[int] = None, max_num_seqs: Optional[int] = None,
kv_hit_rate: Optional[float] = None,
) -> tuple[float, float, float]: ) -> tuple[float, float, float]:
"""Find the maximum agg engine request rate under both SLA targets. """Find the maximum agg engine request rate under both SLA targets.
...@@ -131,6 +143,11 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -131,6 +143,11 @@ class AggRegressionModel(_BaseRegressionModel):
Request rate is derived via Little's law: Request rate is derived via Little's law:
``engine_rps = best_batch_size / (osl * wall_time_per_iter)``. ``engine_rps = best_batch_size / (osl * wall_time_per_iter)``.
``kv_hit_rate`` discounts only the prefill portion of each
iteration; decode KV residency uses the full ISL because cache
hits reduce prefill compute but do not shrink the KV footprint
used during decode.
The upper bound for the batch-size sweep is the smallest of: The upper bound for the batch-size sweep is the smallest of:
1. KV cache capacity: ``max_kv_tokens / (isl + osl/2)`` 1. KV cache capacity: ``max_kv_tokens / (isl + osl/2)``
2. ``max_num_seqs`` (engine concurrency limit) 2. ``max_num_seqs`` (engine concurrency limit)
...@@ -162,6 +179,9 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -162,6 +179,9 @@ class AggRegressionModel(_BaseRegressionModel):
): ):
return (0.0, 0.0, 0.0) return (0.0, 0.0, 0.0)
prefill_scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
effective_isl = isl * prefill_scale
avg_ctx = isl + osl / 2.0 avg_ctx = isl + osl / 2.0
# KV cache cap # KV cache cap
...@@ -174,14 +194,17 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -174,14 +194,17 @@ class AggRegressionModel(_BaseRegressionModel):
seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap seq_cap = max_num_seqs if max_num_seqs and max_num_seqs > 0 else kv_cap
# Prefill/decode balance cap via binary search within [1, min(kv_cap, seq_cap)] # Prefill/decode balance cap via binary search within [1, min(kv_cap, seq_cap)]
# For each candidate x, check: isl / (max_num_batched_tokens - x) <= osl # For each candidate x, check: effective_isl / (max_num_batched_tokens - x) <= osl
# Uses ``effective_isl`` (post-cache) because cache reuse shrinks the
# prefill tokens each new request consumes from the per-iteration
# budget, raising the admissible batch size.
hard_cap = min(kv_cap, seq_cap, max_num_batched_tokens - 1) hard_cap = min(kv_cap, seq_cap, max_num_batched_tokens - 1)
def _prefill_balanced(x: int) -> bool: def _prefill_balanced(x: int) -> bool:
prefill_budget = max_num_batched_tokens - x prefill_budget = max_num_batched_tokens - x
if prefill_budget <= 0: if prefill_budget <= 0:
return False return False
return isl / prefill_budget <= osl return effective_isl / prefill_budget <= osl
lo, hi = 1, max(1, hard_cap) lo, hi = 1, max(1, hard_cap)
while lo < hi: while lo < hi:
...@@ -198,14 +221,27 @@ class AggRegressionModel(_BaseRegressionModel): ...@@ -198,14 +221,27 @@ class AggRegressionModel(_BaseRegressionModel):
for bs in range(1, max_bs + 1): for bs in range(1, max_bs + 1):
decode_kv = bs * avg_ctx decode_kv = bs * avg_ctx
prefill_per_iter = min(bs * isl / max(1.0, osl), max_num_batched_tokens) # Discounted prefill per iter feeds the wall-time regression: the
# engine actually computes ``effective_isl`` tokens per request
# because the cached prefix is skipped.
prefill_per_iter = min(
bs * effective_isl / max(1.0, osl), max_num_batched_tokens
)
wt = self._predict_2d(prefill_per_iter, decode_kv) wt = self._predict_2d(prefill_per_iter, decode_kv)
itl_ms = wt * 1000.0 itl_ms = wt * 1000.0
# ``estimate_next_ttft`` applies the same discount internally to
# both the queued portion and the avg_isl portion. To keep the
# discount uniform, we pass the *raw* prefill_per_iter as the
# queued contribution and forward ``kv_hit_rate`` so the
# function's own ``(1 - clamp(kv_hit_rate))`` factor scales
# both sides consistently.
raw_prefill_per_iter = min(bs * isl / max(1.0, osl), max_num_batched_tokens)
est_ttft = self.estimate_next_ttft( est_ttft = self.estimate_next_ttft(
queued_prefill_tokens=int(prefill_per_iter), queued_prefill_tokens=int(raw_prefill_per_iter),
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=int(decode_kv), current_decode_kv=int(decode_kv),
kv_hit_rate=kv_hit_rate,
) )
ttft_ms = est_ttft * 1000.0 if est_ttft is not None else 0.0 ttft_ms = est_ttft * 1000.0 if est_ttft is not None else 0.0
......
...@@ -11,7 +11,7 @@ decode, and agg perf model subclasses. ...@@ -11,7 +11,7 @@ decode, and agg perf model subclasses.
import logging import logging
import math import math
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
...@@ -20,6 +20,22 @@ from dynamo.common.forward_pass_metrics import ForwardPassMetrics ...@@ -20,6 +20,22 @@ from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Upper bound on the applied KV hit rate discount. A full 1.0 reading would
# zero out queued/avg prefill tokens and could mask a genuine backlog; cap
# at 0.95 so the planner always sees *some* work ahead.
_MAX_KV_HIT_RATE_DISCOUNT = 0.95
def _clamp_kv_hit_rate(kv_hit_rate: Optional[float]) -> float:
"""Clamp a raw hit rate into the usable discount range.
Returns 0.0 for ``None`` / NaN (no discount, preserves pre-change
behavior), otherwise clamps into ``[0.0, _MAX_KV_HIT_RATE_DISCOUNT]``.
"""
if kv_hit_rate is None or math.isnan(kv_hit_rate):
return 0.0
return max(0.0, min(_MAX_KV_HIT_RATE_DISCOUNT, float(kv_hit_rate)))
class _MovingAverage: class _MovingAverage:
"""Fixed-window moving average that skips leading zeros. """Fixed-window moving average that skips leading zeros.
......
...@@ -13,7 +13,11 @@ from typing import Optional ...@@ -13,7 +13,11 @@ from typing import Optional
import numpy as np import numpy as np
from dynamo.common.forward_pass_metrics import ForwardPassMetrics from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.core.perf_model.base import _BaseRegressionModel, _MovingAverage from dynamo.planner.core.perf_model.base import (
_BaseRegressionModel,
_clamp_kv_hit_rate,
_MovingAverage,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,15 +62,23 @@ class PrefillRegressionModel(_BaseRegressionModel): ...@@ -58,15 +62,23 @@ class PrefillRegressionModel(_BaseRegressionModel):
self, self,
queued_prefill_tokens: int, queued_prefill_tokens: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
kv_hit_rate: Optional[float] = None,
) -> Optional[float]: ) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request. """Simulate prefill scheduling to estimate TTFT for the next request.
``kv_hit_rate`` (0.0-1.0) discounts the aggregate work ahead --
both the queue backlog and the hypothetical next request's ISL --
because a new arrival will benefit from the same prefix-cache hit
rate as the current workload. The regression features themselves
(per-iter chunk sizes) remain unchanged, so no double-counting.
Returns estimated TTFT in seconds, or None if the model is not ready. Returns estimated TTFT in seconds, or None if the model is not ready.
""" """
if not self._ensure_fitted() or max_num_batched_tokens <= 0: if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None return None
total_tokens = queued_prefill_tokens + self._avg_isl.value scale = 1.0 - _clamp_kv_hit_rate(kv_hit_rate)
total_tokens = (queued_prefill_tokens + self._avg_isl.value) * scale
if total_tokens <= 0: if total_tokens <= 0:
return 0.0 return 0.0
......
...@@ -94,6 +94,9 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -94,6 +94,9 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._num_req_predictor = predictor_cls(config) self._num_req_predictor = predictor_cls(config)
self._isl_predictor = predictor_cls(config) self._isl_predictor = predictor_cls(config)
self._osl_predictor = predictor_cls(config) self._osl_predictor = predictor_cls(config)
# KV hit rate has no good offline-trace proxy, so it is NOT warmed
# via ``warm_load_predictors``; it learns only from live observations.
self._kv_hit_rate_predictor = predictor_cls(config)
self._num_p_workers: int = 0 self._num_p_workers: int = 0
self._num_d_workers: int = 0 self._num_d_workers: int = 0
...@@ -103,6 +106,12 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -103,6 +106,12 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._throughput_lower_bound_p: int = 1 self._throughput_lower_bound_p: int = 1
self._throughput_lower_bound_d: int = 1 self._throughput_lower_bound_d: int = 1
# Most recent observed KV hit rate from the router. Used by load-scaling
# to discount queued/avg prefill tokens in ``estimate_next_ttft``. Sticky
# across ticks because load-scaling and throughput-scaling cadences
# may differ. ``None`` means "no observation yet" -> no discount.
self._last_kv_hit_rate: Optional[float] = None
self._next_load_s: float = float("inf") self._next_load_s: float = float("inf")
self._next_throughput_s: float = float("inf") self._next_throughput_s: float = float("inf")
...@@ -112,6 +121,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -112,6 +121,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_predicted_num_req: Optional[float] = None self._diag_predicted_num_req: Optional[float] = None
self._diag_predicted_isl: Optional[float] = None self._diag_predicted_isl: Optional[float] = None
self._diag_predicted_osl: Optional[float] = None self._diag_predicted_osl: Optional[float] = None
self._diag_predicted_kv_hit_rate: Optional[float] = None
self._diag_engine_rps_prefill: Optional[float] = None self._diag_engine_rps_prefill: Optional[float] = None
self._diag_engine_rps_decode: Optional[float] = None self._diag_engine_rps_decode: Optional[float] = None
self._diag_load_reason: Optional[str] = None self._diag_load_reason: Optional[str] = None
...@@ -193,6 +203,11 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -193,6 +203,11 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
) )
if tick.run_load_scaling: if tick.run_load_scaling:
# In load-only deployments the kv-hit-rate scrape rides on the
# load tick, so consume the traffic observation here. In mixed
# mode the throughput branch above already handled it.
if not tick.run_throughput_scaling and tick_input.traffic is not None:
self._observe_traffic(tick_input.traffic)
if tick_input.fpm_observations is not None: if tick_input.fpm_observations is not None:
if not self._is_easy: if not self._is_easy:
self._observe_fpm(tick_input.fpm_observations) self._observe_fpm(tick_input.fpm_observations)
...@@ -216,6 +231,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -216,6 +231,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
self._diag_predicted_num_req = None self._diag_predicted_num_req = None
self._diag_predicted_isl = None self._diag_predicted_isl = None
self._diag_predicted_osl = None self._diag_predicted_osl = None
self._diag_predicted_kv_hit_rate = None
self._diag_engine_rps_prefill = None self._diag_engine_rps_prefill = None
self._diag_engine_rps_decode = None self._diag_engine_rps_decode = None
self._diag_load_reason = None self._diag_load_reason = None
...@@ -232,6 +248,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -232,6 +248,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
predicted_num_req=self._diag_predicted_num_req, predicted_num_req=self._diag_predicted_num_req,
predicted_isl=self._diag_predicted_isl, predicted_isl=self._diag_predicted_isl,
predicted_osl=self._diag_predicted_osl, predicted_osl=self._diag_predicted_osl,
predicted_kv_hit_rate=self._diag_predicted_kv_hit_rate,
engine_rps_prefill=self._diag_engine_rps_prefill, engine_rps_prefill=self._diag_engine_rps_prefill,
engine_rps_decode=self._diag_engine_rps_decode, engine_rps_decode=self._diag_engine_rps_decode,
throughput_lower_bound_prefill=self._throughput_lower_bound_p, throughput_lower_bound_prefill=self._throughput_lower_bound_p,
...@@ -255,16 +272,27 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -255,16 +272,27 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
at_s = min(self._next_load_s, self._next_throughput_s) at_s = min(self._next_load_s, self._next_throughput_s)
is_load = self._next_load_s <= at_s + self._MERGE_TOLERANCE_S is_load = self._next_load_s <= at_s + self._MERGE_TOLERANCE_S
is_throughput = self._next_throughput_s <= at_s + self._MERGE_TOLERANCE_S is_throughput = self._next_throughput_s <= at_s + self._MERGE_TOLERANCE_S
# Throughput ticks scrape full traffic over the throughput interval.
# In load-only deployments (no throughput tick ever fires) load ticks
# carry a kv-hit-rate-only scrape over the load interval so the
# planner can still discount prefill work by recent prefix reuse.
if is_throughput:
need_traffic = True
traffic_duration_s = float(self._config.throughput_adjustment_interval)
elif is_load and not self._config.enable_throughput_scaling:
need_traffic = True
traffic_duration_s = float(self._config.load_adjustment_interval)
else:
need_traffic = False
traffic_duration_s = 0.0
return ScheduledTick( return ScheduledTick(
at_s=at_s, at_s=at_s,
run_load_scaling=is_load, run_load_scaling=is_load,
run_throughput_scaling=is_throughput, run_throughput_scaling=is_throughput,
need_worker_states=True, need_worker_states=True,
need_worker_fpm=is_load, need_worker_fpm=is_load,
need_traffic_metrics=is_throughput, need_traffic_metrics=need_traffic,
traffic_metrics_duration_s=( traffic_metrics_duration_s=traffic_duration_s,
self._config.throughput_adjustment_interval if is_throughput else 0.0
),
) )
# ------------------------------------------------------------------ # ------------------------------------------------------------------
...@@ -312,9 +340,25 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin): ...@@ -312,9 +340,25 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
logger.info(f"FPM load stats: {len(obs.decode)} decode engines observed") logger.info(f"FPM load stats: {len(obs.decode)} decode engines observed")
def _observe_traffic(self, traffic: TrafficObservation) -> None: def _observe_traffic(self, traffic: TrafficObservation) -> None:
# Throughput-scaling predictors only have a downstream consumer when
# throughput scaling is enabled. In load-only mode the traffic scrape
# is a kv-hit-rate-only path and num_req/isl/osl arrive as zero
# placeholders, so feeding the predictors would just pollute them.
if self._config.enable_throughput_scaling:
self._num_req_predictor.add_data_point(traffic.num_req) self._num_req_predictor.add_data_point(traffic.num_req)
self._isl_predictor.add_data_point(traffic.isl) self._isl_predictor.add_data_point(traffic.isl)
self._osl_predictor.add_data_point(traffic.osl) self._osl_predictor.add_data_point(traffic.osl)
if traffic.kv_hit_rate is not None and not math.isnan(traffic.kv_hit_rate):
if self._config.enable_throughput_scaling:
# Mixed mode: feed the predictor; ``_last_kv_hit_rate`` will be
# overwritten with the predicted value inside
# ``_advance_throughput`` so load scaling consumes the smoothed
# forecast (not the raw per-window observation).
self._kv_hit_rate_predictor.add_data_point(traffic.kv_hit_rate)
else:
# Load-only mode: there is no predictor path, the load tick
# consumes the freshly observed average directly.
self._last_kv_hit_rate = traffic.kv_hit_rate
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Budget # Budget
......
...@@ -14,6 +14,7 @@ import logging ...@@ -14,6 +14,7 @@ import logging
import math import math
from typing import Optional from typing import Optional
from dynamo.planner.core.perf_model.base import _clamp_kv_hit_rate
from dynamo.planner.core.types import ScalingDecision, TrafficObservation from dynamo.planner.core.types import ScalingDecision, TrafficObservation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -26,11 +27,14 @@ class ThroughputScalingMixin: ...@@ -26,11 +27,14 @@ class ThroughputScalingMixin:
_diag_predicted_num_req: Optional[float] _diag_predicted_num_req: Optional[float]
_diag_predicted_isl: Optional[float] _diag_predicted_isl: Optional[float]
_diag_predicted_osl: Optional[float] _diag_predicted_osl: Optional[float]
_diag_predicted_kv_hit_rate: Optional[float]
_diag_engine_rps_prefill: Optional[float] _diag_engine_rps_prefill: Optional[float]
_diag_engine_rps_decode: Optional[float] _diag_engine_rps_decode: Optional[float]
_diag_throughput_reason: Optional[str] _diag_throughput_reason: Optional[str]
_diag_throughput_reason_prefill: Optional[str] _diag_throughput_reason_prefill: Optional[str]
_diag_throughput_reason_decode: Optional[str] _diag_throughput_reason_decode: Optional[str]
# Sticky value consumed by the load-scaling path between throughput ticks.
_last_kv_hit_rate: Optional[float]
def _advance_throughput( def _advance_throughput(
self, traffic: TrafficObservation self, traffic: TrafficObservation
...@@ -48,13 +52,27 @@ class ThroughputScalingMixin: ...@@ -48,13 +52,27 @@ class ThroughputScalingMixin:
self._diag_throughput_reason = "no_traffic_data" self._diag_throughput_reason = "no_traffic_data"
return None return None
demand_rps = next_num_req / traffic.duration_s demand_rps = next_num_req / traffic.duration_s
predicted_hit_rate = self._predict_kv_hit_rate()
# Promote the predicted value to the sticky field so subsequent
# load-scaling ticks (between throughput ticks) discount prefill work
# using the smoothed forecast rather than the raw last-window
# observation.
if predicted_hit_rate is not None and not math.isnan(predicted_hit_rate):
self._last_kv_hit_rate = predicted_hit_rate
mode = self._config.mode mode = self._config.mode
if mode == "agg": if mode == "agg":
return self._throughput_agg(demand_rps, next_isl, next_osl) return self._throughput_agg(
demand_rps, next_isl, next_osl, predicted_hit_rate
)
if mode == "disagg": if mode == "disagg":
return self._throughput_disagg(demand_rps, next_isl, next_osl) return self._throughput_disagg(
return self._throughput_single(demand_rps, next_isl, next_osl, mode) demand_rps, next_isl, next_osl, predicted_hit_rate
)
return self._throughput_single(
demand_rps, next_isl, next_osl, mode, predicted_hit_rate
)
def _predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]: def _predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try: try:
...@@ -73,11 +91,33 @@ class ThroughputScalingMixin: ...@@ -73,11 +91,33 @@ class ThroughputScalingMixin:
self._diag_throughput_reason = "predict_failed" self._diag_throughput_reason = "predict_failed"
return None, None, None return None, None, None
def _predict_kv_hit_rate(self) -> Optional[float]:
"""Predict next-interval KV hit rate.
Returns ``None`` if the predictor isn't ready (cold start: no live
observations yet, no trace-based warmup) -- the caller treats that
as a 0.0 discount, preserving pre-change throughput-scaling behavior.
"""
try:
predicted = self._kv_hit_rate_predictor.predict_next()
except Exception as e:
logger.warning(f"Failed to predict kv_hit_rate: {e}")
self._diag_predicted_kv_hit_rate = None
return None
self._diag_predicted_kv_hit_rate = predicted
logger.info(f"Predicted kv_hit_rate={predicted:.3f}")
return predicted
def _throughput_single( def _throughput_single(
self, demand_rps: float, isl: float, osl: float, component: str self,
demand_rps: float,
isl: float,
osl: float,
component: str,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]: ) -> Optional[ScalingDecision]:
desired = ( desired = (
self._compute_prefill_replicas(demand_rps, isl, osl) self._compute_prefill_replicas(demand_rps, isl, osl, kv_hit_rate)
if component == "prefill" if component == "prefill"
else self._compute_decode_replicas(demand_rps, isl, osl) else self._compute_decode_replicas(demand_rps, isl, osl)
) )
...@@ -102,9 +142,13 @@ class ThroughputScalingMixin: ...@@ -102,9 +142,13 @@ class ThroughputScalingMixin:
) )
def _throughput_disagg( def _throughput_disagg(
self, demand_rps: float, isl: float, osl: float self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]: ) -> Optional[ScalingDecision]:
num_p = self._compute_prefill_replicas(demand_rps, isl, osl) num_p = self._compute_prefill_replicas(demand_rps, isl, osl, kv_hit_rate)
num_d = self._compute_decode_replicas(demand_rps, isl, osl) num_d = self._compute_decode_replicas(demand_rps, isl, osl)
# _compute_* sets _diag_throughput_reason = "model_not_ready" when # _compute_* sets _diag_throughput_reason = "model_not_ready" when
# the regression isn't fit yet. If one side is not ready, the other # the regression isn't fit yet. If one side is not ready, the other
...@@ -136,7 +180,11 @@ class ThroughputScalingMixin: ...@@ -136,7 +180,11 @@ class ThroughputScalingMixin:
return ScalingDecision(num_prefill=num_p, num_decode=num_d) return ScalingDecision(num_prefill=num_p, num_decode=num_d)
def _throughput_agg( def _throughput_agg(
self, demand_rps: float, isl: float, osl: float self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[ScalingDecision]: ) -> Optional[ScalingDecision]:
d_caps = self._capabilities.decode d_caps = self._capabilities.decode
max_tokens = d_caps.max_num_batched_tokens if d_caps else None max_tokens = d_caps.max_num_batched_tokens if d_caps else None
...@@ -159,6 +207,7 @@ class ThroughputScalingMixin: ...@@ -159,6 +207,7 @@ class ThroughputScalingMixin:
itl_sla=self._config.itl, itl_sla=self._config.itl,
max_kv_tokens=d_caps.max_kv_tokens if d_caps else None, max_kv_tokens=d_caps.max_kv_tokens if d_caps else None,
max_num_seqs=d_caps.max_num_seqs if d_caps else None, max_num_seqs=d_caps.max_num_seqs if d_caps else None,
kv_hit_rate=kv_hit_rate,
) )
if engine_rps <= 0: if engine_rps <= 0:
logger.warning("Agg perf model not ready, skipping throughput scaling") logger.warning("Agg perf model not ready, skipping throughput scaling")
...@@ -188,12 +237,20 @@ class ThroughputScalingMixin: ...@@ -188,12 +237,20 @@ class ThroughputScalingMixin:
return ScalingDecision(num_decode=desired) return ScalingDecision(num_decode=desired)
def _compute_prefill_replicas( def _compute_prefill_replicas(
self, demand_rps: float, isl: float, osl: float self,
demand_rps: float,
isl: float,
osl: float,
kv_hit_rate: Optional[float] = None,
) -> Optional[int]: ) -> Optional[int]:
# Prefix cache reuse shrinks the *compute* work of prefill but not
# decode KV residency, so we discount only the ISL fed into the
# prefill regression.
effective_isl = isl * (1.0 - _clamp_kv_hit_rate(kv_hit_rate))
p_caps = self._capabilities.prefill p_caps = self._capabilities.prefill
engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps( engine_rps, ttft_ms = self._prefill_regression.find_best_engine_prefill_rps(
ttft_sla=self._config.ttft, ttft_sla=self._config.ttft,
isl=isl, isl=effective_isl,
max_num_batched_tokens=p_caps.max_num_batched_tokens if p_caps else None, max_num_batched_tokens=p_caps.max_num_batched_tokens if p_caps else None,
) )
if engine_rps <= 0: if engine_rps <= 0:
...@@ -209,7 +266,9 @@ class ThroughputScalingMixin: ...@@ -209,7 +266,9 @@ class ThroughputScalingMixin:
result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint) result = max(math.ceil(demand_rps / engine_rps), self._config.min_endpoint)
logger.info( logger.info(
f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, est_ttft={ttft_ms:.1f}ms" f"Prefill: {demand_rps:.2f} rps / {engine_rps:.2f} = {result}, "
f"est_ttft={ttft_ms:.1f}ms, isl_raw={isl:.1f}, "
f"isl_effective={effective_isl:.1f}"
) )
return result return result
......
...@@ -48,6 +48,7 @@ class TrafficObservation: ...@@ -48,6 +48,7 @@ class TrafficObservation:
num_req: float num_req: float
isl: float isl: float
osl: float osl: float
kv_hit_rate: Optional[float] = None
@dataclass @dataclass
...@@ -107,6 +108,7 @@ class TickDiagnostics: ...@@ -107,6 +108,7 @@ class TickDiagnostics:
predicted_num_req: Optional[float] = None predicted_num_req: Optional[float] = None
predicted_isl: Optional[float] = None predicted_isl: Optional[float] = None
predicted_osl: Optional[float] = None predicted_osl: Optional[float] = None
predicted_kv_hit_rate: Optional[float] = None
# Throughput-scaling: single-engine capacity under SLA (req/s) # Throughput-scaling: single-engine capacity under SLA (req/s)
engine_rps_prefill: Optional[float] = None engine_rps_prefill: Optional[float] = None
......
...@@ -55,6 +55,7 @@ class TickSnapshot: ...@@ -55,6 +55,7 @@ class TickSnapshot:
observed_request_duration_seconds: Optional[float] = None observed_request_duration_seconds: Optional[float] = None
observed_input_sequence_tokens: Optional[float] = None observed_input_sequence_tokens: Optional[float] = None
observed_output_sequence_tokens: Optional[float] = None observed_output_sequence_tokens: Optional[float] = None
observed_kv_hit_rate: Optional[float] = None
# Diagnostics from state machine # Diagnostics from state machine
estimated_ttft_ms: Optional[float] = None estimated_ttft_ms: Optional[float] = None
...@@ -62,6 +63,7 @@ class TickSnapshot: ...@@ -62,6 +63,7 @@ class TickSnapshot:
predicted_requests_per_second: Optional[float] = None predicted_requests_per_second: Optional[float] = None
predicted_input_sequence_tokens: Optional[float] = None predicted_input_sequence_tokens: Optional[float] = None
predicted_output_sequence_tokens: Optional[float] = None predicted_output_sequence_tokens: Optional[float] = None
predicted_kv_hit_rate: Optional[float] = None
engine_rps_prefill: Optional[float] = None engine_rps_prefill: Optional[float] = None
engine_rps_decode: Optional[float] = None engine_rps_decode: Optional[float] = None
load_decision_reason: Optional[str] = None load_decision_reason: Optional[str] = None
...@@ -173,6 +175,7 @@ class DiagnosticsRecorder: ...@@ -173,6 +175,7 @@ class DiagnosticsRecorder:
observed_request_duration_seconds=observed.request_duration, observed_request_duration_seconds=observed.request_duration,
observed_input_sequence_tokens=observed.isl, observed_input_sequence_tokens=observed.isl,
observed_output_sequence_tokens=observed.osl, observed_output_sequence_tokens=observed.osl,
observed_kv_hit_rate=observed.kv_hit_rate,
estimated_ttft_ms=diag.estimated_ttft_ms, estimated_ttft_ms=diag.estimated_ttft_ms,
estimated_itl_ms=diag.estimated_itl_ms, estimated_itl_ms=diag.estimated_itl_ms,
predicted_requests_per_second=( predicted_requests_per_second=(
...@@ -182,6 +185,7 @@ class DiagnosticsRecorder: ...@@ -182,6 +185,7 @@ class DiagnosticsRecorder:
), ),
predicted_input_sequence_tokens=diag.predicted_isl, predicted_input_sequence_tokens=diag.predicted_isl,
predicted_output_sequence_tokens=diag.predicted_osl, predicted_output_sequence_tokens=diag.predicted_osl,
predicted_kv_hit_rate=diag.predicted_kv_hit_rate,
engine_rps_prefill=diag.engine_rps_prefill, engine_rps_prefill=diag.engine_rps_prefill,
engine_rps_decode=diag.engine_rps_decode, engine_rps_decode=diag.engine_rps_decode,
load_decision_reason=diag.load_decision_reason, load_decision_reason=diag.load_decision_reason,
......
...@@ -42,6 +42,7 @@ class Metrics: ...@@ -42,6 +42,7 @@ class Metrics:
request_duration: Optional[float] = None request_duration: Optional[float] = None
p_load: Optional[float] = None p_load: Optional[float] = None
d_load: Optional[float] = None d_load: Optional[float] = None
kv_hit_rate: Optional[float] = None
def is_valid(self) -> bool: def is_valid(self) -> bool:
"""Check if all required metrics are valid (not None and not NaN).""" """Check if all required metrics are valid (not None and not NaN)."""
...@@ -301,6 +302,47 @@ class PrometheusAPIClient: ...@@ -301,6 +302,47 @@ class PrometheusAPIClient:
model_name, model_name,
) )
def get_avg_kv_hit_rate(self, interval: str, model_name: str) -> Optional[float]:
"""Average predicted KV cache hit rate (0.0-1.0) from the router.
Only available when metrics_source == "router" (the histogram lives on
the LocalRouter component). In disagg deployments the scrape is
namespace-filtered, so if the planner's ``dynamo_namespace`` matches
the prefill pool, the returned value pools only prefill-router
observations.
Returns ``None`` (not ``0.0``) on missing data — Prometheus scrape
gaps must not be confused with a real "no reuse" signal: the state
machine treats a real ``0.0`` as a valid observation and would
otherwise drag the predictor / sticky value down toward zero on
every scrape failure. The caller's ``_clamp_kv_hit_rate(None)``
falls back to no-discount behavior, which is the safe choice.
"""
if self.metrics_source != "router":
return None
full_metric_name = (
f"{prometheus_names.name_prefix.COMPONENT}_"
f"{prometheus_names.router.KV_HIT_RATE}"
)
try:
ns = self.dynamo_namespace.replace("-", "_")
ns_filter = f'{prometheus_names.labels.NAMESPACE}="{ns}"'
query = (
f"sum(increase({full_metric_name}_sum{{{ns_filter}}}[{interval}])) / "
f"sum(increase({full_metric_name}_count{{{ns_filter}}}[{interval}]))"
)
result = self.prom.custom_query(query=query)
if not result:
logger.info(
f"No prometheus data for {full_metric_name}, returning None"
)
return None
value = float(result[0]["value"][1])
return None if math.isnan(value) else value
except Exception as e:
logger.warning(f"Error getting avg kv hit rate: {e}")
return None
def warn_if_router_not_scraped(self) -> None: def warn_if_router_not_scraped(self) -> None:
"""Warn if Prometheus is not scraping any dynamo_component_router_* series. """Warn if Prometheus is not scraping any dynamo_component_router_* series.
......
...@@ -421,11 +421,16 @@ class ReplayPlannerAdapter: ...@@ -421,11 +421,16 @@ class ReplayPlannerAdapter:
duration_s = t.get("duration_s", 0.0) duration_s = t.get("duration_s", 0.0)
if duration_s > 0: if duration_s > 0:
num_req = float(t.get("num_req", 0)) num_req = float(t.get("num_req", 0))
# The mocker publishes avg_kv_hit_rate as 0.0 when the
# window had no admissions with non-zero ISL blocks;
# pass it through as-is so the state machine can decide
# whether to feed its predictor.
traffic = TrafficObservation( traffic = TrafficObservation(
duration_s=duration_s, duration_s=duration_s,
num_req=num_req, num_req=num_req,
isl=t.get("avg_isl", 0.0), isl=t.get("avg_isl", 0.0),
osl=t.get("avg_osl", 0.0), osl=t.get("avg_osl", 0.0),
kv_hit_rate=t.get("avg_kv_hit_rate"),
) )
# Stash observed TTFT/ITL for the diagnostics recorder. # Stash observed TTFT/ITL for the diagnostics recorder.
# When num_req == 0, the Rust accumulator returns 0 as a # When num_req == 0, the Rust accumulator returns 0 as a
...@@ -437,6 +442,7 @@ class ReplayPlannerAdapter: ...@@ -437,6 +442,7 @@ class ReplayPlannerAdapter:
num_req=traffic.num_req, num_req=traffic.num_req,
isl=traffic.isl, isl=traffic.isl,
osl=traffic.osl, osl=traffic.osl,
kv_hit_rate=traffic.kv_hit_rate,
) )
return TickInput( return TickInput(
......
...@@ -185,6 +185,109 @@ class TestPrefillRegressionModel: ...@@ -185,6 +185,109 @@ class TestPrefillRegressionModel:
) )
assert est is not None assert est is not None
def test_kv_hit_rate_none_equals_zero(self):
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
for tokens in [500, 1000, 1500, 2000, 2500]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens + 0.002,
)
model.add_observation(fpm)
none_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
kv_hit_rate=None,
)
zero_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
kv_hit_rate=0.0,
)
assert none_est == zero_est
def test_kv_hit_rate_discounts_queued_and_avg_isl(self):
"""A hit rate of 0.5 should halve the simulated work, roughly halving TTFT."""
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
# Fit on several points so the regression is stable and ~linear in tokens.
for tokens in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
model.add_observation(fpm)
max_batched = 100_000 # single-iteration regime, no chunking rounding
est_full = model.estimate_next_ttft(
queued_prefill_tokens=4000,
max_num_batched_tokens=max_batched,
kv_hit_rate=0.0,
)
est_half = model.estimate_next_ttft(
queued_prefill_tokens=4000,
max_num_batched_tokens=max_batched,
kv_hit_rate=0.5,
)
assert est_full is not None and est_half is not None
# With a ~linear regression and no chunking rounding, 0.5 discount
# should produce roughly half the TTFT (within 20% tolerance for
# linearly-fitted intercept noise).
assert est_half < est_full
assert est_half / est_full < 0.75
def test_kv_hit_rate_clamped(self):
model = PrefillRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
for tokens in [500, 1000, 1500, 2000, 2500]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens,
)
model.add_observation(fpm)
# kv_hit_rate > 1.0 should clamp to 0.95 (not 1.0) so queued/avg don't
# fully zero out.
est_above = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=1.5,
)
est_cap = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=0.95,
)
assert est_above == est_cap
# Negative values clamp to 0.0 (no discount).
est_negative = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=-0.3,
)
est_zero = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=0.0,
)
assert est_negative == est_zero
# NaN falls back to 0.0.
est_nan = model.estimate_next_ttft(
queued_prefill_tokens=2000,
max_num_batched_tokens=100_000,
kv_hit_rate=float("nan"),
)
assert est_nan == est_zero
# ── Bucketed retirement tests ───────────────────────────────────────── # ── Bucketed retirement tests ─────────────────────────────────────────
...@@ -447,3 +550,103 @@ class TestAggRegressionModel: ...@@ -447,3 +550,103 @@ class TestAggRegressionModel:
itl_sla=50.0, itl_sla=50.0,
) )
assert thpt == 0.0 assert thpt == 0.0
def test_agg_kv_hit_rate_none_equals_zero(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
none_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
current_decode_kv=1000,
kv_hit_rate=None,
)
zero_est = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=2048,
current_decode_kv=1000,
kv_hit_rate=0.0,
)
assert none_est == zero_est
def test_agg_kv_hit_rate_discounts_prefill(self):
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
est_full = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=100_000,
current_decode_kv=1000,
kv_hit_rate=0.0,
)
est_half = model.estimate_next_ttft(
queued_prefill_tokens=3000,
max_num_batched_tokens=100_000,
current_decode_kv=1000,
kv_hit_rate=0.5,
)
assert est_full is not None and est_half is not None
assert est_half < est_full
def test_agg_find_best_engine_rps_hit_rate_increases_throughput(self):
"""find_best_engine_agg_rps should discount only prefill work,
leaving decode KV at full context; higher hit rate should yield
greater-or-equal engine rps."""
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
rps_base, _, _ = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
max_num_batched_tokens=4096,
ttft_sla=500.0,
itl_sla=50.0,
kv_hit_rate=0.0,
)
rps_hit, _, _ = model.find_best_engine_agg_rps(
isl=2048.0,
osl=150.0,
max_num_batched_tokens=4096,
ttft_sla=500.0,
itl_sla=50.0,
kv_hit_rate=0.6,
)
assert rps_hit >= rps_base
def test_agg_find_best_engine_rps_uniform_discount_in_ttft_estimate(self):
"""``find_best_engine_agg_rps`` must apply the kv_hit_rate discount
uniformly to BOTH the per-iter prefill and the avg_isl portion of
the TTFT simulation. Regression for the bug where the function
passed already-discounted prefill_per_iter to estimate_next_ttft
without forwarding kv_hit_rate, leaving avg_isl at full size and
inflating the predicted TTFT (= over-provisioning replicas)."""
model = AggRegressionModel(
max_num_fpm_samples=50, min_observations=3, bucket_count=16
)
self._train_agg(model)
# With a permissive ITL/TTFT SLA, the only difference in engine_rps
# at high hit rate vs zero hit rate should come from the prefill
# discount. If the bug recurs the high-hit-rate path will under-
# estimate capacity (smaller batch sweep) and produce strictly less
# rps growth than the discount factor warrants.
rps_zero, _, _ = model.find_best_engine_agg_rps(
isl=4000.0,
osl=200.0,
max_num_batched_tokens=8192,
ttft_sla=10_000.0,
itl_sla=10_000.0,
kv_hit_rate=0.0,
)
rps_high, _, _ = model.find_best_engine_agg_rps(
isl=4000.0,
osl=200.0,
max_num_batched_tokens=8192,
ttft_sla=10_000.0,
itl_sla=10_000.0,
kv_hit_rate=0.8,
)
# Strictly greater capacity at 80% hit rate (not just >=).
assert rps_high > rps_zero
...@@ -325,6 +325,29 @@ class TestPrometheusAPIClientRouterSource: ...@@ -325,6 +325,29 @@ class TestPrometheusAPIClientRouterSource:
expected_metric = f"{prometheus_names.name_prefix.COMPONENT}_{prometheus_names.router.OUTPUT_SEQUENCE_TOKENS}" expected_metric = f"{prometheus_names.name_prefix.COMPONENT}_{prometheus_names.router.OUTPUT_SEQUENCE_TOKENS}"
assert expected_metric in call_args assert expected_metric in call_args
def test_get_avg_kv_hit_rate_dispatches_to_router_histogram(self, router_client):
"""get_avg_kv_hit_rate with router source queries dynamo_component_router_kv_hit_rate."""
# Return a plausible 0.0-1.0 ratio rather than the default 42.0 fixture.
router_client.prom.custom_query.return_value = [{"value": [0, "0.35"]}]
result = router_client.get_avg_kv_hit_rate("60s", "mymodel")
assert result == 0.35
call_args = str(router_client.prom.custom_query.call_args)
expected_metric = f"{prometheus_names.name_prefix.COMPONENT}_{prometheus_names.router.KV_HIT_RATE}"
assert expected_metric in call_args
def test_get_avg_kv_hit_rate_returns_none_for_frontend_source(self):
"""Frontend source doesn't publish an aggregate kv_hit_rate, so the
client should short-circuit to None rather than issue a PromQL query."""
client = PrometheusAPIClient(
"http://localhost:9090", "test-fe-namespace", metrics_source="frontend"
)
client.prom = MagicMock()
client.prom.custom_query.return_value = [{"value": [0, "42.0"]}]
result = client.get_avg_kv_hit_rate("60s", "mymodel")
assert result is None
client.prom.custom_query.assert_not_called()
def test_get_avg_request_count_uses_router_requests_total(self, router_client): def test_get_avg_request_count_uses_router_requests_total(self, router_client):
"""get_avg_request_count with router source queries dynamo_component_router_requests_total.""" """get_avg_request_count with router source queries dynamo_component_router_requests_total."""
result = router_client.get_avg_request_count("60s", "mymodel") result = router_client.get_avg_request_count("60s", "mymodel")
......
...@@ -176,7 +176,10 @@ class TestInitialTick: ...@@ -176,7 +176,10 @@ class TestInitialTick:
tick = core.initial_tick(start_s=0.0) tick = core.initial_tick(start_s=0.0)
assert tick.at_s == 5.0 assert tick.at_s == 5.0
assert tick.need_worker_fpm assert tick.need_worker_fpm
assert not tick.need_traffic_metrics # Load-only mode rides a kv-hit-rate scrape on the load tick so the
# planner can discount prefill work by recent prefix reuse.
assert tick.need_traffic_metrics
assert tick.traffic_metrics_duration_s == 5.0
def test_throughput_only(self): def test_throughput_only(self):
core = _make_core(enable_load_scaling=False) core = _make_core(enable_load_scaling=False)
...@@ -426,6 +429,221 @@ class TestThroughputScaling: ...@@ -426,6 +429,221 @@ class TestThroughputScaling:
assert effects.next_tick.at_s == 120.0 assert effects.next_tick.at_s == 120.0
class TestKvHitRatePlumbing:
def test_load_only_observe_traffic_updates_last_kv_hit_rate(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(
duration_s=5, num_req=100, isl=1000, osl=150, kv_hit_rate=0.3
)
)
assert core._last_kv_hit_rate == 0.3
def test_load_only_skips_throughput_predictor_feeds(self):
"""In load-only mode the throughput predictors have no consumer; we
must not pollute their buffers with placeholder zeros."""
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.4)
)
assert core._num_req_predictor.data_buffer == []
assert core._isl_predictor.data_buffer == []
assert core._osl_predictor.data_buffer == []
# kv predictor also untouched in load-only mode (no prediction needed)
assert core._kv_hit_rate_predictor.data_buffer == []
def test_load_only_none_kv_hit_rate_leaves_last_value_unchanged(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.42)
)
# Subsequent observation without a hit rate (scrape failure / frontend
# source) must not clobber the sticky value -- the planner keeps
# using the most recent valid reading.
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=None)
)
assert core._last_kv_hit_rate == 0.42
def test_load_only_nan_kv_hit_rate_is_ignored(self):
core = _make_core(enable_throughput_scaling=False)
core._observe_traffic(
TrafficObservation(duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.5)
)
core._observe_traffic(
TrafficObservation(
duration_s=5,
num_req=0,
isl=0,
osl=0,
kv_hit_rate=float("nan"),
)
)
assert core._last_kv_hit_rate == 0.5
def test_mixed_mode_observe_traffic_feeds_predictor_only(self):
"""In mixed mode the raw observation feeds the predictor; the sticky
``_last_kv_hit_rate`` is *not* updated until ``_advance_throughput``
promotes the predicted value to it."""
core = _make_core() # both load + throughput scaling enabled
assert core._last_kv_hit_rate is None
core._observe_traffic(
TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.3
)
)
# Predictor saw the observation
assert len(core._kv_hit_rate_predictor.data_buffer) == 1
# Sticky value is *not* set from the raw observation in mixed mode
assert core._last_kv_hit_rate is None
def test_mixed_mode_advance_throughput_promotes_predicted_value(self):
"""After a throughput tick fires, ``_last_kv_hit_rate`` should hold
the predicted value (used by all subsequent load ticks until the
next throughput tick)."""
core = _make_core(
mode="prefill", enable_load_scaling=True, enable_throughput_scaling=True
)
_train_prefill_regression(core)
# ConstantPredictor returns the last observed value once min_data_points=1.
# Feed a known value and run a throughput tick.
traffic = TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.6
)
tick_input = TickInput(
now_s=60.0,
traffic=traffic,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
core.on_tick(_tick_for(tick_input), tick_input)
# Constant predictor returns 0.6, which is then promoted to sticky
assert core._last_kv_hit_rate == pytest.approx(0.6)
def test_load_only_scheduler_sets_need_traffic_on_load_tick(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=False,
load_adjustment_interval=7,
)
tick = core.initial_tick(start_s=0.0)
# Load-only mode: the load tick should request a kv-hit-rate scrape
# over the load interval.
assert tick.run_load_scaling
assert not tick.run_throughput_scaling
assert tick.need_traffic_metrics
assert tick.traffic_metrics_duration_s == 7.0
def test_throughput_enabled_scheduler_skips_traffic_on_pure_load_tick(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=True,
load_adjustment_interval=5,
throughput_adjustment_interval=60,
)
tick = core.initial_tick(start_s=0.0)
# First tick is a pure load tick (5s < 60s); traffic scrape is reserved
# for the throughput tick when both modes are enabled.
assert tick.run_load_scaling
assert not tick.run_throughput_scaling
assert not tick.need_traffic_metrics
def test_load_only_load_tick_consumes_traffic(self):
core = _make_core(
mode="prefill",
enable_load_scaling=True,
enable_throughput_scaling=False,
)
tick_input = TickInput(
now_s=5.0,
traffic=TrafficObservation(
duration_s=5, num_req=0, isl=0, osl=0, kv_hit_rate=0.7
),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
core.on_tick(_tick_for(tick_input), tick_input)
assert core._last_kv_hit_rate == 0.7
def test_warm_load_predictors_skips_kv_hit_rate(self):
"""kv_hit_rate has no good offline-trace proxy, so it must not
receive warmup data (only live observations feed it)."""
core = _make_core()
observations = [
TrafficObservation(
duration_s=60, num_req=50 * i, isl=1000, osl=150, kv_hit_rate=0.1 * i
)
for i in range(1, 4)
]
core.warm_load_predictors(observations)
# Other predictors accumulated their respective series
assert len(core._num_req_predictor.data_buffer) == 3
assert len(core._isl_predictor.data_buffer) == 3
assert len(core._osl_predictor.data_buffer) == 3
# kv_hit_rate predictor stayed cold
assert core._kv_hit_rate_predictor.data_buffer == []
def test_throughput_diagnostics_include_predicted_kv_hit_rate(self):
core = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core)
core._observe_traffic(
TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.4
)
)
tick = TickInput(
now_s=60.0,
traffic=TrafficObservation(
duration_s=60, num_req=100, isl=1000, osl=150, kv_hit_rate=0.4
),
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects = core.on_tick(_tick_for(tick), tick)
# ConstantPredictor predicts the last value it saw
assert effects.diagnostics.predicted_kv_hit_rate == 0.4
def test_high_predicted_hit_rate_reduces_prefill_replicas(self):
"""With the same demand + regression, a high predicted hit rate
should yield fewer (or at worst equal) prefill replicas than no
reuse."""
core_base = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core_base)
core_hit = _make_core(
mode="prefill", enable_load_scaling=False, enable_throughput_scaling=True
)
_train_prefill_regression(core_hit)
# Feed several observations so the (constant) predictor locks in.
traffic_base = TrafficObservation(
duration_s=60, num_req=500, isl=4000, osl=150, kv_hit_rate=0.0
)
traffic_hit = TrafficObservation(
duration_s=60, num_req=500, isl=4000, osl=150, kv_hit_rate=0.8
)
core_base._observe_traffic(traffic_base)
core_hit._observe_traffic(traffic_hit)
tick_base = TickInput(
now_s=60.0,
traffic=traffic_base,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
tick_hit = TickInput(
now_s=60.0,
traffic=traffic_hit,
worker_counts=WorkerCounts(ready_num_prefill=1),
)
effects_base = core_base.on_tick(_tick_for(tick_base), tick_base)
effects_hit = core_hit.on_tick(_tick_for(tick_hit), tick_hit)
assert effects_base.scale_to is not None
assert effects_hit.scale_to is not None
assert effects_hit.scale_to.num_prefill <= effects_base.scale_to.num_prefill
# ── FPM reconciliation ─────────────────────────────────────────────── # ── FPM reconciliation ───────────────────────────────────────────────
......
...@@ -1356,6 +1356,13 @@ impl PlannerReplayBridge { ...@@ -1356,6 +1356,13 @@ impl PlannerReplayBridge {
/// - `avg_itl_ms` (f64): mean inter-token latency in milliseconds, /// - `avg_itl_ms` (f64): mean inter-token latency in milliseconds,
/// averaged only over requests that generated /// averaged only over requests that generated
/// at least one token gap (0.0 when no samples) /// at least one token gap (0.0 when no samples)
/// - `avg_kv_hit_rate` (f64): arithmetic mean of per-request
/// ``overlap_blocks / isl_blocks`` ratios
/// across router admissions in the window
/// (one sample per request, not weighted
/// by ISL), matching the real router's
/// `dynamo_component_router_kv_hit_rate`
/// histogram semantics
/// ///
/// Call this only on throughput-scaling ticks so the observation window /// Call this only on throughput-scaling ticks so the observation window
/// covers the full `throughput_adjustment_interval`. /// covers the full `throughput_adjustment_interval`.
...@@ -1374,6 +1381,7 @@ impl PlannerReplayBridge { ...@@ -1374,6 +1381,7 @@ impl PlannerReplayBridge {
"avg_osl": stats.avg_osl, "avg_osl": stats.avg_osl,
"avg_ttft_ms": stats.avg_ttft_ms, "avg_ttft_ms": stats.avg_ttft_ms,
"avg_itl_ms": stats.avg_itl_ms, "avg_itl_ms": stats.avg_itl_ms,
"avg_kv_hit_rate": stats.avg_kv_hit_rate,
}); });
pythonize(py, &result) pythonize(py, &result)
......
...@@ -317,6 +317,8 @@ class router: ...@@ -317,6 +317,8 @@ class router:
INPUT_SEQUENCE_TOKENS = "router_input_sequence_tokens" INPUT_SEQUENCE_TOKENS = "router_input_sequence_tokens"
# Output sequence length in tokens observed at the router # Output sequence length in tokens observed at the router
OUTPUT_SEQUENCE_TOKENS = "router_output_sequence_tokens" OUTPUT_SEQUENCE_TOKENS = "router_output_sequence_tokens"
# Predicted KV cache hit rate at routing time (0.0-1.0)
KV_HIT_RATE = "router_kv_hit_rate"
class router_request: class router_request:
...@@ -379,22 +381,6 @@ class tokio_perf: ...@@ -379,22 +381,6 @@ class tokio_perf:
class transport: class transport:
"""Transport-specific metrics (TCP / NATS)""" """Transport-specific metrics (TCP / NATS)"""
# NOTE: Nested classes added manually because the codegen does not yet
# handle Rust submodules (see TODO in prometheus_parser.rs).
# Re-running gen-python-prometheus-names will overwrite this file and
# lose these classes until the codegen is updated.
class tcp:
POOL_ACTIVE = "tcp_pool_active"
POOL_IDLE = "tcp_pool_idle"
BYTES_SENT_TOTAL = "tcp_bytes_sent_total"
BYTES_RECEIVED_TOTAL = "tcp_bytes_received_total"
ERRORS_TOTAL = "tcp_errors_total"
SERVER_QUEUE_DEPTH = "tcp_server_queue_depth"
class nats:
ERRORS_TOTAL = "nats_errors_total"
class trtllm_additional: class trtllm_additional:
"""Additional TRT-LLM worker metrics beyond what the engine natively provides.""" """Additional TRT-LLM worker metrics beyond what the engine natively provides."""
......
...@@ -257,7 +257,14 @@ impl AggRuntime { ...@@ -257,7 +257,14 @@ impl AggRuntime {
&mut self, &mut self,
admissions: Vec<WorkerAdmission>, admissions: Vec<WorkerAdmission>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions { for WorkerAdmission {
uuid,
worker_idx,
overlap_blocks,
isl_blocks,
} in admissions
{
self.traffic.on_admission(overlap_blocks, isl_blocks);
let request = self let request = self
.requests .requests
.get_mut(&uuid) .get_mut(&uuid)
......
...@@ -35,6 +35,16 @@ use crate::replay::router_shared::{ ...@@ -35,6 +35,16 @@ use crate::replay::router_shared::{
type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key; type ReplayQueueKey = <RouterSchedulingPolicy as SchedulingPolicy>::Key;
/// Internal result of a successful ``admit_request`` call: the chosen
/// worker plus the router's view of prefix-cache overlap, so callers can
/// forward the overlap stats to the traffic accumulator.
#[derive(Debug, Clone, Copy)]
struct AdmitOutcome {
worker_idx: usize,
overlap_blocks: u32,
isl_blocks: u32,
}
#[cfg(test)] #[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OfflinePendingRequestSnapshot { pub(crate) struct OfflinePendingRequestSnapshot {
...@@ -252,12 +262,16 @@ impl OfflineReplayRouter { ...@@ -252,12 +262,16 @@ impl OfflineReplayRouter {
return Ok(RouterEffects::default()); return Ok(RouterEffects::default());
} }
let uuid = request
.uuid
.expect("offline replay requests must have UUIDs before router submission");
let outcome = self.admit_request(pending, decay_now)?;
Ok(RouterEffects { Ok(RouterEffects {
admissions: vec![WorkerAdmission { admissions: vec![WorkerAdmission {
uuid: request uuid,
.uuid worker_idx: outcome.worker_idx,
.expect("offline replay requests must have UUIDs before router submission"), overlap_blocks: outcome.overlap_blocks,
worker_idx: self.admit_request(pending, decay_now)?, isl_blocks: outcome.isl_blocks,
}], }],
}) })
} }
...@@ -279,11 +293,7 @@ impl OfflineReplayRouter { ...@@ -279,11 +293,7 @@ impl OfflineReplayRouter {
.mark_prefill_completed(&uuid.to_string(), decay_now) .mark_prefill_completed(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
Ok(RouterEffects { Ok(RouterEffects {
admissions: self admissions: self.drain_pending(decay_now)?,
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
}) })
} }
...@@ -297,11 +307,7 @@ impl OfflineReplayRouter { ...@@ -297,11 +307,7 @@ impl OfflineReplayRouter {
.free(&uuid.to_string(), decay_now) .free(&uuid.to_string(), decay_now)
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
Ok(RouterEffects { Ok(RouterEffects {
admissions: self admissions: self.drain_pending(decay_now)?,
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
}) })
} }
...@@ -310,11 +316,7 @@ impl OfflineReplayRouter { ...@@ -310,11 +316,7 @@ impl OfflineReplayRouter {
pub(crate) fn try_drain_pending(&mut self, now_ms: f64) -> Result<RouterEffects> { pub(crate) fn try_drain_pending(&mut self, now_ms: f64) -> Result<RouterEffects> {
let decay_now = self.decay_now(now_ms); let decay_now = self.decay_now(now_ms);
Ok(RouterEffects { Ok(RouterEffects {
admissions: self admissions: self.drain_pending(decay_now)?,
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
}) })
} }
...@@ -359,11 +361,7 @@ impl OfflineReplayRouter { ...@@ -359,11 +361,7 @@ impl OfflineReplayRouter {
} }
let decay_now = self.decay_now(now_ms); let decay_now = self.decay_now(now_ms);
Ok(RouterEffects { Ok(RouterEffects {
admissions: self admissions: self.drain_pending(decay_now)?,
.drain_pending(decay_now)?
.into_iter()
.map(|(uuid, worker_idx)| WorkerAdmission { uuid, worker_idx })
.collect(),
}) })
} }
...@@ -486,7 +484,11 @@ impl OfflineReplayRouter { ...@@ -486,7 +484,11 @@ impl OfflineReplayRouter {
}) })
} }
fn admit_request(&mut self, request: PendingRequest, decay_now: Instant) -> Result<usize> { fn admit_request(
&mut self,
request: PendingRequest,
decay_now: Instant,
) -> Result<AdmitOutcome> {
let (decode_blocks, prefill_tokens) = self let (decode_blocks, prefill_tokens) = self
.slots .slots
.potential_blocks_and_tokens_with_prefill_tracking( .potential_blocks_and_tokens_with_prefill_tracking(
...@@ -511,6 +513,10 @@ impl OfflineReplayRouter { ...@@ -511,6 +513,10 @@ impl OfflineReplayRouter {
request.track_prefill_tokens, request.track_prefill_tokens,
); );
let isl_blocks = u32::try_from(request.isl_tokens.div_ceil(self.block_size as usize))
.unwrap_or(u32::MAX);
let overlap_blocks = selection.overlap_blocks;
self.slots self.slots
.add_request( .add_request(
SequenceRequest { SequenceRequest {
...@@ -526,10 +532,14 @@ impl OfflineReplayRouter { ...@@ -526,10 +532,14 @@ impl OfflineReplayRouter {
) )
.map_err(anyhow::Error::from)?; .map_err(anyhow::Error::from)?;
Ok(worker_idx) Ok(AdmitOutcome {
worker_idx,
overlap_blocks,
isl_blocks,
})
} }
fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<(Uuid, usize)>> { fn drain_pending(&mut self, decay_now: Instant) -> Result<Vec<WorkerAdmission>> {
let Some(threshold) = self.queue_threshold else { let Some(threshold) = self.queue_threshold else {
return Ok(Vec::new()); return Ok(Vec::new());
}; };
...@@ -540,8 +550,13 @@ impl OfflineReplayRouter { ...@@ -540,8 +550,13 @@ impl OfflineReplayRouter {
break; break;
}; };
let uuid = request.uuid; let uuid = request.uuid;
let worker_idx = self.admit_request(request, decay_now)?; let outcome = self.admit_request(request, decay_now)?;
admissions.push((uuid, worker_idx)); admissions.push(WorkerAdmission {
uuid,
worker_idx: outcome.worker_idx,
overlap_blocks: outcome.overlap_blocks,
isl_blocks: outcome.isl_blocks,
});
} }
Ok(admissions) Ok(admissions)
...@@ -784,6 +799,8 @@ mod tests { ...@@ -784,6 +799,8 @@ mod tests {
vec![WorkerAdmission { vec![WorkerAdmission {
uuid: Uuid::from_u128(1), uuid: Uuid::from_u128(1),
worker_idx: 3, worker_idx: 3,
overlap_blocks: 0,
isl_blocks: 1,
}] }]
); );
} }
...@@ -836,6 +853,8 @@ mod tests { ...@@ -836,6 +853,8 @@ mod tests {
vec![WorkerAdmission { vec![WorkerAdmission {
uuid: Uuid::from_u128(2), uuid: Uuid::from_u128(2),
worker_idx: 1, worker_idx: 1,
overlap_blocks: 0,
isl_blocks: 1,
}] }]
); );
assert_eq!(router.pending_count(), 0); assert_eq!(router.pending_count(), 0);
......
...@@ -25,6 +25,13 @@ pub(in crate::replay::offline) enum EnginePassMode { ...@@ -25,6 +25,13 @@ pub(in crate::replay::offline) enum EnginePassMode {
pub(crate) struct WorkerAdmission { pub(crate) struct WorkerAdmission {
pub(crate) uuid: Uuid, pub(crate) uuid: Uuid,
pub(crate) worker_idx: usize, pub(crate) worker_idx: usize,
/// Number of blocks the router matched against the prefix cache at
/// admission time. Used by the traffic accumulator to derive an
/// average KV hit rate for the planner.
pub(crate) overlap_blocks: u32,
/// Total ISL expressed in blocks (ceil(isl_tokens / block_size)),
/// paired with ``overlap_blocks`` for the hit-rate ratio.
pub(crate) isl_blocks: u32,
} }
#[derive(Debug)] #[derive(Debug)]
...@@ -79,10 +86,21 @@ pub struct TrafficStats { ...@@ -79,10 +86,21 @@ pub struct TrafficStats {
pub avg_osl: f64, pub avg_osl: f64,
pub avg_ttft_ms: f64, pub avg_ttft_ms: f64,
pub avg_itl_ms: f64, pub avg_itl_ms: f64,
/// Mean prefix-cache hit rate (0.0-1.0) across router admissions in
/// the window, computed as ``mean(overlap_blocks / isl_blocks)`` over
/// admitted requests (i.e. the arithmetic mean of per-request
/// ratios). Matches the semantics of the real router's
/// ``dynamo_component_router_kv_hit_rate`` Prometheus histogram,
/// which observes one ``overlap/isl`` sample per request; the
/// PromQL query ``sum(increase(_sum)) / sum(increase(_count))``
/// returns the arithmetic mean of those samples, independent of
/// per-request ISL size.
pub avg_kv_hit_rate: f64,
} }
/// Accumulates traffic statistics between planner ticks for deriving /// Accumulates traffic statistics between planner ticks for deriving
/// `TrafficObservation` (num_req, avg ISL, avg OSL over a window). /// `TrafficObservation` (num_req, avg ISL, avg OSL, avg latencies, avg
/// KV hit rate over a window).
/// ///
/// Latency samples are tracked independently of request counts: a request /// Latency samples are tracked independently of request counts: a request
/// only contributes to ``total_ttft_ms`` / ``ttft_count`` if a positive TTFT /// only contributes to ``total_ttft_ms`` / ``ttft_count`` if a positive TTFT
...@@ -90,6 +108,12 @@ pub struct TrafficStats { ...@@ -90,6 +108,12 @@ pub struct TrafficStats {
/// ``avg_itl_ms`` reflect only requests that actually produced the sample, /// ``avg_itl_ms`` reflect only requests that actually produced the sample,
/// rather than silently underestimating when some requests lack latency /// rather than silently underestimating when some requests lack latency
/// data (e.g. requests that fail before emitting a token). /// data (e.g. requests that fail before emitting a token).
///
/// KV hit-rate observations come from the router at admission time (not
/// completion) and are recorded as per-request ratios, matching the real
/// router's per-request histogram: each admission contributes one
/// ``overlap_blocks / isl_blocks`` sample to the running mean, so large
/// requests don't get weighted more heavily than small ones.
#[derive(Debug)] #[derive(Debug)]
pub(in crate::replay::offline) struct TrafficAccumulator { pub(in crate::replay::offline) struct TrafficAccumulator {
window_start_ms: f64, window_start_ms: f64,
...@@ -100,6 +124,11 @@ pub(in crate::replay::offline) struct TrafficAccumulator { ...@@ -100,6 +124,11 @@ pub(in crate::replay::offline) struct TrafficAccumulator {
total_itl_ms: f64, total_itl_ms: f64,
ttft_count: usize, ttft_count: usize,
itl_count: usize, itl_count: usize,
/// Running sum of per-request hit-rate ratios (``overlap / isl``);
/// divided by ``hit_rate_count`` at drain time to give the mean.
total_hit_rate: f64,
/// Number of admissions with non-zero ISL blocks in the current window.
hit_rate_count: usize,
} }
impl TrafficAccumulator { impl TrafficAccumulator {
...@@ -113,6 +142,8 @@ impl TrafficAccumulator { ...@@ -113,6 +142,8 @@ impl TrafficAccumulator {
total_itl_ms: 0.0, total_itl_ms: 0.0,
ttft_count: 0, ttft_count: 0,
itl_count: 0, itl_count: 0,
total_hit_rate: 0.0,
hit_rate_count: 0,
} }
} }
...@@ -138,6 +169,26 @@ impl TrafficAccumulator { ...@@ -138,6 +169,26 @@ impl TrafficAccumulator {
} }
} }
/// Record one router admission's prefix-cache overlap as a
/// per-request ratio. Called at admission time (not completion) so
/// the mean hit rate reflects the router's view at routing decision
/// — matching the real router's per-request histogram, where each
/// request contributes exactly one ``overlap/isl`` sample.
/// Admissions with ``isl_blocks == 0`` are skipped (no meaningful
/// ratio), mirroring ``RequestTracker::kv_hit_rate()`` returning
/// ``None`` in that case.
pub(in crate::replay::offline) fn on_admission(
&mut self,
overlap_blocks: u32,
isl_blocks: u32,
) {
if isl_blocks == 0 {
return;
}
self.total_hit_rate += f64::from(overlap_blocks) / f64::from(isl_blocks);
self.hit_rate_count += 1;
}
/// Drain the accumulator at the given simulated time, resetting counters. /// Drain the accumulator at the given simulated time, resetting counters.
pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> TrafficStats { pub(in crate::replay::offline) fn drain(&mut self, now_ms: f64) -> TrafficStats {
let duration_s = (now_ms - self.window_start_ms) / 1000.0; let duration_s = (now_ms - self.window_start_ms) / 1000.0;
...@@ -162,6 +213,11 @@ impl TrafficAccumulator { ...@@ -162,6 +213,11 @@ impl TrafficAccumulator {
} else { } else {
0.0 0.0
}; };
let avg_kv_hit_rate = if self.hit_rate_count > 0 {
self.total_hit_rate / self.hit_rate_count as f64
} else {
0.0
};
self.window_start_ms = now_ms; self.window_start_ms = now_ms;
self.num_req = 0; self.num_req = 0;
self.total_isl = 0; self.total_isl = 0;
...@@ -170,6 +226,8 @@ impl TrafficAccumulator { ...@@ -170,6 +226,8 @@ impl TrafficAccumulator {
self.total_itl_ms = 0.0; self.total_itl_ms = 0.0;
self.ttft_count = 0; self.ttft_count = 0;
self.itl_count = 0; self.itl_count = 0;
self.total_hit_rate = 0.0;
self.hit_rate_count = 0;
TrafficStats { TrafficStats {
duration_s, duration_s,
num_req, num_req,
...@@ -177,6 +235,64 @@ impl TrafficAccumulator { ...@@ -177,6 +235,64 @@ impl TrafficAccumulator {
avg_osl, avg_osl,
avg_ttft_ms, avg_ttft_ms,
avg_itl_ms, avg_itl_ms,
avg_kv_hit_rate,
}
} }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn traffic_accumulator_drain_with_no_admissions_reports_zero_hit_rate() {
let mut acc = TrafficAccumulator::new();
acc.on_request(100, 50, None);
let stats = acc.drain(1_000.0);
assert_eq!(stats.num_req, 1);
assert!((stats.avg_isl - 100.0).abs() < 1e-9);
assert!((stats.avg_osl - 50.0).abs() < 1e-9);
assert_eq!(stats.avg_kv_hit_rate, 0.0);
}
#[test]
fn traffic_accumulator_hit_rate_is_mean_of_per_request_ratios() {
let mut acc = TrafficAccumulator::new();
// Small request: mostly hit. Big request: no hit.
acc.on_admission(3, 4); // per-request ratio: 0.75
acc.on_admission(0, 12); // per-request ratio: 0.0
acc.on_request(256, 32, None);
acc.on_request(768, 32, None);
let stats = acc.drain(1_000.0);
assert_eq!(stats.num_req, 2);
// Per-request mean matches the real router's Prometheus histogram:
// (0.75 + 0.0) / 2 = 0.375. Every request contributes one sample
// regardless of ISL size, so large requests don't dominate.
assert!((stats.avg_kv_hit_rate - 0.375).abs() < 1e-9);
}
#[test]
fn traffic_accumulator_skips_admissions_with_zero_isl_blocks() {
let mut acc = TrafficAccumulator::new();
acc.on_admission(0, 0); // skipped -- no meaningful ratio
acc.on_admission(2, 4); // ratio = 0.5
let stats = acc.drain(1_000.0);
// Only the non-zero-ISL sample counts toward the mean.
assert!((stats.avg_kv_hit_rate - 0.5).abs() < 1e-9);
}
#[test]
fn traffic_accumulator_resets_counters_on_drain() {
let mut acc = TrafficAccumulator::new();
acc.on_admission(5, 10);
acc.on_request(100, 50, None);
let _ = acc.drain(1_000.0);
// Second drain on the same accumulator should see no state carried over.
let stats = acc.drain(2_000.0);
assert!((stats.duration_s - 1.0).abs() < 1e-9);
assert_eq!(stats.num_req, 0);
assert_eq!(stats.avg_isl, 0.0);
assert_eq!(stats.avg_osl, 0.0);
assert_eq!(stats.avg_kv_hit_rate, 0.0);
} }
} }
...@@ -303,7 +303,14 @@ impl DisaggRuntime { ...@@ -303,7 +303,14 @@ impl DisaggRuntime {
/// Turn prefill router admissions into concrete worker dispatches. /// Turn prefill router admissions into concrete worker dispatches.
fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> { fn dispatch_prefill_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions { for WorkerAdmission {
uuid,
worker_idx,
overlap_blocks,
isl_blocks,
} in admissions
{
self.traffic.on_admission(overlap_blocks, isl_blocks);
if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill { if self.state(uuid)?.phase != DisaggPhase::QueuedPrefill {
bail!("offline disagg replay expected queued prefill request for {uuid}"); bail!("offline disagg replay expected queued prefill request for {uuid}");
} }
...@@ -313,8 +320,16 @@ impl DisaggRuntime { ...@@ -313,8 +320,16 @@ impl DisaggRuntime {
} }
/// Turn decode router admissions into concrete worker dispatches. /// Turn decode router admissions into concrete worker dispatches.
///
/// Note: only the prefill router's admissions are fed to
/// ``traffic.on_admission``; decode-router admissions reflect the
/// same requests re-routing after prefill completes and would double
/// count overlap observations.
fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> { fn dispatch_decode_admissions(&mut self, admissions: Vec<WorkerAdmission>) -> Result<()> {
for WorkerAdmission { uuid, worker_idx } in admissions { for WorkerAdmission {
uuid, worker_idx, ..
} in admissions
{
if self.state(uuid)?.phase != DisaggPhase::QueuedDecode { if self.state(uuid)?.phase != DisaggPhase::QueuedDecode {
bail!("offline disagg replay expected queued decode request for {uuid}"); bail!("offline disagg replay expected queued decode request for {uuid}");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment