Unverified Commit 22274b21 authored by Seiji Eicher's avatar Seiji Eicher Committed by GitHub
Browse files

[Misc] Add ReplicaId to Ray metrics (#24267)


Signed-off-by: default avatarSeiji Eicher <seiji@anyscale.com>
Co-authored-by: default avatarrongfu.leng <1275177125@qq.com>
parent fc95521b
...@@ -7,37 +7,55 @@ from vllm.v1.metrics.loggers import PrometheusStatLogger ...@@ -7,37 +7,55 @@ from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingProm
try: try:
from ray import serve as ray_serve
from ray.util import metrics as ray_metrics from ray.util import metrics as ray_metrics
from ray.util.metrics import Metric from ray.util.metrics import Metric
except ImportError: except ImportError:
ray_metrics = None ray_metrics = None
ray_serve = None
import regex as re import regex as re
def _get_replica_id() -> str | None:
"""Get the current Ray Serve replica ID, or None if not in a Serve context."""
if ray_serve is None:
return None
try:
return ray_serve.get_replica_context().replica_id.unique_id
except ray_serve.exceptions.RayServeException:
return None
class RayPrometheusMetric: class RayPrometheusMetric:
def __init__(self): def __init__(self):
if ray_metrics is None: if ray_metrics is None:
raise ImportError("RayPrometheusMetric requires Ray to be installed.") raise ImportError("RayPrometheusMetric requires Ray to be installed.")
self.metric: Metric = None self.metric: Metric = None
def labels(self, *labels, **labelskwargs): @staticmethod
if labelskwargs: def _get_tag_keys(labelnames: list[str] | None) -> tuple[str, ...]:
for k, v in labelskwargs.items(): labels = list(labelnames) if labelnames else []
if not isinstance(v, str): labels.append("ReplicaId")
labelskwargs[k] = str(v) return tuple(labels)
self.metric.set_default_tags(labelskwargs)
def labels(self, *labels, **labelskwargs):
if labels: if labels:
if len(labels) != len(self.metric._tag_keys): # -1 because ReplicaId was added automatically
expected = len(self.metric._tag_keys) - 1
if len(labels) != expected:
raise ValueError( raise ValueError(
"Number of labels must match the number of tag keys. " "Number of labels must match the number of tag keys. "
f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" f"Expected {expected}, got {len(labels)}"
) )
labelskwargs.update(zip(self.metric._tag_keys, labels))
self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) labelskwargs["ReplicaId"] = _get_replica_id() or ""
if labelskwargs:
for k, v in labelskwargs.items():
if not isinstance(v, str):
labelskwargs[k] = str(v)
self.metric.set_default_tags(labelskwargs)
return self return self
@staticmethod @staticmethod
...@@ -71,10 +89,14 @@ class RayGaugeWrapper(RayPrometheusMetric): ...@@ -71,10 +89,14 @@ class RayGaugeWrapper(RayPrometheusMetric):
# "mostrecent", "all", "sum" do not apply. This logic can be manually # "mostrecent", "all", "sum" do not apply. This logic can be manually
# implemented at the observability layer (Prometheus/Grafana). # implemented at the observability layer (Prometheus/Grafana).
del multiprocess_mode del multiprocess_mode
labelnames_tuple = tuple(labelnames) if labelnames else None
tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name) name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Gauge( self.metric = ray_metrics.Gauge(
name=name, description=documentation, tag_keys=labelnames_tuple name=name,
description=documentation,
tag_keys=tag_keys,
) )
def set(self, value: int | float): def set(self, value: int | float):
...@@ -95,10 +117,12 @@ class RayCounterWrapper(RayPrometheusMetric): ...@@ -95,10 +117,12 @@ class RayCounterWrapper(RayPrometheusMetric):
documentation: str | None = "", documentation: str | None = "",
labelnames: list[str] | None = None, labelnames: list[str] | None = None,
): ):
labelnames_tuple = tuple(labelnames) if labelnames else None tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name) name = self._get_sanitized_opentelemetry_name(name)
self.metric = ray_metrics.Counter( self.metric = ray_metrics.Counter(
name=name, description=documentation, tag_keys=labelnames_tuple name=name,
description=documentation,
tag_keys=tag_keys,
) )
def inc(self, value: int | float = 1.0): def inc(self, value: int | float = 1.0):
...@@ -118,13 +142,14 @@ class RayHistogramWrapper(RayPrometheusMetric): ...@@ -118,13 +142,14 @@ class RayHistogramWrapper(RayPrometheusMetric):
labelnames: list[str] | None = None, labelnames: list[str] | None = None,
buckets: list[float] | None = None, buckets: list[float] | None = None,
): ):
labelnames_tuple = tuple(labelnames) if labelnames else None tag_keys = self._get_tag_keys(labelnames)
name = self._get_sanitized_opentelemetry_name(name) name = self._get_sanitized_opentelemetry_name(name)
boundaries = buckets if buckets else [] boundaries = buckets if buckets else []
self.metric = ray_metrics.Histogram( self.metric = ray_metrics.Histogram(
name=name, name=name,
description=documentation, description=documentation,
tag_keys=labelnames_tuple, tag_keys=tag_keys,
boundaries=boundaries, boundaries=boundaries,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment