stats.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import deque
6
from dataclasses import dataclass, field
7
from typing import TYPE_CHECKING, Any
8

9
10
from vllm.v1.spec_decode.metrics import SpecDecodingStats

11
if TYPE_CHECKING:
12
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
13
    from vllm.v1.engine.output_processor import RequestState
14
15


16
@dataclass
17
18
class BaseCacheStats:
    """Stores cache hit statistics."""
19

20
    reset: bool = False
21
22
    """Whether the cache was reset."""

23
    requests: int = 0
24
25
    """The number of requests in this update."""

26
    queries: int = 0
27
28
    """The number of queries in these requests."""

29
    hits: int = 0
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    """The number of hits in these requests."""


class CachingMetrics:
    """Metrics for caching with a hit rate of the most recent N requests.
    Args:
        interval: The number of the most recent requests to aggregate.
            Defaults to 1000.
    """

    def __init__(self, max_recent_requests: int = 1000) -> None:
        super().__init__()

        self.max_recent_requests = max_recent_requests
        # The current aggregated values.
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0

        # A deque of (requests, queries, hits) for the most recent requests.
        self.query_queue = deque[tuple[int, int, int]]()

    def observe(self, stats: BaseCacheStats):
        """Observe the prefix caching for a set of requests.

        This function is called with information gathered when new requests
        are being scheduled and are looking for computed blocks.

        When there are more than `max_recent_requests` requests, the oldest set
        of requests are removed from the metrics.

        Args:
            stats: The prefix cache stats.
        """
        # reset_prefix_cache was invoked before the current update.
        # Reset the metrics before aggregating the current stats.
        if stats.reset:
            self.reset()

        # DO NOT appending empty stats to avoid helpful info get kicked out
        # due to sliding window.
        if stats.requests == 0:
            return

        # Update the metrics.
        self.query_queue.append((stats.requests, stats.queries, stats.hits))
        self.aggregated_requests += stats.requests
        self.aggregated_query_total += stats.queries
        self.aggregated_query_hit += stats.hits

        # Remove the oldest stats until number of requests does not exceed
        # the limit.
        # NOTE: We preserve the latest added stats regardless.
        while (
            len(self.query_queue) > 1
            and self.aggregated_requests > self.max_recent_requests
        ):
            old_requests, old_queries, old_hits = self.query_queue.popleft()
            self.aggregated_requests -= old_requests
            self.aggregated_query_total -= old_queries
            self.aggregated_query_hit -= old_hits

    def reset(self):
        """Reset the metrics."""
        self.aggregated_requests = 0
        self.aggregated_query_total = 0
        self.aggregated_query_hit = 0
        self.query_queue.clear()

99
100
101
102
103
    @property
    def empty(self) -> bool:
        """Return true if no requests have been observed."""
        return self.aggregated_requests == 0

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    @property
    def hit_rate(self) -> float:
        """Calculate the hit rate for the past N requests."""
        if self.aggregated_query_total == 0:
            return 0.0
        return self.aggregated_query_hit / self.aggregated_query_total


@dataclass
class PrefixCacheStats(BaseCacheStats):
    """
    Stores prefix cache hit statistics.
    - `reset`: Whether `reset_prefix_cache` was invoked.
    - `queries`: Refers to the number of tokens that were queried.
    """

120
    preempted_requests: int = 0
121
122
    """The number of previously preempted requests in this update."""

123
    preempted_queries: int = 0
124
125
    """The `queries` number for preempted requests."""

126
    preempted_hits: int = 0
127
128
129
130
131
132
133
134
135
136
137
    """The `hits` number for preempted requests."""


@dataclass
class MultiModalCacheStats(BaseCacheStats):
    """
    Stores multi-modal cache hit statistics.
    - `reset`: Whether `reset_mm_cache` was invoked.
    - `queries`: Refers to the number of multi-modal data items
      that were queried.
    """
138
139


140
141
142
143
144
145
146
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
    num_waiting_reqs: int = 0

147
148
149
150
    # These are used for internal DP load-balancing.
    step_counter: int = 0
    current_wave: int = 0

151
    kv_cache_usage: float = 0.0
152

153
    prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
154

155
156
    spec_decoding_stats: SpecDecodingStats | None = None
    kv_connector_stats: dict[str, Any] | None = None
157

158
159
    num_corrupted_reqs: int = 0

160

161
162
@dataclass
class LoRAStats:
163
164
    waiting_requests: set[str] = field(default_factory=set)
    running_requests: set[str] = field(default_factory=set)
165
166


167
168
169
170
171
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
172

173
    # This is an engine frontend timestamp (wall-clock)
174
175
176
177
178
179
180
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
181

182
183
184
    # first token latency
    first_token_latency: float = 0.0

185
186
187
188
189

@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

190
    finish_reason: "FinishReason"
191
    e2e_latency: float = 0.0
192
193
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
194
    max_tokens_param: int | None = None
195
196
    queued_time: float = 0.0
    prefill_time: float = 0.0
197
198
    inference_time: float = 0.0
    decode_time: float = 0.0
199
    mean_time_per_output_token: float = 0.0
200
201


202
203
204
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

205
206
    def __init__(self):
        self.iteration_timestamp = time.time()
207
208
        self.num_generation_tokens = 0
        self.num_prompt_tokens = 0
209
        self.num_preempted_reqs = 0
210
        self.finished_requests: list[FinishedRequestStats] = []
