stats.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict, deque
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, Any
8

9
import vllm.envs as envs
10
from vllm.compilation.cuda_graph import CUDAGraphStat
11
from vllm.v1.metrics.perf import PerfStats
12
13
from vllm.v1.spec_decode.metrics import SpecDecodingStats

14
if TYPE_CHECKING:
15
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
16
17


18
@dataclass
19
20
class BaseCacheStats:
    """Stores cache hit statistics."""
21

22
    reset: bool = False
23
24
    """Whether the cache was reset."""

25
    requests: int = 0
26
27
    """The number of requests in this update."""

28
    queries: int = 0
29
30
    """The number of queries in these requests."""

31
    hits: int = 0
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    """The number of hits in these requests."""


class CachingMetrics:
    """Metrics for caching with a hit rate of the most recent N requests.
    Args:
        interval: The number of the most recent requests to aggregate.
            Defaults to 1000.
    """

    def __init__(self, max_recent_requests: int = 1000) -> None:
        super().__init__()

        self.max_recent_requests = max_recent_requests
        # The current aggregated values.
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0

        # A deque of (requests, queries, hits) for the most recent requests.
        self.query_queue = deque[tuple[int, int, int]]()

    def observe(self, stats: BaseCacheStats):
        """Observe the prefix caching for a set of requests.

        This function is called with information gathered when new requests
        are being scheduled and are looking for computed blocks.

        When there are more than `max_recent_requests` requests, the oldest set
        of requests are removed from the metrics.

        Args:
            stats: The prefix cache stats.
        """
        # reset_prefix_cache was invoked before the current update.
        # Reset the metrics before aggregating the current stats.
        if stats.reset:
            self.reset()

        # DO NOT appending empty stats to avoid helpful info get kicked out
        # due to sliding window.
        if stats.requests == 0:
            return

        # Update the metrics.
        self.query_queue.append((stats.requests, stats.queries, stats.hits))
        self.aggregated_requests += stats.requests
        self.aggregated_query_total += stats.queries
        self.aggregated_query_hit += stats.hits

        # Remove the oldest stats until number of requests does not exceed
        # the limit.
        # NOTE: We preserve the latest added stats regardless.
        while (
            len(self.query_queue) > 1
            and self.aggregated_requests > self.max_recent_requests
        ):
            old_requests, old_queries, old_hits = self.query_queue.popleft()
            self.aggregated_requests -= old_requests
            self.aggregated_query_total -= old_queries
            self.aggregated_query_hit -= old_hits

    def reset(self):
        """Reset the metrics."""
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0
        self.query_queue.clear()

101
102
103
104
105
    @property
    def empty(self) -> bool:
        """Return true if no requests have been observed."""
        return self.aggregated_requests == 0

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    @property
    def hit_rate(self) -> float:
        """Calculate the hit rate for the past N requests."""
        if self.aggregated_query_total == 0:
            return 0.0
        return self.aggregated_query_hit / self.aggregated_query_total


@dataclass
class PrefixCacheStats(BaseCacheStats):
    """
    Stores prefix cache hit statistics.
    - `reset`: Whether `reset_prefix_cache` was invoked.
    - `queries`: Refers to the number of tokens that were queried.
    """

122
    preempted_requests: int = 0
123
124
    """The number of previously preempted requests in this update."""

125
    preempted_queries: int = 0
126
127
    """The `queries` number for preempted requests."""

128
    preempted_hits: int = 0
129
130
    """The `hits` number for preempted requests."""

131
132
133
134
135
136
137
138
139
140
141
142
143
    def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
        """Aggregate request information into the stats."""
        if preempted:
            # Previously preempted request
            self.preempted_requests += 1
            self.preempted_queries += num_tokens
            self.preempted_hits += num_hits
        else:
            # New request
            self.requests += 1
            self.queries += num_tokens
            self.hits += num_hits

144
145
146
147
148
149
150
151
152

@dataclass
class MultiModalCacheStats(BaseCacheStats):
    """
    Stores multi-modal cache hit statistics.
    - `reset`: Whether `reset_mm_cache` was invoked.
    - `queries`: Refers to the number of multi-modal data items
      that were queried.
    """
153

154
155
156
157
158
159
    def record(self, num_queries: int, num_hits: int) -> None:
        """Aggregate request information into the stats."""
        self.requests += 1
        self.queries += num_queries
        self.hits += num_hits

160

161
162
163
164
165
166
167
168
169
@dataclass
class KVCacheEvictionEvent:
    """Single KV cache block eviction sample."""

    lifetime_seconds: float
    idle_seconds: float
    reuse_gaps_seconds: tuple[float, ...]


170
171
172
173
174
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
175
176
177

    num_waiting_reqs: int = 0  # length of the "waiting" request queue
    num_skipped_waiting_reqs: int = 0  # length of the "skipped waiting" queue
178

179
180
181
182
    # These are used for internal DP load-balancing.
    step_counter: int = 0
    current_wave: int = 0

183
    kv_cache_usage: float = 0.0
184
    encoder_cache_usage: float = 0.0
185

186
    prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
