Unverified Commit b12e6710 authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat: read prefill/decode worker counts from DGD status (#5934)

parent ef292944
......@@ -153,6 +153,44 @@ class KubernetesAPI:
return ready_condition is not None and ready_condition.get("status") == "True"
def get_service_replica_status(
self, deployment: dict, service_name: str
) -> tuple[int, bool]:
"""
Get the actual ready replica count for a service from DGD status.
Returns:
tuple[int, bool]: (replica_count, is_stable)
- replica_count: number of replicas serving traffic (availableReplicas if present, else readyReplicas)
- is_stable: no rollout is in progress (desired == updated == ready/available)
"""
# Get desired replicas from spec
service_spec = (
deployment.get("spec", {}).get("services", {}).get(service_name, {})
)
desired_replicas = service_spec.get("replicas", 0)
# Get status fields
service_status = (
deployment.get("status", {}).get("services", {}).get(service_name, {})
)
available = service_status.get("availableReplicas")
ready = service_status.get("readyReplicas", 0)
updated = service_status.get("updatedReplicas", 0)
# availableReplicas takes precedence over readyReplicas for the count
# refer to ServiceReplicaStatus type (https://github.com/ai-dynamo/dynamo/blob/main/deploy/operator/api/v1alpha1/dynamographdeployment_types.go#L157)
if available is not None:
traffic_serving_replicas = available
else:
traffic_serving_replicas = ready
# Stable means: desired == updated == ready/available
# This ensures we're not in a scale-up, scale-down, or rollout
is_stable = desired_replicas == updated == traffic_serving_replicas
return traffic_serving_replicas, is_stable
async def wait_for_graph_deployment_ready(
self,
graph_deployment_name: str,
......
......@@ -284,6 +284,52 @@ class KubernetesConnector(PlannerConnector):
self.graph_deployment_name,
)
def get_actual_worker_counts(
self,
prefill_component_name: Optional[str] = None,
decode_component_name: Optional[str] = None,
) -> tuple[int, int, bool]:
"""
Get actual ready worker counts for prefill and decode from DGD status.
Returns:
tuple[int, int, bool]: (prefill_count, decode_count, is_stable)
- is_stable: False if any service is in a rollout (scaling should be skipped)
"""
deployment = self.kube_api.get_graph_deployment(self.graph_deployment_name)
prefill_count = 0
decode_count = 0
all_stable = True
if prefill_component_name:
service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.PREFILL,
component_name=prefill_component_name,
)
ready_replicas, is_stable = self.kube_api.get_service_replica_status(
deployment, service.name
)
if not is_stable:
all_stable = False
prefill_count = ready_replicas
if decode_component_name:
service = get_service_from_sub_component_type_or_name(
deployment,
SubComponentType.DECODE,
component_name=decode_component_name,
)
ready_replicas, is_stable = self.kube_api.get_service_replica_status(
deployment, service.name
)
if not is_stable:
all_stable = False
decode_count = ready_replicas
return prefill_count, decode_count, all_stable
async def set_component_replicas(
self, target_replicas: list[TargetReplica], blocking: bool = True
):
......
......@@ -123,8 +123,8 @@ class PlannerPrometheusMetrics:
@dataclass
class PlannerSharedState:
last_metrics: Metrics = field(default_factory=Metrics)
p_endpoints: list = field(default_factory=list)
d_endpoints: list = field(default_factory=list)
num_p_workers: int = 0
num_d_workers: int = 0
cumulative_gpu_hours: float = 0.0
last_adjustment_time: float = 0.0
......@@ -444,12 +444,41 @@ class BasePlanner:
async def get_workers_info(
self, require_prefill: bool = True, require_decode: bool = True
):
) -> tuple[int, int, bool]:
"""
Get worker counts for prefill and decode components.
Returns:
tuple[int, int, bool]: (num_p_workers, num_d_workers, is_stable)
- is_stable: False if rollout in progress (scaling should be skipped)
"""
num_p_workers = 0
num_d_workers = 0
# For Kubernetes, use DGD status instead of runtime client
if hasattr(self, "connector") and isinstance(
self.connector, KubernetesConnector
):
(
prefill_count,
decode_count,
is_stable,
) = self.connector.get_actual_worker_counts(
prefill_component_name=(
self.prefill_component_name if require_prefill else None
),
decode_component_name=(
self.decode_component_name if require_decode else None
),
)
num_p_workers = prefill_count if require_prefill else 0
num_d_workers = decode_count if require_decode else 0
return num_p_workers, num_d_workers, is_stable
# Fall back to runtime client for non-Kubernetes environments
if self.runtime is None:
raise RuntimeError("Runtime is not initialized")
p_endpoints = []
d_endpoints = []
worker_names = WORKER_COMPONENT_NAMES[self.args.backend]
if require_prefill:
......@@ -459,9 +488,9 @@ class BasePlanner:
worker_names.prefill_worker_component_name,
worker_names.prefill_worker_endpoint,
)
p_endpoints = self.prefill_client.instance_ids() # type: ignore
num_p_workers = len(self.prefill_client.instance_ids()) # type: ignore
except Exception:
p_endpoints = []
num_p_workers = 0
logger.warning(
"No prefill workers found, aggregated mode is not supported yet"
)
......@@ -473,35 +502,39 @@ class BasePlanner:
worker_names.decode_worker_component_name,
worker_names.decode_worker_endpoint,
)
d_endpoints = self.workers_client.instance_ids() # type: ignore
num_d_workers = len(self.workers_client.instance_ids()) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to get decode worker endpoints: {e}")
return p_endpoints, d_endpoints
return num_p_workers, num_d_workers, True # Always stable for non-K8s
async def observe_metrics(
self, require_prefill: bool = True, require_decode: bool = True
):
p_endpoints, d_endpoints = await self.get_workers_info(
) -> None:
"""
Observe metrics from Prometheus and update shared state.
"""
num_p_workers, num_d_workers, _ = await self.get_workers_info(
require_prefill=require_prefill, require_decode=require_decode
)
self.shared_state.p_endpoints = p_endpoints
self.shared_state.d_endpoints = d_endpoints
self.shared_state.num_p_workers = num_p_workers
self.shared_state.num_d_workers = num_d_workers
logger.debug(
f"Number of prefill workers: {len(p_endpoints)}, number of decode workers: {len(d_endpoints)}"
f"Number of prefill workers: {num_p_workers}, number of decode workers: {num_d_workers}"
)
# Update Prometheus metrics if server is running
if self.prometheus_port != 0 and self.prometheus_metrics is not None:
self.prometheus_metrics.num_p_workers.set(len(p_endpoints))
self.prometheus_metrics.num_d_workers.set(len(d_endpoints))
self.prometheus_metrics.num_p_workers.set(num_p_workers)
self.prometheus_metrics.num_d_workers.set(num_d_workers)
# Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours = (
(
len(p_endpoints) * self.args.prefill_engine_num_gpu
+ len(d_endpoints) * self.args.decode_engine_num_gpu
num_p_workers * self.args.prefill_engine_num_gpu
+ num_d_workers * self.args.decode_engine_num_gpu
)
* self.args.adjustment_interval
/ 3600
......@@ -769,14 +802,14 @@ class DecodePlanner(BasePlanner):
component_type = SubComponentType.DECODE
def _update_correction_factor(self) -> bool:
if not self.shared_state.d_endpoints:
if self.shared_state.num_d_workers == 0:
logger.warning(
"No decode workers found for correction factor, skipping correction update"
)
return True
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
/ len(self.shared_state.d_endpoints)
/ self.shared_state.num_d_workers
* self.last_metrics.request_duration # type: ignore
/ self.args.adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
......
......@@ -346,3 +346,162 @@ async def test_get_graph_deployment_not_found(k8s_api, mock_custom_api):
exception = exc_info.value
assert exception.deployment_name == "parent-dgd"
assert exception.namespace == "default"
# Tests for get_service_replica_status
def test_get_service_replica_status_stable_with_available_replicas(
k8s_api, mock_custom_api
):
"""Test stable case with availableReplicas present (takes precedence over readyReplicas)"""
deployment: Dict[str, Any] = {
"spec": {"services": {"prefill-worker": {"replicas": 2}}},
"status": {
"services": {
"prefill-worker": {
"availableReplicas": 2,
"readyReplicas": 2,
"updatedReplicas": 2,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
assert count == 2
assert is_stable is True
def test_get_service_replica_status_stable_with_ready_replicas_fallback(
k8s_api, mock_custom_api
):
"""Test stable case falling back to readyReplicas when availableReplicas is not present"""
deployment: Dict[str, Any] = {
"spec": {"services": {"decode-worker": {"replicas": 4}}},
"status": {
"services": {
"decode-worker": {
"readyReplicas": 4,
"updatedReplicas": 4,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "decode-worker")
assert count == 4
assert is_stable is True
def test_get_service_replica_status_scale_up_in_progress(k8s_api, mock_custom_api):
"""Test scale-up in progress: desired=4, updated=2, ready=2"""
deployment: Dict[str, Any] = {
"spec": {"services": {"prefill-worker": {"replicas": 4}}},
"status": {
"services": {
"prefill-worker": {
"availableReplicas": 2,
"readyReplicas": 2,
"updatedReplicas": 2,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
assert count == 2
assert is_stable is False
def test_get_service_replica_status_scale_down_in_progress(k8s_api, mock_custom_api):
"""Test scale-down in progress: desired=2, updated=4, ready=4"""
deployment: Dict[str, Any] = {
"spec": {"services": {"decode-worker": {"replicas": 2}}},
"status": {
"services": {
"decode-worker": {
"availableReplicas": 4,
"readyReplicas": 4,
"updatedReplicas": 4,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "decode-worker")
assert count == 4
assert is_stable is False
def test_get_service_replica_status_rollout_in_progress(k8s_api, mock_custom_api):
"""Test rollout in progress: desired=4, updated=2, ready=4 (old replicas still running)"""
deployment: Dict[str, Any] = {
"spec": {"services": {"prefill-worker": {"replicas": 4}}},
"status": {
"services": {
"prefill-worker": {
"availableReplicas": 4,
"readyReplicas": 4,
"updatedReplicas": 2,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
assert count == 4
assert is_stable is False
def test_get_service_replica_status_missing_status_fields(k8s_api, mock_custom_api):
"""Test handling when status fields are missing"""
deployment: Dict[str, Any] = {
"spec": {"services": {"prefill-worker": {"replicas": 2}}},
"status": {"services": {}},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
# Should default to 0 for missing fields
assert count == 0
# desired=2, updated=0, count=0 -> not stable
assert is_stable is False
def test_get_service_replica_status_empty_deployment(k8s_api, mock_custom_api):
"""Test handling when deployment has no spec or status"""
deployment: Dict[str, Any] = {}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
# All values default to 0, which makes it "stable" (0 == 0 == 0)
assert count == 0
assert is_stable is True
def test_get_service_replica_status_available_replicas_zero(k8s_api, mock_custom_api):
"""Test when availableReplicas is explicitly 0 (should use 0, not fall back to ready)"""
deployment: Dict[str, Any] = {
"spec": {"services": {"prefill-worker": {"replicas": 0}}},
"status": {
"services": {
"prefill-worker": {
"availableReplicas": 0,
"readyReplicas": 2, # Should be ignored
"updatedReplicas": 0,
}
}
},
}
count, is_stable = k8s_api.get_service_replica_status(deployment, "prefill-worker")
# availableReplicas=0 should be used (not readyReplicas)
assert count == 0
assert is_stable is True
......@@ -888,3 +888,147 @@ def test_get_gpu_counts_service_not_found_raises_error(
kubernetes_connector.get_gpu_counts()
assert "decode GPU count" in str(exc_info.value)
# Tests for get_actual_worker_counts
def test_get_actual_worker_counts_stable(kubernetes_connector, mock_kube_api):
"""Test get_actual_worker_counts when both services are stable"""
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"prefill-component": {},
"decode-component": {},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.get_service_replica_status.side_effect = [(2, True), (4, True)]
(
prefill_count,
decode_count,
is_stable,
) = kubernetes_connector.get_actual_worker_counts(
prefill_component_name="prefill-component",
decode_component_name="decode-component",
)
assert prefill_count == 2
assert decode_count == 4
assert is_stable is True
def test_get_actual_worker_counts_prefill_rollout_in_progress(
kubernetes_connector, mock_kube_api
):
"""Test get_actual_worker_counts when prefill has rollout in progress"""
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"prefill-component": {},
"decode-component": {},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.get_service_replica_status.side_effect = [(2, False), (4, True)]
(
prefill_count,
decode_count,
is_stable,
) = kubernetes_connector.get_actual_worker_counts(
prefill_component_name="prefill-component",
decode_component_name="decode-component",
)
assert prefill_count == 2
assert decode_count == 4
assert is_stable is False
def test_get_actual_worker_counts_prefill_only(kubernetes_connector, mock_kube_api):
"""Test get_actual_worker_counts with only prefill component"""
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"prefill-component": {
"replicas": 2,
"subComponentType": "prefill",
},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.get_service_replica_status.return_value = (2, True)
(
prefill_count,
decode_count,
is_stable,
) = kubernetes_connector.get_actual_worker_counts(
prefill_component_name="prefill-component",
decode_component_name=None,
)
assert prefill_count == 2
assert decode_count == 0
assert is_stable is True
def test_get_actual_worker_counts_decode_only(kubernetes_connector, mock_kube_api):
"""Test get_actual_worker_counts with only decode component"""
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {
"services": {
"decode-component": {
"replicas": 4,
"subComponentType": "decode",
},
}
},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
mock_kube_api.get_service_replica_status.return_value = (4, True)
(
prefill_count,
decode_count,
is_stable,
) = kubernetes_connector.get_actual_worker_counts(
prefill_component_name=None,
decode_component_name="decode-component",
)
assert prefill_count == 0
assert decode_count == 4
assert is_stable is True
def test_get_actual_worker_counts_no_components(kubernetes_connector, mock_kube_api):
"""Test get_actual_worker_counts with no components specified"""
mock_deployment = {
"metadata": {"name": "test-graph"},
"spec": {"services": {}},
"status": {"services": {}},
}
mock_kube_api.get_graph_deployment.return_value = mock_deployment
(
prefill_count,
decode_count,
is_stable,
) = kubernetes_connector.get_actual_worker_counts(
prefill_component_name=None,
decode_component_name=None,
)
assert prefill_count == 0
assert decode_count == 0
assert is_stable is True
......@@ -88,8 +88,9 @@ def _build_planners(args, prometheus_client):
async def mock_get_workers_info(require_prefill=True, require_decode=True):
return (
["prefill-0"] if require_prefill else [],
["decode-0"] if require_decode else [],
1 if require_prefill else 0,
1 if require_decode else 0,
True, # is_stable
)
prefill_planner.get_workers_info = mock_get_workers_info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment