Unverified Commit 880be2b1 authored by Martin Hickey's avatar Martin Hickey Committed by GitHub
Browse files

[Metrics] Some small refactoring for better maintainability (#33898)


Signed-off-by: default avatarMartin Hickey <martin.hickey@ie.ibm.com>
parent c0f5fae6
......@@ -126,28 +126,17 @@ class KVConnectorPromMetrics:
self._labelnames = labelnames
self.per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self.per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
created using the create_metric_per_engine() helper method.
"""
raise NotImplementedError
class KVConnectorPrometheus:
class KVConnectorProm:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
......
......@@ -65,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size
......@@ -3057,7 +3058,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
......@@ -3065,7 +3068,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
......@@ -3074,8 +3079,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = self.make_per_engine(
nixl_histogram_bytes_transferred
self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred, self.per_engine_labelvalues
)
buckets = [
10,
......@@ -3100,24 +3105,24 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = self.make_per_engine(
nixl_histogram_num_descriptors
self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors, self.per_engine_labelvalues
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = self.make_per_engine(
counter_nixl_num_failed_transfers
self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers, self.per_engine_labelvalues
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = self.make_per_engine(
counter_nixl_num_failed_notifications
self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications, self.per_engine_labelvalues
)
counter_nixl_num_kv_expired_reqs = self._counter_cls(
......@@ -3126,8 +3131,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
"NOTE: This metric is tracked on the P instance.",
labelnames=labelnames,
)
self.counter_nixl_num_kv_expired_reqs = self.make_per_engine(
counter_nixl_num_kv_expired_reqs
self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
......
......@@ -5,7 +5,6 @@ import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
......@@ -14,7 +13,7 @@ from vllm.compilation.cuda_graph import CUDAGraphLogging
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
KVConnectorPrometheus,
KVConnectorProm,
)
from vllm.logger import init_logger
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
......@@ -28,6 +27,7 @@ from vllm.v1.metrics.stats import (
PromptTokenStats,
SchedulerStats,
)
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
......@@ -391,7 +391,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
_counter_cls = Counter
_histogram_cls = Histogram
_spec_decoding_cls = SpecDecodingProm
_kv_connector_cls = KVConnectorPrometheus
_kv_connector_cls = KVConnectorProm
_perf_metrics_cls = PerfMetricsProm
def __init__(
......@@ -415,9 +415,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
per_engine_labelvalues: dict[int, list[object]] = {
self.per_engine_labelvalues: dict[int, list[object]] = {
idx: [model_name, str(idx)] for idx in engine_indexes
}
per_engine_labelvalues = self.per_engine_labelvalues
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, per_engine_labelvalues
......@@ -438,8 +439,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_scheduler_running = make_per_engine(
gauge_scheduler_running, engine_indexes, model_name
self.gauge_scheduler_running = create_metric_per_engine(
gauge_scheduler_running, per_engine_labelvalues
)
gauge_scheduler_waiting = self._gauge_cls(
......@@ -448,8 +449,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_scheduler_waiting = make_per_engine(
gauge_scheduler_waiting, engine_indexes, model_name
self.gauge_scheduler_waiting = create_metric_per_engine(
gauge_scheduler_waiting, per_engine_labelvalues
)
gauge_engine_sleep_state = self._gauge_cls(
......@@ -484,8 +485,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_kv_cache_usage = make_per_engine(
gauge_kv_cache_usage, engine_indexes, model_name
self.gauge_kv_cache_usage = create_metric_per_engine(
gauge_kv_cache_usage, per_engine_labelvalues
)
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
......@@ -497,8 +498,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_corrupted_requests = make_per_engine(
counter_corrupted_requests, engine_indexes, model_name
self.counter_corrupted_requests = create_metric_per_engine(
counter_corrupted_requests, per_engine_labelvalues
)
counter_prefix_cache_queries = self._counter_cls(
......@@ -508,8 +509,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_prefix_cache_queries = make_per_engine(
counter_prefix_cache_queries, engine_indexes, model_name
self.counter_prefix_cache_queries = create_metric_per_engine(
counter_prefix_cache_queries, per_engine_labelvalues
)
counter_prefix_cache_hits = self._counter_cls(
......@@ -517,8 +518,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation=("Prefix cache hits, in terms of number of cached tokens."),
labelnames=labelnames,
)
self.counter_prefix_cache_hits = make_per_engine(
counter_prefix_cache_hits, engine_indexes, model_name
self.counter_prefix_cache_hits = create_metric_per_engine(
counter_prefix_cache_hits, per_engine_labelvalues
)
#
......@@ -533,8 +534,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_queries = make_per_engine(
counter_connector_prefix_cache_queries, engine_indexes, model_name
self.counter_connector_prefix_cache_queries = create_metric_per_engine(
counter_connector_prefix_cache_queries, per_engine_labelvalues
)
counter_connector_prefix_cache_hits = self._counter_cls(
......@@ -545,8 +546,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_hits = make_per_engine(
counter_connector_prefix_cache_hits, engine_indexes, model_name
self.counter_connector_prefix_cache_hits = create_metric_per_engine(
counter_connector_prefix_cache_hits, per_engine_labelvalues
)
#
......@@ -560,8 +561,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_mm_cache_queries = make_per_engine(
counter_mm_cache_queries, engine_indexes, model_name
self.counter_mm_cache_queries = create_metric_per_engine(
counter_mm_cache_queries, per_engine_labelvalues
)
counter_mm_cache_hits = self._counter_cls(
......@@ -571,8 +572,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_mm_cache_hits = make_per_engine(
counter_mm_cache_hits, engine_indexes, model_name
self.counter_mm_cache_hits = create_metric_per_engine(
counter_mm_cache_hits, per_engine_labelvalues
)
#
......@@ -583,8 +584,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames,
)
self.counter_num_preempted_reqs = make_per_engine(
counter_num_preempted_reqs, engine_indexes, model_name
self.counter_num_preempted_reqs = create_metric_per_engine(
counter_num_preempted_reqs, per_engine_labelvalues
)
counter_prompt_tokens = self._counter_cls(
......@@ -592,8 +593,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
)
self.counter_prompt_tokens = make_per_engine(
counter_prompt_tokens, engine_indexes, model_name
self.counter_prompt_tokens = create_metric_per_engine(
counter_prompt_tokens, per_engine_labelvalues
)
# Labeled prompt token counters by source
......@@ -617,8 +618,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached prompt tokens (local + external).",
labelnames=labelnames,
)
self.counter_prompt_tokens_cached = make_per_engine(
counter_prompt_tokens_cached, engine_indexes, model_name
self.counter_prompt_tokens_cached = create_metric_per_engine(
counter_prompt_tokens_cached, per_engine_labelvalues
)
# Recomputed tokens (last token recomputed when entire prompt is cached)
......@@ -627,8 +628,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames,
)
self.counter_prompt_tokens_recomputed = make_per_engine(
counter_prompt_tokens_recomputed, engine_indexes, model_name
self.counter_prompt_tokens_recomputed = create_metric_per_engine(
counter_prompt_tokens_recomputed, per_engine_labelvalues
)
counter_generation_tokens = self._counter_cls(
......@@ -636,8 +637,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of generation tokens processed.",
labelnames=labelnames,
)
self.counter_generation_tokens = make_per_engine(
counter_generation_tokens, engine_indexes, model_name
self.counter_generation_tokens = create_metric_per_engine(
counter_generation_tokens, per_engine_labelvalues
)
self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {}
......@@ -663,8 +664,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_num_prompt_tokens_request = make_per_engine(
histogram_num_prompt_tokens_request, engine_indexes, model_name
self.histogram_num_prompt_tokens_request = create_metric_per_engine(
histogram_num_prompt_tokens_request, per_engine_labelvalues
)
histogram_num_generation_tokens_request = self._histogram_cls(
......@@ -673,8 +674,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_num_generation_tokens_request = make_per_engine(
histogram_num_generation_tokens_request, engine_indexes, model_name
self.histogram_num_generation_tokens_request = create_metric_per_engine(
histogram_num_generation_tokens_request, per_engine_labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
......@@ -686,8 +687,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
labelnames=labelnames,
)
self.histogram_iteration_tokens = make_per_engine(
histogram_iteration_tokens, engine_indexes, model_name
self.histogram_iteration_tokens = create_metric_per_engine(
histogram_iteration_tokens, per_engine_labelvalues
)
histogram_max_num_generation_tokens_request = self._histogram_cls(
......@@ -696,8 +697,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_max_num_generation_tokens_request = make_per_engine(
histogram_max_num_generation_tokens_request, engine_indexes, model_name
self.histogram_max_num_generation_tokens_request = create_metric_per_engine(
histogram_max_num_generation_tokens_request, per_engine_labelvalues
)
histogram_n_request = self._histogram_cls(
......@@ -706,8 +707,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames,
)
self.histogram_n_request = make_per_engine(
histogram_n_request, engine_indexes, model_name
self.histogram_n_request = create_metric_per_engine(
histogram_n_request, per_engine_labelvalues
)
histogram_max_tokens_request = self._histogram_cls(
......@@ -716,8 +717,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_max_tokens_request = make_per_engine(
histogram_max_tokens_request, engine_indexes, model_name
self.histogram_max_tokens_request = create_metric_per_engine(
histogram_max_tokens_request, per_engine_labelvalues
)
#
......@@ -752,8 +753,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_time_to_first_token = make_per_engine(
histogram_time_to_first_token, engine_indexes, model_name
self.histogram_time_to_first_token = create_metric_per_engine(
histogram_time_to_first_token, per_engine_labelvalues
)
histogram_inter_token_latency = self._histogram_cls(
......@@ -782,8 +783,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_inter_token_latency = make_per_engine(
histogram_inter_token_latency, engine_indexes, model_name
self.histogram_inter_token_latency = create_metric_per_engine(
histogram_inter_token_latency, per_engine_labelvalues
)
histogram_request_time_per_output_token = self._histogram_cls(
......@@ -812,8 +813,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_request_time_per_output_token = make_per_engine(
histogram_request_time_per_output_token, engine_indexes, model_name
self.histogram_request_time_per_output_token = create_metric_per_engine(
histogram_request_time_per_output_token, per_engine_labelvalues
)
request_latency_buckets = [
......@@ -845,8 +846,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_e2e_time_request = make_per_engine(
histogram_e2e_time_request, engine_indexes, model_name
self.histogram_e2e_time_request = create_metric_per_engine(
histogram_e2e_time_request, per_engine_labelvalues
)
histogram_queue_time_request = self._histogram_cls(
......@@ -855,8 +856,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_queue_time_request = make_per_engine(
histogram_queue_time_request, engine_indexes, model_name
self.histogram_queue_time_request = create_metric_per_engine(
histogram_queue_time_request, per_engine_labelvalues
)
histogram_inference_time_request = self._histogram_cls(
......@@ -865,8 +866,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_inference_time_request = make_per_engine(
histogram_inference_time_request, engine_indexes, model_name
self.histogram_inference_time_request = create_metric_per_engine(
histogram_inference_time_request, per_engine_labelvalues
)
histogram_prefill_time_request = self._histogram_cls(
......@@ -875,8 +876,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_prefill_time_request = make_per_engine(
histogram_prefill_time_request, engine_indexes, model_name
self.histogram_prefill_time_request = create_metric_per_engine(
histogram_prefill_time_request, per_engine_labelvalues
)
histogram_decode_time_request = self._histogram_cls(
......@@ -885,8 +886,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_decode_time_request = make_per_engine(
histogram_decode_time_request, engine_indexes, model_name
self.histogram_decode_time_request = create_metric_per_engine(
histogram_decode_time_request, per_engine_labelvalues
)
histogram_prefill_kv_computed_request = self._histogram_cls(
......@@ -898,8 +899,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_prefill_kv_computed_request = make_per_engine(
histogram_prefill_kv_computed_request, engine_indexes, model_name
self.histogram_prefill_kv_computed_request = create_metric_per_engine(
histogram_prefill_kv_computed_request, per_engine_labelvalues
)
#
......@@ -939,8 +940,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_lifetime = make_per_engine(
histogram_kv_block_lifetime, engine_indexes, model_name
self.histogram_kv_block_lifetime = create_metric_per_engine(
histogram_kv_block_lifetime, per_engine_labelvalues
)
histogram_kv_block_idle_before_evict = self._histogram_cls(
......@@ -952,8 +953,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_idle_before_evict = make_per_engine(
histogram_kv_block_idle_before_evict, engine_indexes, model_name
self.histogram_kv_block_idle_before_evict = create_metric_per_engine(
histogram_kv_block_idle_before_evict, per_engine_labelvalues
)
histogram_kv_block_reuse_gap = self._histogram_cls(
......@@ -967,8 +968,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_reuse_gap = make_per_engine(
histogram_kv_block_reuse_gap, engine_indexes, model_name
self.histogram_kv_block_reuse_gap = create_metric_per_engine(
histogram_kv_block_reuse_gap, per_engine_labelvalues
)
else:
self.histogram_kv_block_lifetime = {}
......@@ -1203,15 +1204,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
PromMetric: TypeAlias = Gauge | Counter | Histogram
def make_per_engine(
metric: PromMetric, engine_idxs: list[int], model_name: object
) -> dict[int, PromMetric]:
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
......
......@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import (
get_kv_cache_torch_dtype,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__)
......@@ -1291,7 +1292,9 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues)
self.counter_flops = create_metric_per_engine(
counter_flops, per_engine_labelvalues
)
counter_read_bytes = self._counter_cls(
name="vllm:estimated_read_bytes_per_gpu_total",
......@@ -1301,7 +1304,7 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_read_bytes = make_per_engine(
self.counter_read_bytes = create_metric_per_engine(
counter_read_bytes, per_engine_labelvalues
)
......@@ -1313,7 +1316,7 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_write_bytes = make_per_engine(
self.counter_write_bytes = create_metric_per_engine(
counter_write_bytes, per_engine_labelvalues
)
......@@ -1329,16 +1332,6 @@ class PerfMetricsProm:
self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu)
def make_per_engine(
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]]
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
## util functions
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorProm
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.metrics.perf import PerfMetricsProm
from vllm.v1.spec_decode.metrics import SpecDecodingProm
......@@ -168,9 +168,9 @@ class RaySpecDecodingProm(SpecDecodingProm):
_counter_cls = RayCounterWrapper
class RayKVConnectorPrometheus(KVConnectorPrometheus):
class RayKVConnectorProm(KVConnectorProm):
"""
RayKVConnectorPrometheus is used by RayMetrics to log Ray
RayKVConnectorProm is used by RayMetrics to log Ray
metrics. Provides the same metrics as KV connectors but
uses Ray's util.metrics library.
"""
......@@ -197,7 +197,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
_kv_connector_cls = RayKVConnectorPrometheus
_kv_connector_cls = RayKVConnectorProm
_perf_metrics_cls = RayPerfMetricsProm
@staticmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
PromMetric: TypeAlias = Gauge | Counter | Histogram
def create_metric_per_engine(
metric: PromMetric,
per_engine_labelvalues: dict[int, list[object]],
) -> dict[int, PromMetric]:
"""Create a labeled metric child for each engine index."""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
......@@ -9,6 +9,7 @@ import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__)
......@@ -155,7 +156,7 @@ class SpecDecodingProm:
documentation="Number of spec decoding drafts.",
labelnames=labelnames,
)
self.counter_spec_decode_num_drafts = make_per_engine(
self.counter_spec_decode_num_drafts = create_metric_per_engine(
counter_drafts, per_engine_labelvalues
)
......@@ -164,7 +165,7 @@ class SpecDecodingProm:
documentation="Number of draft tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
self.counter_spec_decode_num_draft_tokens = create_metric_per_engine(
counter_draft_tokens, per_engine_labelvalues
)
......@@ -173,7 +174,7 @@ class SpecDecodingProm:
documentation="Number of accepted tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
self.counter_spec_decode_num_accepted_tokens = create_metric_per_engine(
counter_accepted_tokens, per_engine_labelvalues
)
......@@ -212,14 +213,3 @@ class SpecDecodingProm:
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[object]],
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment