loggers.py 39.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
6
import time
from abc import ABC, abstractmethod
7
8
from collections.abc import Callable
from typing import TypeAlias
9

10
from prometheus_client import Counter, Gauge, Histogram
11

12
from vllm.config import SupportsMetricsInfo, VllmConfig
13
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
14
from vllm.logger import init_logger
15
from vllm.plugins import load_plugins_by_group
16
from vllm.v1.engine import FinishReason
17
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
18
19
20
21
22
23
from vllm.v1.metrics.stats import (
    CachingMetrics,
    IterationStats,
    MultiModalCacheStats,
    SchedulerStats,
)
24
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
25
26
27

logger = init_logger(__name__)

28
29
30
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
AggregateStatLoggerFactory = type["AggregateStatLoggerBase"]
StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory
31

32
33

class StatLoggerBase(ABC):
34
35
36
37
38
39
40
41
    """Interface for logging metrics.

    API users may define custom loggers that implement this interface.
    However, note that the `SchedulerStats` and `IterationStats` classes
    are not considered stable interfaces and may change in future versions.
    """

    @abstractmethod
42
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ...
43
44

    @abstractmethod
45
46
    def record(
        self,
47
48
49
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
50
51
        engine_idx: int = 0,
    ): ...
52

53
    @abstractmethod
54
    def log_engine_initialized(self): ...
55

56
57
58
    def log(self):  # noqa
        pass

59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
    factories: list[StatLoggerFactory] = []

    for name, plugin_class in load_plugins_by_group("vllm.stat_logger_plugins").items():
        if not isinstance(plugin_class, type) or not issubclass(
            plugin_class, StatLoggerBase
        ):
            raise TypeError(
                f"Stat logger plugin {name!r} must be a subclass of "
                f"StatLoggerBase (got {plugin_class!r})."
            )

        factories.append(plugin_class)

    return factories


77
78
79
80
81
82
83
84
class AggregateStatLoggerBase(StatLoggerBase):
    """Abstract base class for loggers that
    aggregate across multiple DP engines."""

    @abstractmethod
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ...


85
class LoggingStatLogger(StatLoggerBase):
86
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
87
        self.engine_index = engine_index
88
        self.vllm_config = vllm_config
89
        self._reset(time.monotonic())
90

91
        self.last_scheduler_stats = SchedulerStats()
92
93

        # Caching metrics. This cannot be reset.
94
        # TODO: Make the interval configurable.
95
96
97
        self.prefix_caching_metrics = CachingMetrics()
        self.mm_caching_metrics = CachingMetrics()

98
        self.spec_decoding_logging = SpecDecodingLogging()
99
        kv_tranfer_config = self.vllm_config.kv_transfer_config
100
        self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
101
102
        self.last_prompt_throughput: float = 0.0
        self.last_generation_throughput: float = 0.0
103
104
        self.engine_is_idle = False
        self.aggregated = False
105

106
107
108
109
    def _reset(self, now):
        self.last_log_time = now

        # Tracked stats over current local logging interval.
110
111
        self.num_prompt_tokens: int = 0
        self.num_generation_tokens: int = 0
112

113
114
    def _track_iteration_stats(self, iteration_stats: IterationStats):
        # Save tracked stats for token counters.
115
116
        self.num_prompt_tokens += iteration_stats.num_prompt_tokens
        self.num_generation_tokens += iteration_stats.num_generation_tokens
117

118
    def _get_throughput(self, tracked_stats: int, now: float) -> float:
119
        # Compute summary metrics for tracked stats
120
121
122
123
        delta_time = now - self.last_log_time
        if delta_time <= 0.0:
            return 0.0
        return float(tracked_stats / delta_time)
124

125
126
127
128
    @property
    def log_prefix(self):
        return "Engine {:03d}: ".format(self.engine_index)

129
130
    def record(
        self,
131
132
133
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
134
135
        engine_idx: int = 0,
    ):
136
        """Log Stats to standard output."""
137
138
        if iteration_stats:
            self._track_iteration_stats(iteration_stats)
139

140
        if scheduler_stats is not None:
141
            self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
142

143
            if scheduler_stats.spec_decoding_stats is not None:
144
                self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
145
            if kv_connector_stats := scheduler_stats.kv_connector_stats:
146
                self.kv_connector_logging.observe(kv_connector_stats)
147
148
            if not self.aggregated:
                self.last_scheduler_stats = scheduler_stats
149
150
151
        if mm_cache_stats:
            self.mm_caching_metrics.observe(mm_cache_stats)

152
    def _update_stats(self):
153
        now = time.monotonic()
154
        prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
155
        generation_throughput = self._get_throughput(self.num_generation_tokens, now)
156
157

        self._reset(now)
158
        self.engine_is_idle = not any(
159
160
161
162
163
164
            (
                prompt_throughput,
                generation_throughput,
                self.last_prompt_throughput,
                self.last_generation_throughput,
            )
165
        )
166
167
168
        self.last_generation_throughput = generation_throughput
        self.last_prompt_throughput = prompt_throughput

169
170
171
172
173
174
175
176
177
    def aggregate_scheduler_stats(self):
        # noop for per engine loggers
        return

    def log(self):
        self._update_stats()
        self.aggregate_scheduler_stats()
        # Avoid log noise on an idle production system
        log_fn = logger.debug if self.engine_is_idle else logger.info
178
        # Format and print output.
179
180
181
182
183
184
        log_parts = [
            "Avg prompt throughput: %.1f tokens/s",
            "Avg generation throughput: %.1f tokens/s",
            "Running: %d reqs",
            "Waiting: %d reqs",
            "GPU KV cache usage: %.1f%%",
185
            "Prefix cache hit rate: %.1f%%",
186
187
        ]
        log_args = [
188
189
190
191
192
            self.last_prompt_throughput,
            self.last_generation_throughput,
            self.last_scheduler_stats.num_running_reqs,
            self.last_scheduler_stats.num_waiting_reqs,
            self.last_scheduler_stats.kv_cache_usage * 100,
193
            self.prefix_caching_metrics.hit_rate * 100,
194
        ]
195
        if not self.mm_caching_metrics.empty:
196
197
198
199
            log_parts.append("MM cache hit rate: %.1f%%")
            log_args.append(self.mm_caching_metrics.hit_rate * 100)

        log_fn(
200
            self.log_prefix + ", ".join(log_parts),
201
            *log_args,
202
        )
203

204
        self.spec_decoding_logging.log(log_fn=log_fn)
205
        self.kv_connector_logging.log(log_fn=log_fn)
206

207
    def log_engine_initialized(self):
208
209
210
        if self.vllm_config.cache_config.num_gpu_blocks:
            logger.info(
                "Engine %03d: vllm cache_config_info with initialization "
211
212
213
214
                "after num_gpu_blocks is: %d",
                self.engine_index,
                self.vllm_config.cache_config.num_gpu_blocks,
            )
215

216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
    def __init__(
        self,
        vllm_config: VllmConfig,
        engine_indexes: list[int],
    ):
        self.engine_indexes = engine_indexes
        self.last_scheduler_stats_dict: dict[int, SchedulerStats] = {
            idx: SchedulerStats() for idx in self.engine_indexes
        }
        LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
        self.aggregated = True

    @property
    def log_prefix(self):
        return "{} Engines Aggregated: ".format(len(self.engine_indexes))

    def record(
        self,
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int = 0,
    ):
        if engine_idx not in self.engine_indexes:
            logger.warning("Unexpected engine_idx: %d", engine_idx)
            return
        LoggingStatLogger.record(
            self,
            scheduler_stats,
            iteration_stats,
            mm_cache_stats=mm_cache_stats,
            engine_idx=engine_idx,
        )
        if scheduler_stats is not None:
            self.last_scheduler_stats_dict[engine_idx] = scheduler_stats

    def aggregate_scheduler_stats(self):
        self.last_scheduler_stats = SchedulerStats()
        for last_scheduler_stats in self.last_scheduler_stats_dict.values():
            self.last_scheduler_stats.num_waiting_reqs += (
                last_scheduler_stats.num_waiting_reqs
            )
            self.last_scheduler_stats.num_running_reqs += (
                last_scheduler_stats.num_running_reqs
            )
            self.last_scheduler_stats.num_corrupted_reqs += (
                last_scheduler_stats.num_corrupted_reqs
            )
            self.last_scheduler_stats.kv_cache_usage += (
                last_scheduler_stats.kv_cache_usage
            )
        self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict)

    def log(self):
        LoggingStatLogger.log(self)

    def log_engine_initialized(self):
        if self.vllm_config.cache_config.num_gpu_blocks:
            logger.info(
                "%d Engines: vllm cache_config_info with initialization "
                "after num_gpu_blocks is: %d",
                len(self.engine_indexes),
                self.vllm_config.cache_config.num_gpu_blocks,
            )