211
212
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
213
        self.time_to_first_tokens_iter: list[float] = []
214
        self.inter_token_latencies_iter: list[float] = []
215
216
        self.waiting_lora_adapters: dict[str, int] = {}
        self.running_lora_adapters: dict[str, int] = {}
217

218
    def __repr__(self) -> str:
219
        field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
220
221
        return f"{self.__class__.__name__}({field_to_value_str})"

222
223
224
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
225

226
227
228
229
230
231
232
    def update_from_output(
        self,
        output: "EngineCoreOutput",
        engine_core_timestamp: float,
        is_prefilling: bool,
        prompt_len: int,
        req_stats: RequestStateStats,
233
        lora_stats: LoRAStats | None,
234
    ):
235
236
237
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
238
        if is_prefilling:
239
240
241
242
            self.num_prompt_tokens += prompt_len

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)
243
            req_stats.first_token_latency = first_token_latency
244
245
246
247
248

        req_stats.num_generation_tokens += num_new_generation_tokens

        # Process request-level engine core events
        if output.events is not None:
249
250
251
            self.update_from_events(
                output.request_id, output.events, is_prefilling, req_stats, lora_stats
            )
252
253
254

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
255
            req_stats.first_token_ts = engine_core_timestamp
256
        else:
257
258
            itl = engine_core_timestamp - req_stats.last_token_ts
            self.inter_token_latencies_iter.append(itl)
259

260
        req_stats.last_token_ts = engine_core_timestamp
261

262
263
264
265
266
267
    def update_from_events(
        self,
        req_id: str,
        events: list["EngineCoreEvent"],
        is_prefilling: bool,
        req_stats: RequestStateStats,
268
        lora_stats: LoRAStats | None,
269
    ):
270
271
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
272

273
274
275
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
276
277
                if lora_stats is not None:
                    lora_stats.waiting_requests.add(req_id)
278
            elif event.type == EngineCoreEventType.SCHEDULED:
279
280
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
281
                LoRARequestStates.scheduled_request(lora_stats, req_id)
282
283
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
284
                LoRARequestStates.preempted_request(lora_stats, req_id)
285

286
287
288
289
    def update_from_finished_request(
        self,
        finish_reason: "FinishReason",
        num_prompt_tokens: int,
290
        max_tokens_param: int | None,
291
292
        req_stats: RequestStateStats,
    ):
293
294
        e2e_latency = self._time_since(req_stats.arrival_time)

295
296
297
298
299
300
301
302
303
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
304
305
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

306
307
308
309
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

310
        # Do not count the token generated by the prefill phase
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        mean_time_per_output_token = (
            decode_time / (req_stats.num_generation_tokens - 1)
            if req_stats.num_generation_tokens - 1 > 0
            else 0
        )

        finished_req = FinishedRequestStats(
            finish_reason=finish_reason,
            e2e_latency=e2e_latency,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=req_stats.num_generation_tokens,
            max_tokens_param=max_tokens_param,
            queued_time=queued_time,
            prefill_time=prefill_time,
            inference_time=inference_time,
            decode_time=decode_time,
            mean_time_per_output_token=mean_time_per_output_token,
        )
329
        self.finished_requests.append(finished_req)
330
331
332
333
334
335


class LoRARequestStates:
    """Per-LoRA request state stats."""

    def __init__(self):
336
        self.lora_name_to_stats: dict[str, LoRAStats] = {}
337

338
    def get_stats(self, req_state: "RequestState") -> LoRAStats | None:
339
340
341
342
343
344
        if req_state.lora_name is None:
            return None
        if req_state.lora_name not in self.lora_name_to_stats:
            self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
        return self.lora_name_to_stats[req_state.lora_name]

345
    def add_request(self, req_state: "RequestState"):
346
347
348
        if (lora_stats := self.get_stats(req_state)) is not None:
            lora_stats.waiting_requests.add(req_state.request_id)

349
    def finish_request(self, req_state: "RequestState"):
350
351
352
353
354
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.running_requests.remove(req_state.request_id)

355
    def abort_request(self, req_state: "RequestState"):
356
357
358
359
360
361
362
363
364
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.waiting_requests.discard(req_state.request_id)
        lora_stats.running_requests.discard(req_state.request_id)

    # Break the pattern for this lifecycle methods so we can
    # call this from IterationStats.update_from_events()
    @staticmethod
365
    def scheduled_request(lora_stats: LoRAStats | None, request_id: str):
366
367
368
369
370
        if lora_stats is None:
            return
        lora_stats.waiting_requests.remove(request_id)
        lora_stats.running_requests.add(request_id)

371
    @staticmethod
372
    def preempted_request(lora_stats: LoRAStats | None, request_id: str):
373
374
375
376
377
        if lora_stats is None:
            return
        lora_stats.running_requests.remove(request_id)
        lora_stats.waiting_requests.add(request_id)

378
    def update_iteration_stats(self, iteration_stats: IterationStats | None):
379
380
381
382
        if iteration_stats is None:
            return
        for lora_name, stats in self.lora_name_to_stats.items():
            if stats.waiting_requests:
383
384
385
                iteration_stats.waiting_lora_adapters[lora_name] = len(
                    stats.waiting_requests
                )
386
            if stats.running_requests:
387
388
389
                iteration_stats.running_lora_adapters[lora_name] = len(
                    stats.running_requests
                )