187
    connector_prefix_cache_stats: PrefixCacheStats | None = None
188

189
190
    kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list)

191
192
    spec_decoding_stats: SpecDecodingStats | None = None
    kv_connector_stats: dict[str, Any] | None = None
193

194
195
    waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
    running_lora_adapters: dict[str, int] = field(default_factory=dict)
196

197
198
    cudagraph_stats: CUDAGraphStat | None = None

199
200
    perf_stats: PerfStats | None = None

201

202
203
204
205
206
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
207

208
    # This is an engine frontend timestamp (wall-clock)
209
210
211
212
213
214
215
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
216

217
218
219
    # first token latency
    first_token_latency: float = 0.0

220
221
222
    # Track if this request is corrupted (NaNs in logits)
    is_corrupted: bool = False

223
224
225
226
227

@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

228
    finish_reason: "FinishReason"
229
    e2e_latency: float = 0.0
230
231
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
232
    max_tokens_param: int | None = None
233
234
    queued_time: float = 0.0
    prefill_time: float = 0.0
235
236
    inference_time: float = 0.0
    decode_time: float = 0.0
237
    mean_time_per_output_token: float = 0.0
238
    is_corrupted: bool = False
239
    num_cached_tokens: int = 0
240
241


242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
@dataclass
class PrefillStats:
    """Breakdown of a scheduled prefill computation.

    Fields:
        num_prompt_tokens: Total number of tokens to be prefilled.
        num_computed_tokens: Tokens to be prefilled locally (actual compute work).
        num_cached_tokens: Tokens to be prefilled without actual compute work.
        num_local_cached_tokens: Tokens to be prefilled from local prefix cache.
        num_external_cached_tokens: Tokens to be prefilled from external KV transfer.
    """

    num_prompt_tokens: int = 0
    num_computed_tokens: int = 0
    num_cached_tokens: int = 0
    num_local_cached_tokens: int = 0
    num_external_cached_tokens: int = 0

    def set(
        self,
        num_prompt_tokens: int,
        num_local_cached_tokens: int,
        num_external_cached_tokens: int,
    ):
        num_cached_tokens = num_local_cached_tokens + num_external_cached_tokens
        assert num_cached_tokens <= num_prompt_tokens

        self.num_prompt_tokens = num_prompt_tokens
        self.num_computed_tokens = num_prompt_tokens - num_cached_tokens
        self.num_cached_tokens = num_cached_tokens
        self.num_local_cached_tokens = num_local_cached_tokens
        self.num_external_cached_tokens = num_external_cached_tokens


276
277
278
279
280
281
282
283
284
285
286
287
@dataclass
class PromptTokenStats:
    """Breakdown of prompt tokens by source.

    Fields:
        computed: Tokens prefilled locally (actual compute work).
        local_cache_hit: Tokens from local prefix cache.
        external_kv_transfer: Tokens from external KV transfer.
        cached_tokens: Tokens skipped during prefill (from scheduler).
        total: Total prompt tokens.

    Invariants:
288
289
        computed + local_cache_hit + external_kv_transfer = total
        local_cache_hit + external_kv_transfer = cached_tokens
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    """

    ALL_SOURCES: tuple[str, ...] = (
        "local_compute",
        "local_cache_hit",
        "external_kv_transfer",
    )

    computed: int = 0
    local_cache_hit: int = 0
    external_kv_transfer: int = 0
    cached_tokens: int = 0
    total: int = 0

304
    def update_from_output(self, prefill_stats: PrefillStats) -> None:
305
        """Update stats from a prefill output."""
306
307
308
309
310
311
        self.computed += prefill_stats.num_computed_tokens
        self.cached_tokens += prefill_stats.num_cached_tokens
        self.total += prefill_stats.num_prompt_tokens

        self.local_cache_hit += prefill_stats.num_local_cached_tokens
        self.external_kv_transfer += prefill_stats.num_external_cached_tokens
312
313
314
315
316
317
318
319
320
321
322
323
324

    def get_by_source(self, source: str) -> int:
        """Get token count by source label."""
        source_map = {
            "local_compute": self.computed,
            "local_cache_hit": self.local_cache_hit,
            "external_kv_transfer": self.external_kv_transfer,
        }
        if source not in source_map:
            raise ValueError(f"Unknown source: {source}")
        return source_map[source]


325
326
327
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

328
329
    def __init__(self):
        self.iteration_timestamp = time.time()
330
        self.num_generation_tokens = 0
331
        self.prompt_token_stats = PromptTokenStats()
332
        self.num_preempted_reqs = 0
333
        self.finished_requests: list[FinishedRequestStats] = []
334
335
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
336
        self.time_to_first_tokens_iter: list[float] = []
337
        self.inter_token_latencies_iter: list[float] = []
338
        self.num_corrupted_reqs: int = 0
339

340
    def __repr__(self) -> str:
341
        field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
342
343
        return f"{self.__class__.__name__}({field_to_value_str})"

344
345
346
347
348
    @property
    def num_prompt_tokens(self) -> int:
        """Total prompt tokens (for backward compatibility)."""
        return self.prompt_token_stats.total

