stats.py 9.94 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Any, Optional
7

8
9
from vllm.v1.spec_decode.metrics import SpecDecodingStats

10
if TYPE_CHECKING:
11
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
12
    from vllm.v1.engine.output_processor import RequestState
13
14


15
16
17
18
19
20
21
22
@dataclass
class PrefixCacheStats:
    """Stores prefix cache hit statistics."""
    # Whether reset_prefix_cache was invoked.
    reset: bool = False
    # The number of requests in this update.
    requests: int = 0
    # The number of queries in these requests. Note that "queries" here
23
    # means the number of tokens that were queried from the cache.
24
25
26
27
28
    queries: int = 0
    # The number of hits in these requests.
    hits: int = 0


29
30
31
32
33
34
35
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
    num_waiting_reqs: int = 0

36
37
38
39
    # These are used for internal DP load-balancing.
    step_counter: int = 0
    current_wave: int = 0

40
    kv_cache_usage: float = 0.0
41
42
43

    prefix_cache_stats: PrefixCacheStats = field(
        default_factory=PrefixCacheStats)
44

45
    spec_decoding_stats: Optional[SpecDecodingStats] = None
46
    kv_connector_stats: Optional[dict[str, Any]] = None
47

48
49
    num_corrupted_reqs: int = 0

50

51
52
@dataclass
class LoRAStats:
53
54
    waiting_requests: set[str] = field(default_factory=set)
    running_requests: set[str] = field(default_factory=set)
55
56


57
58
59
60
61
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
62

63
    # This is an engine frontend timestamp (wall-clock)
64
65
66
67
68
69
70
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
71

72
73
74
    # first token latency
    first_token_latency: float = 0.0

75
76
77
78
79

@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

80
    finish_reason: "FinishReason"
81
    e2e_latency: float = 0.0
82
83
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
84
    max_tokens_param: Optional[int] = None
85
86
    queued_time: float = 0.0
    prefill_time: float = 0.0
87
88
    inference_time: float = 0.0
    decode_time: float = 0.0
89
    mean_time_per_output_token: float = 0.0
90
91


92
93
94
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

95
96
    def __init__(self):
        self.iteration_timestamp = time.time()
97
98
        self.num_generation_tokens = 0
        self.num_prompt_tokens = 0
99
        self.num_preempted_reqs = 0
100
        self.finished_requests: list[FinishedRequestStats] = []
101
102
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
103
        self.time_to_first_tokens_iter: list[float] = []
104
        self.inter_token_latencies_iter: list[float] = []
105
106
        self.waiting_lora_adapters: dict[str, int] = {}
        self.running_lora_adapters: dict[str, int] = {}
107

108
109
110
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
111

112
113
    def update_from_output(self, output: "EngineCoreOutput",
                           engine_core_timestamp: float, is_prefilling: bool,
114
115
                           prompt_len: int, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
116
117
118
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
119
        if is_prefilling:
120
121
122
123
            self.num_prompt_tokens += prompt_len

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)
124
            req_stats.first_token_latency = first_token_latency
125
126
127
128
129

        req_stats.num_generation_tokens += num_new_generation_tokens

        # Process request-level engine core events
        if output.events is not None:
130
131
            self.update_from_events(output.request_id, output.events,
                                    is_prefilling, req_stats, lora_stats)
132
133
134

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
135
            req_stats.first_token_ts = engine_core_timestamp
136
        else:
137
138
            itl = engine_core_timestamp - req_stats.last_token_ts
            self.inter_token_latencies_iter.append(itl)
139

140
        req_stats.last_token_ts = engine_core_timestamp
141

142
    def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
143
144
                           is_prefilling: bool, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
145
146
147
148
149
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
150
151
                if lora_stats is not None:
                    lora_stats.waiting_requests.add(req_id)
152
            elif event.type == EngineCoreEventType.SCHEDULED:
153
154
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
155
                LoRARequestStates.scheduled_request(lora_stats, req_id)
156
157
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
158
                LoRARequestStates.preempted_request(lora_stats, req_id)
159

160
    def update_from_finished_request(self, finish_reason: "FinishReason",
161
                                     num_prompt_tokens: int,
162
                                     max_tokens_param: Optional[int],
163
164
165
                                     req_stats: RequestStateStats):
        e2e_latency = self._time_since(req_stats.arrival_time)

166
167
168
169
170
171
172
173
174
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
175
176
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

177
178
179
180
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

181
182
183
184
185
186
        # Do not count the token generated by the prefill phase
        mean_time_per_output_token = (decode_time /
                                      (req_stats.num_generation_tokens - 1)
                                      if req_stats.num_generation_tokens -
                                      1 > 0 else 0)

187
188
189
        finished_req = \
            FinishedRequestStats(finish_reason=finish_reason,
                                 e2e_latency=e2e_latency,
190
                                 num_prompt_tokens=num_prompt_tokens,
191
                                 num_generation_tokens=req_stats.num_generation_tokens,
192
                                 max_tokens_param=max_tokens_param,
193
194
                                 queued_time=queued_time,
                                 prefill_time=prefill_time,
195
                                 inference_time=inference_time,
196
197
                                 decode_time=decode_time,
                                 mean_time_per_output_token=mean_time_per_output_token)
198
        self.finished_requests.append(finished_req)
199
200
201
202
203
204


class LoRARequestStates:
    """Per-LoRA request state stats."""

    def __init__(self):
205
        self.lora_name_to_stats: dict[str, LoRAStats] = {}
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

    def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
        if req_state.lora_name is None:
            return None
        if req_state.lora_name not in self.lora_name_to_stats:
            self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
        return self.lora_name_to_stats[req_state.lora_name]

    def add_request(self, req_state: 'RequestState'):
        if (lora_stats := self.get_stats(req_state)) is not None:
            lora_stats.waiting_requests.add(req_state.request_id)

    def finish_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.running_requests.remove(req_state.request_id)

    def abort_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.waiting_requests.discard(req_state.request_id)
        lora_stats.running_requests.discard(req_state.request_id)

    # Break the pattern for this lifecycle methods so we can
    # call this from IterationStats.update_from_events()
    @staticmethod
    def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.waiting_requests.remove(request_id)
        lora_stats.running_requests.add(request_id)

240
241
242
243
244
245
246
    @staticmethod
    def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.running_requests.remove(request_id)
        lora_stats.waiting_requests.add(request_id)

247
248
249
250
251
252
253
254
255
256
257
    def update_iteration_stats(self,
                               iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return
        for lora_name, stats in self.lora_name_to_stats.items():
            if stats.waiting_requests:
                iteration_stats.waiting_lora_adapters[lora_name] = \
                    len(stats.waiting_requests)
            if stats.running_requests:
                iteration_stats.running_lora_adapters[lora_name] = \
                    len(stats.running_requests)