class PerEngineStatLoggerAdapter(AggregateStatLoggerBase):
    def __init__(
        self,
        vllm_config: VllmConfig,
        engine_indexes: list[int],
        per_engine_stat_logger_factory: PerEngineStatLoggerFactory,
    ) -> None:
        self.per_engine_stat_loggers = {}
        self.engine_indexes = engine_indexes
        for engine_index in engine_indexes:
            self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory(
                vllm_config, engine_index
            )

    def record(
        self,
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int = 0,
    ):
        if engine_idx not in self.per_engine_stat_loggers:
            logger.warning("Unexpected engine_idx: %d", engine_idx)
            return
        self.per_engine_stat_loggers[engine_idx].record(
            scheduler_stats,
            iteration_stats,
            mm_cache_stats=mm_cache_stats,
            engine_idx=engine_idx,
        )

    def log(self):
        for per_engine_stat_logger in self.per_engine_stat_loggers.values():
            per_engine_stat_logger.log()

    def log_engine_initialized(self):
        for per_engine_stat_logger in self.per_engine_stat_loggers.values():
            per_engine_stat_logger.log_engine_initialized()


class PrometheusStatLogger(AggregateStatLoggerBase):
325
326
327
    _gauge_cls = Gauge
    _counter_cls = Counter
    _histogram_cls = Histogram
328
    _spec_decoding_cls = SpecDecodingProm
329

330
    def __init__(
331
        self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None
332
    ):
333
334
        if engine_indexes is None:
            engine_indexes = [0]
335

336
        self.engine_indexes = engine_indexes
337
338

        unregister_vllm_metrics()
339
        self.vllm_config = vllm_config
340
341
        # Use this flag to hide metrics that were deprecated in
        # a previous release and which will be removed future
342
        self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics
343

344
        labelnames = ["model_name", "engine"]
345
        model_name = vllm_config.model_config.served_model_name
346
        max_model_len = vllm_config.model_config.max_model_len
347

348
        spec_decode_labelvalues: dict[int, list[str]] = {
349
            idx: [model_name, str(idx)] for idx in engine_indexes
350
351
        }

352
        self.spec_decoding_prom = self._spec_decoding_cls(
353
354
            vllm_config.speculative_config, labelnames, spec_decode_labelvalues
        )
355

356
357
358
        #
        # Scheduler state
        #
359
        gauge_scheduler_running = self._gauge_cls(
360
361
            name="vllm:num_requests_running",
            documentation="Number of requests in model execution batches.",
362
            multiprocess_mode="mostrecent",
363
364
365
366
367
            labelnames=labelnames,
        )
        self.gauge_scheduler_running = make_per_engine(
            gauge_scheduler_running, engine_indexes, model_name
        )
368

369
        gauge_scheduler_waiting = self._gauge_cls(
370
371
            name="vllm:num_requests_waiting",
            documentation="Number of requests waiting to be processed.",
372
            multiprocess_mode="mostrecent",
373
374
375
376
377
            labelnames=labelnames,
        )
        self.gauge_scheduler_waiting = make_per_engine(
            gauge_scheduler_waiting, engine_indexes, model_name
        )
378

379
380
381
        #
        # GPU cache
        #
382
383
384
385
386
387
388
389
        # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
        # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
        # TODO: remove in 0.12.0
        if self.show_hidden_metrics:
            gauge_gpu_cache_usage = self._gauge_cls(
                name="vllm:gpu_cache_usage_perc",
                documentation=(
                    "GPU KV-cache usage. 1 means 100 percent usage."
390
391
                    "DEPRECATED: Use vllm:kv_cache_usage_perc instead."
                ),
392
                multiprocess_mode="mostrecent",
393
394
                labelnames=labelnames,
            )
395
            self.gauge_gpu_cache_usage = make_per_engine(
396
397
                gauge_gpu_cache_usage, engine_indexes, model_name
            )
398
399
400
401
402
403
404
405
406
407
408

        # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries
        # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
        # TODO: remove in 0.12.0
        if self.show_hidden_metrics:
            counter_gpu_prefix_cache_queries = self._counter_cls(
                name="vllm:gpu_prefix_cache_queries",
                documentation=(
                    "GPU prefix cache queries, in terms of number of queried"
                    "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."
                ),
409
410
                labelnames=labelnames,
            )
411
            self.counter_gpu_prefix_cache_queries = make_per_engine(
412
413
                counter_gpu_prefix_cache_queries, engine_indexes, model_name
            )
414
415
416
417
418
419
420
421
422

        # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits
        # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10
        # TODO: remove in 0.12.0
        if self.show_hidden_metrics:
            counter_gpu_prefix_cache_hits = self._counter_cls(
                name="vllm:gpu_prefix_cache_hits",
                documentation=(
                    "GPU prefix cache hits, in terms of number of cached "
423
424
425
426
                    "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."
                ),
                labelnames=labelnames,
            )
427
            self.counter_gpu_prefix_cache_hits = make_per_engine(
428
429
                counter_gpu_prefix_cache_hits, engine_indexes, model_name
            )
430

431
        gauge_kv_cache_usage = self._gauge_cls(
432
433
            name="vllm:kv_cache_usage_perc",
            documentation="KV-cache usage. 1 means 100 percent usage.",
434
435
436
437
438
            labelnames=labelnames,
        )
        self.gauge_kv_cache_usage = make_per_engine(
            gauge_kv_cache_usage, engine_indexes, model_name
        )
439

440
        counter_prefix_cache_queries = self._counter_cls(
441
442
            name="vllm:prefix_cache_queries",
            documentation=(
443
444
445
446
                "Prefix cache queries, in terms of number of queried tokens."
            ),
            labelnames=labelnames,
        )
447
        self.counter_prefix_cache_queries = make_per_engine(
448
449
            counter_prefix_cache_queries, engine_indexes, model_name
        )
450

451
        counter_prefix_cache_hits = self._counter_cls(
452
            name="vllm:prefix_cache_hits",
453
454
455
            documentation=("Prefix cache hits, in terms of number of cached tokens."),
            labelnames=labelnames,
        )
456
        self.counter_prefix_cache_hits = make_per_engine(
457
458
            counter_prefix_cache_hits, engine_indexes, model_name
        )
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        #
        # Multi-modal cache
        #

        counter_mm_cache_queries = self._counter_cls(
            name="vllm:mm_cache_queries",
            documentation=(
                "Multi-modal cache queries, in terms of number of queried items."
            ),
            labelnames=labelnames,
        )
        self.counter_mm_cache_queries = make_per_engine(
            counter_mm_cache_queries, engine_indexes, model_name
        )

        counter_mm_cache_hits = self._counter_cls(
            name="vllm:mm_cache_hits",
            documentation=(
                "Multi-modal cache hits, in terms of number of cached items."
            ),
            labelnames=labelnames,
        )
        self.counter_mm_cache_hits = make_per_engine(
            counter_mm_cache_hits, engine_indexes, model_name
        )

486
487
488
        #
        # Counters
        #
489
        counter_num_preempted_reqs = self._counter_cls(
490
            name="vllm:num_preemptions",
491
            documentation="Cumulative number of preemption from the engine.",
492
493
            labelnames=labelnames,
        )
494
        self.counter_num_preempted_reqs = make_per_engine(
495
496
            counter_num_preempted_reqs, engine_indexes, model_name
        )
497

498
        counter_prompt_tokens = self._counter_cls(
499
            name="vllm:prompt_tokens",
500
            documentation="Number of prefill tokens processed.",
501
502
503
504
505
            labelnames=labelnames,
        )
        self.counter_prompt_tokens = make_per_engine(
            counter_prompt_tokens, engine_indexes, model_name
        )
506

507
        counter_generation_tokens = self._counter_cls(
508
            name="vllm:generation_tokens",
509
            documentation="Number of generation tokens processed.",
510
511
            labelnames=labelnames,
        )
512
        self.counter_generation_tokens = make_per_engine(
513
514
            counter_generation_tokens, engine_indexes, model_name
        )
515

516
        self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {}
517
        counter_request_success_base = self._counter_cls(
518
            name="vllm:request_success",
519
            documentation="Count of successfully processed requests.",
520
521
            labelnames=labelnames + ["finished_reason"],
        )
522
        for reason in FinishReason:
523
            self.counter_request_success[reason] = {
524
525
526
                idx: counter_request_success_base.labels(
                    model_name, str(idx), str(reason)
                )
527
528
                for idx in engine_indexes
            }
529

530
531
532
        #
        # Histograms of counts
        #
533
534
535
536
        histogram_num_prompt_tokens_request = self._histogram_cls(
            name="vllm:request_prompt_tokens",
            documentation="Number of prefill tokens processed.",
            buckets=build_1_2_5_buckets(max_model_len),
537
538
            labelnames=labelnames,
        )
539
        self.histogram_num_prompt_tokens_request = make_per_engine(
540
541
            histogram_num_prompt_tokens_request, engine_indexes, model_name
        )
542
543
544
545
546

        histogram_num_generation_tokens_request = self._histogram_cls(
            name="vllm:request_generation_tokens",
            documentation="Number of generation tokens processed.",
            buckets=build_1_2_5_buckets(max_model_len),
547
548
            labelnames=labelnames,
        )
549
        self.histogram_num_generation_tokens_request = make_per_engine(
550
551
            histogram_num_generation_tokens_request, engine_indexes, model_name
        )
552

553
554
555
        # TODO: This metric might be incorrect in case of using multiple
        # api_server counts which uses prometheus mp.
        # See: https://github.com/vllm-project/vllm/pull/18053
556
557
558
        histogram_iteration_tokens = self._histogram_cls(
            name="vllm:iteration_tokens_total",
            documentation="Histogram of number of tokens per engine_step.",
559
560
561
            buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
            labelnames=labelnames,
        )
562
        self.histogram_iteration_tokens = make_per_engine(
563
564
            histogram_iteration_tokens, engine_indexes, model_name
        )
565
566
567

        histogram_max_num_generation_tokens_request = self._histogram_cls(
            name="vllm:request_max_num_generation_tokens",
568
            documentation="Histogram of maximum number of requested generation tokens.",
569
            buckets=build_1_2_5_buckets(max_model_len),
570
571
            labelnames=labelnames,
        )
572
        self.histogram_max_num_generation_tokens_request = make_per_engine(
573
574
            histogram_max_num_generation_tokens_request, engine_indexes, model_name
        )
575
576
577
578
579

        histogram_n_request = self._histogram_cls(
            name="vllm:request_params_n",
            documentation="Histogram of the n request parameter.",
            buckets=[1, 2, 5, 10, 20],
580
581
582
583
584
            labelnames=labelnames,
        )
        self.histogram_n_request = make_per_engine(
            histogram_n_request, engine_indexes, model_name
        )
585
586
587
588
589

        histogram_max_tokens_request = self._histogram_cls(
            name="vllm:request_params_max_tokens",
            documentation="Histogram of the max_tokens request parameter.",
            buckets=build_1_2_5_buckets(max_model_len),
590
591
            labelnames=labelnames,
        )
592
        self.histogram_max_tokens_request = make_per_engine(
593
594
            histogram_max_tokens_request, engine_indexes, model_name
        )
595
596
597
598

        #
        # Histogram of timing intervals
        #
599
600
601
602
        histogram_time_to_first_token = self._histogram_cls(
            name="vllm:time_to_first_token_seconds",
            documentation="Histogram of time to first token in seconds.",
            buckets=[
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
                0.001,
                0.005,
                0.01,
                0.02,
                0.04,
                0.06,
                0.08,
                0.1,
                0.25,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
                160.0,
                640.0,
                2560.0,
625
            ],
626
627
            labelnames=labelnames,
        )
628
        self.histogram_time_to_first_token = make_per_engine(
629
630
            histogram_time_to_first_token, engine_indexes, model_name
        )
631

632
633
        # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
        # TODO: in 0.12, only enable if show_hidden_metrics=True
634
635
        histogram_time_per_output_token = self._histogram_cls(
            name="vllm:time_per_output_token_seconds",
636
637
            documentation=(
                "Histogram of time per output token in seconds."
638
639
                "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
            ),
640
            buckets=[
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
                0.01,
                0.025,
                0.05,
                0.075,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
660
            ],
661
662
            labelnames=labelnames,
        )
663
        self.histogram_time_per_output_token = make_per_engine(
664
665
            histogram_time_per_output_token, engine_indexes, model_name
        )
666

667
668
669
670
        histogram_inter_token_latency = self._histogram_cls(
            name="vllm:inter_token_latency_seconds",
            documentation="Histogram of inter-token latency in seconds.",
            buckets=[
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
                0.01,
                0.025,
                0.05,
                0.075,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
690
            ],
691
692
            labelnames=labelnames,
        )
693
        self.histogram_inter_token_latency = make_per_engine(
694
695
            histogram_inter_token_latency, engine_indexes, model_name
        )
696

697
698
        histogram_request_time_per_output_token = self._histogram_cls(
            name="vllm:request_time_per_output_token_seconds",
699
            documentation="Histogram of time_per_output_token_seconds per request.",
700
            buckets=[
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
                0.01,
                0.025,
                0.05,
                0.075,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
720
            ],
721
722
            labelnames=labelnames,
        )
