Unverified Commit a58bcc31 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: load planner using new forwardpass metric and many improvements (#7351)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent db14d63f
...@@ -7,11 +7,26 @@ Receive ForwardPassMetrics via the Dynamo event plane. ...@@ -7,11 +7,26 @@ Receive ForwardPassMetrics via the Dynamo event plane.
Auto-discovers engine publishers through the discovery plane (K8s CRD / Auto-discovers engine publishers through the discovery plane (K8s CRD /
etcd / file) and prints each metric message as JSON. etcd / file) and prints each metric message as JSON.
Supports two modes:
- **recv** (default): pull individual messages one at a time.
- **tracking**: periodically poll ``get_recent_stats()`` to print the
latest snapshot keyed by ``(worker_id, dp_rank)``.
Usage: Usage:
# recv mode (default)
python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate
# tracking mode (poll every 2 seconds)
python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate \\
--mode tracking --poll-interval 2.0
# recv mode with plot saving
python -m dynamo.common.recv_forward_pass_metrics \\ python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate \\ --namespace dynamo --component backend --endpoint generate \\
[--discovery-backend etcd] [--request-plane nats] \\ --save-plot metrics.png
[--save-plot metrics.png]
""" """
import argparse import argparse
...@@ -94,11 +109,24 @@ def main() -> None: ...@@ -94,11 +109,24 @@ def main() -> None:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Request plane (default: nats)", help="Request plane (default: nats)",
) )
parser.add_argument(
"--mode",
choices=["recv", "tracking"],
default="recv",
help="Consumption mode: 'recv' for individual messages, "
"'tracking' for latest-snapshot polling (default: recv)",
)
parser.add_argument(
"--poll-interval",
type=float,
default=2.0,
help="Polling interval in seconds for tracking mode (default: 2.0)",
)
parser.add_argument( parser.add_argument(
"--save-plot", "--save-plot",
metavar="PATH", metavar="PATH",
default=None, default=None,
help="Save a time-series plot to the given PNG path on exit", help="Save a time-series plot to the given PNG path on exit (recv mode only)",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -117,15 +145,29 @@ async def run(args: argparse.Namespace) -> None: ...@@ -117,15 +145,29 @@ async def run(args: argparse.Namespace) -> None:
endpoint = runtime.endpoint(f"{args.namespace}.{args.component}.{args.endpoint}") endpoint = runtime.endpoint(f"{args.namespace}.{args.component}.{args.endpoint}")
subscriber = FpmEventSubscriber(endpoint) subscriber = FpmEventSubscriber(endpoint)
json_encoder = msgspec.json.Encoder()
logger.info( logger.info(
"Subscribed to forward-pass-metrics via event plane " "Subscribed to forward-pass-metrics via event plane "
"(namespace=%s, component=%s) Ctrl+C to stop", "(namespace=%s, component=%s, mode=%s) Ctrl+C to stop",
args.namespace, args.namespace,
args.component, args.component,
args.mode,
) )
try:
if args.mode == "tracking":
await _run_tracking(subscriber, args)
else:
await _run_recv(subscriber, args)
except KeyboardInterrupt:
logger.info("Stopped.")
finally:
subscriber.shutdown()
async def _run_recv(subscriber, args: argparse.Namespace) -> None:
"""Pull individual FPM messages and print each as JSON."""
json_encoder = msgspec.json.Encoder()
history: list[tuple[float, ForwardPassMetrics]] = [] history: list[tuple[float, ForwardPassMetrics]] = []
start_time: float | None = None start_time: float | None = None
...@@ -154,13 +196,43 @@ async def run(args: argparse.Namespace) -> None: ...@@ -154,13 +196,43 @@ async def run(args: argparse.Namespace) -> None:
metrics.counter_id, metrics.counter_id,
json.dumps(pretty, indent=2), json.dumps(pretty, indent=2),
) )
except KeyboardInterrupt:
logger.info("Stopped.")
finally: finally:
subscriber.shutdown()
if args.save_plot and history: if args.save_plot and history:
_save_plot(args.save_plot, history) _save_plot(args.save_plot, history)
async def _run_tracking(subscriber, args: argparse.Namespace) -> None:
"""Poll get_recent_stats() and print the latest snapshot periodically."""
json_encoder = msgspec.json.Encoder()
subscriber.start_tracking()
logger.info("Tracking mode started (poll every %.1fs)", args.poll_interval)
poll = 0
while True:
await asyncio.sleep(args.poll_interval)
stats = subscriber.get_recent_stats()
if not stats:
logger.info("[poll=%d] (no engines tracked)", poll)
else:
snapshot = {}
for (worker_id, dp_rank), raw_bytes in stats.items():
metrics = decode(raw_bytes)
if metrics is None:
continue
key = f"{worker_id}:dp{dp_rank}"
snapshot[key] = json.loads(json_encoder.encode(metrics))
ts = time.strftime("%H:%M:%S")
logger.info(
"[poll=%d t=%s engines=%d] %s",
poll,
ts,
len(stats),
json.dumps(snapshot, indent=2),
)
poll += 1
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -27,15 +27,15 @@ The SLA Planner supports two scaling modes that can be used independently or tog ...@@ -27,15 +27,15 @@ The SLA Planner supports two scaling modes that can be used independently or tog
Uses pre-deployment profiling data and traffic prediction to compute the number of prefill/decode replicas needed to meet TTFT and ITL SLA targets. Requires profiling data from the Dynamo profiler. Uses pre-deployment profiling data and traffic prediction to compute the number of prefill/decode replicas needed to meet TTFT and ITL SLA targets. Requires profiling data from the Dynamo profiler.
### Load-Based Scaling (Experimental) ### Load-Based Scaling
Uses real-time per-worker load metrics (active prefill tokens, active KV blocks) from the router to make SLA-aware scaling decisions via online linear regression. Does not require profiling data. Responds quickly to traffic bursts. Uses ForwardPassMetrics (FPM) from the Dynamo event plane to make SLA-aware scaling decisions via online linear regression. Does not require profiling data or the KV Router. Responds quickly to traffic bursts. Currently only supported with vLLM (FPM only available in vllm).
When both modes are enabled, throughput-based scaling provides a lower bound on replicas while load-based scaling handles real-time adjustments. When both modes are enabled, throughput-based scaling provides a lower bound on replicas while load-based scaling handles real-time adjustments.
### Support Matrix ### Support Matrix
| Deployment Type | Throughput-Based | Load-Based (Experimental) | | Deployment Type | Throughput-Based | Load-Based |
|-----------------|:----------------:|:-------------------------:| |-----------------|:----------------:|:-------------------------:|
| Disaggregated | Supported | Supported | | Disaggregated | Supported | Supported |
| Aggregated | Unsupported | Supported | | Aggregated | Unsupported | Supported |
......
...@@ -9,6 +9,7 @@ __all__ = [ ...@@ -9,6 +9,7 @@ __all__ = [
"SLAPlannerDefaults", "SLAPlannerDefaults",
"TargetReplica", "TargetReplica",
"SubComponentType", "SubComponentType",
"WorkerInfo",
] ]
# Import the classes # Import the classes
from dynamo.planner.defaults import SLAPlannerDefaults, SubComponentType from dynamo.planner.defaults import SLAPlannerDefaults, SubComponentType
...@@ -16,6 +17,7 @@ from dynamo.planner.global_planner_connector import GlobalPlannerConnector ...@@ -16,6 +17,7 @@ from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica from dynamo.planner.kubernetes_connector import KubernetesConnector, TargetReplica
from dynamo.planner.planner_connector import PlannerConnector from dynamo.planner.planner_connector import PlannerConnector
from dynamo.planner.virtual_connector import VirtualConnector from dynamo.planner.virtual_connector import VirtualConnector
from dynamo.planner.worker_info import WorkerInfo
try: try:
from ._version import __version__ from ._version import __version__
......
...@@ -29,10 +29,6 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker ...@@ -29,10 +29,6 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# start planner 30 seconds after the other components to make sure planner can see them
# TODO: remove this delay
INIT_PLANNER_START_DELAY = 30
class RequestType(BaseModel): class RequestType(BaseModel):
text: str text: str
...@@ -51,21 +47,29 @@ async def start_planner(runtime: DistributedRuntime, config: PlannerConfig): ...@@ -51,21 +47,29 @@ async def start_planner(runtime: DistributedRuntime, config: PlannerConfig):
planner = AggPlanner(runtime, config) planner = AggPlanner(runtime, config)
else: else:
raise ValueError(f"Invalid planner mode: {mode}") raise ValueError(f"Invalid planner mode: {mode}")
await planner._async_init()
await planner.run()
await planner._async_init()
async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await asyncio.sleep(INIT_PLANNER_START_DELAY)
await start_planner(runtime, config)
async def generate(request: RequestType): async def generate(request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint" yield "mock endpoint"
generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate") generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate)
# serve_endpoint registers a health check target and sets HealthStatus::Ready
# once the handler is registered with the transport. Running it concurrently
# with planner.run() ensures the system server (/health, /live) reports the
# planner as ready only after _async_init() has completed.
await asyncio.gather(
generate_endpoint.serve_endpoint(
generate, # type: ignore[arg-type]
health_check_payload={"text": "health"},
),
planner.run(),
)
async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
await start_planner(runtime, config)
def _parse_config() -> PlannerConfig: def _parse_config() -> PlannerConfig:
......
...@@ -80,9 +80,6 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -80,9 +80,6 @@ class SLAPlannerDefaults(BasePlannerDefaults):
enable_load_scaling = False enable_load_scaling = False
# Load-based scaling settings # Load-based scaling settings
load_router_metrics_url: Optional[
str
] = None # will be auto-discovered from the DGD in kubernetes mode if not provided
load_adjustment_interval = 5 # in seconds, must be < throughput_adjustment_interval load_adjustment_interval = 5 # in seconds, must be < throughput_adjustment_interval
load_learning_window = 50 # sliding window size for regression load_learning_window = 50 # sliding window size for regression
load_scaling_down_sensitivity = 80 # 0-100 load_scaling_down_sensitivity = 80 # 0-100
......
...@@ -196,7 +196,7 @@ class GlobalPlannerConnector(PlannerConnector): ...@@ -196,7 +196,7 @@ class GlobalPlannerConnector(PlannerConnector):
"(GlobalPlanner will validate on its side)" "(GlobalPlanner will validate on its side)"
) )
async def wait_for_deployment_ready(self): async def wait_for_deployment_ready(self, include_planner: bool = True):
""" """
Wait for deployment to be ready (no-op for GlobalPlanner). Wait for deployment to be ready (no-op for GlobalPlanner).
......
...@@ -204,31 +204,60 @@ class KubernetesAPI: ...@@ -204,31 +204,60 @@ class KubernetesAPI:
async def wait_for_graph_deployment_ready( async def wait_for_graph_deployment_ready(
self, self,
graph_deployment_name: str, graph_deployment_name: str,
include_planner: bool = True,
max_attempts: int = 180, # default: 30 minutes total max_attempts: int = 180, # default: 30 minutes total
delay_seconds: int = 10, # default: check every 10 seconds delay_seconds: int = 10, # default: check every 10 seconds
) -> None: ) -> None:
"""Wait for a graph deployment to be ready""" """Wait for a graph deployment to be ready.
Args:
graph_deployment_name: Name of the DGD to wait for.
include_planner: If False, skip services with componentType "planner"
and check per-service readiness instead of the global DGD Ready
condition. This avoids a circular wait when the planner itself
is one of the services in the DGD.
max_attempts: Maximum polling iterations.
delay_seconds: Seconds between polls.
"""
for attempt in range(max_attempts): for attempt in range(max_attempts):
await asyncio.sleep(delay_seconds) await asyncio.sleep(delay_seconds)
graph_deployment = self.get_graph_deployment(graph_deployment_name) graph_deployment = self.get_graph_deployment(graph_deployment_name)
conditions = graph_deployment.get("status", {}).get("conditions", []) if include_planner:
ready_condition = next( conditions = graph_deployment.get("status", {}).get("conditions", [])
(c for c in conditions if c.get("type") == "Ready"), None ready_condition = next(
) (c for c in conditions if c.get("type") == "Ready"), None
)
if ready_condition and ready_condition.get("status") == "True":
return
if ready_condition and ready_condition.get("status") == "True": logger.info(
return # Deployment is ready f"[Attempt {attempt + 1}/{max_attempts}] "
f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, "
f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})"
)
else:
services = graph_deployment.get("spec", {}).get("services", {})
not_ready: list[str] = []
for svc_name, svc_spec in services.items():
if svc_spec.get("componentType", "") == "planner":
continue
_, is_stable = self.get_service_replica_status(
graph_deployment, svc_name
)
if not is_stable:
not_ready.append(svc_name)
if not not_ready:
return
logger.info( logger.info(
f"[Attempt {attempt + 1}/{max_attempts}] " f"[Attempt {attempt + 1}/{max_attempts}] "
f"(status: {ready_condition.get('status') if ready_condition else 'N/A'}, " f"Waiting for services (excluding planner): "
f"message: {ready_condition.get('message') if ready_condition else 'no condition found'})" f"not ready: {not_ready}"
) )
# Raise after all attempts exhausted (without additional delay)
raise TimeoutError( raise TimeoutError(
f"Graph deployment '{graph_deployment_name}' " f"Graph deployment '{graph_deployment_name}' "
f"is not ready after {max_attempts * delay_seconds} seconds" f"is not ready after {max_attempts * delay_seconds} seconds"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
from typing import Optional from typing import Optional
...@@ -33,6 +34,7 @@ from dynamo.planner.utils.exceptions import ( ...@@ -33,6 +34,7 @@ from dynamo.planner.utils.exceptions import (
PlannerError, PlannerError,
UserProvidedModelNameMismatchError, UserProvidedModelNameMismatchError,
) )
from dynamo.planner.worker_info import WorkerInfo, build_worker_info_from_defaults
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -307,11 +309,275 @@ class KubernetesConnector(PlannerConnector): ...@@ -307,11 +309,275 @@ class KubernetesConnector(PlannerConnector):
return None return None
async def wait_for_deployment_ready(self): async def wait_for_deployment_ready(self, include_planner: bool = True):
"""Wait for the deployment to be ready""" """Wait for the deployment to be ready.
Args:
include_planner: If False, skip the planner service when checking
readiness. This lets the planner read MDC from worker pods
without waiting for itself to be marked ready in the DGD.
"""
await self.kube_api.wait_for_graph_deployment_ready( await self.kube_api.wait_for_graph_deployment_ready(
self.graph_deployment_name, self.graph_deployment_name,
include_planner=include_planner,
)
def _list_worker_metadata_crs(self) -> list[dict]:
"""List all DynamoWorkerMetadata CRs in the current namespace.
Returns an empty list only when the CRD is not yet installed (404).
Other API errors (RBAC, connectivity) are re-raised so callers can
handle them explicitly.
"""
from kubernetes.client import ApiException
try:
result = self.kube_api.custom_api.list_namespaced_custom_object(
group="nvidia.com",
version="v1alpha1",
namespace=self.kube_api.current_namespace,
plural="dynamoworkermetadatas",
)
return result.get("items", [])
except ApiException as e:
if e.status == 404:
logger.info("DynamoWorkerMetadata CRD not found, skipping MDC")
return []
raise
def _extract_mdc_entries(
self,
) -> list[dict]:
"""Extract MDC entries belonging to this DGD.
CRs are named after the worker pod (e.g. ``<dgd>-0-<service>-<hash>``),
so we filter by the DGD name prefix to avoid picking up entries from
other deployments sharing the namespace.
Returns a list of dicts, each containing:
namespace, component, endpoint, instance_id, card_json
"""
crs = self._list_worker_metadata_crs()
dgd_prefix = f"{self.graph_deployment_name}-"
entries: list[dict] = []
for cr in crs:
cr_name = cr.get("metadata", {}).get("name", "")
if not cr_name.startswith(dgd_prefix):
continue
data = cr.get("spec", {}).get("data", {})
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
continue
model_cards = data.get("model_cards", {})
for _key, instance in model_cards.items():
if instance.get("type") != "Model":
continue
entries.append(instance)
return entries
def _mdc_entry_is_prefill(self, entry: dict) -> bool:
"""Check if an MDC entry is a prefill worker.
model_type can be serialized as:
- An integer bitflag (ModelType::Prefill = 1 << 4 = 16)
- A dict with a "bits" key (serde bitflags format)
- A string like "Prefill" or "Chat|Completions"
"""
card = entry.get("card_json", {})
model_type = card.get("model_type", 0)
if isinstance(model_type, str):
return "prefill" in model_type.lower()
if isinstance(model_type, dict):
model_type = model_type.get("bits", 0)
return bool(model_type & 0x10)
def _build_worker_info_from_mdc(
self,
entry: dict,
sub_component_type: SubComponentType,
) -> WorkerInfo:
"""Build a WorkerInfo from an MDC entry, applying the fallback chain.
Priority: MDC -> DGD container-arg parsing -> hard-coded defaults.
"""
defaults = build_worker_info_from_defaults(
self._backend_hint or "vllm", sub_component_type
)
card = entry.get("card_json", {})
runtime_cfg = card.get("runtime_config", {})
# --- component / endpoint names from MDC wrapper ---
mdc_component = entry.get("component")
mdc_endpoint = entry.get("endpoint")
component_name = mdc_component or defaults.component_name
endpoint = mdc_endpoint or defaults.endpoint
if not mdc_component:
logger.info(
f"MDC missing 'component' for {sub_component_type.value}, "
f"falling back to default: {defaults.component_name}"
)
if not mdc_endpoint:
logger.info(
f"MDC missing 'endpoint' for {sub_component_type.value}, "
f"falling back to default: {defaults.endpoint}"
)
# --- model name ---
mdc_model = card.get("display_name")
model_name = mdc_model
if not model_name:
# Fallback: parse from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
model_name = service.get_model_name()
if model_name:
logger.info(
f"MDC missing model name for {sub_component_type.value}, "
f"fell back to DGD container args: {model_name}"
)
except PlannerError:
pass
if not model_name:
logger.warning(
f"Could not determine model name for {sub_component_type.value} "
f"from MDC or DGD container args"
)
# --- runtime config fields ---
total_kv_blocks = runtime_cfg.get("total_kv_blocks")
max_num_seqs = runtime_cfg.get("max_num_seqs")
max_num_batched_tokens = runtime_cfg.get("max_num_batched_tokens")
kv_cache_block_size = card.get("kv_cache_block_size")
context_length = card.get("context_length")
if total_kv_blocks is None:
logger.info(f"MDC missing total_kv_blocks for {sub_component_type.value}")
if max_num_seqs is None:
logger.info(f"MDC missing max_num_seqs for {sub_component_type.value}")
# --- k8s_name: resolve from DGD subComponentType ---
k8s_name = defaults.k8s_name
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
k8s_name = service.name
except PlannerError:
logger.info(
f"Could not resolve k8s service name for {sub_component_type.value}, "
f"using default: {defaults.k8s_name}"
)
info = WorkerInfo(
k8s_name=k8s_name,
component_name=component_name,
endpoint=endpoint,
model_name=model_name,
total_kv_blocks=total_kv_blocks,
kv_cache_block_size=kv_cache_block_size,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
context_length=context_length,
)
return info
def get_worker_info(
self,
sub_component_type: SubComponentType,
backend: str = "vllm",
) -> WorkerInfo:
"""Get WorkerInfo for a sub-component, trying MDC first, then fallbacks.
Args:
sub_component_type: PREFILL or DECODE
backend: Backend framework name (for default fallback)
"""
self._backend_hint = backend
entries = self._extract_mdc_entries()
# Resolve the expected component name so we can scope card selection
# and avoid picking up LoRA-adapter cards that share the same CR but
# carry a different component/model identity.
expected_component: Optional[str] = None
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
expected_component = service.name
except PlannerError:
expected_component = build_worker_info_from_defaults(
backend, sub_component_type
).component_name
for entry in entries:
is_prefill = self._mdc_entry_is_prefill(entry)
if sub_component_type == SubComponentType.PREFILL and not is_prefill:
continue
if sub_component_type == SubComponentType.DECODE and is_prefill:
continue
entry_component = entry.get("component")
if (
entry_component
and expected_component
and entry_component != expected_component
):
logger.debug(
f"Skipping MDC entry with component={entry_component!r}, "
f"expected {expected_component!r} for {sub_component_type.value}"
)
continue
info = self._build_worker_info_from_mdc(entry, sub_component_type)
logger.info(
f"Built {sub_component_type.value} WorkerInfo from MDC: "
f"{info.summary()}"
)
return info
# No MDC entry found -- fall back entirely to defaults + DGD arg parsing
logger.warning(
f"No DynamoWorkerMetadata CR found for {sub_component_type.value}. "
f"Workers may not be registered yet. Falling back to defaults."
)
info = build_worker_info_from_defaults(backend, sub_component_type)
# Try to enrich model_name from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
info.k8s_name = service.name
arg_model = service.get_model_name()
if arg_model:
info.model_name = arg_model
logger.info(
f"Enriched {sub_component_type.value} WorkerInfo model name "
f"from DGD args: {arg_model}"
)
except PlannerError as e:
logger.info(
f"Could not enrich WorkerInfo from DGD for {sub_component_type.value}: {e}"
)
logger.info(
f"Using fallback WorkerInfo for {sub_component_type.value}: {info.summary()}"
) )
return info
def get_actual_worker_counts( def get_actual_worker_counts(
self, self,
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
import asyncio import asyncio
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from dynamo.planner import SubComponentType, TargetReplica from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.utils.load_based_regression import LoadBasedRegressionModel from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.utils.planner_config import PlannerConfig
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.planner.utils.planner_core import ( from dynamo.planner.utils.planner_core import (
BasePlanner, BasePlanner,
PlannerPrometheusMetrics, PlannerPrometheusMetrics,
...@@ -15,7 +18,6 @@ from dynamo.planner.utils.planner_core import ( ...@@ -15,7 +18,6 @@ from dynamo.planner.utils.planner_core import (
_apply_component_gpu_budget, _apply_component_gpu_budget,
_initialize_gpu_counts, _initialize_gpu_counts,
) )
from dynamo.planner.utils.prometheus import CachedLoadMetrics
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -24,24 +26,24 @@ logger = logging.getLogger(__name__) ...@@ -24,24 +26,24 @@ logger = logging.getLogger(__name__)
class AggPlanner: class AggPlanner:
"""Aggregated planner: load-based scaling only, single engine type. """Aggregated planner: FPM-driven load-based scaling, single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill). In aggregated mode, engines handle both prefill and decode (chunked prefill).
Engine metrics are labeled "decode" by the router. A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression.
Scaling logic: Scaling logic:
- TTFT and ITL regression models are both maintained. - Estimate next TTFT per engine by simulating prefill chunking with
- Regression uses per-worker time-averaged metrics (not latest snapshot) piggybacked decode (steady-state decode load).
because chunked prefill adds noise to instantaneous TTFT/ITL. - Estimate next ITL per engine by predicting decode iteration time with
- Scale up if either prefill or decode target is exceeded. average piggybacked prefill load.
- Scale down if both prefill and decode are below their boundaries. - Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
""" """
# Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE = "decode"
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None: def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config self.config = config
self.runtime = runtime
self.shared_state = PlannerSharedState() self.shared_state = PlannerSharedState()
if config.enable_throughput_scaling: if config.enable_throughput_scaling:
...@@ -56,8 +58,6 @@ class AggPlanner: ...@@ -56,8 +58,6 @@ class AggPlanner:
prometheus_metrics = PlannerPrometheusMetrics() prometheus_metrics = PlannerPrometheusMetrics()
# Use a single BasePlanner instance for infra (connector, prometheus, etc.)
# We use DECODE component_type because engine metrics are labeled "decode"
self.planner = BasePlanner( self.planner = BasePlanner(
runtime, runtime,
config, config,
...@@ -67,28 +67,27 @@ class AggPlanner: ...@@ -67,28 +67,27 @@ class AggPlanner:
component_type=SubComponentType.DECODE, component_type=SubComponentType.DECODE,
) )
# Create both regression models (agg needs both TTFT and ITL) from dynamo.planner.utils.fpm_regression import AggRegressionModel
self.ttft_regression = LoadBasedRegressionModel(
window_size=config.load_learning_window, self.regression = AggRegressionModel(
min_observations=config.load_min_observations,
)
self.itl_regression = LoadBasedRegressionModel(
window_size=config.load_learning_window, window_size=config.load_learning_window,
min_observations=config.load_min_observations, min_observations=config.load_min_observations,
) )
self.cached_load_metrics = CachedLoadMetrics()
async def _async_init(self): async def _async_init(self):
await self.planner._async_init() defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
async def run(self):
if not self.config.no_operation: if not self.config.no_operation:
connector = getattr(self.planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...") logger.info("Validating deployment...")
# Agg mode: only decode component exists (engines serve both P and D)
await self.planner.connector.validate_deployment( await self.planner.connector.validate_deployment(
prefill_component_name=None, prefill_component_name=None,
decode_component_name=self.planner.decode_component_name, decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=False, require_prefill=False,
require_decode=True, require_decode=True,
) )
...@@ -101,208 +100,102 @@ class AggPlanner: ...@@ -101,208 +100,102 @@ class AggPlanner:
require_decode=True, require_decode=True,
) )
await self.planner.connector.wait_for_deployment_ready() await self.planner.connector.wait_for_deployment_ready(
include_planner=False
# Model name discovery runs in all modes (needed for metrics collection)
if not self.config.no_operation:
model_name = await self.planner._get_model_name(
require_prefill=False, require_decode=True
)
logger.info(f"Detected model name from deployment: {model_name}")
self.planner.model_name = model_name.lower()
else:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
self.planner.model_name = self.config.model_name.lower()
loops = [
self._load_loop(),
self.planner.prometheus_engine_client.run_sampling_loop(
self.config.load_metric_samples,
self.config.load_adjustment_interval,
),
]
await asyncio.gather(*loops)
async def _observe_engine_load_stats(self) -> None:
"""Fetch metrics and update regression models using per-worker time-averaged data."""
result = self.planner.prometheus_engine_client.get_recent_and_averaged_metrics(
self.ENGINE_WORKER_TYPE
)
if result is None:
logger.warning(
f"No per-worker metrics available yet for {self.ENGINE_WORKER_TYPE} (buffer empty)"
)
return
recent, per_worker_averaged, cluster_averaged = result
self.cached_load_metrics = CachedLoadMetrics(
recent=recent,
per_worker_averaged=per_worker_averaged,
cluster_averaged=cluster_averaged,
)
# Agg uses per-worker time-averaged metrics for regression
# because chunked prefill adds noise to instantaneous TTFT/ITL
for wid, m in per_worker_averaged.items():
# TTFT regression: (active_prefill_tokens + ISL) -> TTFT
active_prefill = m.get("active_prefill_tokens", 0.0)
last_isl = m.get("last_isl", 0.0)
last_ttft = m.get("last_ttft", 0.0)
if last_ttft > 0 and last_isl > 0:
x = active_prefill + last_isl
y = last_ttft * 1000 # seconds -> ms
logger.info(
f"Agg Worker {wid} prefill observation: TTFT {y:.2f}ms @ tokens {x:.2f}"
)
self.ttft_regression.add_observation(x, y)
# ITL regression: active_decode_blocks -> ITL
active_decode = m.get("active_decode_blocks", 0.0)
last_itl = m.get("last_itl", 0.0)
if last_itl > 0 and active_decode > 0:
x = active_decode
y = last_itl * 1000 # seconds -> ms
logger.info(
f"Agg Worker {wid} decode observation: ITL {y:.2f}ms @ blocks {x:.2f}"
)
self.itl_regression.add_observation(x, y)
def _prefill_scaling_decision(self, num_workers: int) -> Optional[str]:
"""Returns "up", "down", or None for prefill dimension."""
if not self.cached_load_metrics.recent:
return None
if not self.ttft_regression.has_sufficient_data():
logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
f"/{self.ttft_regression.min_observations}), skipping"
)
return None
x_sla = self.ttft_regression.predict_x_from_sla(self.config.ttft)
if x_sla is None:
return None
recent = self.cached_load_metrics.recent
cluster_averaged = self.cached_load_metrics.cluster_averaged
avg_isl = cluster_averaged.get("last_isl", 0.0)
target = x_sla - avg_isl
if target <= 0:
logger.warning(
f"Agg TTFT SLA unachievable at current ISL: x_sla={x_sla:.1f}, "
f"avg_isl={avg_isl:.1f}, skipping prefill scaling decision"
)
return None
logger.info(
f"Agg prefill: x_sla={x_sla:.1f}, avg_isl={avg_isl:.1f}, "
f"target_active_tokens={target:.1f}, workers={num_workers}"
)
# Scale up: ALL workers above target
if all(m.get("active_prefill_tokens", 0.0) > target for m in recent.values()):
return "up"
# Scale down: ALL workers below boundary
if num_workers > self.config.min_endpoint:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = target * (num_workers - 1) / num_workers * sensitivity
if all(
m.get("active_prefill_tokens", 0.0) < boundary for m in recent.values()
):
return "down"
return None
def _decode_scaling_decision(self, num_workers: int) -> Optional[str]:
"""Returns "up", "down", or None for decode dimension."""
if not self.cached_load_metrics.recent:
return None
if not self.itl_regression.has_sufficient_data():
logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
f"/{self.itl_regression.min_observations}), skipping"
)
return None
x_sla = self.itl_regression.predict_x_from_sla(self.config.itl)
if x_sla is None:
return None
if x_sla <= 0:
logger.warning(
f"Agg ITL SLA unachievable: x_sla={x_sla:.1f}, "
"skipping decode scaling decision"
) )
return None
recent = self.cached_load_metrics.recent await self.planner._init_worker_info(require_prefill=False, require_decode=True)
logger.info(f"Agg decode: x_sla={x_sla:.1f}, workers={num_workers}") # Delegate FPM tracking to the inner BasePlanner (component_type=DECODE).
if self.runtime is not None:
await self.planner._init_fpm_subscriber()
# Scale up: ALL workers above target async def run(self):
if all(m.get("active_decode_blocks", 0.0) > x_sla for m in recent.values()): """Main scaling loop. Call _async_init() before this."""
return "up" await asyncio.gather(self._load_loop())
# Scale down: ALL workers below boundary
# TODO: should we strictly enforce all workers below boundary?
# how about user-configurable percentage?
if num_workers > self.config.min_endpoint:
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
boundary = x_sla * (num_workers - 1) / num_workers * sensitivity
if all(
m.get("active_decode_blocks", 0.0) < boundary for m in recent.values()
):
return "down"
return None
async def _load_loop(self) -> None: async def _load_loop(self) -> None:
"""Load-based scaling loop for aggregated mode.""" """FPM-driven load-based scaling loop for aggregated mode."""
pending_desired: Optional[int] = None
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New agg load-based adjustment interval started!") logger.info("New agg load-based adjustment interval started!")
# Query DGD for fresh worker counts
_, num_d, _ = await self.planner.get_workers_info( _, num_d, _ = await self.planner.get_workers_info(
require_prefill=False, require_decode=True require_prefill=False, require_decode=True
) )
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
num_workers = num_d num_workers = num_d
# Observe per-worker metrics # Always observe FPM stats and update regression, even during scaling.
await self._observe_engine_load_stats() fpm_stats = self.planner._get_fpm_stats()
if not fpm_stats:
logger.warning("No FPM data available for agg engines")
continue
# Reconcile worker counts for (wid, dp), fpm in fpm_stats.items():
prom_count = len(self.cached_load_metrics.recent) BasePlanner._log_fpm(wid, dp, fpm, "agg")
if prom_count != num_workers: self.regression.add_observation(fpm)
logger.warning(
f"Worker count mismatch: DGD reports {num_workers}, " # If a previous scaling action is still in progress, skip decisions.
f"router metrics reports {prom_count}. Skipping." if pending_desired is not None:
if num_workers == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
pending_desired = None
else:
logger.info(
f"Scaling in progress ({num_workers} -> {pending_desired}), "
"observing only"
)
continue
if not BasePlanner._reconcile_fpm_worker_count(
fpm_stats, num_workers, "agg"
):
continue
if not self.regression.has_sufficient_data():
logger.info(
f"Agg regression: insufficient data "
f"({self.regression.num_observations}/{self.regression.min_observations})"
) )
continue continue
if not self.cached_load_metrics.recent: max_num_batched_tokens = getattr(
self.planner.decode_worker_info, "max_num_batched_tokens", None
)
if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.warning(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping agg scaling"
)
continue continue
# Make scaling decisions separately for prefill and decode p_desired = self._prefill_scaling_decision(
p_decision = self._prefill_scaling_decision(num_workers) fpm_stats, num_workers, max_num_batched_tokens
d_decision = self._decode_scaling_decision(num_workers) )
d_desired = self._decode_scaling_decision(fpm_stats, num_workers)
logger.info( logger.info(
f"Agg scaling decisions: prefill={p_decision}, decode={d_decision}" f"Agg scaling decisions: prefill={p_desired}, decode={d_desired} "
f"(current={num_workers})"
) )
# Scale up if EITHER needs scale up # Scale up if EITHER dimension wants more workers.
# Scale down if BOTH need scale down # Scale down only if BOTH dimensions agree on fewer.
if p_decision == "up" or d_decision == "up": if p_desired is not None and p_desired > num_workers:
desired = num_workers + 1 desired = p_desired
elif p_decision == "down" and d_decision == "down": elif d_desired is not None and d_desired > num_workers:
desired = num_workers - 1 desired = d_desired
elif (
p_desired is not None
and p_desired < num_workers
and d_desired is not None
and d_desired < num_workers
):
desired = max(p_desired, d_desired)
else: else:
logger.info("Agg scaling: no scaling needed") logger.info("Agg scaling: no scaling needed")
continue continue
...@@ -322,13 +215,54 @@ class AggPlanner: ...@@ -322,13 +215,54 @@ class AggPlanner:
self.planner.prometheus_metrics.predicted_num_d.set(desired) self.planner.prometheus_metrics.predicted_num_d.set(desired)
if not self.config.no_operation: if not self.config.no_operation:
pending_desired = desired
target_replicas = [ target_replicas = [
TargetReplica( TargetReplica(
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
component_name=self.planner.decode_component_name, component_name=self.planner.decode_worker_info.k8s_name,
desired_replicas=desired, desired_replicas=desired,
) )
] ]
await self.planner.connector.set_component_replicas( await self.planner.connector.set_component_replicas(
target_replicas, blocking=True target_replicas, blocking=False
) )
def _prefill_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
max_num_batched_tokens: int,
) -> Optional[int]:
"""Returns desired replica count for the prefill (TTFT) dimension, or None."""
estimated_ttfts: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_ttft(
queued_prefill_tokens=fpm.queued_requests.sum_prefill_tokens,
max_num_batched_tokens=max_num_batched_tokens,
current_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_ttfts.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_ttfts, self.config.ttft, num_workers, "agg TTFT"
)
def _decode_scaling_decision(
self,
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
num_workers: int,
) -> Optional[int]:
"""Returns desired replica count for the decode (ITL) dimension, or None."""
estimated_itls: list[float] = []
for (wid, dp), fpm in fpm_stats.items():
est = self.regression.estimate_next_itl(
scheduled_decode_kv=fpm.scheduled_requests.sum_decode_kv_tokens,
queued_decode_kv=fpm.queued_requests.sum_decode_kv_tokens,
)
if est is not None:
estimated_itls.append(est * 1000)
return self.planner._load_based_scaling_decision_from_estimates(
estimated_itls, self.config.itl, num_workers, "agg ITL"
)
...@@ -17,7 +17,15 @@ class DecodePlanner(BasePlanner): ...@@ -17,7 +17,15 @@ class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE component_type = SubComponentType.DECODE
def load_plan_adjustment(self) -> Optional[int]: def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for decode. Returns desired_replicas or None.""" """Load-based scaling decision for decode using FPM data.
For each engine, estimates next decode ITL:
- Uses scheduled + queued decode KV tokens + avg decode length
- Predicts wall time via regression
Scale up if ALL engines' estimated ITL > SLA.
Scale down if ALL engines' estimated ITL < SLA * sensitivity.
"""
if not self.itl_regression.has_sufficient_data(): if not self.itl_regression.has_sufficient_data():
logger.info( logger.info(
f"ITL regression: insufficient data ({self.itl_regression.num_observations}" f"ITL regression: insufficient data ({self.itl_regression.num_observations}"
...@@ -25,64 +33,38 @@ class DecodePlanner(BasePlanner): ...@@ -25,64 +33,38 @@ class DecodePlanner(BasePlanner):
) )
return None return None
x_sla = self.itl_regression.predict_x_from_sla(self.config.itl) fpm_stats = self._get_fpm_stats()
if x_sla is None: if not fpm_stats:
return None return None
if x_sla <= 0:
logger.warning(
f"ITL SLA unachievable: x_sla={x_sla:.1f}, "
"skipping load-based decode scaling"
)
return None
if not self.cached_load_metrics.recent:
return None
recent = self.cached_load_metrics.recent
num_workers = self.shared_state.num_d_workers num_workers = self.shared_state.num_d_workers
if num_workers == 0: if num_workers == 0:
return None return None
logger.info( estimated_itls: list[float] = []
f"Load-based decode: x_sla={x_sla:.1f}, workers={num_workers}, " for (wid, dp), fpm in fpm_stats.items():
f"slope={self.itl_regression.slope:.6f}, intercept={self.itl_regression.intercept:.3f}" scheduled_kv = fpm.scheduled_requests.sum_decode_kv_tokens
) queued_kv = fpm.queued_requests.sum_decode_kv_tokens
est = self.itl_regression.estimate_next_itl(
# Scale up: ALL workers above target (use recent metrics) scheduled_decode_kv=scheduled_kv,
all_above = all( queued_decode_kv=queued_kv,
m.get("active_decode_blocks", 0.0) > x_sla for m in recent.values()
)
if all_above:
logger.info(
f"Load-based decode: ALL workers above target ({x_sla:.1f}), "
f"scaling up to {num_workers + 1}"
) )
return num_workers + 1 if est is None:
continue
# Scale down: ALL workers below boundary (use recent metrics) est_ms = est * 1000
if num_workers > 1: estimated_itls.append(est_ms)
sensitivity = self.config.load_scaling_down_sensitivity / 100.0 logger.info(
boundary = x_sla * (num_workers - 1) / num_workers * sensitivity f"Decode engine {wid}:dp{dp}: estimated ITL {est_ms:.2f}ms "
all_below = all( f"(sched_kv={scheduled_kv}, queued_kv={queued_kv}, "
m.get("active_decode_blocks", 0.0) < boundary for m in recent.values() f"avg_decode_len={self.itl_regression.avg_decode_length:.1f})"
) )
if all_below:
if num_workers - 1 < self.config.min_endpoint:
logger.info(
f"Load-based decode: ALL workers below boundary ({boundary:.1f}), "
f"but cannot scale down below min_endpoint ({self.config.min_endpoint}); "
f"maintaining {num_workers} decode workers"
)
return num_workers
logger.info(
f"Load-based decode: ALL workers below boundary ({boundary:.1f}), "
f"scaling down to {num_workers - 1}"
)
return num_workers - 1
return None return self._load_based_scaling_decision_from_estimates(
estimates=estimated_itls,
sla=self.config.itl,
num_workers=num_workers,
label="decode ITL",
)
def _update_correction_factor(self) -> bool: def _update_correction_factor(self) -> bool:
if self.shared_state.num_d_workers == 0: if self.shared_state.num_d_workers == 0:
......
...@@ -6,9 +6,11 @@ import logging ...@@ -6,9 +6,11 @@ import logging
import time import time
from dynamo.planner import SubComponentType, TargetReplica from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.utils.decode_planner import DecodePlanner from dynamo.planner.utils.decode_planner import DecodePlanner
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.utils.planner_config import PlannerConfig
from dynamo.planner.utils.planner_core import ( from dynamo.planner.utils.planner_core import (
BasePlanner,
PlannerPrometheusMetrics, PlannerPrometheusMetrics,
PlannerSharedState, PlannerSharedState,
_apply_global_gpu_budget, _apply_global_gpu_budget,
...@@ -46,29 +48,34 @@ class DisaggPlanner: ...@@ -46,29 +48,34 @@ class DisaggPlanner:
prometheus_traffic_client=getattr( prometheus_traffic_client=getattr(
self.prefill_planner, "prometheus_traffic_client", None self.prefill_planner, "prometheus_traffic_client", None
), ),
prometheus_engine_client=getattr(
self.prefill_planner, "prometheus_engine_client", None
),
connector=getattr(self.prefill_planner, "connector", None), connector=getattr(self.prefill_planner, "connector", None),
start_prometheus_server=False, start_prometheus_server=False,
) )
async def _async_init(self): async def _async_init(self):
# Prefill/Decode share the same connector instance in disagg mode. # DisaggPlanner overrides _async_init to handle both prefill+decode
await self.prefill_planner._async_init() # and share WorkerInfo between the two sub-planners.
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
async def run(self):
if not self.config.no_operation: if not self.config.no_operation:
# Connector init (prefill/decode share the same connector)
connector = getattr(self.prefill_planner, "connector", None)
if connector and hasattr(connector, "_async_init"):
await connector._async_init()
logger.info("Validating deployment...") logger.info("Validating deployment...")
await self.prefill_planner.connector.validate_deployment( await self.prefill_planner.connector.validate_deployment(
prefill_component_name=self.prefill_planner.prefill_component_name, prefill_component_name=(
decode_component_name=self.prefill_planner.decode_component_name, defaults.prefill_worker_k8s_name if defaults else None
),
decode_component_name=(
defaults.decode_worker_k8s_name if defaults else None
),
require_prefill=True, require_prefill=True,
require_decode=True, require_decode=True,
) )
logger.info("Successfully validated the deployment") logger.info("Successfully validated the deployment")
# Initialize GPU counts
_initialize_gpu_counts( _initialize_gpu_counts(
self.config, self.config,
self.prefill_planner.connector, self.prefill_planner.connector,
...@@ -76,40 +83,40 @@ class DisaggPlanner: ...@@ -76,40 +83,40 @@ class DisaggPlanner:
require_decode=True, require_decode=True,
) )
await self.prefill_planner.connector.wait_for_deployment_ready() await self.prefill_planner.connector.wait_for_deployment_ready(
include_planner=False
# Model name discovery runs in all modes (needed for metrics collection)
if not self.config.no_operation:
model_name = await self.prefill_planner._get_model_name(
require_prefill=True, require_decode=True
) )
logger.info(f"Detected model name from deployment: {model_name}")
model_name = model_name.lower()
else:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
model_name = self.config.model_name.lower()
self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name
await self.prefill_planner._init_worker_info(
require_prefill=True, require_decode=True
)
# Share WorkerInfo and model name with decode planner
self.decode_planner.prefill_worker_info = (
self.prefill_planner.prefill_worker_info
)
self.decode_planner.decode_worker_info = self.prefill_planner.decode_worker_info
self.decode_planner.model_name = self.prefill_planner.model_name
# Start FPM tracking for both planners. DisaggPlanner bypasses each
# sub-planner's _async_init(), so we init subscribers explicitly here.
if self.enable_load:
if self.prefill_planner.runtime is not None:
await self.prefill_planner._init_fpm_subscriber()
if self.decode_planner.runtime is not None:
await self.decode_planner._init_fpm_subscriber()
async def run(self):
"""Main scaling loop. Call _async_init() before this."""
self.shared_state.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time() self.shared_state.last_load_adjustment_time = time.time()
# Build list of concurrent loops based on enabled scaling modes # FPM tracking (started in _async_init) replaces the former
# DirectRouterMetricsClient.run_sampling_loop().
loops = [] loops = []
if self.enable_throughput: if self.enable_throughput:
loops.append(self._throughput_loop()) loops.append(self._throughput_loop())
if self.enable_load: if self.enable_load:
loops.append(self._load_loop()) loops.append(self._load_loop())
loops.append(
self.prefill_planner.prometheus_engine_client.run_sampling_loop(
self.config.load_metric_samples,
self.config.load_adjustment_interval,
)
)
await asyncio.gather(*loops) await asyncio.gather(*loops)
...@@ -156,12 +163,12 @@ class DisaggPlanner: ...@@ -156,12 +163,12 @@ class DisaggPlanner:
target_replicas = [ target_replicas = [
TargetReplica( TargetReplica(
sub_component_type=SubComponentType.PREFILL, sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name, component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=next_num_p, desired_replicas=next_num_p,
), ),
TargetReplica( TargetReplica(
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name, component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=next_num_d, desired_replicas=next_num_d,
), ),
] ]
...@@ -172,34 +179,34 @@ class DisaggPlanner: ...@@ -172,34 +179,34 @@ class DisaggPlanner:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10) await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_loop(self) -> None: async def _load_loop(self) -> None:
"""Load-based scaling loop for disagg mode at shorter interval.""" """FPM-driven load-based scaling loop for disagg mode."""
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load-based adjustment interval started!") logger.info("New load-based adjustment interval started!")
# Query DGD for fresh worker counts
num_p, num_d, _ = await self.prefill_planner.get_workers_info( num_p, num_d, _ = await self.prefill_planner.get_workers_info(
require_prefill=True, require_decode=True require_prefill=True, require_decode=True
) )
self.shared_state.num_p_workers = num_p self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
# Observe per-worker metrics from router # Observe FPM stats and feed into regression models
await self.prefill_planner.observe_engine_load_stats() p_stats = self.prefill_planner.observe_fpm_load_stats()
await self.decode_planner.observe_engine_load_stats() d_stats = self.decode_planner.observe_fpm_load_stats()
# Reconcile DGD worker counts with router Prometheus counts if not p_stats and not d_stats:
p_prom_count = len(self.prefill_planner.cached_load_metrics.recent) logger.warning("No FPM data for either prefill or decode, skipping")
d_prom_count = len(self.decode_planner.cached_load_metrics.recent) continue
if p_prom_count != num_p or d_prom_count != num_d:
logger.warning( if p_stats and not BasePlanner._reconcile_fpm_worker_count(
f"Worker count mismatch: DGD reports P={num_p}, D={num_d}; " p_stats, num_p, "prefill"
f"router metrics reports P={p_prom_count}, D={d_prom_count}. " ):
"Skipping load-based scaling adjustment." continue
) if d_stats and not BasePlanner._reconcile_fpm_worker_count(
d_stats, num_d, "decode"
):
continue continue
# Scale prefill and decode independently
p_desired = self.prefill_planner.load_plan_adjustment() p_desired = self.prefill_planner.load_plan_adjustment()
d_desired = self.decode_planner.load_plan_adjustment() d_desired = self.decode_planner.load_plan_adjustment()
...@@ -241,12 +248,12 @@ class DisaggPlanner: ...@@ -241,12 +248,12 @@ class DisaggPlanner:
target_replicas = [ target_replicas = [
TargetReplica( TargetReplica(
sub_component_type=SubComponentType.PREFILL, sub_component_type=SubComponentType.PREFILL,
component_name=self.prefill_planner.prefill_component_name, component_name=self.prefill_planner.prefill_worker_info.k8s_name,
desired_replicas=final_p, desired_replicas=final_p,
), ),
TargetReplica( TargetReplica(
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
component_name=self.prefill_planner.decode_component_name, component_name=self.prefill_planner.decode_worker_info.k8s_name,
desired_replicas=final_d, desired_replicas=final_d,
), ),
] ]
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""FPM-driven regression models for load-based scaling.
Each model takes ForwardPassMetrics observations and estimates per-engine
TTFT or ITL by simulating the scheduler's chunked prefill / decode
iteration pipeline.
- PrefillRegressionModel: 1D regression (sum_prefill_tokens -> wall_time)
- DecodeRegressionModel: 1D regression (sum_decode_kv_tokens -> wall_time)
- AggRegressionModel: 2D regression (sum_prefill_tokens, sum_decode_kv_tokens -> wall_time)
"""
import logging
import math
from collections import deque
from typing import Optional, Union
import numpy as np
from sklearn.linear_model import LinearRegression
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
logger = logging.getLogger(__name__)
class _MovingAverage:
"""Fixed-window moving average that skips leading zeros.
Initial zero values (pre-traffic idle period) are ignored until the
first non-zero value arrives, matching the throughput planner's
load predictor behavior.
"""
__slots__ = ("_window", "_sum", "_seen_nonzero")
def __init__(self, window_size: int):
self._window: deque[float] = deque(maxlen=window_size)
self._sum: float = 0.0
self._seen_nonzero: bool = False
def add(self, value: float) -> None:
if value == 0.0 and not self._seen_nonzero:
return
if value != 0.0:
self._seen_nonzero = True
if len(self._window) == self._window.maxlen:
self._sum -= self._window[0]
self._window.append(value)
self._sum += value
@property
def value(self) -> float:
if not self._window:
return 0.0
return self._sum / len(self._window)
def __len__(self) -> int:
return len(self._window)
class _BaseRegressionModel:
"""Shared regression infrastructure for FPM-based models."""
def __init__(self, window_size: int, min_observations: int = 5, ndim: int = 1):
self.window_size = window_size
self.min_observations = min_observations
self._ndim = ndim
self._observations: deque[tuple[Union[float, list[float]], float]] = deque(
maxlen=window_size
)
self._model = LinearRegression()
self._is_fitted = False
def _extract_x(self, fpm: ForwardPassMetrics) -> Union[float, list[float]]:
"""Return the regression input(s) from an FPM snapshot."""
raise NotImplementedError
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
"""Update moving averages (called for every FPM, including idle)."""
raise NotImplementedError
def add_observation(self, fpm: ForwardPassMetrics) -> None:
# Always update moving averages so idle state is reflected.
self._update_moving_averages(fpm)
if fpm.wall_time == 0.0:
return
self._observations.append((self._extract_x(fpm), fpm.wall_time))
self._is_fitted = False
def _fit(self) -> bool:
if len(self._observations) < self.min_observations:
return False
X = np.array([o[0] for o in self._observations])
if self._ndim == 1:
X = X.reshape(-1, 1)
y = np.array([o[1] for o in self._observations])
self._model.fit(X, y)
self._is_fitted = True
return True
def _ensure_fitted(self) -> bool:
return self._is_fitted or self._fit()
def has_sufficient_data(self) -> bool:
return len(self._observations) >= self.min_observations
@property
def num_observations(self) -> int:
return len(self._observations)
class PrefillRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled prefill tokens.
Regression: wall_time = f(sum_prefill_tokens)
Simulation: estimate TTFT by chunking queued_prefill_tokens + avg_isl
into max_num_batched_tokens-sized iterations and summing
the predicted wall time for each.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_isl = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_prefill_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
self._avg_num_prefill.add(float(sched.num_prefill_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
) -> Optional[float]:
"""Simulate prefill scheduling to estimate TTFT for the next request.
The scheduler processes prefill tokens in chunks of
max_num_batched_tokens per iteration. We sum the regression-predicted
wall time for each chunk to approximate TTFT.
Args:
queued_prefill_tokens: tokens already queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from WorkerInfo/MDC).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
pred = self._model.predict(np.array([[chunk]]))[0]
total_time += max(0.0, float(pred))
remaining -= chunk
return total_time
class DecodeRegressionModel(_BaseRegressionModel):
"""Predict per-iteration wall time from scheduled decode KV tokens.
Regression: wall_time = f(sum_decode_kv_tokens)
Estimation: predict ITL for the next decode step accounting for
queued (preempted) decode load and one additional request.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=1)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> float:
return float(fpm.scheduled_requests.sum_decode_kv_tokens)
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate the next decode iteration time.
Predicts wall time for the total decode KV load: currently scheduled +
queued (preempted) + one additional request worth of decode context.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, float(self._model.predict(np.array([[total_kv]]))[0]))
class AggRegressionModel(_BaseRegressionModel):
"""2D regression for aggregated (chunked prefill + decode) engines.
Regression: wall_time = f(sum_prefill_tokens, sum_decode_kv_tokens)
Estimation: estimate TTFT by simulating prefill chunking while assuming
steady-state decode load; estimate ITL by predicting decode
iteration time while assuming average piggybacked prefill load.
"""
def __init__(self, window_size: int, min_observations: int = 5):
super().__init__(window_size, min_observations, ndim=2)
self._avg_isl = _MovingAverage(window_size)
self._avg_decode_len = _MovingAverage(window_size)
self._avg_prefill_tokens = _MovingAverage(window_size)
self._avg_num_prefill = _MovingAverage(window_size)
self._avg_num_decode = _MovingAverage(window_size)
def _extract_x(self, fpm: ForwardPassMetrics) -> list[float]:
sched = fpm.scheduled_requests
return [float(sched.sum_prefill_tokens), float(sched.sum_decode_kv_tokens)]
def _update_moving_averages(self, fpm: ForwardPassMetrics) -> None:
sched = fpm.scheduled_requests
if sched.num_prefill_requests > 0:
self._avg_isl.add(sched.sum_prefill_tokens / sched.num_prefill_requests)
if sched.num_decode_requests > 0:
self._avg_decode_len.add(
sched.sum_decode_kv_tokens / sched.num_decode_requests
)
self._avg_prefill_tokens.add(float(sched.sum_prefill_tokens))
self._avg_num_prefill.add(float(sched.num_prefill_requests))
self._avg_num_decode.add(float(sched.num_decode_requests))
@property
def avg_isl(self) -> float:
return self._avg_isl.value
@property
def avg_decode_length(self) -> float:
return self._avg_decode_len.value
@property
def avg_prefill_tokens(self) -> float:
return self._avg_prefill_tokens.value
def _predict_2d(self, prefill_tokens: float, decode_kv_tokens: float) -> float:
return float(
self._model.predict(np.array([[prefill_tokens, decode_kv_tokens]]))[0]
)
def estimate_next_ttft(
self,
queued_prefill_tokens: int,
max_num_batched_tokens: int,
current_decode_kv: int,
) -> Optional[float]:
"""Simulate prefill scheduling with piggybacked decode.
Same chunking simulation as PrefillRegressionModel, but each
iteration also carries the current decode KV load (steady state).
Args:
queued_prefill_tokens: prefill tokens queued ahead of the next request.
max_num_batched_tokens: per-iteration token budget (from MDC).
current_decode_kv: scheduled decode KV tokens from the latest FPM
(assumed steady during prefill).
Returns:
Estimated TTFT in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted() or max_num_batched_tokens <= 0:
return None
total_tokens = queued_prefill_tokens + self._avg_isl.value
if total_tokens <= 0:
return 0.0
num_iterations = math.ceil(total_tokens / max_num_batched_tokens)
total_time = 0.0
remaining = total_tokens
for _ in range(num_iterations):
chunk = min(remaining, max_num_batched_tokens)
total_time += max(0.0, self._predict_2d(chunk, float(current_decode_kv)))
remaining -= chunk
return total_time
def estimate_next_itl(
self,
scheduled_decode_kv: int,
queued_decode_kv: int,
) -> Optional[float]:
"""Estimate decode iteration time with piggybacked prefill.
Uses the moving average of scheduled prefill tokens as the
piggybacked prefill load in the next iteration.
Args:
scheduled_decode_kv: sum_decode_kv_tokens from the latest FPM.
queued_decode_kv: sum_decode_kv_tokens from the queued metrics.
Returns:
Estimated ITL in seconds, or None if the model is not ready.
"""
if not self._ensure_fitted():
return None
total_kv = scheduled_decode_kv + queued_decode_kv + self._avg_decode_len.value
return max(0.0, self._predict_2d(self._avg_prefill_tokens.value, total_kv))
...@@ -108,7 +108,6 @@ class PlannerConfig(BaseModel): ...@@ -108,7 +108,6 @@ class PlannerConfig(BaseModel):
enable_load_scaling: bool = SLAPlannerDefaults.enable_load_scaling enable_load_scaling: bool = SLAPlannerDefaults.enable_load_scaling
# Load-based scaling settings # Load-based scaling settings
load_router_metrics_url: Optional[str] = SLAPlannerDefaults.load_router_metrics_url
load_adjustment_interval: int = SLAPlannerDefaults.load_adjustment_interval load_adjustment_interval: int = SLAPlannerDefaults.load_adjustment_interval
load_learning_window: int = SLAPlannerDefaults.load_learning_window load_learning_window: int = SLAPlannerDefaults.load_learning_window
load_scaling_down_sensitivity: int = ( load_scaling_down_sensitivity: int = (
...@@ -146,13 +145,6 @@ class PlannerConfig(BaseModel): ...@@ -146,13 +145,6 @@ class PlannerConfig(BaseModel):
) )
if self.enable_load_scaling: if self.enable_load_scaling:
# Router metrics URL is required outside kubernetes mode
if not self.load_router_metrics_url and self.environment != "kubernetes":
raise ValueError(
"load_router_metrics_url is required when "
"load-based scaling is enabled outside kubernetes mode"
)
# Load-based interval must be shorter than throughput interval # Load-based interval must be shorter than throughput interval
if self.enable_throughput_scaling: if self.enable_throughput_scaling:
if self.load_adjustment_interval >= self.throughput_adjustment_interval: if self.load_adjustment_interval >= self.throughput_adjustment_interval:
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import math import math
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
from prometheus_client import Gauge, start_http_server from prometheus_client import Gauge, start_http_server
...@@ -19,6 +19,10 @@ from dynamo.planner import ( ...@@ -19,6 +19,10 @@ from dynamo.planner import (
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
from dynamo.planner.global_planner_connector import GlobalPlannerConnector from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.utils.exceptions import DeploymentValidationError from dynamo.planner.utils.exceptions import DeploymentValidationError
if TYPE_CHECKING:
from dynamo.common.forward_pass_metrics import ForwardPassMetrics
from dynamo.llm import FpmEventSubscriber
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import ( from dynamo.planner.utils.perf_interpolation import (
DecodeInterpolator, DecodeInterpolator,
...@@ -26,13 +30,9 @@ from dynamo.planner.utils.perf_interpolation import ( ...@@ -26,13 +30,9 @@ from dynamo.planner.utils.perf_interpolation import (
) )
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.utils.planner_config import PlannerConfig
from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper
from dynamo.planner.utils.prometheus import ( from dynamo.planner.utils.prometheus import Metrics, PrometheusAPIClient
CachedLoadMetrics,
DirectRouterMetricsClient,
Metrics,
PrometheusAPIClient,
)
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
from dynamo.planner.worker_info import WorkerInfo, resolve_worker_info
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -257,7 +257,6 @@ class BasePlanner: ...@@ -257,7 +257,6 @@ class BasePlanner:
shared_state: Optional[PlannerSharedState] = None, shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None, prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_traffic_client: Optional[PrometheusAPIClient] = None, prometheus_traffic_client: Optional[PrometheusAPIClient] = None,
prometheus_engine_client: Optional[DirectRouterMetricsClient] = None,
connector: Optional[ConnectorType] = None, connector: Optional[ConnectorType] = None,
start_prometheus_server: bool = True, start_prometheus_server: bool = True,
component_type: Optional[SubComponentType] = None, component_type: Optional[SubComponentType] = None,
...@@ -392,12 +391,10 @@ class BasePlanner: ...@@ -392,12 +391,10 @@ class BasePlanner:
config.profile_results_dir config.profile_results_dir
) )
self.prefill_component_name = WORKER_COMPONENT_NAMES[ # WorkerInfo: finalized by _init_worker_info() at the start of run().
self.config.backend # Empty placeholders until then.
].prefill_worker_k8s_name self.prefill_worker_info = WorkerInfo()
self.decode_component_name = WORKER_COMPONENT_NAMES[ self.decode_worker_info = WorkerInfo()
self.config.backend
].decode_worker_k8s_name
self.prometheus_metrics: PlannerPrometheusMetrics | None = None self.prometheus_metrics: PlannerPrometheusMetrics | None = None
if not self.dryrun: if not self.dryrun:
...@@ -432,47 +429,20 @@ class BasePlanner: ...@@ -432,47 +429,20 @@ class BasePlanner:
self.no_correction = config.no_correction self.no_correction = config.no_correction
if self.enable_load: if self.enable_load:
if prometheus_engine_client is not None: from dynamo.planner.utils.fpm_regression import (
self.prometheus_engine_client = prometheus_engine_client DecodeRegressionModel,
else: PrefillRegressionModel,
# Auto-discover frontend metrics URL in Kubernetes mode
connector = getattr(self, "connector", None)
if not config.load_router_metrics_url and isinstance(
connector, KubernetesConnector
):
config.load_router_metrics_url = (
connector.get_frontend_metrics_url()
)
if not config.load_router_metrics_url:
raise ValueError(
"Could not auto-discover frontend metrics URL from DGD. "
"No service with componentType 'frontend' found. "
"Please set load_router_metrics_url in the config."
)
else:
logger.info(
f"Auto-discovered frontend metrics URL: {config.load_router_metrics_url}"
)
assert (
config.load_router_metrics_url is not None
), "load_router_metrics_url must be set when load-based scaling is enabled"
self.prometheus_engine_client = DirectRouterMetricsClient(
config.load_router_metrics_url, config.namespace
)
self.cached_load_metrics = CachedLoadMetrics()
from dynamo.planner.utils.load_based_regression import (
LoadBasedRegressionModel,
) )
self.fpm_subscriber: "Optional[FpmEventSubscriber]" = None
if self.component_type == SubComponentType.PREFILL: if self.component_type == SubComponentType.PREFILL:
self.ttft_regression = LoadBasedRegressionModel( self.ttft_regression = PrefillRegressionModel(
window_size=self.config.load_learning_window, window_size=self.config.load_learning_window,
min_observations=self.config.load_min_observations, min_observations=self.config.load_min_observations,
) )
elif self.component_type == SubComponentType.DECODE: elif self.component_type == SubComponentType.DECODE:
self.itl_regression = LoadBasedRegressionModel( self.itl_regression = DecodeRegressionModel(
window_size=self.config.load_learning_window, window_size=self.config.load_learning_window,
min_observations=self.config.load_min_observations, min_observations=self.config.load_min_observations,
) )
...@@ -485,8 +455,26 @@ class BasePlanner: ...@@ -485,8 +455,26 @@ class BasePlanner:
def last_metrics(self, value: Metrics) -> None: def last_metrics(self, value: Metrics) -> None:
self.shared_state.last_metrics = value self.shared_state.last_metrics = value
async def _init_worker_info(
self, require_prefill: bool, require_decode: bool
) -> None:
"""Initialize WorkerInfo and model name in a single step."""
connector = getattr(self, "connector", None)
self.prefill_worker_info, self.decode_worker_info = resolve_worker_info(
backend=self.config.backend,
require_prefill=require_prefill,
require_decode=require_decode,
connector=connector,
config_model_name=getattr(self.config, "model_name", ""),
no_operation=self.config.no_operation,
)
# model_name is resolved and written into both WorkerInfo objects
self.model_name = (
self.decode_worker_info.model_name or self.prefill_worker_info.model_name
)
async def _async_init(self): async def _async_init(self):
"""Async initialization for components that need it""" """Async initialization: connector init, deployment validation, WorkerInfo."""
if ( if (
not self.dryrun not self.dryrun
and hasattr(self, "connector") and hasattr(self, "connector")
...@@ -494,13 +482,87 @@ class BasePlanner: ...@@ -494,13 +482,87 @@ class BasePlanner:
): ):
await self.connector._async_init() await self.connector._async_init()
async def _get_model_name(self, require_prefill: bool, require_decode: bool) -> str: require_prefill = self.component_type == SubComponentType.PREFILL
model_name = self.connector.get_model_name( require_decode = self.component_type == SubComponentType.DECODE
require_prefill=require_prefill, require_decode=require_decode
if not self.dryrun and not self.config.no_operation:
defaults = WORKER_COMPONENT_NAMES.get(self.config.backend)
logger.info("Validating deployment...")
await self.connector.validate_deployment(
prefill_component_name=(
defaults.prefill_worker_k8s_name
if require_prefill and defaults
else None
),
decode_component_name=(
defaults.decode_worker_k8s_name
if require_decode and defaults
else None
),
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info("Successfully validated the deployment")
_initialize_gpu_counts(
self.config,
self.connector,
require_prefill=require_prefill,
require_decode=require_decode,
)
await self.connector.wait_for_deployment_ready(include_planner=False)
await self._init_worker_info(
require_prefill=require_prefill,
require_decode=require_decode,
)
# Start FPM tracking if load-based scaling is enabled.
# The subscriber auto-discovers FPM publishers for this component.
if self.enable_load and self.runtime is not None:
await self._init_fpm_subscriber()
async def _init_fpm_subscriber(self) -> None:
"""Create and start the FPM subscriber for load-based scaling."""
from dynamo.llm import FpmEventSubscriber
worker_info = (
self.prefill_worker_info
if self.component_type == SubComponentType.PREFILL
else self.decode_worker_info
)
if not worker_info.component_name or not worker_info.endpoint:
logger.warning(
"WorkerInfo missing component_name or endpoint, "
"cannot create FPM subscriber"
)
return
assert self.runtime is not None
endpoint = self.runtime.endpoint(
f"{self.namespace}.{worker_info.component_name}.{worker_info.endpoint}"
)
self.fpm_subscriber = FpmEventSubscriber(endpoint)
self.fpm_subscriber.start_tracking()
logger.info(
f"FPM tracker started for {worker_info.component_name}.{worker_info.endpoint}"
) )
if asyncio.iscoroutine(model_name):
model_name = await model_name def _get_fpm_stats(self) -> "dict[tuple[str, int], ForwardPassMetrics]":
return model_name """Get decoded FPM stats from the subscriber, keyed by (worker_id, dp_rank)."""
from dynamo.common.forward_pass_metrics import decode as decode_fpm
if self.fpm_subscriber is None:
return {}
raw_stats = self.fpm_subscriber.get_recent_stats()
result = {}
for key, raw_bytes in raw_stats.items():
fpm = decode_fpm(raw_bytes)
if fpm is not None:
result[key] = fpm
return result
async def _get_or_create_client(self, component_name: str, endpoint_name: str): async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync.""" """Create a client for the given component and endpoint, with a brief sleep for state sync."""
...@@ -535,10 +597,10 @@ class BasePlanner: ...@@ -535,10 +597,10 @@ class BasePlanner:
is_stable, is_stable,
) = self.connector.get_actual_worker_counts( ) = self.connector.get_actual_worker_counts(
prefill_component_name=( prefill_component_name=(
self.prefill_component_name if require_prefill else None self.prefill_worker_info.k8s_name if require_prefill else None
), ),
decode_component_name=( decode_component_name=(
self.decode_component_name if require_decode else None self.decode_worker_info.k8s_name if require_decode else None
), ),
) )
num_p_workers = prefill_count if require_prefill else 0 num_p_workers = prefill_count if require_prefill else 0
...@@ -549,14 +611,14 @@ class BasePlanner: ...@@ -549,14 +611,14 @@ class BasePlanner:
if self.runtime is None: if self.runtime is None:
raise RuntimeError("Runtime is not initialized") raise RuntimeError("Runtime is not initialized")
worker_names = WORKER_COMPONENT_NAMES[self.config.backend]
if require_prefill: if require_prefill:
try: try:
if self.prefill_client is None: if self.prefill_client is None:
assert self.prefill_worker_info.component_name is not None
assert self.prefill_worker_info.endpoint is not None
self.prefill_client = await self._get_or_create_client( self.prefill_client = await self._get_or_create_client(
worker_names.prefill_worker_component_name, self.prefill_worker_info.component_name,
worker_names.prefill_worker_endpoint, self.prefill_worker_info.endpoint,
) )
num_p_workers = len(self.prefill_client.instance_ids()) # type: ignore num_p_workers = len(self.prefill_client.instance_ids()) # type: ignore
except Exception: except Exception:
...@@ -568,9 +630,11 @@ class BasePlanner: ...@@ -568,9 +630,11 @@ class BasePlanner:
if require_decode: if require_decode:
try: try:
if self.workers_client is None: if self.workers_client is None:
assert self.decode_worker_info.component_name is not None
assert self.decode_worker_info.endpoint is not None
self.workers_client = await self._get_or_create_client( self.workers_client = await self._get_or_create_client(
worker_names.decode_worker_component_name, self.decode_worker_info.component_name,
worker_names.decode_worker_endpoint, self.decode_worker_info.endpoint,
) )
num_d_workers = len(self.workers_client.instance_ids()) # type: ignore num_d_workers = len(self.workers_client.instance_ids()) # type: ignore
except Exception as e: except Exception as e:
...@@ -757,8 +821,10 @@ class BasePlanner: ...@@ -757,8 +821,10 @@ class BasePlanner:
def _component_name(self) -> str: def _component_name(self) -> str:
if self.component_type == SubComponentType.PREFILL: if self.component_type == SubComponentType.PREFILL:
return self.prefill_component_name assert self.prefill_worker_info.k8s_name is not None
return self.decode_component_name return self.prefill_worker_info.k8s_name
assert self.decode_worker_info.k8s_name is not None
return self.decode_worker_info.k8s_name
def _engine_num_gpu(self) -> int: def _engine_num_gpu(self) -> int:
if self.component_type == SubComponentType.PREFILL: if self.component_type == SubComponentType.PREFILL:
...@@ -787,7 +853,7 @@ class BasePlanner: ...@@ -787,7 +853,7 @@ class BasePlanner:
await self.connector.set_component_replicas(target_replicas, blocking=False) await self.connector.set_component_replicas(target_replicas, blocking=False)
async def _apply_scaling_blocking(self, desired_replicas: int) -> None: async def _apply_scaling_blocking(self, desired_replicas: int) -> None:
"""Apply scaling with blocking=True (wait for deployment ready).""" """Apply scaling without blocking so the loop continues observing metrics."""
if self.config.no_operation: if self.config.no_operation:
return return
target_replicas = [ target_replicas = [
...@@ -797,53 +863,156 @@ class BasePlanner: ...@@ -797,53 +863,156 @@ class BasePlanner:
desired_replicas=desired_replicas, desired_replicas=desired_replicas,
) )
] ]
await self.connector.set_component_replicas(target_replicas, blocking=True) await self.connector.set_component_replicas(target_replicas, blocking=False)
@staticmethod
def _reconcile_fpm_worker_count(
fpm_stats: "dict[tuple[str, int], ForwardPassMetrics]",
dgd_count: int,
label: str,
) -> bool:
"""Validate that FPM coverage matches DGD worker count, accounting for DP.
With attention DP, each worker emits FPM per dp_rank. We check that
the number of unique worker IDs matches DGD, and that all workers
have the same number of dp_ranks (complete coverage).
Returns True if counts match, False otherwise.
"""
workers_to_dp: dict[str, set[int]] = {}
for wid, dp in fpm_stats:
workers_to_dp.setdefault(wid, set()).add(dp)
fpm_worker_count = len(workers_to_dp)
if fpm_worker_count != dgd_count:
logger.warning(
f"Worker count mismatch: DGD reports {dgd_count}, "
f"FPM reports {fpm_worker_count} workers for {label}. "
"Skipping scaling."
)
return False
dp_sizes = {len(dps) for dps in workers_to_dp.values()}
if len(dp_sizes) > 1:
logger.warning(
f"Inconsistent DP ranks across workers for {label}: "
f"{dict(workers_to_dp)}. Skipping scaling."
)
return False
dp_size = dp_sizes.pop() if dp_sizes else 1
expected_total = dgd_count * dp_size
actual_total = len(fpm_stats)
if actual_total != expected_total:
logger.warning(
f"Incomplete FPM coverage for {label}: expected "
f"{dgd_count} workers × {dp_size} dp_ranks = {expected_total}, "
f"got {actual_total}. Skipping scaling."
)
return False
async def observe_engine_load_stats(self) -> None: if dp_size > 1:
"""Query DirectRouterMetricsClient for per-worker metrics, update regression.""" logger.info(
worker_type = self.component_type.value # "prefill" or "decode" f"FPM {label}: {fpm_worker_count} workers × {dp_size} dp_ranks "
result = self.prometheus_engine_client.get_recent_and_averaged_metrics( f"= {actual_total} engines"
worker_type )
return True
@staticmethod
def _log_fpm(wid: str, dp: int, fpm: "ForwardPassMetrics", label: str) -> None:
sched = fpm.scheduled_requests
queued = fpm.queued_requests
logger.info(
f"FPM {label} engine {wid}:dp{dp}: "
f"wall_time={fpm.wall_time:.4f}s, "
f"sched(prefill_tok={sched.sum_prefill_tokens}, "
f"prefill_req={sched.num_prefill_requests}, "
f"decode_kv={sched.sum_decode_kv_tokens}, "
f"decode_req={sched.num_decode_requests}), "
f"queued(prefill_tok={queued.sum_prefill_tokens}, "
f"decode_kv={queued.sum_decode_kv_tokens})"
) )
if result is None:
def observe_fpm_load_stats(
self,
) -> "dict[tuple[str, int], ForwardPassMetrics]":
"""Get latest FPM stats and feed observations into the regression model.
Returns:
The decoded FPM stats dict for use by load_plan_adjustment().
"""
fpm_stats = self._get_fpm_stats()
if not fpm_stats:
logger.warning( logger.warning(
f"No per-worker metrics available yet for {worker_type} (buffer empty)" f"No FPM data available for {self.component_type.value} (tracker empty)"
) )
return return {}
recent, per_worker_averaged, cluster_averaged = result for (wid, dp), fpm in fpm_stats.items():
self.cached_load_metrics = CachedLoadMetrics( self._log_fpm(wid, dp, fpm, self.component_type.value)
recent=recent, if self.component_type == SubComponentType.PREFILL:
per_worker_averaged=per_worker_averaged, self.ttft_regression.add_observation(fpm)
cluster_averaged=cluster_averaged, elif self.component_type == SubComponentType.DECODE:
self.itl_regression.add_observation(fpm)
logger.info(
f"FPM load stats: {len(fpm_stats)} engines observed for "
f"{self.component_type.value}"
) )
return fpm_stats
if self.component_type == SubComponentType.PREFILL: def _load_based_scaling_decision_from_estimates(
for wid, m in recent.items(): self,
active_prefill = m.get("active_prefill_tokens", 0.0) estimates: list[float],
last_isl = m.get("last_isl", 0.0) sla: float,
last_ttft = m.get("last_ttft", 0.0) num_workers: int,
if last_ttft > 0 and last_isl > 0: label: str,
x = active_prefill + last_isl ) -> Optional[int]:
# last_ttft is in seconds from Prometheus, convert to ms """Shared scale-up/down logic from per-engine latency estimates (ms).
y = last_ttft * 1000
Args:
estimates: per-engine estimated latencies in ms.
sla: target SLA in ms (e.g. config.ttft or config.itl).
num_workers: current worker count for this component.
label: human-readable label for log messages (e.g. "prefill TTFT").
Returns:
Desired replica count, or None if no scaling action needed.
"""
if not estimates:
return None
sensitivity = self.config.load_scaling_down_sensitivity / 100.0
logger.info(
f"Load-based {label}: workers={num_workers}, sla={sla:.1f}ms, "
f"estimates={[f'{t:.1f}' for t in estimates]}"
)
if all(t > sla for t in estimates):
logger.info(
f"Load-based {label}: ALL engines above SLA ({sla:.1f}ms), "
f"scaling up to {num_workers + 1}"
)
return num_workers + 1
if num_workers > 1:
threshold = sla * sensitivity
if all(t < threshold for t in estimates):
desired = max(num_workers - 1, self.config.min_endpoint)
if desired == num_workers:
logger.info( logger.info(
f"{SubComponentType.PREFILL.value} Worker {wid} observed status: TTFT {y:.2f}ms @ prefill tokens {x:.2f}" f"Load-based {label}: ALL engines below threshold "
f"({threshold:.1f}ms), but at min_endpoint ({self.config.min_endpoint})"
) )
self.ttft_regression.add_observation(x, y) else:
elif self.component_type == SubComponentType.DECODE:
for wid, m in recent.items():
active_decode = m.get("active_decode_blocks", 0.0)
last_itl = m.get("last_itl", 0.0)
if last_itl > 0 and active_decode > 0:
x = active_decode
# last_itl is in seconds from Prometheus, convert to ms
y = last_itl * 1000
logger.info( logger.info(
f"{SubComponentType.DECODE.value} Worker {wid} observed status: ITL {y:.2f}ms @ decode blocks {x:.2f}" f"Load-based {label}: ALL engines below threshold "
f"({threshold:.1f}ms), scaling down to {desired}"
) )
self.itl_regression.add_observation(x, y) return desired
return None
def load_plan_adjustment(self) -> Optional[int]: def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision. Override in subclasses.""" """Load-based scaling decision. Override in subclasses."""
...@@ -892,32 +1061,51 @@ class BasePlanner: ...@@ -892,32 +1061,51 @@ class BasePlanner:
await asyncio.sleep(self.config.throughput_adjustment_interval / 10) await asyncio.sleep(self.config.throughput_adjustment_interval / 10)
async def _load_loop(self, require_prefill: bool, require_decode: bool) -> None: async def _load_loop(self, require_prefill: bool, require_decode: bool) -> None:
"""Load-based scaling loop at shorter interval.""" """Load-based scaling loop at shorter interval.
Uses FPM stats from the event plane (via FpmEventSubscriber) instead
of scraping the router's /metrics endpoint.
"""
pending_desired: Optional[int] = None
while True: while True:
await asyncio.sleep(self.config.load_adjustment_interval) await asyncio.sleep(self.config.load_adjustment_interval)
logger.info("New load-based adjustment interval started!") logger.info("New load-based adjustment interval started!")
# Query DGD for fresh worker counts # Query DGD for fresh worker counts
num_p, num_d, _ = await self.get_workers_info( num_p, num_d, is_stable = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode require_prefill=require_prefill, require_decode=require_decode
) )
self.shared_state.num_p_workers = num_p self.shared_state.num_p_workers = num_p
self.shared_state.num_d_workers = num_d self.shared_state.num_d_workers = num_d
# Observe per-worker metrics from router # Always observe FPM stats and update regression, even during scaling.
await self.observe_engine_load_stats() fpm_stats = self.observe_fpm_load_stats()
if not fpm_stats:
continue
# If a previous scaling action is still in progress, skip decisions.
if pending_desired is not None:
dgd_count = (
num_p if self.component_type == SubComponentType.PREFILL else num_d
)
if dgd_count == pending_desired:
logger.info(
f"Scaling to {pending_desired} complete, resuming decisions"
)
pending_desired = None
else:
logger.info(
f"Scaling in progress ({dgd_count} -> {pending_desired}), "
"observing only"
)
continue
# Reconcile DGD worker count with router Prometheus count
prom_count = len(self.cached_load_metrics.recent)
dgd_count = ( dgd_count = (
num_p if self.component_type == SubComponentType.PREFILL else num_d num_p if self.component_type == SubComponentType.PREFILL else num_d
) )
if prom_count != dgd_count: if not self._reconcile_fpm_worker_count(
logger.warning( fpm_stats, dgd_count, self.component_type.value
f"Worker count mismatch: DGD reports {dgd_count} workers, " ):
f"router metrics reports {prom_count} workers. "
"Skipping load-based scaling adjustment."
)
continue continue
desired_replicas = self.load_plan_adjustment() desired_replicas = self.load_plan_adjustment()
...@@ -932,70 +1120,24 @@ class BasePlanner: ...@@ -932,70 +1120,24 @@ class BasePlanner:
desired_replicas = max(desired_replicas, lower_bound) desired_replicas = max(desired_replicas, lower_bound)
desired_replicas = self.apply_component_budget(desired_replicas) desired_replicas = self.apply_component_budget(desired_replicas)
self.update_predicted_replicas_metric(desired_replicas) self.update_predicted_replicas_metric(desired_replicas)
# Load-based planner needs blocking scaling because it only checks pending_desired = desired_replicas
# the current status of the engine, not the predicted load.
# We need to wait for the deployment to be steady before making another one.
await self._apply_scaling_blocking(desired_replicas) await self._apply_scaling_blocking(desired_replicas)
async def run(self): async def run(self):
"""Main loop for the planner""" """Main scaling loop. Call _async_init() before this."""
require_prefill = self.component_type == SubComponentType.PREFILL require_prefill = self.component_type == SubComponentType.PREFILL
require_decode = self.component_type == SubComponentType.DECODE require_decode = self.component_type == SubComponentType.DECODE
if not self.config.no_operation:
logger.info("Validating deployment...")
await self.connector.validate_deployment(
prefill_component_name=(
self.prefill_component_name if require_prefill else None
),
decode_component_name=(
self.decode_component_name if require_decode else None
),
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info("Successfully validated the deployment")
# Initialize GPU counts
_initialize_gpu_counts(
self.config,
self.connector,
require_prefill=require_prefill,
require_decode=require_decode,
)
await self.connector.wait_for_deployment_ready()
# Model name discovery runs in all modes (needed for metrics collection)
if not self.config.no_operation:
model_name = await self._get_model_name(
require_prefill=require_prefill, require_decode=require_decode
)
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = model_name.lower()
else:
model_name = getattr(self.config, "model_name", "")
if not model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
self.model_name = model_name.lower()
self.shared_state.last_adjustment_time = time.time() self.shared_state.last_adjustment_time = time.time()
self.shared_state.last_load_adjustment_time = time.time() self.shared_state.last_load_adjustment_time = time.time()
# Build list of concurrent loops based on enabled scaling modes # Build list of concurrent loops based on enabled scaling modes.
# FPM tracking (started in _async_init) replaces the former
# DirectRouterMetricsClient.run_sampling_loop().
loops = [] loops = []
if self.enable_throughput: if self.enable_throughput:
loops.append(self._throughput_loop(require_prefill, require_decode)) loops.append(self._throughput_loop(require_prefill, require_decode))
if self.enable_load: if self.enable_load:
loops.append(self._load_loop(require_prefill, require_decode)) loops.append(self._load_loop(require_prefill, require_decode))
loops.append(
self.prometheus_engine_client.run_sampling_loop(
self.config.load_metric_samples,
self.config.load_adjustment_interval,
)
)
await asyncio.gather(*loops) await asyncio.gather(*loops)
...@@ -17,7 +17,16 @@ class PrefillPlanner(BasePlanner): ...@@ -17,7 +17,16 @@ class PrefillPlanner(BasePlanner):
component_type = SubComponentType.PREFILL component_type = SubComponentType.PREFILL
def load_plan_adjustment(self) -> Optional[int]: def load_plan_adjustment(self) -> Optional[int]:
"""Load-based scaling decision for prefill. Returns desired_replicas or None.""" """Load-based scaling decision for prefill using FPM data.
For each engine, simulates prefill scheduling to estimate next TTFT:
- Uses queued prefill tokens + avg ISL as total tokens to process
- Chunks into max_num_batched_tokens-sized iterations
- Sums regression-predicted wall time per chunk
Scale up if ALL engines' estimated TTFT > SLA.
Scale down if ALL engines' estimated TTFT < SLA * sensitivity.
"""
if not self.ttft_regression.has_sufficient_data(): if not self.ttft_regression.has_sufficient_data():
logger.info( logger.info(
f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}" f"TTFT regression: insufficient data ({self.ttft_regression.num_observations}"
...@@ -25,73 +34,46 @@ class PrefillPlanner(BasePlanner): ...@@ -25,73 +34,46 @@ class PrefillPlanner(BasePlanner):
) )
return None return None
x_sla = self.ttft_regression.predict_x_from_sla(self.config.ttft) fpm_stats = self._get_fpm_stats()
if x_sla is None: if not fpm_stats:
return None
if not self.cached_load_metrics.recent:
return None
recent = self.cached_load_metrics.recent
cluster_averaged = self.cached_load_metrics.cluster_averaged
# Averaged ISL across all workers in the past adjustment interval
avg_isl = cluster_averaged.get("last_isl", 0.0)
target_active_tokens = x_sla - avg_isl
if target_active_tokens <= 0:
logger.warning(
f"TTFT SLA unachievable at current ISL: x_sla={x_sla:.1f}, "
f"avg_isl={avg_isl:.1f}, skipping load-based prefill scaling"
)
return None return None
num_workers = self.shared_state.num_p_workers num_workers = self.shared_state.num_p_workers
if num_workers == 0: if num_workers == 0:
return None return None
logger.info( max_num_batched_tokens = getattr(
f"Load-based prefill: x_sla={x_sla:.1f}, avg_isl={avg_isl:.1f}, " self.prefill_worker_info, "max_num_batched_tokens", None
f"target_active_tokens={target_active_tokens:.1f}, workers={num_workers}, "
f"slope={self.ttft_regression.slope:.6f}, intercept={self.ttft_regression.intercept:.3f}"
)
# Scale up: ALL workers above target (use recent metrics)
all_above = all(
m.get("active_prefill_tokens", 0.0) > target_active_tokens
for m in recent.values()
) )
if all_above: if not max_num_batched_tokens or max_num_batched_tokens <= 0:
logger.info( logger.warning(
f"Load-based prefill: ALL workers above target ({target_active_tokens:.1f}), " "max_num_batched_tokens not available from WorkerInfo, "
f"scaling up to {num_workers + 1}" "skipping prefill load-based scaling"
) )
return num_workers + 1 return None
# Scale down: ALL workers below boundary (use recent metrics) estimated_ttfts: list[float] = []
if num_workers > 1: for (wid, dp), fpm in fpm_stats.items():
sensitivity = self.config.load_scaling_down_sensitivity / 100.0 queued_prefill = fpm.queued_requests.sum_prefill_tokens
boundary = ( est = self.ttft_regression.estimate_next_ttft(
target_active_tokens * (num_workers - 1) / num_workers * sensitivity queued_prefill_tokens=queued_prefill,
max_num_batched_tokens=max_num_batched_tokens,
) )
all_below = all( if est is None:
m.get("active_prefill_tokens", 0.0) < boundary for m in recent.values() continue
est_ms = est * 1000
estimated_ttfts.append(est_ms)
logger.info(
f"Prefill engine {wid}:dp{dp}: estimated TTFT {est_ms:.2f}ms "
f"(queued_prefill={queued_prefill}, avg_isl={self.ttft_regression.avg_isl:.1f})"
) )
if all_below:
if num_workers - 1 < self.config.min_endpoint: return self._load_based_scaling_decision_from_estimates(
logger.info( estimates=estimated_ttfts,
f"Load-based prefill: ALL workers below boundary ({boundary:.1f}), " sla=self.config.ttft,
f"but cannot scale down below min_endpoint ({self.config.min_endpoint}); " num_workers=num_workers,
f"maintaining {num_workers} prefill workers" label="prefill TTFT",
) )
return num_workers
logger.info(
f"Load-based prefill: ALL workers below boundary ({boundary:.1f}), "
f"scaling down to {num_workers - 1}"
)
return num_workers - 1
return None
def _update_correction_factor(self) -> bool: def _update_correction_factor(self) -> bool:
assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None
......
...@@ -136,7 +136,7 @@ class VirtualConnector(PlannerConnector): ...@@ -136,7 +136,7 @@ class VirtualConnector(PlannerConnector):
"""Validate the deployment""" """Validate the deployment"""
pass pass
async def wait_for_deployment_ready(self): async def wait_for_deployment_ready(self, include_planner: bool = True):
"""Wait for the deployment to be ready""" """Wait for the deployment to be ready"""
await self._wait_for_scaling_completion() await self._wait_for_scaling_completion()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, Optional
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
logger = logging.getLogger(__name__)
@dataclass
class WorkerInfo:
"""Consolidated worker metadata for the planner.
Populated from MDC (DynamoWorkerMetadata CRs) in Kubernetes mode,
with fallback to DGD container-arg parsing, then hard-coded defaults.
"""
# Component / endpoint names used for scaling and runtime client creation
k8s_name: Optional[str] = None
component_name: Optional[str] = None
endpoint: Optional[str] = None
# Runtime configuration from MDC
model_name: Optional[str] = None
total_kv_blocks: Optional[int] = None
kv_cache_block_size: Optional[int] = None
max_num_seqs: Optional[int] = None
max_num_batched_tokens: Optional[int] = None
context_length: Optional[int] = None
@property
def max_kv_tokens(self) -> Optional[int]:
if self.total_kv_blocks is not None and self.kv_cache_block_size is not None:
return self.total_kv_blocks * self.kv_cache_block_size
return None
def summary(self) -> str:
parts = [f"k8s_name={self.k8s_name}"]
parts.append(f"component={self.component_name}")
parts.append(f"endpoint={self.endpoint}")
if self.model_name is not None:
parts.append(f"model={self.model_name}")
if self.max_kv_tokens is not None:
parts.append(f"max_kv_tokens={self.max_kv_tokens}")
if self.max_num_seqs is not None:
parts.append(f"max_num_seqs(max_bs)={self.max_num_seqs}")
if self.max_num_batched_tokens is not None:
parts.append(f"max_num_batched_tokens={self.max_num_batched_tokens}")
if self.context_length is not None:
parts.append(f"context_length={self.context_length}")
return ", ".join(parts)
def build_worker_info_from_defaults(
backend: str, sub_component_type: SubComponentType
) -> WorkerInfo:
"""Build a WorkerInfo populated only from hard-coded backend defaults."""
names = WORKER_COMPONENT_NAMES.get(backend)
if names is None:
return WorkerInfo()
if sub_component_type == SubComponentType.PREFILL:
return WorkerInfo(
k8s_name=names.prefill_worker_k8s_name,
component_name=names.prefill_worker_component_name,
endpoint=names.prefill_worker_endpoint,
)
else:
return WorkerInfo(
k8s_name=names.decode_worker_k8s_name,
component_name=names.decode_worker_component_name,
endpoint=names.decode_worker_endpoint,
)
def resolve_worker_info(
backend: str,
require_prefill: bool,
require_decode: bool,
connector: Any = None,
config_model_name: str = "",
no_operation: bool = False,
) -> tuple[WorkerInfo, WorkerInfo]:
"""Build WorkerInfo for prefill/decode and resolve model name.
If the connector has a ``get_worker_info`` method (KubernetesConnector),
MDC is queried first with fallback to DGD container-arg parsing, then
hard-coded defaults. Otherwise hard-coded defaults are used directly.
The resolved model name is written into both WorkerInfo objects so callers
can read it from either ``prefill_info.model_name`` or
``decode_info.model_name``.
Returns:
(prefill_worker_info, decode_worker_info)
"""
can_query_mdc = connector is not None and hasattr(connector, "get_worker_info")
# --- Build WorkerInfo ---
prefill_info = WorkerInfo()
decode_info = WorkerInfo()
if can_query_mdc:
if require_prefill:
prefill_info = connector.get_worker_info(SubComponentType.PREFILL, backend)
if require_decode:
decode_info = connector.get_worker_info(SubComponentType.DECODE, backend)
else:
if require_prefill:
prefill_info = build_worker_info_from_defaults(
backend, SubComponentType.PREFILL
)
if require_decode:
decode_info = build_worker_info_from_defaults(
backend, SubComponentType.DECODE
)
if require_prefill:
logger.info(f"Prefill WorkerInfo: {prefill_info.summary()}")
if require_decode:
logger.info(f"Decode WorkerInfo: {decode_info.summary()}")
# Cross-validate model names
p_model = prefill_info.model_name
d_model = decode_info.model_name
if (
require_prefill
and require_decode
and p_model
and d_model
and p_model != d_model
):
logger.warning(
f"Model name mismatch between prefill ({p_model}) and "
f"decode ({d_model}) WorkerInfo"
)
# --- Resolve model name and write back into both WorkerInfo ---
if no_operation:
if not config_model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
model_name = config_model_name
else:
mdc_model = decode_info.model_name or prefill_info.model_name
if mdc_model:
model_name = mdc_model
logger.info(f"Using model name from MDC: {model_name}")
elif can_query_mdc:
model_name = connector.get_model_name(
require_prefill=require_prefill,
require_decode=require_decode,
)
logger.info(f"Detected model name from DGD container args: {model_name}")
elif config_model_name:
model_name = config_model_name
logger.info(f"Using model name from config: {model_name}")
else:
raise ValueError(
"Could not determine model name. "
"Please set model_name in the config."
)
prefill_info.model_name = model_name
decode_info.model_name = model_name
return prefill_info, decode_info
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
) )
// PlannerDefaults implements ComponentDefaults for Planner components // PlannerDefaults implements ComponentDefaults for Planner components
...@@ -29,12 +30,61 @@ func (p *PlannerDefaults) GetBaseContainer(context ComponentContext) (corev1.Con ...@@ -29,12 +30,61 @@ func (p *PlannerDefaults) GetBaseContainer(context ComponentContext) (corev1.Con
Name: commonconsts.DynamoMetricsPortName, Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
}, },
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoSystemPortName,
ContainerPort: int32(commonconsts.DynamoSystemPort),
},
}
container.LivenessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 1,
} }
container.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 4,
FailureThreshold: 3,
}
// Startup probe with generous timeout: the planner waits for worker
// services to become ready before it can initialise, so it needs more
// time than a typical worker.
container.StartupProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720, // 10s * 720 = 7200s = 2h
}
container.Env = append(container.Env, []corev1.EnvVar{ container.Env = append(container.Env, []corev1.EnvVar{
{ {
Name: "PLANNER_PROMETHEUS_PORT", Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort), Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
}, },
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
}...) }...)
return container, nil return container, nil
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
) )
func TestPlannerDefaults_GetBaseContainer(t *testing.T) { func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
...@@ -45,6 +46,40 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) { ...@@ -45,6 +46,40 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
}, },
Ports: []corev1.ContainerPort{ Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: commonconsts.DynamoPlannerMetricsPort, Protocol: corev1.ProtocolTCP}, {Name: commonconsts.DynamoMetricsPortName, ContainerPort: commonconsts.DynamoPlannerMetricsPort, Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
LivenessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 1,
},
ReadinessProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/health",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 4,
FailureThreshold: 3,
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
}, },
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{Name: commonconsts.DynamoNamespaceEnvVar, Value: "dynamo-namespace"}, {Name: commonconsts.DynamoNamespaceEnvVar, Value: "dynamo-namespace"},
...@@ -71,6 +106,7 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) { ...@@ -71,6 +106,7 @@ func TestPlannerDefaults_GetBaseContainer(t *testing.T) {
}}, }},
{Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"}, {Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"},
{Name: "PLANNER_PROMETHEUS_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort)}, {Name: "PLANNER_PROMETHEUS_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort)},
{Name: "DYN_SYSTEM_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort)},
}, },
}, },
}, },
......
...@@ -1655,6 +1655,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -1655,6 +1655,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1", "--planner-env-1",
"1", "1",
}, },
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{ EnvFrom: []corev1.EnvFromSource{
{ {
SecretRef: &corev1.SecretEnvSource{ SecretRef: &corev1.SecretEnvSource{
...@@ -1680,6 +1684,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -1680,6 +1684,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
}, },
}, },
}, },
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{ {
Name: "DYNAMO_POD_GANG_SET_REPLICAS", Name: "DYNAMO_POD_GANG_SET_REPLICAS",
...@@ -1717,6 +1732,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -1717,6 +1732,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace", Value: "test-namespace",
}, },
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{ {
Name: "MODEL_EXPRESS_URL", Name: "MODEL_EXPRESS_URL",
Value: "model-express-url", Value: "model-express-url",
...@@ -1775,13 +1794,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -1775,13 +1794,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
MountPath: commonconsts.DefaultSharedMemoryMountPath, MountPath: commonconsts.DefaultSharedMemoryMountPath,
}, },
}, },
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
},
},
}, },
}, },
}, },
...@@ -2632,6 +2644,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -2632,6 +2644,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1", "--planner-env-1",
"1", "1",
}, },
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{ EnvFrom: []corev1.EnvFromSource{
{ {
SecretRef: &corev1.SecretEnvSource{ SecretRef: &corev1.SecretEnvSource{
...@@ -2657,6 +2673,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -2657,6 +2673,17 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
}, },
}, },
}, },
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{
Path: "/live",
Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
},
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
},
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{ {
Name: "DYNAMO_POD_GANG_SET_REPLICAS", Name: "DYNAMO_POD_GANG_SET_REPLICAS",
...@@ -2694,6 +2721,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -2694,6 +2721,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace", Value: "test-namespace",
}, },
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{ {
Name: "PLANNER_PROMETHEUS_PORT", Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort), Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
...@@ -2744,13 +2775,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -2744,13 +2775,6 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
MountPath: commonconsts.DefaultSharedMemoryMountPath, MountPath: commonconsts.DefaultSharedMemoryMountPath,
}, },
}, },
Ports: []corev1.ContainerPort{
{
Protocol: corev1.ProtocolTCP,
Name: commonconsts.DynamoMetricsPortName,
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort),
},
},
}, },
}, },
}, },
...@@ -3618,6 +3642,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -3618,6 +3642,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
"--planner-env-1", "--planner-env-1",
"1", "1",
}, },
Ports: []corev1.ContainerPort{
{Name: commonconsts.DynamoMetricsPortName, ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Protocol: corev1.ProtocolTCP},
{Name: commonconsts.DynamoSystemPortName, ContainerPort: int32(commonconsts.DynamoSystemPort), Protocol: corev1.ProtocolTCP},
},
EnvFrom: []corev1.EnvFromSource{ EnvFrom: []corev1.EnvFromSource{
{ {
SecretRef: &corev1.SecretEnvSource{ SecretRef: &corev1.SecretEnvSource{
...@@ -3643,12 +3671,16 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -3643,12 +3671,16 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
}, },
}, },
}, },
Ports: []corev1.ContainerPort{ StartupProbe: &corev1.Probe{
{ ProbeHandler: corev1.ProbeHandler{
Protocol: corev1.ProtocolTCP, HTTPGet: &corev1.HTTPGetAction{
Name: commonconsts.DynamoMetricsPortName, Path: "/live",
ContainerPort: int32(commonconsts.DynamoPlannerMetricsPort), Port: intstr.FromString(commonconsts.DynamoSystemPortName),
},
}, },
PeriodSeconds: 10,
TimeoutSeconds: 5,
FailureThreshold: 720,
}, },
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{ {
...@@ -3687,6 +3719,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { ...@@ -3687,6 +3719,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) {
Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Name: "DYN_PARENT_DGD_K8S_NAMESPACE",
Value: "test-namespace", Value: "test-namespace",
}, },
{
Name: "DYN_SYSTEM_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort),
},
{ {
Name: "PLANNER_PROMETHEUS_PORT", Name: "PLANNER_PROMETHEUS_PORT",
Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort), Value: fmt.Sprintf("%d", commonconsts.DynamoPlannerMetricsPort),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment