Unverified Commit a58bcc31 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: load planner using new forwardpass metric and many improvements (#7351)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent db14d63f
......@@ -7,11 +7,26 @@ Receive ForwardPassMetrics via the Dynamo event plane.
Auto-discovers engine publishers through the discovery plane (K8s CRD /
etcd / file) and prints each metric message as JSON.
Supports two modes:
- **recv** (default): pull individual messages one at a time.
- **tracking**: periodically poll ``get_recent_stats()`` to print the
latest snapshot keyed by ``(worker_id, dp_rank)``.
Usage:
# recv mode (default)
python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate
# tracking mode (poll every 2 seconds)
python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate \\
[--discovery-backend etcd] [--request-plane nats] \\
[--save-plot metrics.png]
--mode tracking --poll-interval 2.0
# recv mode with plot saving
python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate \\
--save-plot metrics.png
"""
import argparse
......@@ -94,11 +109,24 @@ def main() -> None:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Request plane (default: nats)",
)
parser.add_argument(
"--mode",
choices=["recv", "tracking"],
default="recv",
help="Consumption mode: 'recv' for individual messages, "
"'tracking' for latest-snapshot polling (default: recv)",
)
parser.add_argument(
"--poll-interval",
type=float,
default=2.0,
help="Polling interval in seconds for tracking mode (default: 2.0)",
)
parser.add_argument(
"--save-plot",
metavar="PATH",
default=None,
help="Save a time-series plot to the given PNG path on exit",
help="Save a time-series plot to the given PNG path on exit (recv mode only)",
)
args = parser.parse_args()
......@@ -117,15 +145,29 @@ async def run(args: argparse.Namespace) -> None:
endpoint = runtime.endpoint(f"{args.namespace}.{args.component}.{args.endpoint}")
subscriber = FpmEventSubscriber(endpoint)
json_encoder = msgspec.json.Encoder()
logger.info(
"Subscribed to forward-pass-metrics via event plane "
"(namespace=%s, component=%s) Ctrl+C to stop",
"(namespace=%s, component=%s, mode=%s) Ctrl+C to stop",
args.namespace,
args.component,
args.mode,
)
try:
if args.mode == "tracking":
await _run_tracking(subscriber, args)
else:
await _run_recv(subscriber, args)
except KeyboardInterrupt:
logger.info("Stopped.")
finally:
subscriber.shutdown()
async def _run_recv(subscriber, args: argparse.Namespace) -> None:
"""Pull individual FPM messages and print each as JSON."""
json_encoder = msgspec.json.Encoder()
history: list[tuple[float, ForwardPassMetrics]] = []
start_time: float | None = None
......@@ -154,13 +196,43 @@ async def run(args: argparse.Namespace) -> None:
metrics.counter_id,
json.dumps(pretty, indent=2),
)
except KeyboardInterrupt:
logger.info("Stopped.")
finally:
subscriber.shutdown()
if args.save_plot and history:
_save_plot(args.save_plot, history)
async def _run_tracking(subscriber, args: argparse.Namespace) -> None:
"""Poll get_recent_stats() and print the latest snapshot periodically."""
json_encoder = msgspec.json.Encoder()
subscriber.start_tracking()
logger.info("Tracking mode started (poll every %.1fs)", args.poll_interval)
poll = 0
while True:
await asyncio.sleep(args.poll_interval)
stats = subscriber.get_recent_stats()
if not stats:
logger.info("[poll=%d] (no engines tracked)", poll)
else:
snapshot = {}
for (worker_id, dp_rank), raw_bytes in stats.items():
metrics = decode(raw_bytes)
if metrics is None:
continue
key = f"{worker_id}:dp{dp_rank}"
snapshot[key] = json.loads(json_encoder.encode(metrics))
ts = time.strftime("%H:%M:%S")
logger.info(
"[poll=%d t=%s engines=%d] %s",
poll,
ts,
len(stats),
json.dumps(snapshot, indent=2),
)
poll += 1
if __name__ == "__main__":
main()
......@@ -27,15 +27,15 @@ The SLA Planner supports two scaling modes that can be used independently or tog
Uses pre-deployment profiling data and traffic prediction to compute the number of prefill/decode replicas needed to meet TTFT and ITL SLA targets. Requires profiling data from the Dynamo profiler.
### Load-Based Scaling (Experimental)
### Load-Based Scaling
Uses real-time per-worker load metrics (active prefill tokens, active KV blocks) from the router to make SLA-aware scaling decisions via online linear regression. Does not require profiling data. Responds quickly to traffic bursts.
Uses ForwardPassMetrics (FPM) from the Dynamo event plane to make SLA-aware scaling decisions via online linear regression. Does not require profiling data or the KV Router. Responds quickly to traffic bursts. Currently only supported with vLLM (FPM only available in vllm).
When both modes are enabled, throughput-based scaling provides a lower bound on replicas while load-based scaling handles real-time adjustments.
### Support Matrix
| Deployment Type | Throughput-Based | Load-Based (Experimental) |
| Deployment Type | Throughput-Based | Load-Based |
|-----------------|:----------------:|:-------------------------:|
| Disaggregated | Supported | Supported |
| Aggregated | Unsupported | Supported |
......
......@@ -9,6 +9,7 @@ __all__ = [
"SLAPlannerDefaults",
"TargetReplica",
"SubComponentType",
"WorkerInfo",
]
# Import the classes
from dynamo.planner.defaults import SLAPlannerDefaults, SubComponentType
......@@ -16,6 +17,7 @@ from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.virtual_connector import VirtualConnector
from dynamo.planner.worker_info import WorkerInfo
try:
from ._version import __version__
......
......@@ -29,10 +29,6 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__)
# start planner 30 seconds after the other components to make sure planner can see them
# TODO: remove this delay
INIT_PLANNER_START_DELAY = 30
class RequestType(BaseModel):
text: str
......@@ -51,21 +47,29 @@ async def start_planner(runtime: DistributedRuntime, config: PlannerConfig):
planner = AggPlanner(runtime, config)
else:
raise ValueError(f"Invalid planner mode: {mode}")
await planner._async_init()
await planner.run()
async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await asyncio.sleep(INIT_PLANNER_START_DELAY)
await start_planner(runtime, config)
await planner._async_init()
async def generate(request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint"
generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate)
# serve_endpoint registers a health check target and sets HealthStatus::Ready
# once the handler is registered with the transport. Running it concurrently
# with planner.run() ensures the system server (/health, /live) reports the
# planner as ready only after _async_init() has completed.
await asyncio.gather(
generate_endpoint.serve_endpoint(
generate, # type: ignore[arg-type]
health_check_payload={"text": "health"},
),
planner.run(),
)
async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await start_planner(runtime, config)
def _parse_config() -> PlannerConfig:
......
......@@ -80,9 +80,6 @@ class SLAPlannerDefaults(BasePlannerDefaults):
enable_load_scaling = False
# Load-based scaling settings
load_router_metrics_url: Optional[
str
] = None # will be auto-discovered from the DGD in kubernetes mode if not provided
load_adjustment_interval = 5 # in seconds, must be < throughput_adjustment_interval
load_learning_window = 50 # sliding window size for regression
load_scaling_down_sensitivity = 80 # 0-100
......
......@@ -196,7 +196,7 @@ class GlobalPlannerConnector(PlannerConnector):
"(GlobalPlanner will validate on its side)"
)
async def wait_for_deployment_ready(self):
async def wait_for_deployment_ready(self, include_planner: bool = True):
"""
Wait for deployment to be ready (no-op for GlobalPlanner).
......
......@@ -204,31 +204,60 @@ class KubernetesAPI:
async def wait_for_graph_deployment_ready(
self,
graph_deployment_name: str,
include_planner: bool = True,
max_attempts: int = 180, # default: 30 minutes total
delay_seconds: int = 10, # default: check every 10 seconds
) -> None:
"""Wait for a graph deployment to be ready"""
"""Wait for a graph deployment to be ready.
Args:
graph_deployment_name: Name of the DGD to wait for.
include_planner: If False, skip services with componentType "planner"
and check per-service readiness instead of the global DGD Ready
condition. This avoids a circular wait when the planner itself
is one of the services in the DGD.
max_attempts: Maximum polling iterations.
delay_seconds: Seconds between polls.
"""
for attempt in range(max_attempts):
await asyncio.sleep(delay_seconds)
graph_deployment = self.get_graph_deployment(graph_deployment_name)
if include_planner:
conditions = graph_deployment.get("status", {}).get("conditions", [])
ready_condition = next(
(c for c in conditions if c.get("type") == "Ready"), None
)
if ready_condition and ready_condition.get("status") == "True":
return # Deployment is ready
return
logger.info(
f"[Attempt {attempt + 1}/{max_attempts}] "
f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, "
f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})"
)
else:
services = graph_deployment.get("spec", {}).get("services", {})
not_ready: list[str] = []
for svc_name, svc_spec in services.items():
if svc_spec.get("componentType", "") == "planner":
continue
_, is_stable = self.get_service_replica_status(
graph_deployment, svc_name
)
if not is_stable:
not_ready.append(svc_name)
if not not_ready:
return
logger.info(
f"[Attempt {attempt + 1}/{max_attempts}] "
f"Waiting for services (excluding planner): "
f"not ready: {not_ready}"
)
# Raise after all attempts exhausted (without additional delay)
raise TimeoutError(
f"Graph deployment '{graph_deployment_name}' "
f"is not ready after {max_attempts * delay_seconds} seconds"
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from typing import Optional
......@@ -33,6 +34,7 @@ from dynamo.planner.utils.exceptions import (
PlannerError,
UserProvidedModelNameMismatchError,
)
from dynamo.planner.worker_info import WorkerInfo, build_worker_info_from_defaults
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
......@@ -307,11 +309,275 @@ class KubernetesConnector(PlannerConnector):
return None
async def wait_for_deployment_ready(self):
"""Wait for the deployment to be ready"""
async def wait_for_deployment_ready(self, include_planner: bool = True):
"""Wait for the deployment to be ready.
Args:
include_planner: If False, skip the planner service when checking
readiness. This lets the planner read MDC from worker pods
without waiting for itself to be marked ready in the DGD.
"""
await self.kube_api.wait_for_graph_deployment_ready(
self.graph_deployment_name,
include_planner=include_planner,
)
def _list_worker_metadata_crs(self) -> list[dict]:
"""List all DynamoWorkerMetadata CRs in the current namespace.
Returns an empty list only when the CRD is not yet installed (404).
Other API errors (RBAC, connectivity) are re-raised so callers can
handle them explicitly.
"""
from kubernetes.client import ApiException
try:
result = self.kube_api.custom_api.list_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.kube_api.current_namespace,
plural="dynamoworkermetadatas",
)
return result.get("items", [])
except ApiException as e:
if e.status == 404:
logger.info("DynamoWorkerMetadata CRD not found, skipping MDC")
return []
raise
def _extract_mdc_entries(
self,
) -> list[dict]:
"""Extract MDC entries belonging to this DGD.
CRs are named after the worker pod (e.g. ``<dgd>-0-<service>-<hash>``),
so we filter by the DGD name prefix to avoid picking up entries from
other deployments sharing the namespace.
Returns a list of dicts, each containing:
namespace, component, endpoint, instance_id, card_json
"""
crs = self._list_worker_metadata_crs()
dgd_prefix = f"{self.graph_deployment_name}-"
entries: list[dict] = []
for cr in crs:
cr_name = cr.get("metadata", {}).get("name", "")
if not cr_name.startswith(dgd_prefix):
continue
data = cr.get("spec", {}).get("data", {})
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
continue
model_cards = data.get("model_cards", {})
for _key, instance in model_cards.items():
if instance.get("type") != "Model":
continue
entries.append(instance)
return entries
def _mdc_entry_is_prefill(self, entry: dict) -> bool:
"""Check if an MDC entry is a prefill worker.
model_type can be serialized as:
- An integer bitflag (ModelType::Prefill = 1 << 4 = 16)
- A dict with a "bits" key (serde bitflags format)
- A string like "Prefill" or "Chat|Completions"
"""
card = entry.get("card_json", {})
model_type = card.get("model_type", 0)
if isinstance(model_type, str):
return "prefill" in model_type.lower()
if isinstance(model_type, dict):
model_type = model_type.get("bits", 0)
return bool(model_type & 0x10)
def _build_worker_info_from_mdc(
self,
entry: dict,
sub_component_type: SubComponentType,
) -> WorkerInfo:
"""Build a WorkerInfo from an MDC entry, applying the fallback chain.
Priority: MDC -> DGD container-arg parsing -> hard-coded defaults.
"""
defaults = build_worker_info_from_defaults(
self._backend_hint or "vllm", sub_component_type
)
card = entry.get("card_json", {})
runtime_cfg = card.get("runtime_config", {})
# --- component / endpoint names from MDC wrapper ---
mdc_component = entry.get("component")
mdc_endpoint = entry.get("endpoint")
component_name = mdc_component or defaults.component_name
endpoint = mdc_endpoint or defaults.endpoint
if not mdc_component:
logger.info(
f"MDC missing 'component' for {sub_component_type.value}, "
f"falling back to default: {defaults.component_name}"
)
if not mdc_endpoint:
logger.info(
f"MDC missing 'endpoint' for {sub_component_type.value}, "
f"falling back to default: {defaults.endpoint}"
)
# --- model name ---
mdc_model = card.get("display_name")
model_name = mdc_model
if not model_name:
# Fallback: parse from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
model_name = service.get_model_name()
if model_name:
logger.info(
f"MDC missing model name for {sub_component_type.value}, "
f"fell back to DGD container args: {model_name}"
)
except PlannerError:
pass
if not model_name:
logger.warning(
f"Could not determine model name for {sub_component_type.value} "
f"from MDC or DGD container args"
)
# --- runtime config fields ---
total_kv_blocks = runtime_cfg.get("total_kv_blocks")
max_num_seqs = runtime_cfg.get("max_num_seqs")
max_num_batched_tokens = runtime_cfg.get("max_num_batched_tokens")
kv_cache_block_size = card.get("kv_cache_block_size")
context_length = card.get("context_length")
if total_kv_blocks is None:
logger.info(f"MDC missing total_kv_blocks for {sub_component_type.value}")
if max_num_seqs is None:
logger.info(f"MDC missing max_num_seqs for {sub_component_type.value}")
# --- k8s_name: resolve from DGD subComponentType ---
k8s_name = defaults.k8s_name
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
k8s_name = service.name
except PlannerError:
logger.info(
f"Could not resolve k8s service name for {sub_component_type.value}, "
f"using default: {defaults.k8s_name}"
)
info = WorkerInfo(
k8s_name=k8s_name,
component_name=component_name,
endpoint=endpoint,
model_name=model_name,
total_kv_blocks=total_kv_blocks,
kv_cache_block_size=kv_cache_block_size,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
context_length=context_length,
)
return info
def get_worker_info(
self,
sub_component_type: SubComponentType,
backend: str = "vllm",
) -> WorkerInfo:
"""Get WorkerInfo for a sub-component, trying MDC first, then fallbacks.
Args:
sub_component_type: PREFILL or DECODE
backend: Backend framework name (for default fallback)
"""
self._backend_hint = backend
entries = self._extract_mdc_entries()
# Resolve the expected component name so we can scope card selection
# and avoid picking up LoRA-adapter cards that share the same CR but
# carry a different component/model identity.
expected_component: Optional[str] = None
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
expected_component = service.name
except PlannerError:
expected_component = build_worker_info_from_defaults(
backend, sub_component_type
).component_name
for entry in entries:
is_prefill = self._mdc_entry_is_prefill(entry)
if sub_component_type == SubComponentType.PREFILL and not is_prefill:
continue
if sub_component_type == SubComponentType.DECODE and is_prefill:
continue
entry_component = entry.get("component")
if (
entry_component
and expected_component
and entry_component != expected_component
):
logger.debug(
f"Skipping MDC entry with component={entry_component!r}, "
f"expected {expected_component!r} for {sub_component_type.value}"
)
continue
info = self._build_worker_info_from_mdc(entry, sub_component_type)
logger.info(
f"Built {sub_component_type.value} WorkerInfo from MDC: "
f"{info.summary()}"
)
return info
# No MDC entry found -- fall back entirely to defaults + DGD arg parsing
logger.warning(
f"No DynamoWorkerMetadata CR found for {sub_component_type.value}. "
f"Workers may not be registered yet. Falling back to defaults."
)
info = build_worker_info_from_defaults(backend, sub_component_type)
# Try to enrich model_name from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
info.k8s_name = service.name
arg_model = service.get_model_name()
if arg_model:
info.model_name = arg_model
logger.info(
f"Enriched {sub_component_type.value} WorkerInfo model name "
f"from DGD args: {arg_model}"
)
except PlannerError as e:
logger.info(
f"Could not enrich WorkerInfo from DGD for {sub_component_type.value}: {e}"
)
logger.info(
f"Using fallback WorkerInfo for {sub_component_type.value}: {info.summary()}"
)
return info
def get_actual_worker_counts(
self,
......
......@@ -3,11 +3,14 @@
import asyncio
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.utils.load_based_regression import LoadBasedRegressionModel
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.utils.planner_config import PlannerConfig
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.utils.planner_core import (
BasePlanner,
PlannerPrometheusMetrics,
......@@ -15,7 +18,6 @@ from dynamo.planner.utils.planner_core import (
_apply_component_gpu_budget,
_initialize_gpu_counts,
)
from dynamo.planner.utils.prometheus import CachedLoadMetrics
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -24,24 +26,24 @@ logger = logging.getLogger(__name__)
class AggPlanner:
"""Aggregated planner: load-based scaling only, single engine type.
"""Aggregated planner: FPM-driven load-based scaling, single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill).
Engine metrics are labeled "decode" by the router.
A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression.
Scaling logic:
- TTFT and ITL regression models are both maintained.
- Regression uses per-worker time-averaged metrics (not latest snapshot)
because chunked prefill adds noise to instantaneous TTFT/ITL.
- Scale up if either prefill or decode target is exceeded.
- Scale down if both prefill and decode are below their boundaries.
- Estimate next TTFT per engine by simulating prefill chunking with
piggybacked decode (steady-state decode load).
- Estimate next ITL per engine by predicting decode iteration time with
average piggybacked prefill load.
- Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
"""
# Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE = "decode"
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.runtime = runtime
self.shared_state = PlannerSharedState()
if config.enable_throughput_scaling:
......@@ -56,8 +58,6 @@ class AggPlanner:
prometheus_metrics = PlannerPrometheusMetrics()
# Use a single BasePlanner instance for infra (connector, prometheus, etc.)
# We use DECODE component_type because engine metrics are labeled "decode"
self.planner = BasePlanner(
runtime,
config,
......@@ -67,28 +67,27 @@ class AggPlanner:
component_type=SubComponentType.DECODE,
)
# Create both regression models (agg needs both TTFT and ITL)
self.ttft_regression = LoadBasedRegressionModel(
window_size=config.load_learning_window,
min_observations=config.load_min_observations,
)
self.itl_regression = LoadBasedRegressionModel(
from dynamo.planner.utils.fpm_regression import AggRegressionModel
self.regression = AggRegressionModel(
window_size=config.load_learning_window,
min_observations=config.load_min_observations,
)
self.cached_load_metrics = CachedLoadMetrics()
async def _async_init(self):
await self.planner._async_init()
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
async def run(self):
if not self.config.no_operation:
connector = getattr(self.planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
# Agg mode: only decode component exists (engines serve both P and D)
await self.planner.connector.validate_deployment(
prefill_component_name=None,
decode_component_name=self.planner.decode_component_name,
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=False,
require_decode=True,
)
......@@ -101,208 +100,102 @@ class AggPlanner:
require_decode=True,
)
await self.planner.connector.wait_for_deployment_ready()
# Model name discovery runs in all modes (needed for metrics collection)
if not self.config.no_operation:
model_name = await self.planner._get_model_name(
require_prefill=False, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}")
self.planner.model_name = model_name.lower()
else:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
await self.planner.connector.wait_for_deployment_ready(
include_planner=False
)
self.planner.model_name = self.config.model_name.lower()
loops = [
self._load_loop(),
self.planner.prometheus_engine_client.run_sampling_loop(
self.config.load_metric_samples,
self.config.load_adjustment_interval,
),
]
await asyncio.gather(*loops)
await self.planner._init_worker_info(require_prefill=False, require_decode=True)
async def _observe_engine_load_stats(self) -> None:
"""Fetch metrics and update regression models using per-worker time-averaged data."""
result = self.planner.prometheus_engine_client.get_recent_and_averaged_metrics(
self.ENGINE_WORKER_TYPE
)
if result is None:
logger.warning(
f"No per-worker metrics available yet for {self.ENGINE_WORKER_TYPE} (buffer empty)"
)
return
# Delegate FPM tracking to the inner BasePlanner (component_type=DECODE).
if self.runtime is not None:
await self.planner._init_fpm_subscriber()
recent, per_worker_averaged, cluster_averaged = result
self.cached_load_metrics = CachedLoadMetrics(
recent=recent,
per_worker_averaged=per_worker_averaged,
cluster_averaged=cluster_averaged,
)
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
await asyncio.gather(self._load_loop())
# Agg uses per-worker time-averaged metrics for regression
# because chunked prefill adds noise to instantaneous TTFT/ITL
for wid, m in per_worker_averaged.items():
# TTFT regression: (active_prefill_tokens + ISL) -> TTFT
active_prefill = m.get("active_prefill_tokens", 0.0)
last_isl = m.get("last_isl", 0.0)
last_ttft = m.get("last_ttft", 0.0)
if last_ttft > 0 and last_isl > 0:
x = active_prefill + last_isl
y = last_ttft * 1000 # seconds -> ms
logger.info(
f"Agg Worker {wid} prefill observation: TTFT {y:.2f}ms @ tokens {x:.2f}"
)
self.ttft_regression.add_observation(x, y)
# ITL regression: active_decode_blocks -> ITL
active_decode = m.get("active_decode_blocks", 0.0)
last_itl = m.get("last_itl", 0.0)
if last_itl > 0 and active_decode > 0:
x = active_decode
y = last_itl * 1000 # seconds -> ms
logger.info(
f"Agg Worker {wid} decode observation: ITL {y:.2f}ms @ blocks {x:.2f}"
)
self.itl_regression.add_observation(x, y)
async def _load_loop(self) -> None:
"""FPM-driven load-based scaling loop for aggregated mode."""
pending_desired: Optional[int] = None
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load-based adjustment interval started!")
def _prefill_scaling_decision(self, num_workers: int) -> Optional[str]:
"""Returns "up", "down", or None for prefill dimension."""
if not self.cached_load_metrics.recent:
return None
if not self.ttft_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
f"/{self.ttft_regression.min_observations}), skipping"
_, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True
)
return None
self.shared_state.num_d_workers = num_d
num_workers = num_d
x_sla = self.ttft_regression.predict_x_from_sla(self.config.ttft)
if x_sla is None:
return None
# Always observe FPM stats and update regression, even during scaling.
fpm_stats = self.planner._get_fpm_stats()
if not fpm_stats:
logger.warning("No FPM data available for agg engines")
continue
recent = self.cached_load_metrics.recent
cluster_averaged = self.cached_load_metrics.cluster_averaged
avg_isl = cluster_averaged.get("last_isl", 0.0)
target = x_sla - avg_isl
for (wid, dp), fpm in fpm_stats.items():
BasePlanner._log_fpm(wid, dp, fpm, "agg")
self.regression.add_observation(fpm)
if target <= 0:
logger.warning(
f"Agg TTFT SLA unachievable at current ISL: x_sla={x_sla:.1f}, "
f"avg_isl={avg_isl:.1f}, skipping prefill scaling decision"
# If a previous scaling action is still in progress, skip decisions.
if pending_desired is not None:
if num_workers == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
return None
pending_desired = None
else:
logger.info(
f"Agg prefill: x_sla={x_sla:.1f}, avg_isl={avg_isl:.1f}, "
f"target_active_tokens={target:.1f}, workers={num_workers}"
f"Scaling in progress ({num_workers} -> {pending_desired}), "
"observing only"
)
continue
# Scale up: ALL workers above target
if all(m.get("active_prefill_tokens", 0.0) > target for m in recent.values()):
return "up"
# Scale down: ALL workers below boundary
if num_workers > self.config.min_endpoint:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = target * (num_workers - 1) / num_workers * sensitivity
if all(
m.get("active_prefill_tokens", 0.0) < boundary for m in recent.values()
if not BasePlanner._reconcile_fpm_worker_count(
fpm_stats, num_workers, "agg"
):
return "down"
return None
continue
def _decode_scaling_decision(self, num_workers: int) -> Optional[str]:
"""Returns "up", "down", or None for decode dimension."""
if not self.cached_load_metrics.recent:
return None
if not self.itl_regression.has_sufficient_data():
if not self.regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
f"/{self.itl_regression.min_observations}), skipping"
f"Agg regression: insufficient data "
f"({self.regression.num_observations}/{self.regression.min_observations})"
)
return None
x_sla = self.itl_regression.predict_x_from_sla(self.config.itl)
if x_sla is None:
return None
if x_sla <= 0:
logger.warning(
f"Agg ITL SLA unachievable: x_sla={x_sla:.1f}, "
"skipping decode scaling decision"
)
return None
recent = self.cached_load_metrics.recent
logger.info(f"Agg decode: x_sla={x_sla:.1f}, workers={num_workers}")
# Scale up: ALL workers above target
if all(m.get("active_decode_blocks", 0.0) > x_sla for m in recent.values()):
return "up"
# Scale down: ALL workers below boundary
# TODO: should we strictly enforce all workers below boundary?
# how about user-configurable percentage?
if num_workers > self.config.min_endpoint:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = x_sla * (num_workers - 1) / num_workers * sensitivity
if all(
m.get("active_decode_blocks", 0.0) < boundary for m in recent.values()
):
return "down"
return None
async def _load_loop(self) -> None:
"""Load-based scaling loop for aggregated mode."""
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load-based adjustment interval started!")
continue
# Query DGD for fresh worker counts
_, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True
max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
self.shared_state.num_d_workers = num_d
num_workers = num_d
# Observe per-worker metrics
await self._observe_engine_load_stats()
# Reconcile worker counts
prom_count = len(self.cached_load_metrics.recent)
if prom_count != num_workers:
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
f"Worker count mismatch: DGD reports {num_workers}, "
f"router metrics reports {prom_count}. Skipping."
"max_num_batched_tokens not available from WorkerInfo, "
"skipping agg scaling"
)
continue
if not self.cached_load_metrics.recent:
continue
# Make scaling decisions separately for prefill and decode
p_decision = self._prefill_scaling_decision(num_workers)
d_decision = self._decode_scaling_decision(num_workers)
p_desired = self._prefill_scaling_decision(
fpm_stats, num_workers, max_num_batched_tokens
)
d_desired = self._decode_scaling_decision(fpm_stats, num_workers)
logger.info(
f"Agg scaling decisions: prefill={p_decision}, decode={d_decision}"
f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} "
f"(current={num_workers})"
)
# Scale up if EITHER needs scale up
# Scale down if BOTH need scale down
if p_decision == "up" or d_decision == "up":
desired = num_workers + 1
elif p_decision == "down" and d_decision == "down":
desired = num_workers - 1
# Scale up if EITHER dimension wants more workers.
# Scale down only if BOTH dimensions agree on fewer.
if p_desired is not None and p_desired > num_workers:
desired = p_desired
elif d_desired is not None and d_desired > num_workers:
desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else:
logger.info("Agg scaling: no scaling needed")
continue
......@@ -322,13 +215,54 @@ class AggPlanner:
self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation:
pending_desired = desired
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_component_name,
component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired,
)
]
await self.planner.connector.set_component_replicas(
target_replicas, blocking=True
target_replicas, blocking=False
)
def _prefill_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
max_num_batched_tokens: int,
) -> Optional[int]:
"""Returns desired replica count for the prefill (TTFT) dimension, or None."""
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_ttfts.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_ttfts, self.config.ttft, num_workers, "agg TTFT"
)
def _decode_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
) -> Optional[int]:
"""Returns desired replica count for the decode (ITL) dimension, or None."""
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_itls.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_itls, self.config.itl, num_workers, "agg ITL"
)
......@@ -17,7 +17,15 @@ class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for decode. Returns desired_replicas or None."""
"""Load-based scaling decision for decode using FPM data.
For each engine, estimates next decode ITL:
- Uses scheduled + queued decode KV tokens + avg decode length
- Predicts wall time via regression
Scale up if ALL engines' estimated ITL > SLA.
Scale down if ALL engines' estimated ITL < SLA * sensitivity.
"""
if not self.itl_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
......@@ -25,64 +33,38 @@ class DecodePlanner(BasePlanner):
)
return None
x_sla = self.itl_regression.predict_x_from_sla(self.config.itl)
if x_sla is None:
return None
if x_sla <= 0:
logger.warning(
f"ITL SLA unachievable: x_sla={x_sla:.1f}, "
"skipping load-based decode scaling"
)
return None
if not self.cached_load_metrics.recent:
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
recent = self.cached_load_metrics.recent
num_workers = self.shared_state.num_d_workers
if num_workers == 0:
return None
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
scheduled_kv = fpm.scheduled_requests.sum_decode_kv_tokens
queued_kv = fpm.queued_requests.sum_decode_kv_tokens
est = self.itl_regression.estimate_next_itl(
scheduled_decode_kv=scheduled_kv,
queued_decode_kv=queued_kv,
)
if est is None:
continue
est_ms = est * 1000
estimated_itls.append(est_ms)
logger.info(
f"Load-based decode: x_sla={x_sla:.1f}, workers={num_workers}, "
f"slope={self.itl_regression.slope:.6f}, intercept={self.itl_regression.intercept:.3f}"
f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
f"(sched_kv={scheduled_kv}, queued_kv={queued_kv}, "
f"avg_decode_len={self.itl_regression.avg_decode_length:.1f})"
)
# Scale up: ALL workers above target (use recent metrics)
all_above = all(
m.get("active_decode_blocks", 0.0) > x_sla for m in recent.values()
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_itls,
sla=self.config.itl,
num_workers=num_workers,
label="decode ITL",
)
if all_above:
logger.info(
f"Load-based decode: ALL workers above target ({x_sla:.1f}), "
f"scaling up to {num_workers + 1}"
)
return num_workers + 1
# Scale down: ALL workers below boundary (use recent metrics)
if num_workers > 1:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = x_sla * (num_workers - 1) / num_workers * sensitivity
all_below = all(
m.get("active_decode_blocks", 0.0) < boundary for m in recent.values()
)
if all_below:
if num_workers - 1 < self.config.min_endpoint:
logger.info(
f"Load-based decode: ALL workers below boundary ({boundary:.1f}), "
f"but cannot scale down below min_endpoint ({self.config.min_endpoint}); "
f"maintaining {num_workers} decode workers"
)
return num_workers
logger.info(
f"Load-based decode: ALL workers below boundary ({boundary:.1f}), "
f"scaling down to {num_workers - 1}"
)
return num_workers - 1
return None
def _update_correction_factor(self) -> bool:
if self.shared_state.num_d_workers == 0:
......
......@@ -6,9 +6,11 @@ import logging
import time
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.utils.decode_planner import DecodePlanner
from dynamo.planner.utils.planner_config import PlannerConfig
from dynamo.planner.utils.planner_core import (
BasePlanner,
PlannerPrometheusMetrics,
PlannerSharedState,
_apply_global_gpu_budget,
......@@ -46,29 +48,34 @@ class DisaggPlanner:
prometheus_traffic_client=getattr(
self.prefill_planner, "prometheus_traffic_client", None
),
prometheus_engine_client=getattr(
self.prefill_planner, "prometheus_engine_client", None
),
connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False,
)
async def _async_init(self):
# Prefill/Decode share the same connector instance in disagg mode.
await self.prefill_planner._async_init()
# DisaggPlanner overrides _async_init to handle both prefill+decode
# and share WorkerInfo between the two sub-planners.
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
async def run(self):
if not self.config.no_operation:
# Connector init (prefill/decode share the same connector)
connector = getattr(self.prefill_planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment(
prefill_component_name=self.prefill_planner.prefill_component_name,
decode_component_name=self.prefill_planner.decode_component_name,
prefill_component_name=(
defaults.prefill_worker_k8s_name if defaults else None
),
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=True,
require_decode=True,
)
logger.info("Successfully validated the deployment")
# Initialize GPU counts
_initialize_gpu_counts(
self.config,
self.prefill_planner.connector,
......@@ -76,40 +83,40 @@ class DisaggPlanner:
require_decode=True,
)
await self.prefill_planner.connector.wait_for_deployment_ready()
await self.prefill_planner.connector.wait_for_deployment_ready(
include_planner=False
)
# Model name discovery runs in all modes (needed for metrics collection)
if not self.config.no_operation:
model_name = await self.prefill_planner._get_model_name(
await self.prefill_planner._init_worker_info(
require_prefill=True, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}")
model_name = model_name.lower()
else:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
# Share WorkerInfo and model name with decode planner
self.decode_planner.prefill_worker_info = (
self.prefill_planner.prefill_worker_info
)
model_name = self.config.model_name.lower()
self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name
self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info
self.decode_planner.model_name = self.prefill_planner.model_name
# Start FPM tracking for both planners. DisaggPlanner bypasses each
# sub-planner's _async_init(), so we init subscribers explicitly here.
if self.enable_load:
if self.prefill_planner.runtime is not None:
await self.prefill_planner._init_fpm_subscriber()
if self.decode_planner.runtime is not None:
await self.decode_planner._init_fpm_subscriber()
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time()
# Build list of concurrent loops based on enabled scaling modes
# FPM tracking (started in _async_init) replaces the former
# DirectRouterMetricsClient.run_sampling_loop().
loops = []
if self.enable_throughput:
loops.append(self._throughput_loop())
if self.enable_load:
loops.append(self._load_loop())
loops.append(
self.prefill_planner.prometheus_engine_client.run_sampling_loop(
self.config.load_metric_samples,
self.config.load_adjustment_interval,
)
)
await asyncio.gather(*loops)
......@@ -156,12 +163,12 @@ class DisaggPlanner:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=next_num_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=next_num_d,
),
]
......@@ -172,34 +179,34 @@ class DisaggPlanner:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_loop(self) -> None:
"""Load-based scaling loop for disagg mode at shorter interval."""
"""FPM-driven load-based scaling loop for disagg mode."""
while True:
await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load-based adjustment interval started!")
# Query DGD for fresh worker counts
num_p, num_d, _ = await self.prefill_planner.get_workers_info(
require_prefill=True, require_decode=True
)
self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d
# Observe per-worker metrics from router
await self.prefill_planner.observe_engine_load_stats()
await self.decode_planner.observe_engine_load_stats()
# Observe FPM stats and feed into regression models
p_stats = self.prefill_planner.observe_fpm_load_stats()
d_stats = self.decode_planner.observe_fpm_load_stats()
# Reconcile DGD worker counts with router Prometheus counts
p_prom_count = len(self.prefill_planner.cached_load_metrics.recent)
d_prom_count = len(self.decode_planner.cached_load_metrics.recent)
if p_prom_count != num_p or d_prom_count != num_d:
logger.warning(
f"Worker count mismatch: DGD reports P={num_p}, D={num_d}; "
f"router metrics reports P={p_prom_count}, D={d_prom_count}. "
"Skipping load-based scaling adjustment."
)
if not p_stats and not d_stats:
logger.warning("No FPM data for either prefill or decode, skipping")
continue
if p_stats and not BasePlanner._reconcile_fpm_worker_count(
p_stats, num_p, "prefill"
):
continue
if d_stats and not BasePlanner._reconcile_fpm_worker_count(
d_stats, num_d, "decode"
):
continue
# Scale prefill and decode independently
p_desired = self.prefill_planner.load_plan_adjustment()
d_desired = self.decode_planner.load_plan_adjustment()
......@@ -241,12 +248,12 @@ class DisaggPlanner:
target_replicas = [
TargetReplica(
sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name,
component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=final_p,
),
TargetReplica(
sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name,
component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=final_d,
),
]
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""FPM-driven regression models for load-based scaling.
Each model takes ForwardPassMetrics observations and estimates per-engine
TTFT or ITL by simulating the scheduler's chunked prefill / decode
iteration pipeline.
- PrefillRegressionModel: 1D regression (sum_prefill_tokens -> wall_time)
- DecodeRegressionModel: 1D regression (sum_decode_kv_tokens -> wall_time)
- AggRegressionModel: 2D regression (sum_prefill_tokens, sum_decode_kv_tokens -> wall_time)
"""
import logging
import math
from collections import deque
from typing import Optional, Union
import numpy as np
from sklearn.linear_model import LinearRegression
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class _MovingAverage:
"""Fixed-window moving average that skips leading zeros.
Initial zero values (pre-traffic idle period) are ignored until the
first non-zero value arrives, matching the throughput planner's
load predictor behavior.
"""
__slots__ = ("_window", "_sum", "_seen_nonzero")
def __init__(self, window_size: int):
self._window: deque[float] = deque(maxlen=window_size)
self._sum: float = 0.0
self._seen_nonzero: bool = False
def add(self, value: float) -> None:
if value == 0.0 and not self._seen_nonzero:
return
if value != 0.0:
self._seen_nonzero = True
if len(self._window) == self._window.maxlen:
self._sum -= self._window[0]
self._window.append(value)
self._sum += value
@property
def value(self) -> float:
if not self._window:
return 0.0
return self._sum / len(self._window)
def __len__(self) -> int:
return len(self._window)
class _BaseRegressionModel:
"""Shared regression infrastructure for FPM-based models."""
def __init__(self, window_size: int, min_observations: int = 5, ndim: int = 1):
self.window_size = window_size
self.min_observations = min_observations
self._ndim = ndim
self._observations: deque[tuple[Union[float, list[float]], float]] = deque(
maxlen=window_size
)
self._model = LinearRegression()
self._is_fitted = False
def _extract_x(self, fpm: ForwardPassMetrics) -> Union[float, list[float]]:
"""Return the regression input(s) from an FPM snapshot."""
raise NotImplementedError
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
"""Update moving averages (called for every FPM, including idle)."""
raise NotImplementedError
def add_observation(self, fpm: ForwardPassMetrics) -> None:
# Always update moving averages so idle state is reflected.
self._update_moving_averages(fpm)
if fpm.wall_time == 0.0:
return
self._observations.append((self._extract_x(fpm), fpm.wall_time))
self._is_fitted = False
def _fit(self) -> bool:
if len(self._observations) < self.min_observations:
return False
X = np.array([o[0] for o in self._observations])
if self._ndim == 1:
X = X.reshape(-1, 1)
y = np.array([o[1] for o in self._observations])
self._model.fit(X, y)
self._is_fitted = True
return True
def _ensure_fitted(self) -> bool:
return self._is_fitted or self._fit()
def has_sufficient_data(self) -> bool:
return len(self._observations) >= self.min_observations
@property
def num_observations(self) -> int:
return len(self._observations)
class PrefillRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled prefill tokens.
Regression: wall_time = f(sum_prefill_tokens)
Simulation: estimate TTFT by chunking queued_prefill_tokens + avg_isl
into max_num_batched_tokens-sized iterations and summing
the predicted wall time for each.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_isl = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_prefill_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
self._avg_num_prefill.add(float(sched.num_prefill_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request.
The scheduler processes prefill tokens in chunks of
max_num_batched_tokens per iteration. We sum the regression-predicted
wall time for each chunk to approximate TTFT.
Args:
queued_prefill_tokens: tokens already queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from WorkerInfo/MDC).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
pred = self._model.predict(np.array([[chunk]]))[0]
total_time += max(0.0, float(pred))
remaining -= chunk
return total_time
class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled decode KV tokens.
Regression: wall_time = f(sum_decode_kv_tokens)
Estimation: predict ITL for the next decode step accounting for
queued (preempted) decode load and one additional request.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_decode_kv_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate the next decode iteration time.
Predicts wall time for the total decode KV load: currently scheduled +
queued (preempted) + one additional request worth of decode context.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, float(self._model.predict(np.array([[total_kv]]))[0]))
class AggRegressionModel(_BaseRegressionModel):
"""2D regression for aggregated (chunked prefill + decode) engines.
Regression: wall_time = f(sum_prefill_tokens, sum_decode_kv_tokens)
Estimation: estimate TTFT by simulating prefill chunking while assuming
steady-state decode load; estimate ITL by predicting decode
iteration time while assuming average piggybacked prefill load.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=2)
self._avg_isl = _MovingAverage(window_size)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_prefill_tokens = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> list[float]:
sched = fpm.scheduled_requests
return [float(sched.sum_prefill_tokens), float(sched.sum_decode_kv_tokens)]
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_prefill_tokens.add(float(sched.sum_prefill_tokens))
self._avg_num_prefill.add(float(sched.num_prefill_requests))
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
@property
def avg_prefill_tokens(self) -> float:
return self._avg_prefill_tokens.value
def _predict_2d(self, prefill_tokens: float, decode_kv_tokens: float) -> float:
return float(
self._model.predict(np.array([[prefill_tokens, decode_kv_tokens]]))[0]
)
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
current_decode_kv: int,
) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode.
Same chunking simulation as PrefillRegressionModel, but each
iteration also carries the current decode KV load (steady state).
Args:
queued_prefill_tokens: prefill tokens queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from MDC).
current_decode_kv: scheduled decode KV tokens from the latest FPM
(assumed steady during prefill).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
total_time += max(0.0, self._predict_2d(chunk, float(current_decode_kv)))
remaining -= chunk
return total_time
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate decode iteration time with piggybacked prefill.
Uses the moving average of scheduled prefill tokens as the
piggybacked prefill load in the next iteration.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, self._predict_2d(self._avg_prefill_tokens.value, total_kv))
......@@ -108,7 +108,6 @@ class PlannerConfig(BaseModel):
enable_load_scaling: bool = SLAPlannerDefaults.enable_load_scaling
# Load-based scaling settings
load_router_metrics_url: Optional[str] = SLAPlannerDefaults.load_router_metrics_url
load_adjustment_interval: int = SLAPlannerDefaults.load_adjustment_interval
load_learning_window: int = SLAPlannerDefaults.load_learning_window
load_scaling_down_sensitivity: int = (
......@@ -146,13 +145,6 @@ class PlannerConfig(BaseModel):
)
if self.enable_load_scaling:
# Router metrics URL is required outside kubernetes mode
if not self.load_router_metrics_url and self.environment != "kubernetes":
raise ValueError(
"load_router_metrics_url is required when "
"load-based scaling is enabled outside kubernetes mode"
)
# Load-based interval must be shorter than throughput interval
if self.enable_throughput_scaling:
if self.load_adjustment_interval >= self.throughput_adjustment_interval:
......
......@@ -17,7 +17,16 @@ class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL
def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for prefill. Returns desired_replicas or None."""
"""Load-based scaling decision for prefill using FPM data.
For each engine, simulates prefill scheduling to estimate next TTFT:
- Uses queued prefill tokens + avg ISL as total tokens to process
- Chunks into max_num_batched_tokens-sized iterations
- Sums regression-predicted wall time per chunk
Scale up if ALL engines' estimated TTFT > SLA.
Scale down if ALL engines' estimated TTFT < SLA * sensitivity.
"""
if not self.ttft_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
......@@ -25,73 +34,46 @@ class PrefillPlanner(BasePlanner):
)
return None
x_sla = self.ttft_regression.predict_x_from_sla(self.config.ttft)
if x_sla is None:
return None
if not self.cached_load_metrics.recent:
return None
recent = self.cached_load_metrics.recent
cluster_averaged = self.cached_load_metrics.cluster_averaged
# Averaged ISL across all workers in the past adjustment interval
avg_isl = cluster_averaged.get("last_isl", 0.0)
target_active_tokens = x_sla - avg_isl
if target_active_tokens <= 0:
logger.warning(
f"TTFT SLA unachievable at current ISL: x_sla={x_sla:.1f}, "
f"avg_isl={avg_isl:.1f}, skipping load-based prefill scaling"
)
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
return None
num_workers = self.shared_state.num_p_workers
if num_workers == 0:
return None
logger.info(
f"Load-based prefill: x_sla={x_sla:.1f}, avg_isl={avg_isl:.1f}, "
f"target_active_tokens={target_active_tokens:.1f}, workers={num_workers}, "
f"slope={self.ttft_regression.slope:.6f}, intercept={self.ttft_regression.intercept:.3f}"
max_num_batched_tokens = getattr(
self.prefill_worker_info, "max_num_batched_tokens", None
)
# Scale up: ALL workers above target (use recent metrics)
all_above = all(
m.get("active_prefill_tokens", 0.0) > target_active_tokens
for m in recent.values()
)
if all_above:
logger.info(
f"Load-based prefill: ALL workers above target ({target_active_tokens:.1f}), "
f"scaling up to {num_workers + 1}"
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping prefill load-based scaling"
)
return num_workers + 1
return None
# Scale down: ALL workers below boundary (use recent metrics)
if num_workers > 1:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = (
target_active_tokens * (num_workers - 1) / num_workers * sensitivity
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
queued_prefill = fpm.queued_requests.sum_prefill_tokens
est = self.ttft_regression.estimate_next_ttft(
queued_prefill_tokens=queued_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
all_below = all(
m.get("active_prefill_tokens", 0.0) < boundary for m in recent.values()
)
if all_below:
if num_workers - 1 < self.config.min_endpoint:
logger.info(
f"Load-based prefill: ALL workers below boundary ({boundary:.1f}), "
f"but cannot scale down below min_endpoint ({self.config.min_endpoint}); "
f"maintaining {num_workers} prefill workers"
)
return num_workers
if est is None:
continue
est_ms = est * 1000
estimated_ttfts.append(est_ms)
logger.info(
f"Load-based prefill: ALL workers below boundary ({boundary:.1f}), "
f"scaling down to {num_workers - 1}"
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued_prefill={queued_prefill}, avg_isl={self.ttft_regression.avg_isl:.1f})"
)
return num_workers - 1
return None
return self._load_based_scaling_decision_from_estimates(
estimates=estimated_ttfts,
sla=self.config.ttft,
num_workers=num_workers,
label="prefill TTFT",
)
def _update_correction_factor(self) -> bool:
assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None
......
......@@ -136,7 +136,7 @@ class VirtualConnector(PlannerConnector):
"""Validate the deployment"""
pass
async def wait_for_deployment_ready(self):
async def wait_for_deployment_ready(self, include_planner: bool = True):
"""Wait for the deployment to be ready"""
await self._wait_for_scaling_completion()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Optional
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
logger = logging.getLogger(__name__)
@dataclass
class WorkerInfo:
"""Consolidated worker metadata for the planner.
Populated from MDC (DynamoWorkerMetadata CRs) in Kubernetes mode,
with fallback to DGD container-arg parsing, then hard-coded defaults.
"""
# Component / endpoint names used for scaling and runtime client creation
k8s_name: Optional[str] = None
component_name: Optional[str] = None
endpoint: Optional[str] = None
# Runtime configuration from MDC
model_name: Optional[str] = None
total_kv_blocks: Optional[int] = None
kv_cache_block_size: Optional[int] = None
max_num_seqs: Optional[int] = None
max_num_batched_tokens: Optional[int] = None
context_length: Optional[int] = None
@property
def max_kv_tokens(self) -> Optional[int]:
if self.total_kv_blocks is not None and self.kv_cache_block_size is not None:
return self.total_kv_blocks * self.kv_cache_block_size
return None
def summary(self) -> str:
parts = [f"k8s_name={self.k8s_name}"]
parts.append(f"component={self.component_name}")
parts.append(f"endpoint={self.endpoint}")
if self.model_name is not None:
parts.append(f"model={self.model_name}")
if self.max_kv_tokens is not None:
parts.append(f"max_kv_tokens={self.max_kv_tokens}")
if self.max_num_seqs is not None:
parts.append(f"max_num_seqs(max_bs)={self.max_num_seqs}")
if self.max_num_batched_tokens is not None:
parts.append(f"max_num_batched_tokens={self.max_num_batched_tokens}")
if self.context_length is not None:
parts.append(f"context_length={self.context_length}")
return ", ".join(parts)
def build_worker_info_from_defaults(
backend: str, sub_component_type: SubComponentType
) -> WorkerInfo:
"""Build a WorkerInfo populated only from hard-coded backend defaults."""
names = WORKER_COMPONENT_NAMES.get(backend)
if names is None:
return WorkerInfo()
if sub_component_type == SubComponentType.PREFILL:
return WorkerInfo(
k8s_name=names.prefill_worker_k8s_name,
component_name=names.prefill_worker_component_name,
endpoint=names.prefill_worker_endpoint,
)
else:
return WorkerInfo(
k8s_name=names.decode_worker_k8s_name,
component_name=names.decode_worker_component_name,
endpoint=names.decode_worker_endpoint,
)
def resolve_worker_info(
backend: str,
require_prefill: bool,
require_decode: bool,
connector: Any = None,
config_model_name: str = "",
no_operation: bool = False,
) -> tuple[WorkerInfo, WorkerInfo]:
"""Build WorkerInfo for prefill/decode and resolve model name.
If the connector has a ``get_worker_info`` method (KubernetesConnector),
MDC is queried first with fallback to DGD container-arg parsing, then
hard-coded defaults. Otherwise hard-coded defaults are used directly.
The resolved model name is written into both WorkerInfo objects so callers
can read it from either ``prefill_info.model_name`` or
``decode_info.model_name``.
Returns:
(prefill_worker_info, decode_worker_info)
"""
can_query_mdc = connector is not None and hasattr(connector, "get_worker_info")
# --- Build WorkerInfo ---
prefill_info = WorkerInfo()
decode_info = WorkerInfo()
if can_query_mdc:
if require_prefill:
prefill_info = connector.get_worker_info(SubComponentType.PREFILL, backend)
if require_decode:
decode_info = connector.get_worker_info(SubComponentType.DECODE, backend)
else:
if require_prefill:
prefill_info = build_worker_info_from_defaults(
backend, SubComponentType.PREFILL
)
if require_decode:
decode_info = build_worker_info_from_defaults(
backend, SubComponentType.DECODE
)
if require_prefill:
logger.info(f"Prefill WorkerInfo: {prefill_info.summary()}")
if require_decode:
logger.info(f"Decode WorkerInfo: {decode_info.summary()}")
# Cross-validate model names
p_model = prefill_info.model_name
d_model = decode_info.model_name
if (
require_prefill
and require_decode
and p_model
and d_model
and p_model != d_model
):
logger.warning(
f"Model name mismatch between prefill ({p_model}) and "
f"decode ({d_model}) WorkerInfo"
)
# --- Resolve model name and write back into both WorkerInfo ---
if no_operation:
if not config_model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
model_name = config_model_name
else:
mdc_model = decode_info.model_name or prefill_info.model_name
if mdc_model:
model_name = mdc_model
logger.info(f"Using model name from MDC: {model_name}")
elif can_query_mdc:
model_name = connector.get_model_name(
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info(f"Detected model name from DGD container args: {model_name}")
elif config_model_name:
model_name = config_model_name
logger.info(f"Using model name from config: {model_name}")
else:
raise ValueError(
"Could not determine model name. "
"Please set model_name in the config."
)
prefill_info.model_name = model_name
decode_info.model_name = model_name
return prefill_info, decode_info
......@@ -10,6 +10,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)
// PlannerDefaults implements ComponentDefaults for Planner components
......@@ -29,12 +30,61 @@ func (p *PlannerDefaults) GetBaseContainer(context ComponentContext) (corev1.Con
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
},
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
}
container.LivenessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 1,
}
container.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 4,
FailureThreshold: 3,
}
// Startup probe with generous timeout: the planner waits for worker
// services to become ready before it can initialise, so it needs more
// time than a typical worker.
container.StartupProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720, // 10s * 720 = 7200s = 2h
}
container.Env = append(container.Env, []corev1.EnvVar{
{
Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
},
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
}...)
return container, nil
}
......
......@@ -12,6 +12,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/google/go-cmp/cmp"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)
func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
......@@ -45,6 +46,40 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
},
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: commonconsts.DynamoPlannerMetricsPort, Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 1,
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 4,
FailureThreshold: 3,
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{
{Name: commonconsts.DynamoNamespaceEnvVar, Value: "dynamo-namespace"},
......@@ -71,6 +106,7 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
}},
{Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"},
{Name: "PLANNER_PROMETHEUS_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort)},
{Name: "DYN_SYSTEM_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort)},
},
},
},
......
......@@ -1655,6 +1655,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1",
"1",
},
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
......@@ -1680,6 +1684,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
},
},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
......@@ -1717,6 +1732,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace",
},
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{
Name: "MODEL_EXPRESS_URL",
Value: "model-express-url",
......@@ -1775,13 +1794,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
MountPath: commonconsts.DefaultSharedMemoryMountPath,
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
},
},
},
},
},
......@@ -2632,6 +2644,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1",
"1",
},
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
......@@ -2657,6 +2673,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
},
},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
......@@ -2694,6 +2721,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace",
},
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{
Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
......@@ -2744,13 +2775,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
MountPath: commonconsts.DefaultSharedMemoryMountPath,
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
},
},
},
},
},
......@@ -3618,6 +3642,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1",
"1",
},
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{
{
SecretRef: &corev1.SecretEnvSource{
......@@ -3643,13 +3671,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
},
},
},
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{
{
Name: "DYNAMO_POD_GANG_SET_REPLICAS",
......@@ -3687,6 +3719,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace",
},
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{
Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment