loggers.py 47.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
6
import time
from abc import ABC, abstractmethod
7
8
from collections.abc import Callable
from typing import TypeAlias
9

10
from prometheus_client import Counter, Gauge, Histogram
11

12
import vllm.envs as envs
13
from vllm.compilation.cuda_graph import CUDAGraphLogging
14
from vllm.config import SupportsMetricsInfo, VllmConfig
15
16
17
18
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
    KVConnectorLogging,
    KVConnectorPrometheus,
)
19
from vllm.logger import init_logger
wangxiyuan's avatar
wangxiyuan committed
20
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
21
from vllm.v1.engine import FinishReason
22
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
23
24
25
26
27
28
from vllm.v1.metrics.stats import (
    CachingMetrics,
    IterationStats,
    MultiModalCacheStats,
    SchedulerStats,
)
29
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
30
31
32

logger = init_logger(__name__)

33
34
35
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
AggregateStatLoggerFactory = type["AggregateStatLoggerBase"]
StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory
36

37
38

class StatLoggerBase(ABC):
39
40
41
42
43
44
45
46
    """Interface for logging metrics.

    API users may define custom loggers that implement this interface.
    However, note that the `SchedulerStats` and `IterationStats` classes
    are not considered stable interfaces and may change in future versions.
    """

    @abstractmethod
47
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ...
48
49

    @abstractmethod
50
51
    def record(
        self,
52
53
54
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
55
56
        engine_idx: int = 0,
    ): ...
57

58
    @abstractmethod
59
    def log_engine_initialized(self): ...
60

61
62
63
    def log(self):  # noqa
        pass

64
65
66
    def record_sleep_state(self, is_awake: int, level: int):  # noqa
        pass

67

68
69
70
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
    factories: list[StatLoggerFactory] = []

wangxiyuan's avatar
wangxiyuan committed
71
    for name, plugin_class in load_plugins_by_group(STAT_LOGGER_PLUGINS_GROUP).items():
72
73
74
75
76
77
78
79
80
81
82
83
84
        if not isinstance(plugin_class, type) or not issubclass(
            plugin_class, StatLoggerBase
        ):
            raise TypeError(
                f"Stat logger plugin {name!r} must be a subclass of "
                f"StatLoggerBase (got {plugin_class!r})."
            )

        factories.append(plugin_class)

    return factories


85
86
87
88
89
90
91
92
class AggregateStatLoggerBase(StatLoggerBase):
    """Abstract base class for loggers that
    aggregate across multiple DP engines."""

    @abstractmethod
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ...


93
class LoggingStatLogger(StatLoggerBase):
94
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
95
        self.engine_index = engine_index
96
        self.vllm_config = vllm_config
97
        self._reset(time.monotonic())
98

99
        self.last_scheduler_stats = SchedulerStats()
100
101

        # Caching metrics. This cannot be reset.
102
        # TODO: Make the interval configurable.
103
        self.prefix_caching_metrics = CachingMetrics()
104
        self.connector_prefix_caching_metrics = CachingMetrics()
105
106
        self.mm_caching_metrics = CachingMetrics()

107
        self.spec_decoding_logging = SpecDecodingLogging()
108
109
        kv_transfer_config = self.vllm_config.kv_transfer_config
        self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
110
111
112
113
114
115
        self.cudagraph_logging = None
        if self.vllm_config.observability_config.cudagraph_metrics:
            self.cudagraph_logging = CUDAGraphLogging(
                self.vllm_config.compilation_config.cudagraph_mode,
                self.vllm_config.compilation_config.cudagraph_capture_sizes,
            )
116
117
        self.last_prompt_throughput: float = 0.0
        self.last_generation_throughput: float = 0.0
118
119
        self.engine_is_idle = False
        self.aggregated = False
120

121
122
123
124
    def _reset(self, now):
        self.last_log_time = now

        # Tracked stats over current local logging interval.
125
126
        self.num_prompt_tokens: int = 0
        self.num_generation_tokens: int = 0
127
        self.num_corrupted_reqs: int = 0
128
        self.num_preemptions: int = 0
129

130
131
    def _track_iteration_stats(self, iteration_stats: IterationStats):
        # Save tracked stats for token counters.
132
133
        self.num_prompt_tokens += iteration_stats.num_prompt_tokens
        self.num_generation_tokens += iteration_stats.num_generation_tokens
134
        self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs
135
        self.num_preemptions += iteration_stats.num_preempted_reqs
136

137
    def _get_throughput(self, tracked_stats: int, now: float) -> float:
138
        # Compute summary metrics for tracked stats
139
140
141
142
        delta_time = now - self.last_log_time
        if delta_time <= 0.0:
            return 0.0
        return float(tracked_stats / delta_time)
143

144
145
146
147
    @property
    def log_prefix(self):
        return "Engine {:03d}: ".format(self.engine_index)

148
149
    def record(
        self,
150
151
152
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
153
154
        engine_idx: int = 0,
    ):
155
        """Log Stats to standard output."""
156
157
        if iteration_stats:
            self._track_iteration_stats(iteration_stats)
158

159
        if scheduler_stats is not None:
160
            self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
161

162
163
164
165
166
            if scheduler_stats.connector_prefix_cache_stats is not None:
                self.connector_prefix_caching_metrics.observe(
                    scheduler_stats.connector_prefix_cache_stats
                )

167
            if scheduler_stats.spec_decoding_stats is not None:
168
                self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
169
            if kv_connector_stats := scheduler_stats.kv_connector_stats:
170
                self.kv_connector_logging.observe(kv_connector_stats)
171
172
173
174
175
            if (
                self.cudagraph_logging is not None
                and scheduler_stats.cudagraph_stats is not None
            ):
                self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
176
177
            if not self.aggregated:
                self.last_scheduler_stats = scheduler_stats
178
179
180
        if mm_cache_stats:
            self.mm_caching_metrics.observe(mm_cache_stats)

181
    def _update_stats(self):
182
        now = time.monotonic()
183
        prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
184
        generation_throughput = self._get_throughput(self.num_generation_tokens, now)
185
186

        self._reset(now)
187
        self.engine_is_idle = not any(
188
189
190
191
192
193
            (
                prompt_throughput,
                generation_throughput,
                self.last_prompt_throughput,
                self.last_generation_throughput,
            )
194
        )
195
196
197
        self.last_generation_throughput = generation_throughput
        self.last_prompt_throughput = prompt_throughput

198
199
200
201
202
203
204
205
206
    def aggregate_scheduler_stats(self):
        # noop for per engine loggers
        return

    def log(self):
        self._update_stats()
        self.aggregate_scheduler_stats()
        # Avoid log noise on an idle production system
        log_fn = logger.debug if self.engine_is_idle else logger.info
207
        # Format and print output.
208
209
210
211
212
213
214
        log_parts = [
            "Avg prompt throughput: %.1f tokens/s",
            "Avg generation throughput: %.1f tokens/s",
            "Running: %d reqs",
            "Waiting: %d reqs",
        ]
        log_args = [
215
216
217
218
            self.last_prompt_throughput,
            self.last_generation_throughput,
            self.last_scheduler_stats.num_running_reqs,
            self.last_scheduler_stats.num_waiting_reqs,
219
        ]
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if self.num_preemptions > 0:
            log_parts.append("Preemptions: %d")
            log_args.append(self.num_preemptions)

        log_parts.extend(
            [
                "GPU KV cache usage: %.1f%%",
                "Prefix cache hit rate: %.1f%%",
            ]
        )
        log_args.extend(
            [
                self.last_scheduler_stats.kv_cache_usage * 100,
                self.prefix_caching_metrics.hit_rate * 100,
            ]
        )

238
239
240
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            log_parts.append("Corrupted: %d reqs")
            log_args.append(self.num_corrupted_reqs)
241
242
243
        if not self.connector_prefix_caching_metrics.empty:
            log_parts.append("External prefix cache hit rate: %.1f%%")
            log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100)
244
        if not self.mm_caching_metrics.empty:
245
246
247
248
            log_parts.append("MM cache hit rate: %.1f%%")
            log_args.append(self.mm_caching_metrics.hit_rate * 100)

        log_fn(
249
            self.log_prefix + ", ".join(log_parts),
250
            *log_args,
251
        )
252

253
        self.spec_decoding_logging.log(log_fn=log_fn)
254
        self.kv_connector_logging.log(log_fn=log_fn)
255
256
        if self.cudagraph_logging is not None:
            self.cudagraph_logging.log(log_fn=log_fn)
257

258
    def log_engine_initialized(self):
259
        if self.vllm_config.cache_config.num_gpu_blocks:
260
            logger.debug(
261
                "Engine %03d: vllm cache_config_info with initialization "
262
263
264
265
                "after num_gpu_blocks is: %d",
                self.engine_index,
                self.vllm_config.cache_config.num_gpu_blocks,
            )
266

267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
    def __init__(
        self,
        vllm_config: VllmConfig,
        engine_indexes: list[int],
    ):
        self.engine_indexes = engine_indexes
        self.last_scheduler_stats_dict: dict[int, SchedulerStats] = {
            idx: SchedulerStats() for idx in self.engine_indexes
        }
        LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
        self.aggregated = True

    @property
    def log_prefix(self):
        return "{} Engines Aggregated: ".format(len(self.engine_indexes))

    def record(
        self,
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int = 0,
    ):
        if engine_idx not in self.engine_indexes:
            logger.warning("Unexpected engine_idx: %d", engine_idx)
            return
        LoggingStatLogger.record(
            self,
            scheduler_stats,
            iteration_stats,
            mm_cache_stats=mm_cache_stats,
            engine_idx=engine_idx,
        )
        if scheduler_stats is not None:
            self.last_scheduler_stats_dict[engine_idx] = scheduler_stats

    def aggregate_scheduler_stats(self):
        self.last_scheduler_stats = SchedulerStats()
        for last_scheduler_stats in self.last_scheduler_stats_dict.values():
            self.last_scheduler_stats.num_waiting_reqs += (
                last_scheduler_stats.num_waiting_reqs
            )
            self.last_scheduler_stats.num_running_reqs += (
                last_scheduler_stats.num_running_reqs
            )
            self.last_scheduler_stats.kv_cache_usage += (
                last_scheduler_stats.kv_cache_usage
            )
        self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict)

    def log(self):
        LoggingStatLogger.log(self)

    def log_engine_initialized(self):
        if self.vllm_config.cache_config.num_gpu_blocks:
            logger.info(
                "%d Engines: vllm cache_config_info with initialization "
                "after num_gpu_blocks is: %d",
                len(self.engine_indexes),
                self.vllm_config.cache_config.num_gpu_blocks,
            )