723
        self.histogram_request_time_per_output_token = make_per_engine(
724
725
            histogram_request_time_per_output_token, engine_indexes, model_name
        )
726

727
        request_latency_buckets = [
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
            0.3,
            0.5,
            0.8,
            1.0,
            1.5,
            2.0,
            2.5,
            5.0,
            10.0,
            15.0,
            20.0,
            30.0,
            40.0,
            50.0,
            60.0,
            120.0,
            240.0,
            480.0,
            960.0,
            1920.0,
            7680.0,
749
        ]
750
751
752
753
        histogram_e2e_time_request = self._histogram_cls(
            name="vllm:e2e_request_latency_seconds",
            documentation="Histogram of e2e request latency in seconds.",
            buckets=request_latency_buckets,
754
755
            labelnames=labelnames,
        )
756
        self.histogram_e2e_time_request = make_per_engine(
757
758
            histogram_e2e_time_request, engine_indexes, model_name
        )
759
760
761

        histogram_queue_time_request = self._histogram_cls(
            name="vllm:request_queue_time_seconds",
762
            documentation="Histogram of time spent in WAITING phase for request.",
763
            buckets=request_latency_buckets,
764
765
            labelnames=labelnames,
        )
766
        self.histogram_queue_time_request = make_per_engine(
767
768
            histogram_queue_time_request, engine_indexes, model_name
        )
769
770
771

        histogram_inference_time_request = self._histogram_cls(
            name="vllm:request_inference_time_seconds",
772
            documentation="Histogram of time spent in RUNNING phase for request.",
773
            buckets=request_latency_buckets,
774
775
            labelnames=labelnames,
        )
776
        self.histogram_inference_time_request = make_per_engine(
777
778
            histogram_inference_time_request, engine_indexes, model_name
        )
779
780
781

        histogram_prefill_time_request = self._histogram_cls(
            name="vllm:request_prefill_time_seconds",
782
            documentation="Histogram of time spent in PREFILL phase for request.",
783
            buckets=request_latency_buckets,
784
785
            labelnames=labelnames,
        )
786
        self.histogram_prefill_time_request = make_per_engine(
787
788
            histogram_prefill_time_request, engine_indexes, model_name
        )
789
790
791

        histogram_decode_time_request = self._histogram_cls(
            name="vllm:request_decode_time_seconds",
792
            documentation="Histogram of time spent in DECODE phase for request.",
793
            buckets=request_latency_buckets,
794
795
            labelnames=labelnames,
        )
796
        self.histogram_decode_time_request = make_per_engine(
797
798
            histogram_decode_time_request, engine_indexes, model_name
        )
799

800
801
802
        #
        # LoRA metrics
        #
803
804
805

        # TODO: This metric might be incorrect in case of using multiple
        # api_server counts which uses prometheus mp.
806
        self.gauge_lora_info: Gauge | None = None
807
        if vllm_config.lora_config is not None:
808
            if len(self.engine_indexes) > 1:
809
                raise NotImplementedError("LoRA in DP mode is not supported yet.")
810
811
812
813
            self.labelname_max_lora = "max_lora"
            self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
            self.labelname_running_lora_adapters = "running_lora_adapters"
            self.max_lora = vllm_config.lora_config.max_loras
814
815
816
817
818
819
820
821
822
823
            self.gauge_lora_info = self._gauge_cls(
                name="vllm:lora_requests_info",
                documentation="Running stats on lora requests.",
                multiprocess_mode="sum",
                labelnames=[
                    self.labelname_max_lora,
                    self.labelname_waiting_lora_adapters,
                    self.labelname_running_lora_adapters,
                ],
            )
824

825
826
    def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
        metrics_info = config_obj.metrics_info()
827
        metrics_info["engine"] = ""
828
829
830
831
832
833
834
835
836
837

        name, documentation = None, None
        if type == "cache_config":
            name = "vllm:cache_config_info"
            documentation = "Information of the LLMEngine CacheConfig"
        assert name is not None, f"Unknown metrics info type {type}"

        # Info type metrics are syntactic sugar for a gauge permanently set to 1
        # Since prometheus multiprocessing mode does not support Info, emulate
        # info here with a gauge.
838
        info_gauge = self._gauge_cls(
839
840
            name=name,
            documentation=documentation,
841
842
            multiprocess_mode="mostrecent",
            labelnames=metrics_info.keys(),
843
844
845
846
847
848
        )
        for engine_index in self.engine_indexes:
            metrics_info = config_obj.metrics_info()
            metrics_info["engine"] = str(engine_index)
            info_gauge.labels(**metrics_info).set(1)