349
350
351
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
352

353
354
355
356
357
358
    def update_from_output(
        self,
        output: "EngineCoreOutput",
        engine_core_timestamp: float,
        is_prefilling: bool,
        req_stats: RequestStateStats,
359
360
        lora_states: "LoRARequestStates",
        lora_name: str | None,
361
    ):
362
363
364
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
365
        if is_prefilling:
366
367
            if output.prefill_stats is not None:
                self.prompt_token_stats.update_from_output(output.prefill_stats)
368
369
370

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)
371
            req_stats.first_token_latency = first_token_latency
372
373
374

        req_stats.num_generation_tokens += num_new_generation_tokens

375
376
377
378
379
380
381
382
383
        # Track if this request is corrupted (only check once per request)
        # Early exit if already marked as corrupted to avoid redundant checks
        if (
            envs.VLLM_COMPUTE_NANS_IN_LOGITS
            and not req_stats.is_corrupted
            and output.num_nans_in_logits > 0
        ):
            req_stats.is_corrupted = True

384
385
        # Process request-level engine core events
        if output.events is not None:
386
            self.update_from_events(
387
388
389
390
391
392
                output.request_id,
                output.events,
                is_prefilling,
                req_stats,
                lora_states,
                lora_name,
393
            )
394
395
396

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
397
            req_stats.first_token_ts = engine_core_timestamp
398
        else:
399
400
            itl = engine_core_timestamp - req_stats.last_token_ts
            self.inter_token_latencies_iter.append(itl)
401

402
        req_stats.last_token_ts = engine_core_timestamp
403

404
405
406
407
408
409
    def update_from_events(
        self,
        req_id: str,
        events: list["EngineCoreEvent"],
        is_prefilling: bool,
        req_stats: RequestStateStats,
410
411
        lora_states: "LoRARequestStates",
        lora_name: str | None,
412
    ):
413
414
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
415

416
417
418
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
419
                lora_states.request_waiting(req_id, lora_name)
420
            elif event.type == EngineCoreEventType.SCHEDULED:
421
422
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
423
                lora_states.request_running(req_id, lora_name)
424
425
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
426
                lora_states.request_waiting(req_id, lora_name)
427

428
429
430
431
    def update_from_finished_request(
        self,
        finish_reason: "FinishReason",
        num_prompt_tokens: int,
432
        max_tokens_param: int | None,
433
        req_stats: RequestStateStats,
434
        num_cached_tokens: int = 0,
435
    ):
436
437
        e2e_latency = self._time_since(req_stats.arrival_time)

438
439
440
441
442
443
444
445
446
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
447
448
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

449
450
451
452
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

453
        # Do not count the token generated by the prefill phase
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        mean_time_per_output_token = (
            decode_time / (req_stats.num_generation_tokens - 1)
            if req_stats.num_generation_tokens - 1 > 0
            else 0
        )

        finished_req = FinishedRequestStats(
            finish_reason=finish_reason,
            e2e_latency=e2e_latency,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=req_stats.num_generation_tokens,
            max_tokens_param=max_tokens_param,
            queued_time=queued_time,
            prefill_time=prefill_time,
            inference_time=inference_time,
            decode_time=decode_time,
            mean_time_per_output_token=mean_time_per_output_token,
471
            is_corrupted=req_stats.is_corrupted,
472
            num_cached_tokens=num_cached_tokens,
473
        )
474
        self.finished_requests.append(finished_req)
475

476
477
478
479
        # Count corrupted requests when they finish (only once per request)
        if req_stats.is_corrupted:
            self.num_corrupted_reqs += 1

480

481
482
class LoRAStats:
    """Tracks waiting and running request IDs for a single LoRA."""
483
484

    def __init__(self):
485
486
        self.waiting: set[str] = set()
        self.running: set[str] = set()
487

488
489
490
491
492
493
    def update(self, req_id: str, waiting: bool, running: bool):
        assert not (waiting and running)
        if waiting:
            self.waiting.add(req_id)
        else:
            self.waiting.discard(req_id)
494

495
496
497
498
        if running:
            self.running.add(req_id)
        else:
            self.running.discard(req_id)
499

500
501
502
    @property
    def empty(self) -> bool:
        return not (self.waiting or self.running)
503
504


505
506
507
508
509
510
511
512
513
514
515
class LoRARequestStates:
    """A per-LoRA count of running and waiting requests."""

    def __init__(self, log_stats: bool = False):
        self.log_stats = log_stats
        self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)

    def _request_update(
        self, req_id: str, lora_name: str | None, waiting: bool, running: bool
    ):
        if not self.log_stats or lora_name is None:
516
517
            return

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        lora_stats = self.requests[lora_name]
        lora_stats.update(req_id, waiting, running)
        if lora_stats.empty:
            del self.requests[lora_name]

    def request_waiting(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=True, running=False)

    def request_running(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=False, running=True)

    def request_finished(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=False, running=False)

    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        if not self.log_stats or scheduler_stats is None:
534
            return
535
536
537
        for lora_name, stats in self.requests.items():
            scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
            scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)