class PerEngineStatLoggerAdapter(AggregateStatLoggerBase):
    def __init__(
        self,
        vllm_config: VllmConfig,
        engine_indexes: list[int],
        per_engine_stat_logger_factory: PerEngineStatLoggerFactory,
    ) -> None:
        self.per_engine_stat_loggers = {}
        self.engine_indexes = engine_indexes
        for engine_index in engine_indexes:
            self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory(
                vllm_config, engine_index
            )

    def record(
        self,
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int = 0,
    ):
        if engine_idx not in self.per_engine_stat_loggers:
            logger.warning("Unexpected engine_idx: %d", engine_idx)
            return
        self.per_engine_stat_loggers[engine_idx].record(
            scheduler_stats,
            iteration_stats,
            mm_cache_stats=mm_cache_stats,
            engine_idx=engine_idx,
        )

    def log(self):
        for per_engine_stat_logger in self.per_engine_stat_loggers.values():
            per_engine_stat_logger.log()

    def log_engine_initialized(self):
        for per_engine_stat_logger in self.per_engine_stat_loggers.values():
            per_engine_stat_logger.log_engine_initialized()


class PrometheusStatLogger(AggregateStatLoggerBase):
373
374
375
    _gauge_cls = Gauge
    _counter_cls = Counter
    _histogram_cls = Histogram
376
    _spec_decoding_cls = SpecDecodingProm
377
    _kv_connector_cls = KVConnectorPrometheus
378

379
    def __init__(
380
        self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None
381
    ):
382
383
        if engine_indexes is None:
            engine_indexes = [0]
384

385
        self.engine_indexes = engine_indexes
386
387

        unregister_vllm_metrics()
388
        self.vllm_config = vllm_config
389
390
        # Use this flag to hide metrics that were deprecated in
        # a previous release and which will be removed future
391
        self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics
392
393
394
        self.kv_cache_metrics_enabled = (
            vllm_config.observability_config.kv_cache_metrics
        )
395

396
        labelnames = ["model_name", "engine"]
397
        model_name = vllm_config.model_config.served_model_name
398
        max_model_len = vllm_config.model_config.max_model_len
399

400
        per_engine_labelvalues: dict[int, list[object]] = {
401
            idx: [model_name, str(idx)] for idx in engine_indexes
402
403
        }

404
        self.spec_decoding_prom = self._spec_decoding_cls(
405
406
407
408
            vllm_config.speculative_config, labelnames, per_engine_labelvalues
        )
        self.kv_connector_prom = self._kv_connector_cls(
            vllm_config, labelnames, per_engine_labelvalues
409
        )
410

411
412
413
        #
        # Scheduler state
        #
414
        gauge_scheduler_running = self._gauge_cls(
415
416
            name="vllm:num_requests_running",
            documentation="Number of requests in model execution batches.",
417
            multiprocess_mode="mostrecent",
418
419
420
421
422
            labelnames=labelnames,
        )
        self.gauge_scheduler_running = make_per_engine(
            gauge_scheduler_running, engine_indexes, model_name
        )
423

424
        gauge_scheduler_waiting = self._gauge_cls(
425
426
            name="vllm:num_requests_waiting",
            documentation="Number of requests waiting to be processed.",
427
            multiprocess_mode="mostrecent",
428
429
430
431
432
            labelnames=labelnames,
        )
        self.gauge_scheduler_waiting = make_per_engine(
            gauge_scheduler_waiting, engine_indexes, model_name
        )
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
        gauge_engine_sleep_state = self._gauge_cls(
            name="vllm:engine_sleep_state",
            documentation=(
                "Engine sleep state; awake = 0 means engine is sleeping; "
                "awake = 1 means engine is awake; "
                "weights_offloaded = 1 means sleep level 1; "
                "discard_all = 1 means sleep level 2."
            ),
            labelnames=labelnames + ["sleep_state"],
            multiprocess_mode="mostrecent",
        )

        self.gauge_engine_sleep_state = {}
        sleep_state = ["awake", "weights_offloaded", "discard_all"]
448

449
450
451
452
453
454
455
        for s in sleep_state:
            self.gauge_engine_sleep_state[s] = {
                idx: gauge_engine_sleep_state.labels(
                    engine=idx, model_name=model_name, sleep_state=s
                )
                for idx in engine_indexes
            }
456

457
458
        # Setting default values
        self.record_sleep_state()
459

460
        gauge_kv_cache_usage = self._gauge_cls(
461
462
            name="vllm:kv_cache_usage_perc",
            documentation="KV-cache usage. 1 means 100 percent usage.",
463
            multiprocess_mode="mostrecent",
464
465
466
467
468
            labelnames=labelnames,
        )
        self.gauge_kv_cache_usage = make_per_engine(
            gauge_kv_cache_usage, engine_indexes, model_name
        )
469

470
471
472
473
474
475
476
477
478
479
480
481
482
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            counter_corrupted_requests = self._counter_cls(
                name="vllm:corrupted_requests",
                documentation=(
                    "Corrupted requests, in terms of total number of requests "
                    "with NaNs in logits."
                ),
                labelnames=labelnames,
            )
            self.counter_corrupted_requests = make_per_engine(
                counter_corrupted_requests, engine_indexes, model_name
            )

483
        counter_prefix_cache_queries = self._counter_cls(
484
485
            name="vllm:prefix_cache_queries",
            documentation=(
486
487
488
489
                "Prefix cache queries, in terms of number of queried tokens."
            ),
            labelnames=labelnames,
        )
490
        self.counter_prefix_cache_queries = make_per_engine(
491
492
            counter_prefix_cache_queries, engine_indexes, model_name
        )
493

494
        counter_prefix_cache_hits = self._counter_cls(
495
            name="vllm:prefix_cache_hits",
496
497
498
            documentation=("Prefix cache hits, in terms of number of cached tokens."),
            labelnames=labelnames,
        )
499
        self.counter_prefix_cache_hits = make_per_engine(
500
501
            counter_prefix_cache_hits, engine_indexes, model_name
        )
502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        #
        # External - KV connector prefix cache
        #

        counter_connector_prefix_cache_queries = self._counter_cls(
            name="vllm:external_prefix_cache_queries",
            documentation=(
                "External prefix cache queries from KV connector "
                "cross-instance cache sharing, in terms of number of queried tokens."
            ),
            labelnames=labelnames,
        )
        self.counter_connector_prefix_cache_queries = make_per_engine(
            counter_connector_prefix_cache_queries, engine_indexes, model_name
        )

        counter_connector_prefix_cache_hits = self._counter_cls(
            name="vllm:external_prefix_cache_hits",
            documentation=(
                "External prefix cache hits from KV connector "
                "cross-instance cache sharing, in terms of number of cached tokens."
            ),
            labelnames=labelnames,
        )
        self.counter_connector_prefix_cache_hits = make_per_engine(
            counter_connector_prefix_cache_hits, engine_indexes, model_name
        )

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        #
        # Multi-modal cache
        #

        counter_mm_cache_queries = self._counter_cls(
            name="vllm:mm_cache_queries",
            documentation=(
                "Multi-modal cache queries, in terms of number of queried items."
            ),
            labelnames=labelnames,
        )
        self.counter_mm_cache_queries = make_per_engine(
            counter_mm_cache_queries, engine_indexes, model_name
        )

        counter_mm_cache_hits = self._counter_cls(
            name="vllm:mm_cache_hits",
            documentation=(
                "Multi-modal cache hits, in terms of number of cached items."
            ),
            labelnames=labelnames,
        )
        self.counter_mm_cache_hits = make_per_engine(
            counter_mm_cache_hits, engine_indexes, model_name
        )

557
558
559
        #
        # Counters
        #
560
        counter_num_preempted_reqs = self._counter_cls(
561
            name="vllm:num_preemptions",
562
            documentation="Cumulative number of preemption from the engine.",
563
564
            labelnames=labelnames,
        )
565
        self.counter_num_preempted_reqs = make_per_engine(
566
567
            counter_num_preempted_reqs, engine_indexes, model_name
        )
568

569
        counter_prompt_tokens = self._counter_cls(
570
            name="vllm:prompt_tokens",
571
            documentation="Number of prefill tokens processed.",
572
573
574
575
576
            labelnames=labelnames,
        )
        self.counter_prompt_tokens = make_per_engine(
            counter_prompt_tokens, engine_indexes, model_name
        )
577

578
        counter_generation_tokens = self._counter_cls(
579
            name="vllm:generation_tokens",
580
            documentation="Number of generation tokens processed.",
581
582
            labelnames=labelnames,
        )
583
        self.counter_generation_tokens = make_per_engine(
584
585
            counter_generation_tokens, engine_indexes, model_name
        )
586

587
        self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {}
588
        counter_request_success_base = self._counter_cls(
589
            name="vllm:request_success",
590
            documentation="Count of successfully processed requests.",
591
592
            labelnames=labelnames + ["finished_reason"],
        )
593
        for reason in FinishReason:
594
            self.counter_request_success[reason] = {
595
596
597
                idx: counter_request_success_base.labels(
                    model_name, str(idx), str(reason)
                )
598
599
                for idx in engine_indexes
            }
600

601
602
603
        #
        # Histograms of counts
        #
604
605
606
607
        histogram_num_prompt_tokens_request = self._histogram_cls(
            name="vllm:request_prompt_tokens",
            documentation="Number of prefill tokens processed.",
            buckets=build_1_2_5_buckets(max_model_len),
608
609
            labelnames=labelnames,
        )
610
        self.histogram_num_prompt_tokens_request = make_per_engine(
611
612
            histogram_num_prompt_tokens_request, engine_indexes, model_name
        )
613
614
615
616
617

        histogram_num_generation_tokens_request = self._histogram_cls(
            name="vllm:request_generation_tokens",
            documentation="Number of generation tokens processed.",
            buckets=build_1_2_5_buckets(max_model_len),
618
619
            labelnames=labelnames,
        )
620
        self.histogram_num_generation_tokens_request = make_per_engine(
621
622
            histogram_num_generation_tokens_request, engine_indexes, model_name
        )
623

624
625
626
        # TODO: This metric might be incorrect in case of using multiple
        # api_server counts which uses prometheus mp.
        # See: https://github.com/vllm-project/vllm/pull/18053
627
628
629
        histogram_iteration_tokens = self._histogram_cls(
            name="vllm:iteration_tokens_total",
            documentation="Histogram of number of tokens per engine_step.",
630
631
632
            buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
            labelnames=labelnames,
        )
633
        self.histogram_iteration_tokens = make_per_engine(
634
635
            histogram_iteration_tokens, engine_indexes, model_name
        )
636
637
638

        histogram_max_num_generation_tokens_request = self._histogram_cls(
            name="vllm:request_max_num_generation_tokens",
639
            documentation="Histogram of maximum number of requested generation tokens.",
640
            buckets=build_1_2_5_buckets(max_model_len),
641
642
            labelnames=labelnames,
        )
643
        self.histogram_max_num_generation_tokens_request = make_per_engine(
644
645
            histogram_max_num_generation_tokens_request, engine_indexes, model_name
        )
646
647
648
649
650

        histogram_n_request = self._histogram_cls(
            name="vllm:request_params_n",
            documentation="Histogram of the n request parameter.",
            buckets=[1, 2, 5, 10, 20],
651
652
653
654
655
            labelnames=labelnames,
        )
        self.histogram_n_request = make_per_engine(
            histogram_n_request, engine_indexes, model_name
        )
656
657
658
659
660

        histogram_max_tokens_request = self._histogram_cls(
            name="vllm:request_params_max_tokens",
            documentation="Histogram of the max_tokens request parameter.",
            buckets=build_1_2_5_buckets(max_model_len),
661
662
            labelnames=labelnames,
        )
663
        self.histogram_max_tokens_request = make_per_engine(
664
665
            histogram_max_tokens_request, engine_indexes, model_name
        )
666
667
668
669

        #
        # Histogram of timing intervals
        #
670
671
672
673
        histogram_time_to_first_token = self._histogram_cls(
            name="vllm:time_to_first_token_seconds",
            documentation="Histogram of time to first token in seconds.",
            buckets=[
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
                0.001,
                0.005,
                0.01,
                0.02,
                0.04,
                0.06,
                0.08,
                0.1,
                0.25,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
                160.0,
                640.0,
                2560.0,
696
            ],
697
698
            labelnames=labelnames,
        )
699
        self.histogram_time_to_first_token = make_per_engine(
700
701
            histogram_time_to_first_token, engine_indexes, model_name
        )
702

703
        # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        # With 0.12.x you can enable with --show-hidden-metrics-for-version=0.11
        # TODO: remove in 0.13.0
        if self.show_hidden_metrics:
            histogram_time_per_output_token = self._histogram_cls(
                name="vllm:time_per_output_token_seconds",
                documentation=(
                    "Histogram of time per output token in seconds."
                    "DEPRECATED: Use vllm:inter_token_latency_seconds instead."
                ),
                buckets=[
                    0.01,
                    0.025,
                    0.05,
                    0.075,
                    0.1,
                    0.15,
                    0.2,
                    0.3,
                    0.4,
                    0.5,
                    0.75,
                    1.0,
                    2.5,
                    5.0,
                    7.5,
                    10.0,
                    20.0,
                    40.0,
                    80.0,
                ],
                labelnames=labelnames,
            )
            self.histogram_time_per_output_token = make_per_engine(
                histogram_time_per_output_token, engine_indexes, model_name
            )
739

740
741
742
743
        histogram_inter_token_latency = self._histogram_cls(
            name="vllm:inter_token_latency_seconds",
            documentation="Histogram of inter-token latency in seconds.",
            buckets=[
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                0.01,
                0.025,
                0.05,
                0.075,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
763
            ],
764
765
            labelnames=labelnames,
        )
766
        self.histogram_inter_token_latency = make_per_engine(
767
768
            histogram_inter_token_latency, engine_indexes, model_name
        )
769

770
771
        histogram_request_time_per_output_token = self._histogram_cls(
            name="vllm:request_time_per_output_token_seconds",
772
            documentation="Histogram of time_per_output_token_seconds per request.",
773
            buckets=[
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
                0.01,
                0.025,
                0.05,
                0.075,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
                2.5,
                5.0,
                7.5,
                10.0,
                20.0,
                40.0,
                80.0,
793
            ],
794
795
            labelnames=labelnames,
        )
796
        self.histogram_request_time_per_output_token = make_per_engine(
797
798
            histogram_request_time_per_output_token, engine_indexes, model_name
        )
799

800
        request_latency_buckets = [
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
            0.3,
            0.5,
            0.8,
            1.0,
            1.5,
            2.0,
            2.5,
            5.0,
            10.0,
            15.0,
            20.0,
            30.0,
            40.0,
            50.0,
            60.0,
            120.0,
            240.0,
            480.0,
            960.0,
            1920.0,
            7680.0,
822
        ]
823
824
825
826
        histogram_e2e_time_request = self._histogram_cls(
            name="vllm:e2e_request_latency_seconds",
            documentation="Histogram of e2e request latency in seconds.",
            buckets=request_latency_buckets,
827
828
            labelnames=labelnames,
        )
829
        self.histogram_e2e_time_request = make_per_engine(
830
831
            histogram_e2e_time_request, engine_indexes, model_name
        )
832
833
834

        histogram_queue_time_request = self._histogram_cls(
            name="vllm:request_queue_time_seconds",
835
            documentation="Histogram of time spent in WAITING phase for request.",
836
            buckets=request_latency_buckets,
837
838
            labelnames=labelnames,
        )
839
        self.histogram_queue_time_request = make_per_engine(
840
841
            histogram_queue_time_request, engine_indexes, model_name
        )
842
843
844

        histogram_inference_time_request = self._histogram_cls(
            name="vllm:request_inference_time_seconds",
845
            documentation="Histogram of time spent in RUNNING phase for request.",
846
            buckets=request_latency_buckets,
847
848
            labelnames=labelnames,
        )
849
        self.histogram_inference_time_request = make_per_engine(
850
851
            histogram_inference_time_request, engine_indexes, model_name
        )
852
853
854

        histogram_prefill_time_request = self._histogram_cls(
            name="vllm:request_prefill_time_seconds",
855
            documentation="Histogram of time spent in PREFILL phase for request.",
856
            buckets=request_latency_buckets,
857
858
            labelnames=labelnames,
        )
859
        self.histogram_prefill_time_request = make_per_engine(
860
861
            histogram_prefill_time_request, engine_indexes, model_name
        )