849
850
    def record(
        self,
851
852
853
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
854
855
        engine_idx: int = 0,
    ):
856
        """Log to prometheus."""
857
        if scheduler_stats is not None:
858
            self.gauge_scheduler_running[engine_idx].set(
859
860
                scheduler_stats.num_running_reqs
            )
861
            self.gauge_scheduler_waiting[engine_idx].set(
862
863
                scheduler_stats.num_waiting_reqs
            )
864

865
866
            if self.show_hidden_metrics:
                self.gauge_gpu_cache_usage[engine_idx].set(
867
868
869
                    scheduler_stats.kv_cache_usage
                )
            self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)
870

871
872
            if self.show_hidden_metrics:
                self.counter_gpu_prefix_cache_queries[engine_idx].inc(
873
874
                    scheduler_stats.prefix_cache_stats.queries
                )
875
                self.counter_gpu_prefix_cache_hits[engine_idx].inc(
876
877
                    scheduler_stats.prefix_cache_stats.hits
                )
878

879
            self.counter_prefix_cache_queries[engine_idx].inc(
880
881
                scheduler_stats.prefix_cache_stats.queries
            )
882
            self.counter_prefix_cache_hits[engine_idx].inc(
883
884
                scheduler_stats.prefix_cache_stats.hits
            )
885

886
887
            if scheduler_stats.spec_decoding_stats is not None:
                self.spec_decoding_prom.observe(
888
889
                    scheduler_stats.spec_decoding_stats, engine_idx
                )
890

891
892
893
894
        if mm_cache_stats is not None:
            self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
            self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)

895
896
897
        if iteration_stats is None:
            return

898
        self.counter_num_preempted_reqs[engine_idx].inc(
899
900
901
            iteration_stats.num_preempted_reqs
        )
        self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens)
902
        self.counter_generation_tokens[engine_idx].inc(
903
904
            iteration_stats.num_generation_tokens
        )
905
        self.histogram_iteration_tokens[engine_idx].observe(
906
907
            iteration_stats.num_prompt_tokens + iteration_stats.num_generation_tokens
        )
908

909
        for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
910
911
912
            self.histogram_max_num_generation_tokens_request[engine_idx].observe(
                max_gen_tokens
            )
913
        for n_param in iteration_stats.n_params_iter:
914
            self.histogram_n_request[engine_idx].observe(n_param)
915
        for ttft in iteration_stats.time_to_first_tokens_iter:
916
            self.histogram_time_to_first_token[engine_idx].observe(ttft)
917
918
919
        for itl in iteration_stats.inter_token_latencies_iter:
            self.histogram_inter_token_latency[engine_idx].observe(itl)
            self.histogram_time_per_output_token[engine_idx].observe(itl)
920

921
        for finished_request in iteration_stats.finished_requests:
922
923
924
            self.counter_request_success[finished_request.finish_reason][
                engine_idx
            ].inc()
925
            self.histogram_e2e_time_request[engine_idx].observe(
926
927
                finished_request.e2e_latency
            )
928
            self.histogram_queue_time_request[engine_idx].observe(
929
930
                finished_request.queued_time
            )
931
            self.histogram_prefill_time_request[engine_idx].observe(
932
933
                finished_request.prefill_time
            )
934
            self.histogram_inference_time_request[engine_idx].observe(
935
936
                finished_request.inference_time
            )
937
            self.histogram_decode_time_request[engine_idx].observe(
938
939
                finished_request.decode_time
            )
940
            self.histogram_num_prompt_tokens_request[engine_idx].observe(
941
942
                finished_request.num_prompt_tokens
            )
943
            self.histogram_num_generation_tokens_request[engine_idx].observe(
944
945
                finished_request.num_generation_tokens
            )
946
            self.histogram_request_time_per_output_token[engine_idx].observe(
947
948
                finished_request.mean_time_per_output_token
            )
949
            if finished_request.max_tokens_param:
950
                self.histogram_max_tokens_request[engine_idx].observe(
951
952
                    finished_request.max_tokens_param
                )
953

954
        if self.gauge_lora_info is not None:
955
956
957
958
959
960
            running_lora_adapters = ",".join(
                iteration_stats.running_lora_adapters.keys()
            )
            waiting_lora_adapters = ",".join(
                iteration_stats.waiting_lora_adapters.keys()
            )
