stats.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict, deque
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, Any
8

9
import vllm.envs as envs
10
from vllm.compilation.cuda_graph import CUDAGraphStat
11
12
from vllm.v1.spec_decode.metrics import SpecDecodingStats

13
if TYPE_CHECKING:
14
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
15
16


17
@dataclass
18
19
class BaseCacheStats:
    """Stores cache hit statistics."""
20

21
    reset: bool = False
22
23
    """Whether the cache was reset."""

24
    requests: int = 0
25
26
    """The number of requests in this update."""

27
    queries: int = 0
28
29
    """The number of queries in these requests."""

30
    hits: int = 0
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    """The number of hits in these requests."""


class CachingMetrics:
    """Metrics for caching with a hit rate of the most recent N requests.
    Args:
        interval: The number of the most recent requests to aggregate.
            Defaults to 1000.
    """

    def __init__(self, max_recent_requests: int = 1000) -> None:
        super().__init__()

        self.max_recent_requests = max_recent_requests
        # The current aggregated values.
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0

        # A deque of (requests, queries, hits) for the most recent requests.
        self.query_queue = deque[tuple[int, int, int]]()

    def observe(self, stats: BaseCacheStats):
        """Observe the prefix caching for a set of requests.

        This function is called with information gathered when new requests
        are being scheduled and are looking for computed blocks.

        When there are more than `max_recent_requests` requests, the oldest set
        of requests are removed from the metrics.

        Args:
            stats: The prefix cache stats.
        """
        # reset_prefix_cache was invoked before the current update.
        # Reset the metrics before aggregating the current stats.
        if stats.reset:
            self.reset()

        # DO NOT appending empty stats to avoid helpful info get kicked out
        # due to sliding window.
        if stats.requests == 0:
            return

        # Update the metrics.
        self.query_queue.append((stats.requests, stats.queries, stats.hits))
        self.aggregated_requests += stats.requests
        self.aggregated_query_total += stats.queries
        self.aggregated_query_hit += stats.hits

        # Remove the oldest stats until number of requests does not exceed
        # the limit.
        # NOTE: We preserve the latest added stats regardless.
        while (
            len(self.query_queue) > 1
            and self.aggregated_requests > self.max_recent_requests
        ):
            old_requests, old_queries, old_hits = self.query_queue.popleft()
            self.aggregated_requests -= old_requests
            self.aggregated_query_total -= old_queries
            self.aggregated_query_hit -= old_hits

    def reset(self):
        """Reset the metrics."""
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0
        self.query_queue.clear()

100
101
102
103
104
    @property
    def empty(self) -> bool:
        """Return true if no requests have been observed."""
        return self.aggregated_requests == 0

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    @property
    def hit_rate(self) -> float:
        """Calculate the hit rate for the past N requests."""
        if self.aggregated_query_total == 0:
            return 0.0
        return self.aggregated_query_hit / self.aggregated_query_total


@dataclass
class PrefixCacheStats(BaseCacheStats):
    """
    Stores prefix cache hit statistics.
    - `reset`: Whether `reset_prefix_cache` was invoked.
    - `queries`: Refers to the number of tokens that were queried.
    """

121
    preempted_requests: int = 0
122
123
    """The number of previously preempted requests in this update."""

124
    preempted_queries: int = 0
125
126
    """The `queries` number for preempted requests."""

127
    preempted_hits: int = 0
128
129
    """The `hits` number for preempted requests."""

130
131
132
133
134
135
136
137
138
139
140
141
142
    def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None:
        """Aggregate request information into the stats."""
        if preempted:
            # Previously preempted request
            self.preempted_requests += 1
            self.preempted_queries += num_tokens
            self.preempted_hits += num_hits
        else:
            # New request
            self.requests += 1
            self.queries += num_tokens
            self.hits += num_hits

143
144
145
146
147
148
149
150
151

@dataclass
class MultiModalCacheStats(BaseCacheStats):
    """
    Stores multi-modal cache hit statistics.
    - `reset`: Whether `reset_mm_cache` was invoked.
    - `queries`: Refers to the number of multi-modal data items
      that were queried.
    """
152
153


154
155
156
157
158
159
160
161
162
@dataclass
class KVCacheEvictionEvent:
    """Single KV cache block eviction sample."""

    lifetime_seconds: float
    idle_seconds: float
    reuse_gaps_seconds: tuple[float, ...]


163
164
165
166
167
168
169
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
    num_waiting_reqs: int = 0

170
171
172
173
    # These are used for internal DP load-balancing.
    step_counter: int = 0
    current_wave: int = 0

174
    kv_cache_usage: float = 0.0
175

176
    prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
177
    connector_prefix_cache_stats: PrefixCacheStats | None = None
178

179
180
    kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list)

181
182
    spec_decoding_stats: SpecDecodingStats | None = None
    kv_connector_stats: dict[str, Any] | None = None
183

184
185
    waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
    running_lora_adapters: dict[str, int] = field(default_factory=dict)
186

187
188
    cudagraph_stats: CUDAGraphStat | None = None

189

190
191
192
193
194
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
195

196
    # This is an engine frontend timestamp (wall-clock)
197
198
199
200
201
202
203
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
204

205
206
207
    # first token latency
    first_token_latency: float = 0.0

208
209
210
    # Track if this request is corrupted (NaNs in logits)
    is_corrupted: bool = False

211
212
213
214
215

@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

216
    finish_reason: "FinishReason"
217
    e2e_latency: float = 0.0
218
219
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
220
    max_tokens_param: int | None = None
221
222
    queued_time: float = 0.0
    prefill_time: float = 0.0
223
224
    inference_time: float = 0.0
    decode_time: float = 0.0
225
    mean_time_per_output_token: float = 0.0
226
    is_corrupted: bool = False
227
    num_cached_tokens: int = 0
228
229


230
231
232
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

233
234
    def __init__(self):
        self.iteration_timestamp = time.time()
235
236
        self.num_generation_tokens = 0
        self.num_prompt_tokens = 0
237
        self.num_preempted_reqs = 0
238
        self.finished_requests: list[FinishedRequestStats] = []
239
240
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
241
        self.time_to_first_tokens_iter: list[float] = []
242
        self.inter_token_latencies_iter: list[float] = []
243
        self.num_corrupted_reqs: int = 0
244

245
    def __repr__(self) -> str:
246
        field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
247
248
        return f"{self.__class__.__name__}({field_to_value_str})"

249
250
251
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
252

253
254
255
256
257
258
259
    def update_from_output(
        self,
        output: "EngineCoreOutput",
        engine_core_timestamp: float,
        is_prefilling: bool,
        prompt_len: int,
        req_stats: RequestStateStats,
260
261
        lora_states: "LoRARequestStates",
        lora_name: str | None,
262
    ):
263
264
265
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
266
        if is_prefilling:
267
268
269
270
            self.num_prompt_tokens += prompt_len

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)
271
            req_stats.first_token_latency = first_token_latency
272
273
274

        req_stats.num_generation_tokens += num_new_generation_tokens

275
276
277
278
279
280
281
282
283
        # Track if this request is corrupted (only check once per request)
        # Early exit if already marked as corrupted to avoid redundant checks
        if (
            envs.VLLM_COMPUTE_NANS_IN_LOGITS
            and not req_stats.is_corrupted
            and output.num_nans_in_logits > 0
        ):
            req_stats.is_corrupted = True

284
285
        # Process request-level engine core events
        if output.events is not None:
286
            self.update_from_events(
287
288
289
290
291
292
                output.request_id,
                output.events,
                is_prefilling,
                req_stats,
                lora_states,
                lora_name,
293
            )
294
295
296

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
297
            req_stats.first_token_ts = engine_core_timestamp
298
        else:
299
300
            itl = engine_core_timestamp - req_stats.last_token_ts
            self.inter_token_latencies_iter.append(itl)
301

302
        req_stats.last_token_ts = engine_core_timestamp
303

304
305
306
307
308
309
    def update_from_events(
        self,
        req_id: str,
        events: list["EngineCoreEvent"],
        is_prefilling: bool,
        req_stats: RequestStateStats,
310
311
        lora_states: "LoRARequestStates",
        lora_name: str | None,
312
    ):
313
314
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
315

316
317
318
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
319
                lora_states.request_waiting(req_id, lora_name)
320
            elif event.type == EngineCoreEventType.SCHEDULED:
321
322
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
323
                lora_states.request_running(req_id, lora_name)
324
325
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
326
                lora_states.request_waiting(req_id, lora_name)
327

328
329
330
331
    def update_from_finished_request(
        self,
        finish_reason: "FinishReason",
        num_prompt_tokens: int,
332
        max_tokens_param: int | None,
333
        req_stats: RequestStateStats,
334
        num_cached_tokens: int = 0,
335
    ):
336
337
        e2e_latency = self._time_since(req_stats.arrival_time)

338
339
340
341
342
343
344
345
346
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
347
348
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

349
350
351
352
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

353
        # Do not count the token generated by the prefill phase
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        mean_time_per_output_token = (
            decode_time / (req_stats.num_generation_tokens - 1)
            if req_stats.num_generation_tokens - 1 > 0
            else 0
        )

        finished_req = FinishedRequestStats(
            finish_reason=finish_reason,
            e2e_latency=e2e_latency,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=req_stats.num_generation_tokens,
            max_tokens_param=max_tokens_param,
            queued_time=queued_time,
            prefill_time=prefill_time,
            inference_time=inference_time,
            decode_time=decode_time,
            mean_time_per_output_token=mean_time_per_output_token,
371
            is_corrupted=req_stats.is_corrupted,
372
            num_cached_tokens=num_cached_tokens,
373
        )
374
        self.finished_requests.append(finished_req)
375

376
377
378
379
        # Count corrupted requests when they finish (only once per request)
        if req_stats.is_corrupted:
            self.num_corrupted_reqs += 1

380

381
382
class LoRAStats:
    """Tracks waiting and running request IDs for a single LoRA."""
383
384

    def __init__(self):
385
386
        self.waiting: set[str] = set()
        self.running: set[str] = set()
387

388
389
390
391
392
393
    def update(self, req_id: str, waiting: bool, running: bool):
        assert not (waiting and running)
        if waiting:
            self.waiting.add(req_id)
        else:
            self.waiting.discard(req_id)
394

395
396
397
398
        if running:
            self.running.add(req_id)
        else:
            self.running.discard(req_id)
399

400
401
402
    @property
    def empty(self) -> bool:
        return not (self.waiting or self.running)
403
404


405
406
407
408
409
410
411
412
413
414
415
class LoRARequestStates:
    """A per-LoRA count of running and waiting requests."""

    def __init__(self, log_stats: bool = False):
        self.log_stats = log_stats
        self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)

    def _request_update(
        self, req_id: str, lora_name: str | None, waiting: bool, running: bool
    ):
        if not self.log_stats or lora_name is None:
416
417
            return

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        lora_stats = self.requests[lora_name]
        lora_stats.update(req_id, waiting, running)
        if lora_stats.empty:
            del self.requests[lora_name]

    def request_waiting(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=True, running=False)

    def request_running(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=False, running=True)

    def request_finished(self, req_id: str, lora_name: str | None):
        self._request_update(req_id, lora_name, waiting=False, running=False)

    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        if not self.log_stats or scheduler_stats is None:
434
            return
435
436
437
        for lora_name, stats in self.requests.items():
            scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
            scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)