862
863
864

        histogram_decode_time_request = self._histogram_cls(
            name="vllm:request_decode_time_seconds",
865
            documentation="Histogram of time spent in DECODE phase for request.",
866
            buckets=request_latency_buckets,
867
868
            labelnames=labelnames,
        )
869
        self.histogram_decode_time_request = make_per_engine(
870
871
            histogram_decode_time_request, engine_indexes, model_name
        )
872

873
874
875
876
877
878
879
880
881
882
883
884
885
        histogram_prefill_kv_computed_request = self._histogram_cls(
            name="vllm:request_prefill_kv_computed_tokens",
            documentation=(
                "Histogram of new KV tokens computed during prefill "
                "(excluding cached tokens)."
            ),
            buckets=build_1_2_5_buckets(max_model_len),
            labelnames=labelnames,
        )
        self.histogram_prefill_kv_computed_request = make_per_engine(
            histogram_prefill_kv_computed_request, engine_indexes, model_name
        )

886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        #
        # KV Cache residency metrics
        #
        if self.kv_cache_metrics_enabled:
            kv_cache_residency_buckets = [
                0.001,
                0.002,
                0.005,
                0.01,
                0.02,
                0.05,
                0.1,
                0.2,
                0.5,
                1,
                2,
                5,
                10,
                20,
                30,
                60,
                120,
                300,
                600,
                1200,
                1800,
            ]

            histogram_kv_block_lifetime = self._histogram_cls(
                name="vllm:kv_block_lifetime_seconds",
                documentation=(
                    "Histogram of KV cache block lifetime from allocation to eviction. "
                    "Sampled metrics (controlled by --kv-cache-metrics-sample)."
                ),
                buckets=kv_cache_residency_buckets,
                labelnames=labelnames,
            )
            self.histogram_kv_block_lifetime = make_per_engine(
                histogram_kv_block_lifetime, engine_indexes, model_name
            )

            histogram_kv_block_idle_before_evict = self._histogram_cls(
                name="vllm:kv_block_idle_before_evict_seconds",
                documentation=(
                    "Histogram of idle time before KV cache block eviction. "
                    "Sampled metrics (controlled by --kv-cache-metrics-sample)."
                ),
                buckets=kv_cache_residency_buckets,
                labelnames=labelnames,
            )
            self.histogram_kv_block_idle_before_evict = make_per_engine(
                histogram_kv_block_idle_before_evict, engine_indexes, model_name
            )

            histogram_kv_block_reuse_gap = self._histogram_cls(
                name="vllm:kv_block_reuse_gap_seconds",
                documentation=(
                    "Histogram of time gaps between consecutive KV cache block "
                    "accesses. Only the most recent accesses are recorded "
                    "(ring buffer). Sampled metrics (controlled by "
                    "--kv-cache-metrics-sample)."
                ),
                buckets=kv_cache_residency_buckets,
                labelnames=labelnames,
            )
            self.histogram_kv_block_reuse_gap = make_per_engine(
                histogram_kv_block_reuse_gap, engine_indexes, model_name
            )
        else:
            self.histogram_kv_block_lifetime = {}
            self.histogram_kv_block_idle_before_evict = {}
            self.histogram_kv_block_reuse_gap = {}

959
960
961
        #
        # LoRA metrics
        #
962
963
964

        # TODO: This metric might be incorrect in case of using multiple
        # api_server counts which uses prometheus mp.
965
        self.gauge_lora_info: Gauge | None = None
966
        if vllm_config.lora_config is not None:
967
            if len(self.engine_indexes) > 1:
968
969
970
971
                logger.warning(
                    "vllm:lora_requests_info prometheus metrics may be "
                    "incorrect/misleading with data parallel deployments."
                )
972
973
974
975
            self.labelname_max_lora = "max_lora"
            self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
            self.labelname_running_lora_adapters = "running_lora_adapters"
            self.max_lora = vllm_config.lora_config.max_loras
976
977
978
979
980
981
982
983
984
985
            self.gauge_lora_info = self._gauge_cls(
                name="vllm:lora_requests_info",
                documentation="Running stats on lora requests.",
                multiprocess_mode="sum",
                labelnames=[
                    self.labelname_max_lora,
                    self.labelname_waiting_lora_adapters,
                    self.labelname_running_lora_adapters,
                ],
            )
986

987
988
    def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
        metrics_info = config_obj.metrics_info()
989
        metrics_info["engine"] = ""
990
991
992
993
994
995
996
997
998
999

        name, documentation = None, None
        if type == "cache_config":
            name = "vllm:cache_config_info"
            documentation = "Information of the LLMEngine CacheConfig"
        assert name is not None, f"Unknown metrics info type {type}"

        # Info type metrics are syntactic sugar for a gauge permanently set to 1
        # Since prometheus multiprocessing mode does not support Info, emulate
        # info here with a gauge.
1000
        info_gauge = self._gauge_cls(
1001
1002
            name=name,
            documentation=documentation,
1003
1004
            multiprocess_mode="mostrecent",
            labelnames=metrics_info.keys(),
1005
1006
1007
1008
1009
1010
        )
        for engine_index in self.engine_indexes:
            metrics_info = config_obj.metrics_info()
            metrics_info["engine"] = str(engine_index)
            info_gauge.labels(**metrics_info).set(1)

1011
1012
    def record(
        self,
1013
1014
1015
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
1016
1017
        engine_idx: int = 0,
    ):
1018
        """Log to prometheus."""
1019
        if scheduler_stats is not None:
1020
            self.gauge_scheduler_running[engine_idx].set(
1021
1022
                scheduler_stats.num_running_reqs
            )
1023
            self.gauge_scheduler_waiting[engine_idx].set(
1024
1025
1026
                scheduler_stats.num_waiting_reqs
            )
            self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)
1027

1028
            self.counter_prefix_cache_queries[engine_idx].inc(
1029
1030
                scheduler_stats.prefix_cache_stats.queries
            )
1031
            self.counter_prefix_cache_hits[engine_idx].inc(
1032
1033
                scheduler_stats.prefix_cache_stats.hits
            )
1034

1035
1036
1037
1038
1039
1040
1041
1042
            if scheduler_stats.connector_prefix_cache_stats is not None:
                self.counter_connector_prefix_cache_queries[engine_idx].inc(
                    scheduler_stats.connector_prefix_cache_stats.queries
                )
                self.counter_connector_prefix_cache_hits[engine_idx].inc(
                    scheduler_stats.connector_prefix_cache_stats.hits
                )

1043
1044
            if scheduler_stats.spec_decoding_stats is not None:
                self.spec_decoding_prom.observe(
1045
1046
                    scheduler_stats.spec_decoding_stats, engine_idx
                )
1047

1048
1049
1050
1051
1052
            if scheduler_stats.kv_connector_stats is not None:
                self.kv_connector_prom.observe(
                    scheduler_stats.kv_connector_stats, engine_idx
                )

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
            if (
                self.kv_cache_metrics_enabled
                and scheduler_stats.kv_cache_eviction_events
            ):
                lifetime_hist = self.histogram_kv_block_lifetime[engine_idx]
                idle_hist = self.histogram_kv_block_idle_before_evict[engine_idx]
                reuse_hist = self.histogram_kv_block_reuse_gap[engine_idx]

                for event in scheduler_stats.kv_cache_eviction_events:
                    lifetime_hist.observe(event.lifetime_seconds)
                    idle_hist.observe(event.idle_seconds)
                    for gap in event.reuse_gaps_seconds:
                        reuse_hist.observe(gap)

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            if self.gauge_lora_info is not None:
                running_lora_adapters = ",".join(
                    scheduler_stats.running_lora_adapters.keys()
                )
                waiting_lora_adapters = ",".join(
                    scheduler_stats.waiting_lora_adapters.keys()
                )
                lora_info_labels = {
                    self.labelname_running_lora_adapters: running_lora_adapters,
                    self.labelname_waiting_lora_adapters: waiting_lora_adapters,
                    self.labelname_max_lora: self.max_lora,
                }
                self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()

1081
1082
1083
1084
        if mm_cache_stats is not None:
            self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
            self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)

1085
1086
        if iteration_stats is None:
            return
1087
1088
1089
1090
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            self.counter_corrupted_requests[engine_idx].inc(
                iteration_stats.num_corrupted_reqs
            )
1091
        self.counter_num_preempted_reqs[engine_idx].inc(
1092
1093
1094
            iteration_stats.num_preempted_reqs
        )
        self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens)
1095
        self.counter_generation_tokens[engine_idx].inc(
1096
1097
            iteration_stats.num_generation_tokens
        )
1098
        self.histogram_iteration_tokens[engine_idx].observe(
1099
1100
            iteration_stats.num_prompt_tokens + iteration_stats.num_generation_tokens
        )
1101

1102
        for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
1103
1104
1105
            self.histogram_max_num_generation_tokens_request[engine_idx].observe(
                max_gen_tokens
            )
1106
        for n_param in iteration_stats.n_params_iter:
1107
            self.histogram_n_request[engine_idx].observe(n_param)
1108
        for ttft in iteration_stats.time_to_first_tokens_iter:
1109
            self.histogram_time_to_first_token[engine_idx].observe(ttft)
1110
1111
        for itl in iteration_stats.inter_token_latencies_iter:
            self.histogram_inter_token_latency[engine_idx].observe(itl)
1112
1113
            if self.show_hidden_metrics:
                self.histogram_time_per_output_token[engine_idx].observe(itl)
1114

1115
        for finished_request in iteration_stats.finished_requests:
1116
1117
1118
            self.counter_request_success[finished_request.finish_reason][
                engine_idx
            ].inc()
1119
            self.histogram_e2e_time_request[engine_idx].observe(
1120
1121
                finished_request.e2e_latency
            )
1122
            self.histogram_queue_time_request[engine_idx].observe(
1123
1124
                finished_request.queued_time
            )
1125
            self.histogram_prefill_time_request[engine_idx].observe(
1126
1127
                finished_request.prefill_time
            )
1128
            self.histogram_inference_time_request[engine_idx].observe(
1129
1130
                finished_request.inference_time
            )
1131
            self.histogram_decode_time_request[engine_idx].observe(
1132
1133
                finished_request.decode_time
            )
1134
1135
1136
1137
1138
1139
1140
            # Calculate prefill KV compute (excludes cached tokens)
            prefill_kv_computed = finished_request.num_prompt_tokens - max(
                finished_request.num_cached_tokens, 0
            )
            self.histogram_prefill_kv_computed_request[engine_idx].observe(
                prefill_kv_computed
            )
1141
            self.histogram_num_prompt_tokens_request[engine_idx].observe(
1142
1143
                finished_request.num_prompt_tokens
            )
1144
            self.histogram_num_generation_tokens_request[engine_idx].observe(
1145
1146
                finished_request.num_generation_tokens
            )
1147
            self.histogram_request_time_per_output_token[engine_idx].observe(
1148
1149
                finished_request.mean_time_per_output_token
            )
1150
            if finished_request.max_tokens_param:
1151
                self.histogram_max_tokens_request[engine_idx].observe(
1152
1153
                    finished_request.max_tokens_param
                )
1154

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
    def record_sleep_state(self, sleep: int = 0, level: int = 0):
        awake = 1
        discard_all = 0
        weights_offloaded = 0

        if sleep == 1:
            awake = 0
            if level == 1:
                weights_offloaded = 1
            elif level == 2:
                discard_all = 1

        for engine_idx in self.engine_indexes:
            self.gauge_engine_sleep_state["discard_all"][engine_idx].set(discard_all)
            self.gauge_engine_sleep_state["weights_offloaded"][engine_idx].set(
                weights_offloaded
            )
            self.gauge_engine_sleep_state["awake"][engine_idx].set(awake)

1174
1175
1176
    def log_engine_initialized(self):
        self.log_metrics_info("cache_config", self.vllm_config.cache_config)

1177

1178
PromMetric: TypeAlias = Gauge | Counter | Histogram
1179
1180


1181
def make_per_engine(
1182
    metric: PromMetric, engine_idxs: list[int], model_name: object
1183
) -> dict[int, PromMetric]:
1184
1185
1186
    return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}


1187
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
1188
1189
1190
1191
1192
1193
    """
    Builds a list of buckets with increasing powers of 10 multiplied by
    mantissa values until the value exceeds the specified maximum.

    """
    exponent = 0
1194
    buckets: list[int] = []
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    while True:
        for m in mantissa_lst:
            value = m * 10**exponent
            if value <= max_value:
                buckets.append(value)
            else:
                return buckets
        exponent += 1


1205
def build_1_2_5_buckets(max_value: int) -> list[int]:
1206
1207
1208
1209
1210
1211
    """
    Example:
    >>> build_1_2_5_buckets(100)
    [1, 2, 5, 10, 20, 50, 100]
    """
    return build_buckets([1, 2, 5], max_value)
1212
1213


1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
class StatLoggerManager:
    """
    StatLoggerManager:
        Logging happens at the level of the EngineCore (per scheduler).
         * DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
         * With Local Logger, just make N copies for N EngineCores.
         * With Prometheus, we need a single logger with N "labels"

        This class abstracts away this implementation detail from
        the AsyncLLM, allowing the AsyncLLM to just call .record()
        and .log() to a simple interface.
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1230
1231
        engine_idxs: list[int] | None = None,
        custom_stat_loggers: list[StatLoggerFactory] | None = None,
1232
        enable_default_loggers: bool = True,
1233
        aggregate_engine_logging: bool = False,
1234
        client_count: int = 1,
1235
    ):
1236
1237
1238
        self.engine_indexes = engine_idxs if engine_idxs else [0]
        self.stat_loggers: list[AggregateStatLoggerBase] = []
        stat_logger_factories: list[StatLoggerFactory] = []
1239
        if custom_stat_loggers is not None:
1240
            stat_logger_factories.extend(custom_stat_loggers)
1241
        if enable_default_loggers and logger.isEnabledFor(logging.INFO):
1242
1243
1244
            if client_count > 1:
                logger.warning(
                    "AsyncLLM created with api_server_count more than 1; "
1245
1246
                    "disabling stats logging to avoid incomplete stats."
                )
1247
            else:
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
                default_logger_factory = (
                    AggregatedLoggingStatLogger
                    if aggregate_engine_logging
                    else LoggingStatLogger
                )
                stat_logger_factories.append(default_logger_factory)
        custom_prometheus_logger: bool = False
        for stat_logger_factory in stat_logger_factories:
            if isinstance(stat_logger_factory, type) and issubclass(
                stat_logger_factory, AggregateStatLoggerBase
            ):
                global_stat_logger = stat_logger_factory(
                    vllm_config=vllm_config,
                    engine_indexes=self.engine_indexes,
                )
                if isinstance(global_stat_logger, PrometheusStatLogger):
                    custom_prometheus_logger = True
            else:
                # per engine logger
                global_stat_logger = PerEngineStatLoggerAdapter(
                    vllm_config=vllm_config,
                    engine_indexes=self.engine_indexes,
                    per_engine_stat_logger_factory=stat_logger_factory,  # type: ignore[arg-type]
                )
            self.stat_loggers.append(global_stat_logger)
        if not custom_prometheus_logger:
            self.stat_loggers.append(
                PrometheusStatLogger(vllm_config, self.engine_indexes)
            )
1277
1278
1279

    def record(
        self,
1280
1281
1282
1283
        scheduler_stats: SchedulerStats | None,
        iteration_stats: IterationStats | None,
        mm_cache_stats: MultiModalCacheStats | None = None,
        engine_idx: int | None = None,
1284
1285
1286
    ):
        if engine_idx is None:
            engine_idx = 0
1287
        for logger in self.stat_loggers:
1288
1289
1290
1291
1292
1293
            logger.record(
                scheduler_stats,
                iteration_stats,
                mm_cache_stats=mm_cache_stats,
                engine_idx=engine_idx,
            )
1294

1295
1296
1297
1298
    def record_sleep_state(self, sleep: int = 0, level: int = 0):
        for logger in self.stat_loggers:
            logger.record_sleep_state(sleep, level)

1299
    def log(self):
1300
1301
        for logger in self.stat_loggers:
            logger.log()
1302
1303

    def log_engine_initialized(self):
1304
1305
        for agg_logger in self.stat_loggers:
            agg_logger.log_engine_initialized()