Unverified Commit 197e0227 authored by daiyaanarfeen's avatar daiyaanarfeen Committed by GitHub
Browse files

feat: migrate planner metrics to backend (1/3 vLLM) (#4134)


Signed-off-by: default avatarDaiyaan <darfeen@nvidia.com>
Co-authored-by: default avatartmontfort <tmontfort@nvidia.com>
parent 0dd51694
...@@ -24,7 +24,7 @@ from dynamo.planner.utils.perf_interpolation import ( ...@@ -24,7 +24,7 @@ from dynamo.planner.utils.perf_interpolation import (
PrefillInterpolator, PrefillInterpolator,
) )
from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper
from dynamo.planner.utils.prometheus import PrometheusAPIClient from dynamo.planner.utils.prometheus import MetricSource, PrometheusAPIClient
from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -150,9 +150,20 @@ class Planner: ...@@ -150,9 +150,20 @@ class Planner:
else: else:
raise ValueError(f"Invalid environment: {args.environment}") raise ValueError(f"Invalid environment: {args.environment}")
# Use backend metrics for vLLM (queries vllm:* metrics directly from workers)
# Use frontend metrics for other backends (queries dynamo_frontend_* metrics)
metric_source = (
MetricSource.VLLM
if args.backend.lower() == "vllm"
else MetricSource.FRONTEND
)
logger.info(
f"Initializing Prometheus client with metric_source='{metric_source}' for backend '{args.backend}'"
)
self.prometheus_api_client = PrometheusAPIClient( self.prometheus_api_client = PrometheusAPIClient(
args.metric_pulling_prometheus_endpoint, args.metric_pulling_prometheus_endpoint,
args.namespace, args.namespace,
metric_source=metric_source,
) )
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor]( self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor](
......
...@@ -15,11 +15,15 @@ ...@@ -15,11 +15,15 @@
import logging import logging
import typing import typing
from enum import Enum
from prometheus_api_client import PrometheusConnect from prometheus_api_client import PrometheusConnect
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.prometheus_names import (
frontend_service as metric_names, # Note that we are mapping from frontend metric names to VLLM
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -32,9 +36,11 @@ class FrontendMetric(BaseModel): ...@@ -32,9 +36,11 @@ class FrontendMetric(BaseModel):
endpoint: typing.Optional[str] = None endpoint: typing.Optional[str] = None
instance: typing.Optional[str] = None instance: typing.Optional[str] = None
job: typing.Optional[str] = None job: typing.Optional[str] = None
model: typing.Optional[str] = None model: typing.Optional[str] = None # Frontend uses this label
namespace: typing.Optional[str] = None model_name: typing.Optional[str] = None # Backend (vLLM) uses this label
pod: typing.Optional[str] = None namespace: typing.Optional[str] = None # Kubernetes namespace
pod: typing.Optional[str] = None # Pod name (used for backend filtering)
engine: typing.Optional[str] = None # vLLM engine index
class FrontendMetricContainer(BaseModel): class FrontendMetricContainer(BaseModel):
...@@ -42,10 +48,78 @@ class FrontendMetricContainer(BaseModel): ...@@ -42,10 +48,78 @@ class FrontendMetricContainer(BaseModel):
value: typing.Tuple[float, float] # [timestamp, value] value: typing.Tuple[float, float] # [timestamp, value]
class MetricSource(Enum):
FRONTEND = "frontend"
VLLM = "vllm"
SGLANG = "sglang" # not supported yet
TRTLLM = "trtllm" # not supported yet
METRIC_SOURCE_MAP = { # sourced from prometheus_names.py
MetricSource.VLLM: {
metric_names.TIME_TO_FIRST_TOKEN_SECONDS: "vllm:time_to_first_token_seconds", # histogram
metric_names.INTER_TOKEN_LATENCY_SECONDS: "vllm:inter_token_latency_seconds", # histogram
metric_names.REQUEST_DURATION_SECONDS: "vllm:e2e_request_latency_seconds", # histogram - vLLM's e2e latency
metric_names.INPUT_SEQUENCE_TOKENS: "vllm:prompt_tokens_total", # counter - total prompt tokens
metric_names.OUTPUT_SEQUENCE_TOKENS: "vllm:generation_tokens_total", # counter - total generation tokens
metric_names.REQUESTS_TOTAL: "vllm:request_success_total", # counter
},
MetricSource.FRONTEND: {
metric_names.TIME_TO_FIRST_TOKEN_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.TIME_TO_FIRST_TOKEN_SECONDS}",
metric_names.INTER_TOKEN_LATENCY_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.INTER_TOKEN_LATENCY_SECONDS}",
metric_names.REQUEST_DURATION_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.REQUEST_DURATION_SECONDS}",
metric_names.INPUT_SEQUENCE_TOKENS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.INPUT_SEQUENCE_TOKENS}",
metric_names.OUTPUT_SEQUENCE_TOKENS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.OUTPUT_SEQUENCE_TOKENS}",
metric_names.REQUESTS_TOTAL: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.REQUESTS_TOTAL}",
},
}
METRIC_SOURCE_MODEL_ATTR = {
MetricSource.VLLM: "model_name",
MetricSource.FRONTEND: "model",
}
class PrometheusAPIClient: class PrometheusAPIClient:
def __init__(self, url: str, dynamo_namespace: str): """
Client for querying Dynamo metrics from Prometheus.
Supports querying both frontend and backend metrics:
- Frontend metrics: {prometheus_names.name_prefix.FRONTEND}_* (from Dynamo HTTP frontend)
- Backend metrics: vllm:* (from vLLM engine workers)
Usage:
# Query frontend metrics (default)
frontend_client = PrometheusAPIClient(url="http://prometheus:9090",
dynamo_namespace="my-deployment")
ttft = frontend_client.get_avg_time_to_first_token("60s", "llama-3-8b")
# Query backend worker metrics
backend_client = PrometheusAPIClient(url="http://prometheus:9090",
dynamo_namespace="my-deployment",
metric_source=MetricSource.VLLM)
ttft = backend_client.get_avg_time_to_first_token("60s", "llama-3-8b")
"""
def __init__(
self,
url: str,
dynamo_namespace: str,
metric_source: MetricSource = MetricSource.FRONTEND,
):
"""
Initialize Prometheus API client.
Args:
url: Prometheus server URL
dynamo_namespace: Dynamo namespace to filter metrics
metric_source: Either MetricSource.FRONTEND or MetricSource.VLLM.
"""
self.prom = PrometheusConnect(url=url, disable_ssl=True) self.prom = PrometheusConnect(url=url, disable_ssl=True)
self.dynamo_namespace = dynamo_namespace self.dynamo_namespace = dynamo_namespace
self.metric_source = metric_source
self.model_attr = METRIC_SOURCE_MODEL_ATTR[self.metric_source]
def _get_average_metric( def _get_average_metric(
self, full_metric_name: str, interval: str, operation_name: str, model_name: str self, full_metric_name: str, interval: str, operation_name: str, model_name: str
...@@ -55,45 +129,127 @@ class PrometheusAPIClient: ...@@ -55,45 +129,127 @@ class PrometheusAPIClient:
increase(metric_sum[interval])/increase(metric_count[interval]) increase(metric_sum[interval])/increase(metric_count[interval])
Args: Args:
full_metric_name: Full metric name (e.g., 'dynamo_frontend_inter_token_latency_seconds') full_metric_name: Full metric name (e.g., metric_names.INTER_TOKEN_LATENCY_SECONDS or metric_names.TIME_TO_FIRST_TOKEN_SECONDS)
interval: Time interval for the query (e.g., '60s') interval: Time interval for the query (e.g., '60s')
operation_name: Human-readable name for error logging operation_name: Human-readable name for error logging
model_name: Model name to filter by
Returns: Returns:
Average metric value or 0 if no data/error Average metric value or 0 if no data/error
""" """
try: try:
# Prepend the frontend metric prefix if not already present full_metric_name = METRIC_SOURCE_MAP[self.metric_source][full_metric_name]
if not full_metric_name.startswith(prometheus_names.name_prefix.FRONTEND):
full_metric_name = ( # Query sum and count separately
f"{prometheus_names.name_prefix.FRONTEND}_{full_metric_name}" sum_query = f"increase({full_metric_name}_sum[{interval}])"
) count_query = f"increase({full_metric_name}_count[{interval}])"
query = f"increase({full_metric_name}_sum[{interval}])/increase({full_metric_name}_count[{interval}])"
result = self.prom.custom_query(query=query) sum_result = self.prom.custom_query(query=sum_query)
if not result: count_result = self.prom.custom_query(query=count_query)
if not sum_result or not count_result:
# No data available yet (no requests made) - return 0 silently # No data available yet (no requests made) - return 0 silently
logger.warning( logger.warning(
f"No prometheus metric data available for {full_metric_name}, use 0 instead" f"No prometheus metric data available for {full_metric_name}, use 0 instead"
) )
return 0 return 0
metrics_containers = parse_frontend_metric_containers(result)
values = [] sum_containers = parse_frontend_metric_containers(sum_result)
for container in metrics_containers: count_containers = parse_frontend_metric_containers(count_result)
# Frontend lowercases model names for Prometheus labels so we need to do case-insensitive comparison
if ( # Sum up values for matching containers
container.metric.model total_sum = 0.0
and container.metric.model.lower() == model_name.lower() total_count = 0.0
and container.metric.dynamo_namespace == self.dynamo_namespace
): for container in sum_containers:
values.append(container.value[1]) model_value = getattr(container.metric, self.model_attr, None)
model_match = model_value and model_value.lower() == model_name.lower()
if not values: namespace_match = (
container.metric.dynamo_namespace == self.dynamo_namespace
)
# Filter by model and namespace
if model_match and namespace_match:
total_sum += container.value[1]
for container in count_containers:
model_value = getattr(container.metric, self.model_attr, None)
model_match = model_value and model_value.lower() == model_name.lower()
namespace_match = (
container.metric.dynamo_namespace == self.dynamo_namespace
)
# Filter by model and namespace
if model_match and namespace_match:
total_count += container.value[1]
if total_count == 0:
logger.warning( logger.warning(
f"No prometheus metric data available for {full_metric_name} with model {model_name} and dynamo namespace {self.dynamo_namespace}, use 0 instead" f"No prometheus metric data available for {full_metric_name} with model {model_name} and dynamo namespace {self.dynamo_namespace}, use 0 instead"
) )
return 0 return 0
return sum(values) / len(values)
return total_sum / total_count
except Exception as e:
logger.error(f"Error getting {operation_name}: {e}")
return 0
def _get_counter_average(
self, counter_metric: str, interval: str, model_name: str, operation_name: str
) -> float:
"""
Get average value from a counter metric by dividing total increase by request count increase.
Used for vLLM token counters (prompt_tokens_total, generation_tokens_total).
Formula: increase(counter_total[interval]) / increase(request_success_total[interval])
"""
try:
full_metric_name = METRIC_SOURCE_MAP[self.metric_source][counter_metric]
requests_metric = METRIC_SOURCE_MAP[self.metric_source][
metric_names.REQUESTS_TOTAL
]
# Query both the counter and request count
counter_query = f"increase({full_metric_name}[{interval}])"
requests_query = f"increase({requests_metric}[{interval}])"
counter_result = self.prom.custom_query(query=counter_query)
requests_result = self.prom.custom_query(query=requests_query)
if not counter_result or not requests_result:
logger.warning(
f"No prometheus metric data available for {full_metric_name}, use 0 instead"
)
return 0
counter_containers = parse_frontend_metric_containers(counter_result)
requests_containers = parse_frontend_metric_containers(requests_result)
# Sum up values for matching pods
total_counter = 0.0
total_requests = 0.0
for container in counter_containers:
model_value = getattr(container.metric, self.model_attr, None)
if model_value and model_value.lower() == model_name.lower():
if container.metric.dynamo_namespace == self.dynamo_namespace:
total_counter += container.value[1]
for container in requests_containers:
model_value = getattr(container.metric, self.model_attr, None)
if model_value and model_value.lower() == model_name.lower():
if container.metric.dynamo_namespace == self.dynamo_namespace:
total_requests += container.value[1]
if total_requests == 0:
logger.warning(
f"No requests for {operation_name} calculation, use 0 instead"
)
return 0
average = total_counter / total_requests
return average
except Exception as e: except Exception as e:
logger.error(f"Error getting {operation_name}: {e}") logger.error(f"Error getting {operation_name}: {e}")
...@@ -101,7 +257,7 @@ class PrometheusAPIClient: ...@@ -101,7 +257,7 @@ class PrometheusAPIClient:
def get_avg_inter_token_latency(self, interval: str, model_name: str): def get_avg_inter_token_latency(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
prometheus_names.frontend_service.INTER_TOKEN_LATENCY_SECONDS, metric_names.INTER_TOKEN_LATENCY_SECONDS,
interval, interval,
"avg inter token latency", "avg inter token latency",
model_name, model_name,
...@@ -109,7 +265,7 @@ class PrometheusAPIClient: ...@@ -109,7 +265,7 @@ class PrometheusAPIClient:
def get_avg_time_to_first_token(self, interval: str, model_name: str): def get_avg_time_to_first_token(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, metric_names.TIME_TO_FIRST_TOKEN_SECONDS,
interval, interval,
"avg time to first token", "avg time to first token",
model_name, model_name,
...@@ -117,35 +273,38 @@ class PrometheusAPIClient: ...@@ -117,35 +273,38 @@ class PrometheusAPIClient:
def get_avg_request_duration(self, interval: str, model_name: str): def get_avg_request_duration(self, interval: str, model_name: str):
return self._get_average_metric( return self._get_average_metric(
prometheus_names.frontend_service.REQUEST_DURATION_SECONDS, metric_names.REQUEST_DURATION_SECONDS,
interval, interval,
"avg request duration", "avg request duration",
model_name, model_name,
) )
def get_avg_request_count(self, interval: str, model_name: str): def get_avg_request_count(self, interval: str, model_name: str):
# This function follows a different query pattern than the other metrics """
Get request count over the specified interval.
For frontend: queries dynamo_frontend_requests_total
For backend: queries vllm:request_success_total
"""
try: try:
requests_total_metric = prometheus_names.frontend_service.REQUESTS_TOTAL requests_total_metric = METRIC_SOURCE_MAP[self.metric_source][
# Prepend the frontend metric prefix if not already present metric_names.REQUESTS_TOTAL
if not requests_total_metric.startswith( ]
prometheus_names.name_prefix.FRONTEND
):
requests_total_metric = (
f"{prometheus_names.name_prefix.FRONTEND}_{requests_total_metric}"
)
raw_res = self.prom.custom_query( raw_res = self.prom.custom_query(
query=f"increase({requests_total_metric}[{interval}])" query=f"increase({requests_total_metric}[{interval}])"
) )
metrics_containers = parse_frontend_metric_containers(raw_res) metrics_containers = parse_frontend_metric_containers(raw_res)
total_count = 0.0 total_count = 0.0
for container in metrics_containers: for container in metrics_containers:
# Frontend lowercases model names for Prometheus labels so we need to do case-insensitive comparison model_value = getattr(container.metric, self.model_attr, None)
if ( model_match = model_value and model_value.lower() == model_name.lower()
container.metric.model namespace_match = (
and container.metric.model.lower() == model_name.lower() container.metric.dynamo_namespace == self.dynamo_namespace
and container.metric.dynamo_namespace == self.dynamo_namespace )
):
# Filter by model and namespace
if model_match and namespace_match:
total_count += container.value[1] total_count += container.value[1]
return total_count return total_count
except Exception as e: except Exception as e:
...@@ -153,16 +312,32 @@ class PrometheusAPIClient: ...@@ -153,16 +312,32 @@ class PrometheusAPIClient:
return 0 return 0
def get_avg_input_sequence_tokens(self, interval: str, model_name: str): def get_avg_input_sequence_tokens(self, interval: str, model_name: str):
if self.metric_source == MetricSource.VLLM:
# Backend uses prompt_tokens counter (not histogram)
return self._get_counter_average(
metric_names.INPUT_SEQUENCE_TOKENS,
interval,
model_name,
"input_sequence_tokens",
)
return self._get_average_metric( return self._get_average_metric(
prometheus_names.frontend_service.INPUT_SEQUENCE_TOKENS, metric_names.INPUT_SEQUENCE_TOKENS,
interval, interval,
"avg input sequence tokens", "avg input sequence tokens",
model_name, model_name,
) )
def get_avg_output_sequence_tokens(self, interval: str, model_name: str): def get_avg_output_sequence_tokens(self, interval: str, model_name: str):
if self.metric_source == MetricSource.VLLM:
# Backend uses generation_tokens counter (not histogram)
return self._get_counter_average(
metric_names.OUTPUT_SEQUENCE_TOKENS,
interval,
model_name,
"output_sequence_tokens",
)
return self._get_average_metric( return self._get_average_metric(
prometheus_names.frontend_service.OUTPUT_SEQUENCE_TOKENS, metric_names.OUTPUT_SEQUENCE_TOKENS,
interval, interval,
"avg output sequence tokens", "avg output sequence tokens",
model_name, model_name,
......
...@@ -57,6 +57,11 @@ spec: ...@@ -57,6 +57,11 @@ spec:
- interval: 5s - interval: 5s
path: /metrics path: /metrics
port: system port: system
relabelings:
- action: replace
sourceLabels:
- __meta_kubernetes_pod_label_nvidia_com_dynamo_namespace
targetLabel: dynamo_namespace
selector: selector:
matchLabels: matchLabels:
nvidia.com/dynamo-component-type: worker nvidia.com/dynamo-component-type: worker
......
...@@ -18,9 +18,12 @@ from unittest.mock import patch ...@@ -18,9 +18,12 @@ from unittest.mock import patch
import pytest import pytest
from dynamo import prometheus_names
from dynamo.planner.utils.prometheus import ( from dynamo.planner.utils.prometheus import (
METRIC_SOURCE_MAP,
FrontendMetric, FrontendMetric,
FrontendMetricContainer, FrontendMetricContainer,
MetricSource,
PrometheusAPIClient, PrometheusAPIClient,
) )
...@@ -33,8 +36,8 @@ pytestmark = [ ...@@ -33,8 +36,8 @@ pytestmark = [
@pytest.fixture @pytest.fixture
def mock_prometheus_result(): def mock_prometheus_sum_result():
"""Fixture providing mock prometheus result data for testing""" """Fixture providing mock prometheus sum metric data for testing"""
return [ return [
{ {
"metric": { "metric": {
...@@ -75,6 +78,49 @@ def mock_prometheus_result(): ...@@ -75,6 +78,49 @@ def mock_prometheus_result():
] ]
@pytest.fixture
def mock_prometheus_count_result():
"""Fixture providing mock prometheus count metric data for testing"""
return [
{
"metric": {
"container": "main",
"dynamo_namespace": "different_namespace",
"model": "different_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 1.0],
},
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 1.0],
},
{
"metric": {
"container": "worker",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 1.0],
},
{
"metric": {
"container": "sidecar",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [30.0, 1.0],
},
]
def test_frontend_metric_container_with_nan_value(): def test_frontend_metric_container_with_nan_value():
test_data = { test_data = {
"metric": { "metric": {
...@@ -140,7 +186,7 @@ def test_get_average_metric_none_result(): ...@@ -140,7 +186,7 @@ def test_get_average_metric_none_result():
mock_query.return_value = None mock_query.return_value = None
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="test_model", model_name="test_model",
...@@ -157,7 +203,7 @@ def test_get_average_metric_empty_result(): ...@@ -157,7 +203,7 @@ def test_get_average_metric_empty_result():
mock_query.return_value = [] mock_query.return_value = []
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="test_model", model_name="test_model",
...@@ -166,16 +212,21 @@ def test_get_average_metric_empty_result(): ...@@ -166,16 +212,21 @@ def test_get_average_metric_empty_result():
assert result == 0 assert result == 0
def test_get_average_metric_no_matching_containers(mock_prometheus_result): def test_get_average_metric_no_matching_containers(
mock_prometheus_sum_result, mock_prometheus_count_result
):
"""Test _get_average_metric with valid containers but no matches""" """Test _get_average_metric with valid containers but no matches"""
client = PrometheusAPIClient("http://localhost:9090", "test_namespace") client = PrometheusAPIClient("http://localhost:9090", "test_namespace")
with patch.object(client.prom, "custom_query") as mock_query: with patch.object(client.prom, "custom_query") as mock_query:
# Use only the first container which doesn't match target criteria # Use only the first container which doesn't match target criteria
mock_query.return_value = [mock_prometheus_result[0]] mock_query.side_effect = [
[mock_prometheus_sum_result[0]], # sum_result
[mock_prometheus_count_result[0]], # count_result
]
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="target_model", model_name="target_model",
...@@ -184,21 +235,41 @@ def test_get_average_metric_no_matching_containers(mock_prometheus_result): ...@@ -184,21 +235,41 @@ def test_get_average_metric_no_matching_containers(mock_prometheus_result):
assert result == 0 assert result == 0
def test_get_average_metric_one_matching_container(mock_prometheus_result): def test_get_average_metric_one_matching_container(
mock_prometheus_sum_result, mock_prometheus_count_result
):
"""Test _get_average_metric with one matching container""" """Test _get_average_metric with one matching container"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace") client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query: with patch.object(client.prom, "custom_query") as mock_query:
# Use first two containers - one doesn't match, one does # Use first two containers - one doesn't match, one does
mock_query.return_value = mock_prometheus_result[:2] mock_query.side_effect = [
mock_prometheus_sum_result[:2], # sum_result
mock_prometheus_count_result[:2], # count_result
]
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="target_model", model_name="target_model",
) )
# Verify the correct queries were made
assert mock_query.call_count == 2
sum_call = mock_query.call_args_list[0]
assert (
sum_call.kwargs["query"]
== f"increase({METRIC_SOURCE_MAP[MetricSource.FRONTEND][prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS]}_sum[60s])"
)
count_call = mock_query.call_args_list[1]
assert (
count_call.kwargs["query"]
== f"increase({METRIC_SOURCE_MAP[MetricSource.FRONTEND][prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS]}_count[60s])"
)
assert result == 42.7 assert result == 42.7
...@@ -206,7 +277,7 @@ def test_get_average_metric_with_validation_error(): ...@@ -206,7 +277,7 @@ def test_get_average_metric_with_validation_error():
"""Test _get_average_metric with one valid container and one that fails validation""" """Test _get_average_metric with one valid container and one that fails validation"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace") client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
mock_result = [ mock_sum_result = [
{ {
"metric": { "metric": {
"container": "main", "container": "main",
...@@ -223,11 +294,28 @@ def test_get_average_metric_with_validation_error(): ...@@ -223,11 +294,28 @@ def test_get_average_metric_with_validation_error():
}, },
] ]
mock_count_result = [
{
"metric": {
"container": "main",
"dynamo_namespace": "target_namespace",
"model": "target_model",
"namespace": "dynamo-system",
},
"value": [1758857776.071, 1.0],
},
{
# Invalid structure - missing required fields that will cause validation error
"invalid_structure": "bad_data",
"value": "not_a_tuple",
},
]
with patch.object(client.prom, "custom_query") as mock_query: with patch.object(client.prom, "custom_query") as mock_query:
mock_query.return_value = mock_result mock_query.side_effect = [mock_sum_result, mock_count_result]
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="target_model", model_name="target_model",
...@@ -236,21 +324,26 @@ def test_get_average_metric_with_validation_error(): ...@@ -236,21 +324,26 @@ def test_get_average_metric_with_validation_error():
assert result == 25.5 assert result == 25.5
def test_get_average_metric_multiple_matching_containers(mock_prometheus_result): def test_get_average_metric_multiple_matching_containers(
mock_prometheus_sum_result, mock_prometheus_count_result
):
"""Test _get_average_metric with multiple matching containers returns average""" """Test _get_average_metric with multiple matching containers returns average"""
client = PrometheusAPIClient("http://localhost:9090", "target_namespace") client = PrometheusAPIClient("http://localhost:9090", "target_namespace")
with patch.object(client.prom, "custom_query") as mock_query: with patch.object(client.prom, "custom_query") as mock_query:
# Use containers 1, 2, 3 which all match target criteria # Use containers 1, 2, 3 which all match target criteria
mock_query.return_value = mock_prometheus_result[1:] mock_query.side_effect = [
mock_prometheus_sum_result[1:], # sum_result
mock_prometheus_count_result[1:], # count_result
]
result = client._get_average_metric( result = client._get_average_metric(
full_metric_name="test_metric", full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS,
interval="60s", interval="60s",
operation_name="test operation", operation_name="test operation",
model_name="target_model", model_name="target_model",
) )
# Average of 42.7, 35.5, and 15.5 (using value[1] from each container) # Total sum: 42.7 + 35.5 + 15.5 = 93.7, Total count: 1.0 + 1.0 + 1.0 = 3.0
expected = (42.7 + 35.5 + 15.5) / 3 expected = (42.7 + 35.5 + 15.5) / 3
assert result == expected assert result == expected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment