Unverified Commit 44190094 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

fix(planner): backfill max_num_batched_tokens from discovery for VirtualConnector (#8042)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent 5c8f8ffb
......@@ -21,6 +21,12 @@ from typing import Optional
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.connectors.kubernetes_api import KubernetesAPI
from dynamo.planner.connectors.mdc import (
MdcEntry,
is_model_card,
select_entry,
worker_info_from_mdc,
)
from dynamo.planner.errors import (
DeploymentModelNameMismatchError,
DeploymentValidationError,
......@@ -340,22 +346,18 @@ class KubernetesConnector(PlannerConnector):
return []
raise
def _extract_mdc_entries(
self,
) -> list[dict]:
def _extract_mdc_entries(self) -> list[MdcEntry]:
"""Extract MDC entries belonging to this DGD.
CRs are named after the worker pod (e.g. ``<dgd>-0-<service>-<hash>``),
so we filter by the DGD name prefix to avoid picking up entries from
other deployments sharing the namespace.
Returns a list of dicts, each containing:
namespace, component, endpoint, instance_id, card_json
other deployments sharing the namespace. LoRA-adapter wrappers are
dropped via :func:`is_model_card`.
"""
crs = self._list_worker_metadata_crs()
dgd_prefix = f"{self.graph_deployment_name}-"
entries: list[dict] = []
entries: list[MdcEntry] = []
for cr in crs:
cr_name = cr.get("metadata", {}).get("name", "")
if not cr_name.startswith(dgd_prefix):
......@@ -368,125 +370,37 @@ class KubernetesConnector(PlannerConnector):
except json.JSONDecodeError:
continue
model_cards = data.get("model_cards", {})
for _key, instance in model_cards.items():
if instance.get("type") != "Model":
for _key, wrapper in model_cards.items():
if not is_model_card(wrapper):
continue
entries.append(instance)
return entries
def _mdc_entry_is_prefill(self, entry: dict) -> bool:
"""Check if an MDC entry is a prefill worker.
model_type can be serialized as:
- An integer bitflag (ModelType::Prefill = 1 << 4 = 16)
- A dict with a "bits" key (serde bitflags format)
- A string like "Prefill" or "Chat|Completions"
"""
card = entry.get("card_json", {})
model_type = card.get("model_type", 0)
if isinstance(model_type, str):
return "prefill" in model_type.lower()
if isinstance(model_type, dict):
model_type = model_type.get("bits", 0)
return bool(model_type & 0x10)
def _build_worker_info_from_mdc(
self,
entry: dict,
sub_component_type: SubComponentType,
) -> WorkerInfo:
"""Build a WorkerInfo from an MDC entry, applying the fallback chain.
Priority: MDC -> DGD container-arg parsing -> hard-coded defaults.
"""
defaults = build_worker_info_from_defaults(
self._backend_hint or "vllm", sub_component_type
entries.append(
MdcEntry(
card_json=wrapper.get("card_json") or {},
component=wrapper.get("component"),
endpoint=wrapper.get("endpoint"),
instance_id=wrapper.get("instance_id"),
)
card = entry.get("card_json", {})
runtime_cfg = card.get("runtime_config", {})
# --- component / endpoint names from MDC wrapper ---
mdc_component = entry.get("component")
mdc_endpoint = entry.get("endpoint")
component_name = mdc_component or defaults.component_name
endpoint = mdc_endpoint or defaults.endpoint
if not mdc_component:
logger.info(
f"MDC missing 'component' for {sub_component_type.value}, "
f"falling back to default: {defaults.component_name}"
)
if not mdc_endpoint:
logger.info(
f"MDC missing 'endpoint' for {sub_component_type.value}, "
f"falling back to default: {defaults.endpoint}"
)
return entries
# --- model name ---
mdc_model = card.get("display_name")
model_name = mdc_model
if not model_name:
# Fallback: parse from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
model_name = service.get_model_name()
if model_name:
logger.info(
f"MDC missing model name for {sub_component_type.value}, "
f"fell back to DGD container args: {model_name}"
)
except PlannerError:
pass
if not model_name:
logger.warning(
f"Could not determine model name for {sub_component_type.value} "
f"from MDC or DGD container args"
)
# --- runtime config fields ---
total_kv_blocks = runtime_cfg.get("total_kv_blocks")
max_num_seqs = runtime_cfg.get("max_num_seqs")
max_num_batched_tokens = runtime_cfg.get("max_num_batched_tokens")
kv_cache_block_size = card.get("kv_cache_block_size")
context_length = card.get("context_length")
if total_kv_blocks is None:
logger.info(f"MDC missing total_kv_blocks for {sub_component_type.value}")
if max_num_seqs is None:
logger.info(f"MDC missing max_num_seqs for {sub_component_type.value}")
def _resolve_dgd_service(
self, sub_component_type: SubComponentType, backend: str
) -> tuple[Optional[str], str]:
"""Return (dgd_service_name, component_name_for_filter).
# --- k8s_name: resolve from DGD subComponentType ---
k8s_name = defaults.k8s_name
Uses the DGD when available; otherwise falls back to backend defaults
so that stale/missing DGD state still lets us filter out LoRA cards
by their expected component name.
"""
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
k8s_name = service.name
return service.name, service.name
except PlannerError:
logger.info(
f"Could not resolve k8s service name for {sub_component_type.value}, "
f"using default: {defaults.k8s_name}"
)
info = WorkerInfo(
k8s_name=k8s_name,
component_name=component_name,
endpoint=endpoint,
model_name=model_name,
total_kv_blocks=total_kv_blocks,
kv_cache_block_size=kv_cache_block_size,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
context_length=context_length,
)
return info
defaults = build_worker_info_from_defaults(backend, sub_component_type)
return None, defaults.component_name or ""
def get_worker_info(
self,
......@@ -499,75 +413,58 @@ class KubernetesConnector(PlannerConnector):
sub_component_type: PREFILL or DECODE
backend: Backend framework name (for default fallback)
"""
self._backend_hint = backend
entries = self._extract_mdc_entries()
dgd_service_name, expected_component = self._resolve_dgd_service(
sub_component_type, backend
)
# Resolve the expected component name so we can scope card selection
# and avoid picking up LoRA-adapter cards that share the same CR but
# carry a different component/model identity.
expected_component: Optional[str] = None
def _dgd_model_name() -> Optional[str]:
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
deployment = self.kube_api.get_graph_deployment(
self.graph_deployment_name
)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
expected_component = service.name
return service.get_model_name()
except PlannerError:
expected_component = build_worker_info_from_defaults(
backend, sub_component_type
).component_name
for entry in entries:
is_prefill = self._mdc_entry_is_prefill(entry)
if sub_component_type == SubComponentType.PREFILL and not is_prefill:
continue
if sub_component_type == SubComponentType.DECODE and is_prefill:
continue
return None
entry_component = entry.get("component")
if (
entry_component
and expected_component
and entry_component != expected_component
):
logger.debug(
f"Skipping MDC entry with component={entry_component!r}, "
f"expected {expected_component!r} for {sub_component_type.value}"
entry = select_entry(entries, sub_component_type, expected_component)
if entry is not None:
info = worker_info_from_mdc(
entry,
sub_component_type,
backend=backend,
model_name_fallback=_dgd_model_name,
k8s_name_override=dgd_service_name,
)
if not info.model_name:
logger.warning(
f"Could not determine model name for {sub_component_type.value} "
f"from MDC or DGD container args"
)
continue
info = self._build_worker_info_from_mdc(entry, sub_component_type)
logger.info(
f"Built {sub_component_type.value} WorkerInfo from MDC: "
f"{info.summary()}"
)
return info
# No MDC entry found -- fall back entirely to defaults + DGD arg parsing
# No MDC entry found -- fall back entirely to defaults + DGD arg parsing.
logger.warning(
f"No DynamoWorkerMetadata CR found for {sub_component_type.value}. "
f"Workers may not be registered yet. Falling back to defaults."
)
info = build_worker_info_from_defaults(backend, sub_component_type)
# Try to enrich model_name from DGD container args
try:
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
service = get_service_from_sub_component_type_or_name(
deployment, sub_component_type
)
info.k8s_name = service.name
arg_model = service.get_model_name()
if dgd_service_name is not None:
info.k8s_name = dgd_service_name
arg_model = _dgd_model_name()
if arg_model:
info.model_name = arg_model
logger.info(
f"Enriched {sub_component_type.value} WorkerInfo model name "
f"from DGD args: {arg_model}"
)
except PlannerError as e:
logger.info(
f"Could not enrich WorkerInfo from DGD for {sub_component_type.value}: {e}"
)
logger.info(
f"Using fallback WorkerInfo for {sub_component_type.value}: {info.summary()}"
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared ModelDeploymentCard (MDC) plumbing for planner connectors.
Connectors that want to populate ``WorkerInfo`` from MDCs feed
:class:`MdcEntry` records into :func:`worker_info_from_mdc`. The transform
is pure: it has no K8s or discovery dependency. Each connector supplies
its own :class:`MdcSource` implementation that fetches entries from
whatever backing store it uses (Kubernetes CRDs, the dynamo discovery
watch, etc.).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Protocol
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.monitoring.worker_info import (
WorkerInfo,
build_worker_info_from_defaults,
)
logger = logging.getLogger(__name__)
# ModelType::Prefill bit in ModelDeploymentCard::model_type (bitflags: 1 << 4).
_MODEL_TYPE_PREFILL_BIT = 0x10
@dataclass
class MdcEntry:
"""Normalized MDC record consumed by :func:`worker_info_from_mdc`.
``component`` / ``endpoint`` come from the CRD wrapper in K8s mode and
are typically ``None`` for discovery-sourced entries (where they fall
back to backend defaults).
"""
card_json: dict = field(default_factory=dict)
component: Optional[str] = None
endpoint: Optional[str] = None
instance_id: Optional[str] = None
class MdcSource(Protocol):
"""Source of :class:`MdcEntry` records scoped to a sub-component type."""
def get_entries(self, sub_component_type: SubComponentType) -> list[MdcEntry]:
...
def is_model_card(wrapper: dict) -> bool:
"""Filter that excludes LoRA-adapter cards from a CRD wrapper dict.
K8s CRDs store both Model and LoRA cards together; discovery-sourced
entries are already scoped to the ``Model`` variant, so this filter
only matters for the CRD path.
"""
return wrapper.get("type") == "Model"
def is_prefill_card(card_json: dict) -> bool:
"""Whether a card_json belongs to a prefill worker.
``model_type`` can be serialized three ways depending on the producer:
an integer bitflag, a serde-bitflags dict with a ``bits`` key, or a
human-readable string (e.g. ``"Prefill"`` / ``"Chat|Completions"``).
"""
model_type: Any = card_json.get("model_type", 0)
if isinstance(model_type, str):
return "prefill" in model_type.lower()
if isinstance(model_type, dict):
model_type = model_type.get("bits", 0)
try:
return bool(int(model_type) & _MODEL_TYPE_PREFILL_BIT)
except (TypeError, ValueError):
return False
def select_entry(
entries: list[MdcEntry],
sub_component_type: SubComponentType,
expected_component: Optional[str] = None,
) -> Optional[MdcEntry]:
"""Pick the first entry matching ``sub_component_type`` and (if given)
``expected_component``. Used to scope past LoRA-adapter cards and
stray entries from sibling deployments.
"""
want_prefill = sub_component_type == SubComponentType.PREFILL
for entry in entries:
if is_prefill_card(entry.card_json) != want_prefill:
continue
if (
entry.component
and expected_component
and entry.component != expected_component
):
continue
return entry
return None
def worker_info_from_mdc(
entry: MdcEntry,
sub_component_type: SubComponentType,
backend: str,
model_name_fallback: Optional[Callable[[], Optional[str]]] = None,
k8s_name_override: Optional[str] = None,
) -> WorkerInfo:
"""Build a :class:`WorkerInfo` from a single :class:`MdcEntry`.
Pure function. Connector-specific enrichment is injected via
``model_name_fallback`` (called only when the card lacks a
``display_name``) and ``k8s_name_override``.
"""
defaults = build_worker_info_from_defaults(backend, sub_component_type)
card = entry.card_json or {}
runtime_cfg = card.get("runtime_config")
if not isinstance(runtime_cfg, dict):
runtime_cfg = {}
component_name = entry.component or defaults.component_name
endpoint = entry.endpoint or defaults.endpoint
model_name: Optional[str] = card.get("display_name")
if not model_name and model_name_fallback is not None:
try:
model_name = model_name_fallback()
except (RuntimeError, OSError, ValueError) as e:
logger.debug(f"Model name fallback raised: {e}")
model_name = None
k8s_name = k8s_name_override if k8s_name_override is not None else defaults.k8s_name
return WorkerInfo(
k8s_name=k8s_name,
component_name=component_name,
endpoint=endpoint,
model_name=model_name,
total_kv_blocks=runtime_cfg.get("total_kv_blocks"),
kv_cache_block_size=card.get("kv_cache_block_size"),
max_num_seqs=runtime_cfg.get("max_num_seqs"),
max_num_batched_tokens=runtime_cfg.get("max_num_batched_tokens"),
context_length=card.get("context_length"),
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
from typing import Optional
from typing import TYPE_CHECKING, Optional
from dynamo._core import VirtualConnectorCoordinator
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.connectors.base import PlannerConnector
from dynamo.planner.connectors.mdc import MdcEntry, select_entry, worker_info_from_mdc
from dynamo.planner.errors import EmptyTargetReplicasError
from dynamo.planner.monitoring.worker_info import (
WorkerInfo,
build_worker_info_from_defaults,
)
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
if TYPE_CHECKING:
from dynamo.llm import FpmEventSubscriber
configure_dynamo_logging()
logger = logging.getLogger(__name__)
......@@ -25,6 +34,35 @@ SCALING_MAX_WAIT_TIME = int(
SCALING_MAX_RETRIES = SCALING_MAX_WAIT_TIME // SCALING_CHECK_INTERVAL # 180 retries
def _mdc_entries_from_subscriber(
subscriber: Optional["FpmEventSubscriber"],
) -> list[MdcEntry]:
"""Read the discovery-captured card JSON snapshot from an FPM subscriber.
Returns an empty list if tracking has not been started yet or no cards
have been observed. Discovery-sourced entries have no wrapper
component/endpoint (those come from the CRD in K8s mode); worker_info_from_mdc
falls back to backend defaults for those fields.
"""
if subscriber is None:
return []
try:
cards = subscriber.get_model_cards()
except RuntimeError:
# start_tracking() not called yet.
return []
entries: list[MdcEntry] = []
for worker_id, card_str in cards.items():
try:
card_json = json.loads(card_str)
except json.JSONDecodeError:
logger.warning(f"Skipping malformed MDC card JSON for worker {worker_id}")
continue
entries.append(MdcEntry(card_json=card_json, instance_id=worker_id))
return entries
class VirtualConnector(PlannerConnector):
"""
This is a virtual connector for planner to output scaling decisions to non-native environments
......@@ -53,6 +91,55 @@ class VirtualConnector(PlannerConnector):
self.dynamo_namespace = dynamo_namespace
# MDC sources injected by NativePlannerBase after FPM subscribers exist.
self._prefill_mdc_sub: Optional["FpmEventSubscriber"] = None
self._decode_mdc_sub: Optional["FpmEventSubscriber"] = None
def set_mdc_subscribers(
self,
prefill: Optional["FpmEventSubscriber"] = None,
decode: Optional["FpmEventSubscriber"] = None,
) -> None:
"""Inject FPM subscribers used as the MDC source for get_worker_info.
VirtualConnector has no K8s CRDs to read, so it reads model cards
from the discovery watch maintained by the FPM subscribers. Until
this is called, get_worker_info returns backend defaults only.
"""
self._prefill_mdc_sub = prefill
self._decode_mdc_sub = decode
def get_worker_info(
self,
sub_component_type: SubComponentType,
backend: str = "vllm",
) -> WorkerInfo:
"""Populate WorkerInfo from discovery-sourced MDCs, with defaults fallback.
Called by ``resolve_worker_info`` (once, at init) and by the tick-loop
refresh (once cards are available in discovery).
"""
subscriber = (
self._prefill_mdc_sub
if sub_component_type == SubComponentType.PREFILL
else self._decode_mdc_sub
)
entries = _mdc_entries_from_subscriber(subscriber)
entry = select_entry(entries, sub_component_type)
if entry is not None:
info = worker_info_from_mdc(
entry,
sub_component_type,
backend=backend,
)
if not info.model_name:
info.model_name = self.model_name
return info
info = build_worker_info_from_defaults(backend, sub_component_type)
info.model_name = self.model_name
return info
async def _async_init(self):
"""Async initialization that must be called after __init__"""
await self.connector.async_init()
......
......@@ -24,7 +24,7 @@ import aiohttp.web
from prometheus_client import start_http_server
from dynamo.planner.config.backend_components import WORKER_COMPONENT_NAMES
from dynamo.planner.config.defaults import TargetReplica
from dynamo.planner.config.defaults import SubComponentType, TargetReplica
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.connectors.global_planner import GlobalPlannerConnector
from dynamo.planner.connectors.kubernetes import KubernetesConnector
......@@ -259,6 +259,10 @@ class NativePlannerBase:
)
await self.connector.wait_for_deployment_ready(include_planner=False)
# Resolve WorkerInfo once from the connector. For K8s this populates
# runtime_config fields from MDC CRDs; for Virtual it returns backend
# defaults (subscribers aren't attached yet) which is enough to
# construct the FPM endpoint.
await self._init_worker_info()
if self.runtime is not None:
......@@ -267,6 +271,15 @@ class NativePlannerBase:
if self.require_decode:
await self._init_fpm_subscriber("decode")
# VirtualConnector reads MDC from the FPM subscriber's discovery watch;
# hand it the subscribers now that they exist. The tick-loop refresh
# will backfill runtime_config fields once discovery sees the workers.
if isinstance(self.connector, VirtualConnector):
self.connector.set_mdc_subscribers(
prefill=self._prefill_fpm_sub,
decode=self._decode_fpm_sub,
)
await self._bootstrap_regression()
# Log operating mode at startup
......@@ -332,6 +345,73 @@ class NativePlannerBase:
"""Override in subclasses to bootstrap regression models."""
pass
# ------------------------------------------------------------------
# Discovery refresh
# ------------------------------------------------------------------
_MDC_REFRESH_FIELDS = (
"total_kv_blocks",
"kv_cache_block_size",
"max_num_seqs",
"max_num_batched_tokens",
"context_length",
)
def _refresh_worker_info_from_connector(self) -> None:
"""Re-query the connector for any sub-component whose WorkerInfo is
still missing runtime-config fields.
This handles the cold-start path where workers haven't registered
their model cards yet when ``_init_worker_info`` first runs. It is
a no-op for K8s mode once CRDs are present, and drives the
VirtualConnector's discovery-sourced population once cards arrive.
"""
if not hasattr(self.connector, "get_worker_info"):
return
targets: list[tuple[WorkerInfo, SubComponentType]] = []
if self.require_prefill:
targets.append((self.prefill_worker_info, SubComponentType.PREFILL))
if self.require_decode:
targets.append((self.decode_worker_info, SubComponentType.DECODE))
changed = False
for worker_info, sub_type in targets:
if worker_info.max_num_batched_tokens is not None:
continue
try:
fresh = self.connector.get_worker_info(sub_type, self.config.backend)
except Exception as e:
logger.debug(
f"get_worker_info refresh for {sub_type.value} failed: {e}"
)
continue
updated = False
for field_name in self._MDC_REFRESH_FIELDS:
fresh_val = getattr(fresh, field_name)
if (
fresh_val is not None
and getattr(worker_info, field_name) != fresh_val
):
setattr(worker_info, field_name, fresh_val)
updated = True
if updated:
changed = True
logger.info(
f"Refreshed {sub_type.value} WorkerInfo from connector: "
f"{worker_info.summary()}"
)
if changed and self._state_machine is not None:
self._state_machine.update_capabilities(
build_worker_capabilities(
self.config,
self.prefill_worker_info,
self.decode_worker_info,
)
)
# ------------------------------------------------------------------
# Data collection (runtime I/O)
# ------------------------------------------------------------------
......@@ -731,6 +811,8 @@ class NativePlannerBase:
await asyncio.sleep(min(next_tick.at_s - now, poll_interval))
continue
self._refresh_worker_info_from_connector()
tick_input = await self._gather_tick_input(next_tick)
effects = self.state_machine.on_tick(next_tick, tick_input)
await self._apply_effects(effects)
......
......@@ -63,6 +63,7 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
) -> None:
self._config = config
self._capabilities = capabilities or WorkerCapabilities()
self._is_agg = config.mode == "agg"
self._has_prefill = config.mode in ("disagg", "prefill")
self._has_decode = config.mode in ("disagg", "decode", "agg")
......@@ -135,6 +136,10 @@ class PlannerStateMachine(LoadScalingMixin, ThroughputScalingMixin):
# Public API
# ------------------------------------------------------------------
def update_capabilities(self, capabilities: WorkerCapabilities) -> None:
"""Replace the current worker capabilities."""
self._capabilities = capabilities
def initial_tick(self, start_s: float) -> ScheduledTick:
self._next_load_s = start_s + self._config.load_adjustment_interval
if self._config.enable_throughput_scaling:
......
......@@ -9,6 +9,8 @@ DecodeRegressionModel, AggRegressionModel) without any planner adapter.
FPM-driven scaling integration tests live in test_state_machine.py.
"""
from unittest.mock import Mock, patch
import pytest
try:
......@@ -24,11 +26,14 @@ try:
)
except ImportError:
pytest.skip("forward_pass_metrics not available", allow_module_level=True)
from dynamo.planner.config.planner_config import PlannerConfig
from dynamo.planner.core.base import NativePlannerBase
from dynamo.planner.core.perf_model import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.monitoring.worker_info import WorkerInfo
pytestmark = [
pytest.mark.gpu_0,
......@@ -650,3 +655,153 @@ class TestAggRegressionModel:
)
# Strictly greater capacity at 80% hit rate (not just >=).
assert rps_high > rps_zero
# ── Connector-driven refresh tests ──────────────────────────────────
class TestRefreshWorkerInfoFromConnector:
"""Tests for NativePlannerBase._refresh_worker_info_from_connector.
The tick-loop refresh delegates to the connector's ``get_worker_info``,
which is where each connector implements its own MDC source (K8s CRDs
for KubernetesConnector, discovery watch for VirtualConnector). These
tests exercise the shared refresh plumbing with a mock connector.
"""
def _make_planner(self, require_prefill=False, require_decode=True):
"""Build a minimal NativePlannerBase with no_operation=True."""
with patch("dynamo.planner.monitoring.planner_metrics.Gauge") as mock_gauge:
mock_gauge.return_value = Mock()
config = PlannerConfig.model_construct(
throughput_adjustment_interval=60,
prefill_engine_num_gpu=1,
decode_engine_num_gpu=1,
min_endpoint=1,
max_gpu_budget=-1,
ttft=500.0,
itl=50.0,
backend="vllm",
no_operation=True,
metric_pulling_prometheus_endpoint="http://localhost:9090",
metric_reporting_prometheus_port=0,
load_predictor="constant",
environment="kubernetes",
namespace="test-namespace",
mode="agg",
enable_load_scaling=True,
enable_throughput_scaling=True,
load_adjustment_interval=5,
max_num_fpm_samples=50,
fpm_sample_bucket_size=16,
load_scaling_down_sensitivity=80,
load_metric_samples=10,
load_min_observations=5,
)
planner = NativePlannerBase(None, config)
planner.require_prefill = require_prefill
planner.require_decode = require_decode
planner.prefill_worker_info = WorkerInfo()
planner.decode_worker_info = WorkerInfo()
return planner
def _install_mock_connector(self, planner, **fresh_info_kwargs):
"""Replace planner.connector with a Mock returning a fresh WorkerInfo."""
fresh = WorkerInfo(**fresh_info_kwargs)
mock_connector = Mock()
mock_connector.get_worker_info.return_value = fresh
planner.connector = mock_connector
return mock_connector
def test_refresh_populates_missing_fields(self):
"""Connector returns a populated WorkerInfo; missing fields backfill."""
planner = self._make_planner()
assert planner.decode_worker_info.max_num_batched_tokens is None
self._install_mock_connector(
planner,
max_num_batched_tokens=8192,
total_kv_blocks=1024,
max_num_seqs=256,
kv_cache_block_size=16,
context_length=4096,
)
planner._refresh_worker_info_from_connector()
assert planner.decode_worker_info.max_num_batched_tokens == 8192
assert planner.decode_worker_info.total_kv_blocks == 1024
assert planner.decode_worker_info.max_num_seqs == 256
assert planner.decode_worker_info.kv_cache_block_size == 16
assert planner.decode_worker_info.context_length == 4096
def test_noop_when_already_set(self):
"""Does not re-query once max_num_batched_tokens is populated."""
planner = self._make_planner()
planner.decode_worker_info = WorkerInfo(max_num_batched_tokens=2048)
mock_connector = self._install_mock_connector(
planner, max_num_batched_tokens=8192
)
planner._refresh_worker_info_from_connector()
assert planner.decode_worker_info.max_num_batched_tokens == 2048
mock_connector.get_worker_info.assert_not_called()
def test_noop_when_connector_lacks_get_worker_info(self):
"""Silently does nothing if the connector does not implement get_worker_info."""
planner = self._make_planner()
class _StubConnector:
pass
planner.connector = _StubConnector()
planner._refresh_worker_info_from_connector()
assert planner.decode_worker_info.max_num_batched_tokens is None
def test_noop_when_connector_returns_none_fields(self):
"""Fresh WorkerInfo with None everywhere does not overwrite anything."""
planner = self._make_planner()
self._install_mock_connector(planner) # All Nones
planner._refresh_worker_info_from_connector()
assert planner.decode_worker_info.max_num_batched_tokens is None
def test_exception_does_not_propagate(self):
"""If connector.get_worker_info throws, refresh is a no-op."""
planner = self._make_planner()
mock_connector = Mock()
mock_connector.get_worker_info.side_effect = RuntimeError("boom")
planner.connector = mock_connector
planner._refresh_worker_info_from_connector() # must not raise
assert planner.decode_worker_info.max_num_batched_tokens is None
def test_updates_state_machine_capabilities(self):
"""State machine capabilities are updated via update_capabilities()."""
planner = self._make_planner()
_ = planner.state_machine
assert planner._state_machine is not None
self._install_mock_connector(planner, max_num_batched_tokens=4096)
planner._refresh_worker_info_from_connector()
assert planner.decode_worker_info.max_num_batched_tokens == 4096
assert (
planner._state_machine._capabilities.decode.max_num_batched_tokens == 4096
)
def test_refresh_skips_unneeded_sub_component(self):
"""Only sub-components with require_* True are refreshed."""
planner = self._make_planner(require_prefill=False, require_decode=True)
def _side_effect(sub_type, backend):
# Should only be called for DECODE.
assert sub_type.value == "decode"
return WorkerInfo(max_num_batched_tokens=4096)
mock_connector = Mock()
mock_connector.get_worker_info.side_effect = _side_effect
planner.connector = mock_connector
planner._refresh_worker_info_from_connector()
assert planner.prefill_worker_info.max_num_batched_tokens is None
assert planner.decode_worker_info.max_num_batched_tokens == 4096
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for the shared MDC helpers.
These cover the pure transform shared by KubernetesConnector and
VirtualConnector (see PR #8042 review): parsing ``card_json`` /
``runtime_config`` into ``WorkerInfo``, the prefill/decode heuristic,
and the LoRA-card filter.
"""
import pytest
from dynamo.planner.config.defaults import SubComponentType
from dynamo.planner.connectors.mdc import (
MdcEntry,
is_model_card,
is_prefill_card,
select_entry,
worker_info_from_mdc,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.planner,
]
# ── is_model_card ──────────────────────────────────────────────────
class TestIsModelCard:
def test_model_wrapper_accepted(self):
assert is_model_card({"type": "Model"})
def test_lora_wrapper_rejected(self):
assert not is_model_card({"type": "LoRA"})
def test_missing_type_rejected(self):
assert not is_model_card({})
# ── is_prefill_card ────────────────────────────────────────────────
class TestIsPrefillCard:
def test_int_prefill_bit_set(self):
# ModelType::Prefill = 1 << 4 = 16
assert is_prefill_card({"model_type": 16})
def test_int_prefill_bit_unset(self):
# Chat | Completions = 1 | 2 = 3
assert not is_prefill_card({"model_type": 3})
def test_int_prefill_combined_with_other_bits(self):
# Prefill | Chat = 16 | 1 = 17
assert is_prefill_card({"model_type": 17})
def test_bitflags_dict_with_prefill(self):
assert is_prefill_card({"model_type": {"bits": 16}})
def test_bitflags_dict_without_prefill(self):
assert not is_prefill_card({"model_type": {"bits": 2}})
def test_string_prefill(self):
assert is_prefill_card({"model_type": "Prefill"})
def test_string_prefill_combined(self):
assert is_prefill_card({"model_type": "Prefill|Chat"})
def test_string_no_prefill(self):
assert not is_prefill_card({"model_type": "Chat|Completions"})
def test_missing_model_type_defaults_to_decode(self):
assert not is_prefill_card({})
def test_unparseable_model_type_defaults_to_decode(self):
assert not is_prefill_card({"model_type": "not-a-thing"})
assert not is_prefill_card({"model_type": None})
# ── worker_info_from_mdc ──────────────────────────────────────────
def _card(**runtime_config_overrides) -> dict:
"""Build a minimal realistic card_json payload."""
return {
"display_name": "meta-llama/Llama-3.1-8B",
"model_type": 2, # Completions (not prefill)
"kv_cache_block_size": 16,
"context_length": 8192,
"runtime_config": {
"total_kv_blocks": 1024,
"max_num_seqs": 256,
"max_num_batched_tokens": 8192,
**runtime_config_overrides,
},
}
class TestWorkerInfoFromMdc:
def test_happy_path_populates_all_fields(self):
entry = MdcEntry(
card_json=_card(),
component="backend",
endpoint="generate",
)
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
assert info.model_name == "meta-llama/Llama-3.1-8B"
assert info.component_name == "backend"
assert info.endpoint == "generate"
assert info.total_kv_blocks == 1024
assert info.max_num_seqs == 256
assert info.max_num_batched_tokens == 8192
assert info.kv_cache_block_size == 16
assert info.context_length == 8192
def test_missing_wrapper_fields_fall_back_to_defaults(self):
entry = MdcEntry(card_json=_card())
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
# From VllmComponentName
assert info.component_name == "backend"
assert info.endpoint == "generate"
assert info.k8s_name == "VllmDecodeWorker"
def test_prefill_defaults(self):
entry = MdcEntry(card_json=_card())
info = worker_info_from_mdc(entry, SubComponentType.PREFILL, backend="vllm")
assert info.component_name == "prefill"
assert info.k8s_name == "VllmPrefillWorker"
def test_model_name_fallback_invoked_when_card_missing(self):
card = _card()
del card["display_name"]
entry = MdcEntry(card_json=card)
info = worker_info_from_mdc(
entry,
SubComponentType.DECODE,
backend="vllm",
model_name_fallback=lambda: "from-dgd-args",
)
assert info.model_name == "from-dgd-args"
def test_model_name_fallback_not_invoked_when_card_has_it(self):
called = []
entry = MdcEntry(card_json=_card())
info = worker_info_from_mdc(
entry,
SubComponentType.DECODE,
backend="vllm",
model_name_fallback=lambda: (called.append(1), "other")[1],
)
assert info.model_name == "meta-llama/Llama-3.1-8B"
assert not called
def test_k8s_name_override(self):
entry = MdcEntry(card_json=_card())
info = worker_info_from_mdc(
entry,
SubComponentType.DECODE,
backend="vllm",
k8s_name_override="custom-decode-svc",
)
assert info.k8s_name == "custom-decode-svc"
def test_partial_runtime_config_populates_available_fields(self):
card = _card()
del card["runtime_config"]["max_num_seqs"]
entry = MdcEntry(card_json=card)
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
assert info.max_num_batched_tokens == 8192
assert info.max_num_seqs is None
def test_missing_runtime_config(self):
card = _card()
del card["runtime_config"]
entry = MdcEntry(card_json=card)
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
assert info.max_num_batched_tokens is None
assert info.total_kv_blocks is None
# Non-runtime_config fields still populate
assert info.kv_cache_block_size == 16
assert info.context_length == 8192
def test_non_dict_runtime_config_treated_as_missing(self):
card = _card()
card["runtime_config"] = "not-a-dict"
entry = MdcEntry(card_json=card)
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
assert info.max_num_batched_tokens is None
assert info.total_kv_blocks is None
assert info.max_num_seqs is None
def test_empty_card_json(self):
entry = MdcEntry(card_json={})
info = worker_info_from_mdc(entry, SubComponentType.DECODE, backend="vllm")
# Component/endpoint/k8s_name from defaults
assert info.component_name == "backend"
# Runtime fields all None
assert info.max_num_batched_tokens is None
assert info.model_name is None
def test_fallback_exception_produces_none_model_name(self):
card = _card()
del card["display_name"]
entry = MdcEntry(card_json=card)
def _raise():
raise RuntimeError("boom")
info = worker_info_from_mdc(
entry,
SubComponentType.DECODE,
backend="vllm",
model_name_fallback=_raise,
)
assert info.model_name is None
# ── select_entry ───────────────────────────────────────────────────
class TestSelectEntry:
def test_selects_prefill_entry(self):
entries = [
MdcEntry(card_json={**_card(), "model_type": 2}, component="backend"),
MdcEntry(card_json={**_card(), "model_type": 16}, component="prefill"),
]
hit = select_entry(entries, SubComponentType.PREFILL)
assert hit is not None
assert hit.component == "prefill"
def test_selects_decode_entry(self):
entries = [
MdcEntry(card_json={**_card(), "model_type": 2}, component="backend"),
MdcEntry(card_json={**_card(), "model_type": 16}, component="prefill"),
]
hit = select_entry(entries, SubComponentType.DECODE)
assert hit is not None
assert hit.component == "backend"
def test_component_filter_skips_mismatched(self):
# Simulates a LoRA wrapper that slipped past is_model_card but points
# at a different component — should be skipped when we know the
# expected component.
entries = [
MdcEntry(card_json={**_card(), "model_type": 2}, component="lora-adapter"),
MdcEntry(card_json={**_card(), "model_type": 2}, component="backend"),
]
hit = select_entry(
entries, SubComponentType.DECODE, expected_component="backend"
)
assert hit is not None
assert hit.component == "backend"
def test_returns_none_when_no_match(self):
entries = [
MdcEntry(card_json={**_card(), "model_type": 2}, component="backend"),
]
assert select_entry(entries, SubComponentType.PREFILL) is None
......@@ -23,7 +23,7 @@ use super::*;
use crate::Endpoint;
use crate::to_pyerr;
use dynamo_runtime::component::Component;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryQuery};
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventSubscriber;
......@@ -310,6 +310,11 @@ pub(crate) struct FpmEventSubscriber {
// (insert on Added, remove on Removed). Used by get_recent_stats()
// to filter out ghost entries without contending with Task 1's writes.
known_workers: Arc<DashSet<String>>,
// Serialized ModelDeploymentCard per worker_id, captured on discovery
// Added events. Exposed via get_model_cards() so connectors can
// construct WorkerInfo from the same MDC stream the liveness watch
// uses, without the subscriber having to interpret card fields itself.
worker_model_cards: Arc<DashMap<String, String>>,
}
#[pymethods]
......@@ -332,6 +337,7 @@ impl FpmEventSubscriber {
tracking_started: Arc::new(AtomicBool::new(false)),
latest_stats: Arc::new(DashMap::new()),
known_workers: Arc::new(DashSet::new()),
worker_model_cards: Arc::new(DashMap::new()),
})
}
......@@ -511,11 +517,13 @@ impl FpmEventSubscriber {
// normal scale-down path. Any ghost entries created by the race
// condition (FPM arriving *after* the Removed event) are caught by the
// known_workers filter in get_recent_stats().
let cards = self.worker_model_cards.clone();
rt.spawn({
let cancel = cancel.clone();
let component = component.clone();
let stats = stats.clone();
let known = known.clone();
let cards = cards.clone();
async move {
let discovery = component.drt().discovery();
let query = DiscoveryQuery::ComponentModels {
......@@ -545,12 +553,28 @@ impl FpmEventSubscriber {
match event {
Some(Ok(DiscoveryEvent::Added(instance))) => {
let wid = instance.instance_id().to_string();
// Capture the full card JSON so connectors can build WorkerInfo
// from runtime_config / display_name / kv_cache_block_size / etc.
// without the subscriber having to know which fields matter.
if let DiscoveryInstance::Model { ref card_json, .. } = instance {
match serde_json::to_string(card_json) {
Ok(s) => {
cards.insert(wid.clone(), s);
}
Err(e) => {
tracing::warn!(
"FPM tracker: failed to serialize card_json for {wid}: {e}"
);
}
}
}
known.insert(wid.clone());
tracing::debug!("FPM tracker: worker {wid} added to known set");
}
Some(Ok(DiscoveryEvent::Removed(id))) => {
let removed_id = id.instance_id().to_string();
known.remove(&removed_id);
cards.remove(&removed_id);
// Eagerly prune latest_stats for the common case
// (worker removed cleanly before any late FPMs arrive).
......@@ -608,6 +632,32 @@ impl FpmEventSubscriber {
Ok(snapshot)
}
/// Snapshot of model deployment cards keyed by worker id.
///
/// The snapshot is filtered against `known_workers` so entries for
/// already-removed workers are not returned. Values are the raw
/// `ModelDeploymentCard` serialized as a JSON string; callers parse
/// whichever fields they need (e.g. `runtime_config`, `display_name`).
///
/// Returns:
/// dict mapping `worker_id: str` to `card_json: str`.
fn get_model_cards(&self) -> PyResult<HashMap<String, String>> {
if !self.tracking_started.load(Ordering::SeqCst) {
return Err(PyRuntimeError::new_err(
"start_tracking() has not been called",
));
}
let snapshot = self
.worker_model_cards
.iter()
.filter(|entry| self.known_workers.contains(entry.key()))
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
Ok(snapshot)
}
/// Shut down the subscriber (all background tasks).
fn shutdown(&self) {
self.cancel.cancel();
......
......@@ -906,6 +906,23 @@ class FpmEventSubscriber:
"""
...
def get_model_cards(self) -> dict[str, str]:
"""
Snapshot of model deployment cards keyed by worker id.
The snapshot is filtered against the known-workers set so entries
for already-removed workers are not returned. Values are the raw
``ModelDeploymentCard`` serialized as a JSON string; callers parse
whichever fields they need (e.g. ``runtime_config``,
``display_name``).
Raises RuntimeError if ``start_tracking()`` has not been called.
Returns:
dict mapping ``worker_id`` to ``card_json`` (JSON string).
"""
...
def shutdown(self) -> None:
"""Shut down the subscriber (all background tasks)."""
...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment