stats.py 9.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from dataclasses import dataclass, field
5
from typing import TYPE_CHECKING, Optional
6
7

if TYPE_CHECKING:
8
    from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
9
    from vllm.v1.output_processor import RequestState
10
11


12
13
14
15
16
17
18
19
20
21
22
23
24
25
@dataclass
class PrefixCacheStats:
    """Stores prefix cache hit statistics."""
    # Whether reset_prefix_cache was invoked.
    reset: bool = False
    # The number of requests in this update.
    requests: int = 0
    # The number of queries in these requests. Note that "queries" here
    # means the number of blocks that were queried from the cache.
    queries: int = 0
    # The number of hits in these requests.
    hits: int = 0


26
27
28
29
30
31
32
@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    num_running_reqs: int = 0
    num_waiting_reqs: int = 0

33
    gpu_cache_usage: float = 0.0
34
35
36

    prefix_cache_stats: PrefixCacheStats = field(
        default_factory=PrefixCacheStats)
37
38


39
40
@dataclass
class LoRAStats:
41
42
    waiting_requests: set[str] = field(default_factory=set)
    running_requests: set[str] = field(default_factory=set)
43
44


45
46
47
48
49
@dataclass
class RequestStateStats:
    """Stats that need to be tracked across delta updates."""

    num_generation_tokens: int = 0
50
51
52
53
54
55
56
57
58

    # This is a engine frontend timestamp (wall-clock)
    arrival_time: float = 0.0

    # These are engine core timestamps (monotonic)
    queued_ts: float = 0.0
    scheduled_ts: float = 0.0
    first_token_ts: float = 0.0
    last_token_ts: float = 0.0
59
60
61
62
63
64


@dataclass
class FinishedRequestStats:
    """Stats associated with a finished request."""

65
    finish_reason: "FinishReason"
66
    e2e_latency: float = 0.0
67
68
    num_prompt_tokens: int = 0
    num_generation_tokens: int = 0
69
    max_tokens_param: Optional[int] = None
70
71
    queued_time: float = 0.0
    prefill_time: float = 0.0
72
73
    inference_time: float = 0.0
    decode_time: float = 0.0
74
75


76
77
78
class IterationStats:
    """Stats associated with a single set of EngineCoreOutputs."""

79
80
    def __init__(self):
        self.iteration_timestamp = time.time()
81
82
        self.num_generation_tokens = 0
        self.num_prompt_tokens = 0
83
        self.num_preempted_reqs = 0
84
        self.finished_requests: list[FinishedRequestStats] = []
85
86
        self.max_num_generation_tokens_iter: list[int] = []
        self.n_params_iter: list[int] = []
87
88
89
90
        self.time_to_first_tokens_iter: list[float] = []
        self.time_per_output_tokens_iter: list[float] = []
        self.waiting_lora_adapters: dict[str, int] = {}
        self.running_lora_adapters: dict[str, int] = {}
91

92
93
94
    def _time_since(self, start: float) -> float:
        """Calculate an interval relative to this iteration's timestamp."""
        return self.iteration_timestamp - start
95

96
97
    def update_from_output(self, output: "EngineCoreOutput",
                           engine_core_timestamp: float, is_prefilling: bool,
98
99
                           prompt_len: int, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
100
101
102
        num_new_generation_tokens = len(output.new_token_ids)

        self.num_generation_tokens += num_new_generation_tokens
103
        if is_prefilling and num_new_generation_tokens > 0:
104
105
106
107
108
109
110
111
            # TODO(andy): we used to assert that num_new_generation_tokens
            # > 0 with an invariant that EngineCore does not stream outputs
            # for partially completed prefills (scheduler.update_from_output
            # makes EngineCoreOutput iff num_computed_tokens == num_tokens).
            # When prompt logprobs are enabled, we currently stream out the
            # partially completed prompt.
            # This will be reverted in a follow up PR and we should re-enable
            # this assertion / invariant.
112
113
114
115
116
117
118
119
120
            self.num_prompt_tokens += prompt_len

            first_token_latency = self._time_since(req_stats.arrival_time)
            self.time_to_first_tokens_iter.append(first_token_latency)

        req_stats.num_generation_tokens += num_new_generation_tokens

        # Process request-level engine core events
        if output.events is not None:
121
122
            self.update_from_events(output.request_id, output.events,
                                    is_prefilling, req_stats, lora_stats)
123
124
125
126

        # Process the batch-level "new tokens" engine core event
        if is_prefilling:
            # TODO: re-enable no-output-for-partial-prefills invariant as above
127
            if num_new_generation_tokens > 0:
128
                req_stats.first_token_ts = engine_core_timestamp
129
        else:
130
131
132
133
134
135
136
            tpot = engine_core_timestamp - req_stats.last_token_ts
            self.time_per_output_tokens_iter.append(tpot)

        # TODO: re-enable no-output-for-partial-prefills invariant as above
        if num_new_generation_tokens > 0:
            req_stats.last_token_ts = engine_core_timestamp

137
    def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
138
139
                           is_prefilling: bool, req_stats: RequestStateStats,
                           lora_stats: Optional[LoRAStats]):
140
141
142
143
144
        # Avoid circular dependency
        from vllm.v1.engine import EngineCoreEventType
        for event in events:
            if event.type == EngineCoreEventType.QUEUED:
                req_stats.queued_ts = event.timestamp
145
146
                if lora_stats is not None:
                    lora_stats.waiting_requests.add(req_id)
147
            elif event.type == EngineCoreEventType.SCHEDULED:
148
149
                if req_stats.scheduled_ts == 0.0:  # ignore preemptions
                    req_stats.scheduled_ts = event.timestamp
150
                LoRARequestStates.scheduled_request(lora_stats, req_id)
151
152
            elif event.type == EngineCoreEventType.PREEMPTED:
                self.num_preempted_reqs += 1
153
                LoRARequestStates.preempted_request(lora_stats, req_id)
154

155
    def update_from_finished_request(self, finish_reason: "FinishReason",
156
                                     num_prompt_tokens: int,
157
                                     max_tokens_param: Optional[int],
158
159
160
                                     req_stats: RequestStateStats):
        e2e_latency = self._time_since(req_stats.arrival_time)

161
162
163
164
165
166
167
168
169
        # Queued interval is from first QUEUED event to first SCHEDULED
        queued_time = req_stats.scheduled_ts - req_stats.queued_ts

        # Prefill interval is from first SCHEDULED to first NEW_TOKEN
        # Any preemptions during prefill is included in the interval
        prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts

        # Decode interval is from first NEW_TOKEN to last NEW_TOKEN
        # Any preemptions during decode are included
170
171
        decode_time = req_stats.last_token_ts - req_stats.first_token_ts

172
173
174
175
        # Inference interval is from first SCHEDULED to last NEW_TOKEN
        # Any preemptions during prefill or decode are included
        inference_time = req_stats.last_token_ts - req_stats.scheduled_ts

176
177
178
        finished_req = \
            FinishedRequestStats(finish_reason=finish_reason,
                                 e2e_latency=e2e_latency,
179
                                 num_prompt_tokens=num_prompt_tokens,
180
                                 num_generation_tokens=req_stats.num_generation_tokens,
181
                                 max_tokens_param=max_tokens_param,
182
183
                                 queued_time=queued_time,
                                 prefill_time=prefill_time,
184
185
186
                                 inference_time=inference_time,
                                 decode_time=decode_time)
        self.finished_requests.append(finished_req)
187
188
189
190
191
192


class LoRARequestStates:
    """Per-LoRA request state stats."""

    def __init__(self):
193
        self.lora_name_to_stats: dict[str, LoRAStats] = {}
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
        if req_state.lora_name is None:
            return None
        if req_state.lora_name not in self.lora_name_to_stats:
            self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
        return self.lora_name_to_stats[req_state.lora_name]

    def add_request(self, req_state: 'RequestState'):
        if (lora_stats := self.get_stats(req_state)) is not None:
            lora_stats.waiting_requests.add(req_state.request_id)

    def finish_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.running_requests.remove(req_state.request_id)

    def abort_request(self, req_state: 'RequestState'):
        if req_state.lora_name is None:
            return
        lora_stats = self.lora_name_to_stats[req_state.lora_name]
        lora_stats.waiting_requests.discard(req_state.request_id)
        lora_stats.running_requests.discard(req_state.request_id)

    # Break the pattern for this lifecycle methods so we can
    # call this from IterationStats.update_from_events()
    @staticmethod
    def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.waiting_requests.remove(request_id)
        lora_stats.running_requests.add(request_id)

228
229
230
231
232
233
234
    @staticmethod
    def preempted_request(lora_stats: Optional[LoRAStats], request_id: str):
        if lora_stats is None:
            return
        lora_stats.running_requests.remove(request_id)
        lora_stats.waiting_requests.add(request_id)

235
236
237
238
239
240
241
242
243
244
245
    def update_iteration_stats(self,
                               iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return
        for lora_name, stats in self.lora_name_to_stats.items():
            if stats.waiting_requests:
                iteration_stats.waiting_lora_adapters[lora_name] = \
                    len(stats.waiting_requests)
            if stats.running_requests:
                iteration_stats.running_lora_adapters[lora_name] = \
                    len(stats.running_requests)