Unverified Commit 880be2b1 authored by Martin Hickey's avatar Martin Hickey Committed by GitHub
Browse files

[Metrics] Some small refactoring for better maintainability (#33898)


Signed-off-by: default avatarMartin Hickey <martin.hickey@ie.ibm.com>
parent c0f5fae6
...@@ -126,28 +126,17 @@ class KVConnectorPromMetrics: ...@@ -126,28 +126,17 @@ class KVConnectorPromMetrics:
self._labelnames = labelnames self._labelnames = labelnames
self.per_engine_labelvalues = per_engine_labelvalues self.per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self.per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
""" """
Record the supplied transfer statistics to Prometheus metrics. These Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method. created using the create_metric_per_engine() helper method.
""" """
raise NotImplementedError raise NotImplementedError
class KVConnectorPrometheus: class KVConnectorProm:
""" """
Support for registering per-connector Prometheus metrics, and Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses recording transfer statistics to those metrics. Uses
......
...@@ -65,6 +65,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -65,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec, SlidingWindowSpec,
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size from vllm.v1.worker.utils import select_common_block_size
...@@ -3057,7 +3058,9 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3057,7 +3058,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets[1:], buckets=buckets[1:],
labelnames=labelnames, labelnames=labelnames,
) )
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time) self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls( nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds", name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV" documentation="Histogram of transfer post time for NIXL KV"
...@@ -3065,7 +3068,9 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3065,7 +3068,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets, buckets=buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time) self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range # uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)] buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls( nixl_histogram_bytes_transferred = self._histogram_cls(
...@@ -3074,8 +3079,8 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3074,8 +3079,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets, buckets=buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.nixl_histogram_bytes_transferred = self.make_per_engine( self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred nixl_histogram_bytes_transferred, self.per_engine_labelvalues
) )
buckets = [ buckets = [
10, 10,
...@@ -3100,24 +3105,24 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3100,24 +3105,24 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets, buckets=buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.nixl_histogram_num_descriptors = self.make_per_engine( self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors nixl_histogram_num_descriptors, self.per_engine_labelvalues
) )
counter_nixl_num_failed_transfers = self._counter_cls( counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers", name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.", documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_nixl_num_failed_transfers = self.make_per_engine( self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers counter_nixl_num_failed_transfers, self.per_engine_labelvalues
) )
counter_nixl_num_failed_notifications = self._counter_cls( counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications", name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.", documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_nixl_num_failed_notifications = self.make_per_engine( self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications counter_nixl_num_failed_notifications, self.per_engine_labelvalues
) )
counter_nixl_num_kv_expired_reqs = self._counter_cls( counter_nixl_num_kv_expired_reqs = self._counter_cls(
...@@ -3126,8 +3131,8 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3126,8 +3131,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
"NOTE: This metric is tracked on the P instance.", "NOTE: This metric is tracked on the P instance.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_nixl_num_kv_expired_reqs = self.make_per_engine( self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
) )
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
......
...@@ -5,7 +5,6 @@ import logging ...@@ -5,7 +5,6 @@ import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
...@@ -14,7 +13,7 @@ from vllm.compilation.cuda_graph import CUDAGraphLogging ...@@ -14,7 +13,7 @@ from vllm.compilation.cuda_graph import CUDAGraphLogging
from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging, KVConnectorLogging,
KVConnectorPrometheus, KVConnectorProm,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
...@@ -28,6 +27,7 @@ from vllm.v1.metrics.stats import ( ...@@ -28,6 +27,7 @@ from vllm.v1.metrics.stats import (
PromptTokenStats, PromptTokenStats,
SchedulerStats, SchedulerStats,
) )
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -391,7 +391,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -391,7 +391,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
_counter_cls = Counter _counter_cls = Counter
_histogram_cls = Histogram _histogram_cls = Histogram
_spec_decoding_cls = SpecDecodingProm _spec_decoding_cls = SpecDecodingProm
_kv_connector_cls = KVConnectorPrometheus _kv_connector_cls = KVConnectorProm
_perf_metrics_cls = PerfMetricsProm _perf_metrics_cls = PerfMetricsProm
def __init__( def __init__(
...@@ -415,9 +415,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -415,9 +415,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
model_name = vllm_config.model_config.served_model_name model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
per_engine_labelvalues: dict[int, list[object]] = { self.per_engine_labelvalues: dict[int, list[object]] = {
idx: [model_name, str(idx)] for idx in engine_indexes idx: [model_name, str(idx)] for idx in engine_indexes
} }
per_engine_labelvalues = self.per_engine_labelvalues
self.spec_decoding_prom = self._spec_decoding_cls( self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, per_engine_labelvalues vllm_config.speculative_config, labelnames, per_engine_labelvalues
...@@ -438,8 +439,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -438,8 +439,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
labelnames=labelnames, labelnames=labelnames,
) )
self.gauge_scheduler_running = make_per_engine( self.gauge_scheduler_running = create_metric_per_engine(
gauge_scheduler_running, engine_indexes, model_name gauge_scheduler_running, per_engine_labelvalues
) )
gauge_scheduler_waiting = self._gauge_cls( gauge_scheduler_waiting = self._gauge_cls(
...@@ -448,8 +449,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -448,8 +449,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
labelnames=labelnames, labelnames=labelnames,
) )
self.gauge_scheduler_waiting = make_per_engine( self.gauge_scheduler_waiting = create_metric_per_engine(
gauge_scheduler_waiting, engine_indexes, model_name gauge_scheduler_waiting, per_engine_labelvalues
) )
gauge_engine_sleep_state = self._gauge_cls( gauge_engine_sleep_state = self._gauge_cls(
...@@ -484,8 +485,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -484,8 +485,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent", multiprocess_mode="mostrecent",
labelnames=labelnames, labelnames=labelnames,
) )
self.gauge_kv_cache_usage = make_per_engine( self.gauge_kv_cache_usage = create_metric_per_engine(
gauge_kv_cache_usage, engine_indexes, model_name gauge_kv_cache_usage, per_engine_labelvalues
) )
if envs.VLLM_COMPUTE_NANS_IN_LOGITS: if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
...@@ -497,8 +498,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -497,8 +498,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_corrupted_requests = make_per_engine( self.counter_corrupted_requests = create_metric_per_engine(
counter_corrupted_requests, engine_indexes, model_name counter_corrupted_requests, per_engine_labelvalues
) )
counter_prefix_cache_queries = self._counter_cls( counter_prefix_cache_queries = self._counter_cls(
...@@ -508,8 +509,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -508,8 +509,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_prefix_cache_queries = make_per_engine( self.counter_prefix_cache_queries = create_metric_per_engine(
counter_prefix_cache_queries, engine_indexes, model_name counter_prefix_cache_queries, per_engine_labelvalues
) )
counter_prefix_cache_hits = self._counter_cls( counter_prefix_cache_hits = self._counter_cls(
...@@ -517,8 +518,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -517,8 +518,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation=("Prefix cache hits, in terms of number of cached tokens."), documentation=("Prefix cache hits, in terms of number of cached tokens."),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_prefix_cache_hits = make_per_engine( self.counter_prefix_cache_hits = create_metric_per_engine(
counter_prefix_cache_hits, engine_indexes, model_name counter_prefix_cache_hits, per_engine_labelvalues
) )
# #
...@@ -533,8 +534,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -533,8 +534,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_connector_prefix_cache_queries = make_per_engine( self.counter_connector_prefix_cache_queries = create_metric_per_engine(
counter_connector_prefix_cache_queries, engine_indexes, model_name counter_connector_prefix_cache_queries, per_engine_labelvalues
) )
counter_connector_prefix_cache_hits = self._counter_cls( counter_connector_prefix_cache_hits = self._counter_cls(
...@@ -545,8 +546,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -545,8 +546,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_connector_prefix_cache_hits = make_per_engine( self.counter_connector_prefix_cache_hits = create_metric_per_engine(
counter_connector_prefix_cache_hits, engine_indexes, model_name counter_connector_prefix_cache_hits, per_engine_labelvalues
) )
# #
...@@ -560,8 +561,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -560,8 +561,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_mm_cache_queries = make_per_engine( self.counter_mm_cache_queries = create_metric_per_engine(
counter_mm_cache_queries, engine_indexes, model_name counter_mm_cache_queries, per_engine_labelvalues
) )
counter_mm_cache_hits = self._counter_cls( counter_mm_cache_hits = self._counter_cls(
...@@ -571,8 +572,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -571,8 +572,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_mm_cache_hits = make_per_engine( self.counter_mm_cache_hits = create_metric_per_engine(
counter_mm_cache_hits, engine_indexes, model_name counter_mm_cache_hits, per_engine_labelvalues
) )
# #
...@@ -583,8 +584,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -583,8 +584,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_num_preempted_reqs = make_per_engine( self.counter_num_preempted_reqs = create_metric_per_engine(
counter_num_preempted_reqs, engine_indexes, model_name counter_num_preempted_reqs, per_engine_labelvalues
) )
counter_prompt_tokens = self._counter_cls( counter_prompt_tokens = self._counter_cls(
...@@ -592,8 +593,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -592,8 +593,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_prompt_tokens = make_per_engine( self.counter_prompt_tokens = create_metric_per_engine(
counter_prompt_tokens, engine_indexes, model_name counter_prompt_tokens, per_engine_labelvalues
) )
# Labeled prompt token counters by source # Labeled prompt token counters by source
...@@ -617,8 +618,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -617,8 +618,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached prompt tokens (local + external).", documentation="Number of cached prompt tokens (local + external).",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_prompt_tokens_cached = make_per_engine( self.counter_prompt_tokens_cached = create_metric_per_engine(
counter_prompt_tokens_cached, engine_indexes, model_name counter_prompt_tokens_cached, per_engine_labelvalues
) )
# Recomputed tokens (last token recomputed when entire prompt is cached) # Recomputed tokens (last token recomputed when entire prompt is cached)
...@@ -627,8 +628,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -627,8 +628,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached tokens recomputed for forward pass.", documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_prompt_tokens_recomputed = make_per_engine( self.counter_prompt_tokens_recomputed = create_metric_per_engine(
counter_prompt_tokens_recomputed, engine_indexes, model_name counter_prompt_tokens_recomputed, per_engine_labelvalues
) )
counter_generation_tokens = self._counter_cls( counter_generation_tokens = self._counter_cls(
...@@ -636,8 +637,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -636,8 +637,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_generation_tokens = make_per_engine( self.counter_generation_tokens = create_metric_per_engine(
counter_generation_tokens, engine_indexes, model_name counter_generation_tokens, per_engine_labelvalues
) )
self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {} self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {}
...@@ -663,8 +664,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -663,8 +664,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_num_prompt_tokens_request = make_per_engine( self.histogram_num_prompt_tokens_request = create_metric_per_engine(
histogram_num_prompt_tokens_request, engine_indexes, model_name histogram_num_prompt_tokens_request, per_engine_labelvalues
) )
histogram_num_generation_tokens_request = self._histogram_cls( histogram_num_generation_tokens_request = self._histogram_cls(
...@@ -673,8 +674,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -673,8 +674,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_num_generation_tokens_request = make_per_engine( self.histogram_num_generation_tokens_request = create_metric_per_engine(
histogram_num_generation_tokens_request, engine_indexes, model_name histogram_num_generation_tokens_request, per_engine_labelvalues
) )
# TODO: This metric might be incorrect in case of using multiple # TODO: This metric might be incorrect in case of using multiple
...@@ -686,8 +687,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -686,8 +687,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_iteration_tokens = make_per_engine( self.histogram_iteration_tokens = create_metric_per_engine(
histogram_iteration_tokens, engine_indexes, model_name histogram_iteration_tokens, per_engine_labelvalues
) )
histogram_max_num_generation_tokens_request = self._histogram_cls( histogram_max_num_generation_tokens_request = self._histogram_cls(
...@@ -696,8 +697,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -696,8 +697,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_max_num_generation_tokens_request = make_per_engine( self.histogram_max_num_generation_tokens_request = create_metric_per_engine(
histogram_max_num_generation_tokens_request, engine_indexes, model_name histogram_max_num_generation_tokens_request, per_engine_labelvalues
) )
histogram_n_request = self._histogram_cls( histogram_n_request = self._histogram_cls(
...@@ -706,8 +707,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -706,8 +707,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_n_request = make_per_engine( self.histogram_n_request = create_metric_per_engine(
histogram_n_request, engine_indexes, model_name histogram_n_request, per_engine_labelvalues
) )
histogram_max_tokens_request = self._histogram_cls( histogram_max_tokens_request = self._histogram_cls(
...@@ -716,8 +717,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -716,8 +717,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_max_tokens_request = make_per_engine( self.histogram_max_tokens_request = create_metric_per_engine(
histogram_max_tokens_request, engine_indexes, model_name histogram_max_tokens_request, per_engine_labelvalues
) )
# #
...@@ -752,8 +753,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -752,8 +753,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
], ],
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_time_to_first_token = make_per_engine( self.histogram_time_to_first_token = create_metric_per_engine(
histogram_time_to_first_token, engine_indexes, model_name histogram_time_to_first_token, per_engine_labelvalues
) )
histogram_inter_token_latency = self._histogram_cls( histogram_inter_token_latency = self._histogram_cls(
...@@ -782,8 +783,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -782,8 +783,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
], ],
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_inter_token_latency = make_per_engine( self.histogram_inter_token_latency = create_metric_per_engine(
histogram_inter_token_latency, engine_indexes, model_name histogram_inter_token_latency, per_engine_labelvalues
) )
histogram_request_time_per_output_token = self._histogram_cls( histogram_request_time_per_output_token = self._histogram_cls(
...@@ -812,8 +813,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -812,8 +813,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
], ],
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_request_time_per_output_token = make_per_engine( self.histogram_request_time_per_output_token = create_metric_per_engine(
histogram_request_time_per_output_token, engine_indexes, model_name histogram_request_time_per_output_token, per_engine_labelvalues
) )
request_latency_buckets = [ request_latency_buckets = [
...@@ -845,8 +846,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -845,8 +846,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_e2e_time_request = make_per_engine( self.histogram_e2e_time_request = create_metric_per_engine(
histogram_e2e_time_request, engine_indexes, model_name histogram_e2e_time_request, per_engine_labelvalues
) )
histogram_queue_time_request = self._histogram_cls( histogram_queue_time_request = self._histogram_cls(
...@@ -855,8 +856,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -855,8 +856,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_queue_time_request = make_per_engine( self.histogram_queue_time_request = create_metric_per_engine(
histogram_queue_time_request, engine_indexes, model_name histogram_queue_time_request, per_engine_labelvalues
) )
histogram_inference_time_request = self._histogram_cls( histogram_inference_time_request = self._histogram_cls(
...@@ -865,8 +866,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -865,8 +866,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_inference_time_request = make_per_engine( self.histogram_inference_time_request = create_metric_per_engine(
histogram_inference_time_request, engine_indexes, model_name histogram_inference_time_request, per_engine_labelvalues
) )
histogram_prefill_time_request = self._histogram_cls( histogram_prefill_time_request = self._histogram_cls(
...@@ -875,8 +876,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -875,8 +876,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_prefill_time_request = make_per_engine( self.histogram_prefill_time_request = create_metric_per_engine(
histogram_prefill_time_request, engine_indexes, model_name histogram_prefill_time_request, per_engine_labelvalues
) )
histogram_decode_time_request = self._histogram_cls( histogram_decode_time_request = self._histogram_cls(
...@@ -885,8 +886,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -885,8 +886,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_decode_time_request = make_per_engine( self.histogram_decode_time_request = create_metric_per_engine(
histogram_decode_time_request, engine_indexes, model_name histogram_decode_time_request, per_engine_labelvalues
) )
histogram_prefill_kv_computed_request = self._histogram_cls( histogram_prefill_kv_computed_request = self._histogram_cls(
...@@ -898,8 +899,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -898,8 +899,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_prefill_kv_computed_request = make_per_engine( self.histogram_prefill_kv_computed_request = create_metric_per_engine(
histogram_prefill_kv_computed_request, engine_indexes, model_name histogram_prefill_kv_computed_request, per_engine_labelvalues
) )
# #
...@@ -939,8 +940,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -939,8 +940,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets, buckets=kv_cache_residency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_kv_block_lifetime = make_per_engine( self.histogram_kv_block_lifetime = create_metric_per_engine(
histogram_kv_block_lifetime, engine_indexes, model_name histogram_kv_block_lifetime, per_engine_labelvalues
) )
histogram_kv_block_idle_before_evict = self._histogram_cls( histogram_kv_block_idle_before_evict = self._histogram_cls(
...@@ -952,8 +953,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -952,8 +953,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets, buckets=kv_cache_residency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_kv_block_idle_before_evict = make_per_engine( self.histogram_kv_block_idle_before_evict = create_metric_per_engine(
histogram_kv_block_idle_before_evict, engine_indexes, model_name histogram_kv_block_idle_before_evict, per_engine_labelvalues
) )
histogram_kv_block_reuse_gap = self._histogram_cls( histogram_kv_block_reuse_gap = self._histogram_cls(
...@@ -967,8 +968,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -967,8 +968,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets, buckets=kv_cache_residency_buckets,
labelnames=labelnames, labelnames=labelnames,
) )
self.histogram_kv_block_reuse_gap = make_per_engine( self.histogram_kv_block_reuse_gap = create_metric_per_engine(
histogram_kv_block_reuse_gap, engine_indexes, model_name histogram_kv_block_reuse_gap, per_engine_labelvalues
) )
else: else:
self.histogram_kv_block_lifetime = {} self.histogram_kv_block_lifetime = {}
...@@ -1203,15 +1204,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -1203,15 +1204,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.log_metrics_info("cache_config", self.vllm_config.cache_config) self.log_metrics_info("cache_config", self.vllm_config.cache_config)
PromMetric: TypeAlias = Gauge | Counter | Histogram
def make_per_engine(
metric: PromMetric, engine_idxs: list[int], model_name: object
) -> dict[int, PromMetric]:
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
""" """
Builds a list of buckets with increasing powers of 10 multiplied by Builds a list of buckets with increasing powers of 10 multiplied by
......
...@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import ( ...@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import (
get_kv_cache_torch_dtype, get_kv_cache_torch_dtype,
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1291,7 +1292,9 @@ class PerfMetricsProm: ...@@ -1291,7 +1292,9 @@ class PerfMetricsProm:
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues) self.counter_flops = create_metric_per_engine(
counter_flops, per_engine_labelvalues
)
counter_read_bytes = self._counter_cls( counter_read_bytes = self._counter_cls(
name="vllm:estimated_read_bytes_per_gpu_total", name="vllm:estimated_read_bytes_per_gpu_total",
...@@ -1301,7 +1304,7 @@ class PerfMetricsProm: ...@@ -1301,7 +1304,7 @@ class PerfMetricsProm:
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_read_bytes = make_per_engine( self.counter_read_bytes = create_metric_per_engine(
counter_read_bytes, per_engine_labelvalues counter_read_bytes, per_engine_labelvalues
) )
...@@ -1313,7 +1316,7 @@ class PerfMetricsProm: ...@@ -1313,7 +1316,7 @@ class PerfMetricsProm:
), ),
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_write_bytes = make_per_engine( self.counter_write_bytes = create_metric_per_engine(
counter_write_bytes, per_engine_labelvalues counter_write_bytes, per_engine_labelvalues
) )
...@@ -1329,16 +1332,6 @@ class PerfMetricsProm: ...@@ -1329,16 +1332,6 @@ class PerfMetricsProm:
self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu) self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu)
def make_per_engine(
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]]
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
## util functions ## util functions
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorProm
from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.metrics.perf import PerfMetricsProm from vllm.v1.metrics.perf import PerfMetricsProm
from vllm.v1.spec_decode.metrics import SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingProm
...@@ -168,9 +168,9 @@ class RaySpecDecodingProm(SpecDecodingProm): ...@@ -168,9 +168,9 @@ class RaySpecDecodingProm(SpecDecodingProm):
_counter_cls = RayCounterWrapper _counter_cls = RayCounterWrapper
class RayKVConnectorPrometheus(KVConnectorPrometheus): class RayKVConnectorProm(KVConnectorProm):
""" """
RayKVConnectorPrometheus is used by RayMetrics to log Ray RayKVConnectorProm is used by RayMetrics to log Ray
metrics. Provides the same metrics as KV connectors but metrics. Provides the same metrics as KV connectors but
uses Ray's util.metrics library. uses Ray's util.metrics library.
""" """
...@@ -197,7 +197,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger): ...@@ -197,7 +197,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_counter_cls = RayCounterWrapper _counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper _histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm _spec_decoding_cls = RaySpecDecodingProm
_kv_connector_cls = RayKVConnectorPrometheus _kv_connector_cls = RayKVConnectorProm
_perf_metrics_cls = RayPerfMetricsProm _perf_metrics_cls = RayPerfMetricsProm
@staticmethod @staticmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
PromMetric: TypeAlias = Gauge | Counter | Histogram
def create_metric_per_engine(
metric: PromMetric,
per_engine_labelvalues: dict[int, list[object]],
) -> dict[int, PromMetric]:
"""Create a labeled metric child for each engine index."""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
...@@ -9,6 +9,7 @@ import prometheus_client ...@@ -9,6 +9,7 @@ import prometheus_client
from vllm.config import SpeculativeConfig from vllm.config import SpeculativeConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -155,7 +156,7 @@ class SpecDecodingProm: ...@@ -155,7 +156,7 @@ class SpecDecodingProm:
documentation="Number of spec decoding drafts.", documentation="Number of spec decoding drafts.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_spec_decode_num_drafts = make_per_engine( self.counter_spec_decode_num_drafts = create_metric_per_engine(
counter_drafts, per_engine_labelvalues counter_drafts, per_engine_labelvalues
) )
...@@ -164,7 +165,7 @@ class SpecDecodingProm: ...@@ -164,7 +165,7 @@ class SpecDecodingProm:
documentation="Number of draft tokens.", documentation="Number of draft tokens.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_spec_decode_num_draft_tokens = make_per_engine( self.counter_spec_decode_num_draft_tokens = create_metric_per_engine(
counter_draft_tokens, per_engine_labelvalues counter_draft_tokens, per_engine_labelvalues
) )
...@@ -173,7 +174,7 @@ class SpecDecodingProm: ...@@ -173,7 +174,7 @@ class SpecDecodingProm:
documentation="Number of accepted tokens.", documentation="Number of accepted tokens.",
labelnames=labelnames, labelnames=labelnames,
) )
self.counter_spec_decode_num_accepted_tokens = make_per_engine( self.counter_spec_decode_num_accepted_tokens = create_metric_per_engine(
counter_accepted_tokens, per_engine_labelvalues counter_accepted_tokens, per_engine_labelvalues
) )
...@@ -212,14 +213,3 @@ class SpecDecodingProm: ...@@ -212,14 +213,3 @@ class SpecDecodingProm:
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
): ):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[object]],
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment