stats.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Any, Optional
7

8
9
from vllm.v1.spec_decode.metrics import SpecDecodingStats

10
if TYPE_CHECKING:
11
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
12
    from vllm.v1.engine.output_processor import RequestState
13
14


15
16
17
18
19
@dataclass
class PrefixCacheStats:
    """Stores prefix cache hit statistics."""
    # Whether reset_prefix_cache was invoked.
    reset: bool = False
20
    # The number of new requests in this update.
21
22
    requests: int = 0
    # The number of queries in these requests. Note that "queries" here
23
    # means the number of tokens that were queried from the cache.
24
25
26
    queries: int = 0
    # The number of hits in these requests.
    hits: int = 0
27
28
29
30
31
32
    # The number of previously preempted requests in this update.
    preempted_requests: int = 0
    # The `queries` number for preempted requests.
    preempted_queries: int = 0
    # The `hits` number for preempted requests.
    preempted_hits: int = 0
33
34


35
36
37
38
39
40
41
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
    num_waiting_reqs: int = 0

42
43
44
45
    # These are used for internal DP load-balancing.
    step_counter: int = 0
    current_wave: int = 0

46
    kv_cache_usage: float = 0.0
47
48
49

    prefix_cache_stats: PrefixCacheStats = field(
        default_factory=PrefixCacheStats)
50

51
    spec_decoding_stats: Optional[SpecDecodingStats] = None
52
    kv_connector_stats: Optional[dict[str, Any]] = None
53

54
55
    num_corrupted_reqs: int = 0

56

57
58
@dataclass
class LoRAStats:
59
60
    waiting_requests: set[str] = field(default_factory=set)
    running_requests: set[str] = field(default_factory=set)
61
62


63
64
65
66
67
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
68

69
    # This is an engine frontend timestamp (wall-clock)
70
71
72
73
74
75
76
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
77

78
79
80
    # first token latency
    first_token_latency: float = 0.0

81
82
83
84
85

@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

86
    finish_reason: "FinishReason"
87
    e2e_latency: float = 0.0
88
89
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
90
    max_tokens_param: Optional[int] = None
91
92
    queued_time: float = 0.0
    prefill_time: float = 0.0
93
94
    inference_time: float = 0.0
    decode_time: float = 0.0
95
    mean_time_per_output_token: float = 0.0
96
97


98
99
100
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

101
102
    def __init__(self):
        self.iteration_timestamp = time.time()
103
104
        self.num_generation_tokens = 0
        self.num_prompt_tokens = 0
105
        self.num_preempted_reqs = 0
106
        self.finished_requests: list[FinishedRequestStats] = []
107
108
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
109
        self.time_to_first_tokens_iter: list[float] = []
110
        self.inter_token_latencies_iter: list[float] = []
111
112
        self.waiting_lora_adapters: dict[str, int] = {}
        self.running_lora_adapters: dict[str, int] = {}
113

114
115
116
117
118
    def __repr__(self) -> str:
        field_to_value_str = ", ".join(f"{k}={v}"
                                       for k, v in vars(self).items())
        return f"{self.__class__.__name__}({field_to_value_str})"

119
120
121
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
122

123
124
    def update_from_output(self, output: "EngineCoreOutput",
                           engine_core_timestamp: float, is_prefilling: bool,
125
126
                           prompt_len: int, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
127
128
129
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
130
        if is_prefilling:
131
132
133
134
            self.num_prompt_tokens += prompt_len

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)
135
            req_stats.first_token_latency = first_token_latency
136
137
138
139
140

        req_stats.num_generation_tokens += num_new_generation_tokens

        # Process request-level engine core events
        if output.events is not None:
141
142
            self.update_from_events(output.request_id, output.events,
                                    is_prefilling, req_stats, lora_stats)
143
144
145

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
146
            req_stats.first_token_ts = engine_core_timestamp
147
        else:
148
149
            itl = engine_core_timestamp - req_stats.last_token_ts
            self.inter_token_latencies_iter.append(itl)
150

151
        req_stats.last_token_ts = engine_core_timestamp
152

153
    def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
154
155
                           is_prefilling: bool, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
156
157
158
159
160
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
161
162
                if lora_stats is not None:
                    lora_stats.waiting_requests.add(req_id)
163
            elif event.type == EngineCoreEventType.SCHEDULED:
164
165
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
166
                LoRARequestStates.scheduled_request(lora_stats, req_id)
167
168
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
169
                LoRARequestStates.preempted_request(lora_stats, req_id)
170

171
    def update_from_finished_request(self, finish_reason: "FinishReason",
172
                                     num_prompt_tokens: int,
173
                                     max_tokens_param: Optional[int],
174
175
176
                                     req_stats: RequestStateStats):
        e2e_latency = self._time_since(req_stats.arrival_time)

177
178
179
180
181
182
183
184
185
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
186
187
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

188
189
190
191
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

192
193
194
195
196
197
        # Do not count the token generated by the prefill phase
        mean_time_per_output_token = (decode_time /
                                      (req_stats.num_generation_tokens - 1)
                                      if req_stats.num_generation_tokens -
                                      1 > 0 else 0)

198
199
200
        finished_req = \
            FinishedRequestStats(finish_reason=finish_reason,
                                 e2e_latency=e2e_latency,
201
                                 num_prompt_tokens=num_prompt_tokens,
202
                                 num_generation_tokens=req_stats.num_generation_tokens,
203
                                 max_tokens_param=max_tokens_param,
204
205
                                 queued_time=queued_time,
                                 prefill_time=prefill_time,
206
                                 inference_time=inference_time,
207
208
                                 decode_time=decode_time,
                                 mean_time_per_output_token=mean_time_per_output_token)
209
        self.finished_requests.append(finished_req)
210
211
212
213
214
215


class LoRARequestStates:
    """Per-LoRA request state stats."""

    def __init__(self):
216
        self.lora_name_to_stats: dict[str, LoRAStats] = {}
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
        if req_state.lora_name is None:
            return None
        if req_state.lora_name not in self.lora_name_to_stats:
            self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
        return self.lora_name_to_stats[req_state.lora_name]

    def add_request(self, req_state: 'RequestState'):
        if (lora_stats := self.get_stats(req_state)) is not None:
            lora_stats.waiting_requests.add(req_state.request_id)

    def finish_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.running_requests.remove(req_state.request_id)

    def abort_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.waiting_requests.discard(req_state.request_id)
        lora_stats.running_requests.discard(req_state.request_id)

    # Break the pattern for this lifecycle methods so we can
    # call this from IterationStats.update_from_events()
    @staticmethod
    def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.waiting_requests.remove(request_id)
        lora_stats.running_requests.add(request_id)

251
252
253
254
255
256
257
    @staticmethod
    def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.running_requests.remove(request_id)
        lora_stats.waiting_requests.add(request_id)

258
259
260
261
262
263
264
265
266
267
268
    def update_iteration_stats(self,
                               iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return
        for lora_name, stats in self.lora_name_to_stats.items():
            if stats.waiting_requests:
                iteration_stats.waiting_lora_adapters[lora_name] = \
                    len(stats.waiting_requests)
            if stats.running_requests:
                iteration_stats.running_lora_adapters[lora_name] = \
                    len(stats.running_requests)