961
962
963
964
965
            lora_info_labels = {
                self.labelname_running_lora_adapters: running_lora_adapters,
                self.labelname_waiting_lora_adapters: waiting_lora_adapters,
                self.labelname_max_lora: self.max_lora,
            }
966
            self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
967

968
969
970
    def log_engine_initialized(self):
        self.log_metrics_info("cache_config", self.vllm_config.cache_config)

971

972
PromMetric: TypeAlias = Gauge | Counter | Histogram
973
974


975
976
977
def make_per_engine(
    metric: PromMetric, engine_idxs: list[int], model_name: str
) -> dict[int, PromMetric]:
978
979
980
    return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}


981
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
982
983
984
985
986
987
    """
    Builds a list of buckets with increasing powers of 10 multiplied by
    mantissa values until the value exceeds the specified maximum.

    """
    exponent = 0
988
    buckets: list[int] = []
989
990
991
992
993
994
995
996
997
998
    while True:
        for m in mantissa_lst:
            value = m * 10**exponent
            if value <= max_value:
                buckets.append(value)
            else:
                return buckets
        exponent += 1


999
def build_1_2_5_buckets(max_value: int) -> list[int]:
1000
1001
1002
1003
1004
1005
    """
    Example:
    >>> build_1_2_5_buckets(100)
    [1, 2, 5, 10, 20, 50, 100]
    """
    return build_buckets([1, 2, 5], max_value)
1006
1007


1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
class StatLoggerManager:
    """
    StatLoggerManager:
        Logging happens at the level of the EngineCore (per scheduler).
         * DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
         * With Local Logger, just make N copies for N EngineCores.
         * With Prometheus, we need a single logger with N "labels"

        This class abstracts away this implementation detail from
        the AsyncLLM, allowing the AsyncLLM to just call .record()
        and .log() to a simple interface.
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1024
1025
        engine_idxs: list[int] | None = None,
        custom_stat_loggers: list[StatLoggerFactory] | None = None,
1026
        enable_default_loggers: bool = True,
1027
        aggregate_engine_logging: bool = False,
1028
        client_count: int = 1,
1029
    ):
1030
1031
1032
        self.engine_indexes = engine_idxs if engine_idxs else [0]
        self.stat_loggers: list[AggregateStatLoggerBase] = []
        stat_logger_factories: list[StatLoggerFactory] = []
1033
        if custom_stat_loggers is not None:
1034
            stat_logger_factories.extend(custom_stat_loggers)
1035
        if enable_default_loggers and logger.isEnabledFor(logging.INFO):
1036
1037
1038
            if client_count > 1:
                logger.warning(
                    "AsyncLLM created with api_server_count more than 1; "
1039
1040
                    "disabling stats logging to avoid incomplete stats."
                )
1041
            else:
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
                default_logger_factory = (
                    AggregatedLoggingStatLogger
                    if aggregate_engine_logging
                    else LoggingStatLogger
                )
                stat_logger_factories.append(default_logger_factory)
        custom_prometheus_logger: bool = False
        for stat_logger_factory in stat_logger_factories:
            if isinstance(stat_logger_factory, type) and issubclass(
                stat_logger_factory, AggregateStatLoggerBase
            ):
                global_stat_logger = stat_logger_factory(
                    vllm_config=vllm_config,
                    engine_indexes=self.engine_indexes,
                )
                if isinstance(global_stat_logger, PrometheusStatLogger):
                    custom_prometheus_logger = True
            else:
                # per engine logger
                global_stat_logger = PerEngineStatLoggerAdapter(
                    vllm_config=vllm_config,
                    engine_indexes=self.engine_indexes,
                    per_engine_stat_logger_factory=stat_logger_factory,  # type: ignore[arg-type]
                )
            self.stat_loggers.append(global_stat_logger)
        if not custom_prometheus_logger:
            self.stat_loggers.append(
                PrometheusStatLogger(vllm_config, self.engine_indexes)
            )
1071
1072
1073

    def record(
        self,
1074
1075
1076
1077
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int | None = None,
1078
1079
1080
    ):
        if engine_idx is None:
            engine_idx = 0
1081
        for logger in self.stat_loggers:
1082
1083
1084
1085
1086
1087
            logger.record(
                scheduler_stats,
                iteration_stats,
                mm_cache_stats=mm_cache_stats,
                engine_idx=engine_idx,
            )
1088
1089

    def log(self):
1090
1091
        for logger in self.stat_loggers:
            logger.log()
1092
1093

    def log_engine_initialized(self):
1094
1095
        for agg_logger in self.stat_loggers:
            agg_logger.log_engine_initialized()