common.py 16.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import time
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import IntEnum
7
from typing import ClassVar, Optional
8
9
10
11
12
13
14

import msgspec
from msgspec import field as msgspec_field

from vllm.sampling_params import SamplingParams


15
16
17
18
19
class RequestStatsUpdate(
        msgspec.Struct,  # type: ignore
        array_like=True,
        omit_defaults=True,
        gc=False):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """
    An update to the request stats.

    This represents a stats update at a specific timestamp with metadata
    associated with the update.

    NOTE: since there might be multiple processes generating updates at
    different parts of the engine (e.g. input processor, scheduler, engine core,
    etc.), we use the monotonic timestamp to record the update to compute any
    intervals, and explicit wall-clock timestamp should be used for timestamps.

    WARNING: This assumes stats are generated in a single machine. If there are
    potentially multiple machines, one should always generate the stats updates
    on one single machine or use something else.
    """

    class Type(IntEnum):
        """See `RequestStats` for the lifecycle of a request."""

        # Request arrived at the engine frontend.
        ARRIVED = 0
        # Input processed by the input processor.
        INPUT_PROCESSED = 1
        # Queued on the engine core.
        QUEUED = 2
        # Scheduled running prefill by the scheduler.
        # A request could be running a new prefill on the prompt tokens or
        # a resumed prefill on the original prefill tokens + generated output
        # tokens before preemption.
        PREFILLING = 3
        # Preempted by the scheduler.
        PREEMPTED = 4
        # Output token is generated by the engine core.
        DECODING = 5
        # Token detokenized by the detokenizer.
        # We will record the timestamp for each output token, as well as the
        # finish reason.
        DETOKENIZED = 6
        # Request finishes (or aborts).
        FINISHED = 7

    """
    Valid state updates:
    ARRIVED

    ├──────► INPUT_PROCESSED ──────► QUEUED ──────► PREFILLING ◄────┐
    │              │                   │              │             │
    │              │                   │              ▼             │
    │              │                   │       -──► DECODING        │
    │              │                   │       |      │             │
    │              │                   │       |      ▼             │
    │              │                   │       └─ DETOKENIZED       │
    │              │                   │              │             │
    │              │                   │              ▼             │
    │              ▼                   ▼           PREEMPTED ◄──────┘
    │              │                   │              │
    └──────────────┴───────────────────┴──────────────┴


                FINISHED (All could go to FINISHED)
    """
81
    _VALID_TRANSITIONS: ClassVar[dict[Type, set[Type]]] = {
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        Type.ARRIVED: {
            Type.INPUT_PROCESSED,
            Type.FINISHED,
        },
        Type.INPUT_PROCESSED: {
            Type.QUEUED,
            Type.FINISHED,
        },
        Type.QUEUED: {
            Type.PREFILLING,
            Type.FINISHED,
        },
        Type.PREFILLING: {
            Type.DECODING,
            Type.PREEMPTED,
            Type.FINISHED,
        },
        Type.DECODING: {
            Type.DETOKENIZED,
            Type.FINISHED,
        },
        Type.DETOKENIZED: {
            Type.DECODING,
            Type.PREEMPTED,
            Type.FINISHED,
        },
        Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED},
        Type.FINISHED: set(),
    }

    request_id: str

    type: Type

    # Timestamp when the update is recorded. This is used to record time
    # intervals between events rather than wall clock time.
    monotonic_ts_s: float = msgspec_field(
        default_factory=lambda: time.monotonic())

    ############################################################
    # Metadata associated with the update.
    ############################################################
    # For input_processed. Metadata needed for stats logging.
    num_prompt_tokens: Optional[int] = None
    sampling_params: Optional[SamplingParams] = None

    # For running.
    # Number of tokens computed when scheduled to run.
    num_computed_tokens: Optional[int] = None
    # Number of cached tokens when scheduled to run.
    num_cached_tokens: Optional[int] = None

    # For decoded.
    # The number of new output tokens generated.
    num_new_tokens: Optional[int] = None

    # For both detokenized and decoded.
    # Finished reason.
    finish_reason: Optional[str] = None

    # Non-optional fields for each update type.
143
    _REQUIRED_FIELDS: ClassVar[dict[Type, list[str]]] = {
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
        Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
        Type.DETOKENIZED: ["num_new_tokens"],
        Type.FINISHED: ["finish_reason"],
    }

    def __post_init__(self):
        required_fields = self._REQUIRED_FIELDS.get(self.type, [])
        for field in required_fields:
            if getattr(self, field) is None:
                raise ValueError(
                    f"Field {field} is required for update type {self.type}.")

    @staticmethod
    def check_valid_update(
        update: "RequestStatsUpdate",
        last_update_type: Optional[Type],
        last_updated_ts_s: Optional[float],
    ):
        if last_update_type is None:
            assert update.type == RequestStatsUpdate.Type.ARRIVED
        else:
            valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[
                last_update_type]
            assert update.type in valid_cur_update_types, (
                f"Invalid update type: {update.type} for last_update_type: "
                f"{last_update_type}.")

        if last_updated_ts_s is not None:
            assert update.monotonic_ts_s >= last_updated_ts_s, (
                "Update timestamp must be monotonically increasing, but "
                f"last_updated_ts_s={last_updated_ts_s} and "
                f"update.monotonic_ts_s={update.monotonic_ts_s}.")


@dataclass
class RequestStats:
    """Stats associated with a request (`Request`)."""

    ############################################################
    # Metadata
    ############################################################
    request_id: str
    sampling_params: Optional[SamplingParams] = None
    num_prompt_tokens: Optional[int] = None

    ############################################################
    # Metrics and Stats
    ############################################################
    # Timestamp when the request was last updated.
    last_updated_ts_s: Optional[float] = None

    # Last update stats type.
    last_update_type: Optional[RequestStatsUpdate.Type] = None

    # Timestamp when the request arrived at the llm engine.
    arrival_ts_s: Optional[float] = None

    # Number of tokens cached. When part of the request prefix is cached,
    # this will be set.
    num_cached_tokens: int = 0

    # Number of tokens computed.
    num_computed_tokens: int = 0

    # The timestamp when the request become waiting in the queue.
    queued_ts_s: Optional[float] = None

    # When the input processor is completed.
    input_processor_end_ts_s: Optional[float] = None

    # A sorted list of timestamps when the request was scheduled to prefill.
    # This could be when:
    # 1. the request is newly scheduled, so it's a new prefill.
    # 2. the request was preempted and resumed. It is equivalent to running
    #    a prefill of the original prefill tokens + generated output tokens
    #    before preemption.
221
    prefill_start_ts_s_lst: list[float] = dataclass_field(default_factory=list)
222
223

    # A list of timestamps when a token is decoded by the engine core.
224
    decoding_ts_s_lst: list[float] = dataclass_field(default_factory=list)
225
226

    # A sorted list of timestamps for each output token.
227
    output_token_ts_s_lst: list[float] = dataclass_field(default_factory=list)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    # First token's timestamp.
    first_token_ts_s: Optional[float] = None

    # TODO(rickyx): we need model runner to surface these.
    model_forward_duration_s: float = 0.0
    # Includes model forward, block/sync across workers, cpu-gpu sync time
    # and sampling time.
    model_execute_duration_s: float = 0.0

    # A sorted list of timestamps when the request was preempted at the
    # scheduler.
    # TODO(rickyx): right now, we don't actually have a good high-level
    # metric to measure the impact of preemption other than observation of
    # large P99 TPOT. Ideally we could quantify the impact of preemption by
    # measuring the number of tokens re-computed due to preemption.
244
    preempted_ts_s_lst: list[float] = dataclass_field(default_factory=list)
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

    # Timestamp when the request was finished at the engine core.
    finished_ts_s: Optional[float] = None

    # Finish reason.
    finish_reason: Optional[str] = None

    ############################################################
    # Derived properties.
    ############################################################
    @property
    def prefill_ts_s(self) -> Optional[float]:
        """The timestamp when the request started prefilling.
        Since a request could be preempted in decoding and later resumed
        to prefill the decoded tokens, we use the first prefill start timestamp.
        """
        return (self.prefill_start_ts_s_lst[0]
                if self.prefill_start_ts_s_lst else None)

    @property
    def e2e_latency_s(self) -> Optional[float]:
        if self.finished_ts_s is None or self.arrival_ts_s is None:
            return None
        assert self.finished_ts_s >= self.arrival_ts_s
        return self.finished_ts_s - self.arrival_ts_s

    @property
    def queue_duration_s(self) -> Optional[float]:
        """How long the request was waiting to run."""
        if self.queued_ts_s is None or self.prefill_ts_s is None:
            # Either not queued or not running yet.
            return None
        assert self.queued_ts_s <= self.prefill_ts_s
        return self.prefill_ts_s - self.queued_ts_s

    @property
    def inference_latency_s(self) -> Optional[float]:
        """How long the request was running inference
        (prefill and decode)."""
        if self.finished_ts_s is None or self.prefill_ts_s is None:
            return None
        assert self.finished_ts_s >= self.prefill_ts_s
        return self.finished_ts_s - self.prefill_ts_s

    @property
    def first_token_latency_s(self) -> Optional[float]:
        if self.first_token_ts_s is None or self.arrival_ts_s is None:
            return None
        assert self.first_token_ts_s >= self.arrival_ts_s
        return self.first_token_ts_s - self.arrival_ts_s

    @property
    def prefill_latency_s(self) -> Optional[float]:
        if self.first_token_ts_s is None or self.prefill_ts_s is None:
            return None
        assert self.first_token_ts_s >= self.prefill_ts_s
        return self.first_token_ts_s - self.prefill_ts_s

    @property
    def decode_latency_s(self) -> Optional[float]:
        if self.e2e_latency_s is None or self.first_token_latency_s is None:
            return None
        assert self.e2e_latency_s >= self.first_token_latency_s
        return self.e2e_latency_s - self.first_token_latency_s

    @property
311
    def output_token_latency_s_lst(self) -> list[float]:
312
313
314
315
        if len(self.output_token_ts_s_lst) == 0:
            return []
        latency_s_lst = []
        for i in range(1, len(self.output_token_ts_s_lst)):
316
317
            assert (self.output_token_ts_s_lst[i]
                    >= self.output_token_ts_s_lst[i - 1])
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            latency_s = (self.output_token_ts_s_lst[i] -
                         self.output_token_ts_s_lst[i - 1])
            latency_s_lst.append(latency_s)
        return latency_s_lst

    @property
    def num_output_tokens(self) -> int:
        return len(self.output_token_ts_s_lst)

    @property
    def is_finished(self) -> bool:
        return self.finished_ts_s is not None

    def update_from(self, update: "RequestStatsUpdate"):
        RequestStatsUpdate.check_valid_update(update, self.last_update_type,
                                              self.last_updated_ts_s)
        ts = update.monotonic_ts_s
        self.last_updated_ts_s = ts
        self.last_update_type = update.type
        if update.type == RequestStatsUpdate.Type.ARRIVED:
            self.arrival_ts_s = ts
        elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED:
            self.input_processor_end_ts_s = ts
            self.sampling_params = update.sampling_params
            self.num_prompt_tokens = update.num_prompt_tokens
        elif update.type == RequestStatsUpdate.Type.QUEUED:
            self.queued_ts_s = ts
        elif update.type == RequestStatsUpdate.Type.PREFILLING:
            self.prefill_start_ts_s_lst.append(ts)
347
348
            self.num_cached_tokens = update.num_cached_tokens or 0
            self.num_computed_tokens = update.num_computed_tokens or 0
349
350
351
352
353
354
355
        elif update.type == RequestStatsUpdate.Type.PREEMPTED:
            self._reset_for_preemption(ts)
        elif update.type == RequestStatsUpdate.Type.DECODING:
            self.decoding_ts_s_lst.append(ts)
        elif update.type == RequestStatsUpdate.Type.DETOKENIZED:
            self._record_detokenized_output(
                ts,
356
                update.num_new_tokens or 0,
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
            )
        elif update.type == RequestStatsUpdate.Type.FINISHED:
            self.finished_ts_s = ts
            self.finish_reason = update.finish_reason
        else:
            raise ValueError(f"Unknown update type: {update.type}")

    def _record_detokenized_output(
        self,
        ts_s: float,
        num_new_tokens: int,
    ):
        # Update if first output token is generated.
        if len(self.output_token_ts_s_lst) == 0:
            self.first_token_ts_s = ts_s
            assert (
                self.prefill_ts_s is not None
            ), "Request must be running before generating output tokens."

        # Some X new tokens were generated at the ts.
        self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens)

    def _reset_for_preemption(self, ts_s: float):
        self.preempted_ts_s_lst.append(ts_s)
        # Reset the computed tokens since it might restart the prefill.
        self.num_computed_tokens = 0
        # Cached token count might also change when resumed.
        self.num_cached_tokens = 0
        # These stats don't change since they happen before request running.
        # - arrival_ts_s
        # - input_processor_end_ts_s
        # - sampling_params
        # - num_prompt_tokens
        # - first_token_ts_s
        #
        # These stats are accumulated over preemptions:
        # - output_token_ts_s_lst
        # - prefill_start_ts_s_lst (after preemption, it will prefill the
        #   original prefill tokens and any output tokens generated before
        #   preemption.)


@dataclass
class KVCacheStats:
    #   KV Cache Usage in %
    gpu_cache_usage_sys: float = 0.0
    gpu_prefix_cache_hit_rate: float = 0.0


@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    # Number of requests currently running.
    num_running_reqs: int = 0
    # Number of requests currently waiting.
    num_waiting_reqs: int = 0

    kv_cache_stats: KVCacheStats = dataclass_field(
        default_factory=KVCacheStats)


@dataclass
class EngineCoreProcessStats:
    """Stats associated with the engine core process."""

    # Number of requests currently in the input queue. None if the engine core
    # is not running in multiprocess mode.
    input_queue_size: Optional[int] = None
    # Number of outputs currently in the output queue. None if the engine core
    # is not running in multiprocess mode.
    output_queue_size: Optional[int] = None


431
432
433
434
435
class EngineCoreStatsSnapshot(
        msgspec.Struct,  # type: ignore
        array_like=True,
        omit_defaults=True,
        gc=False):
436
437
438
439
440
441
442
443
444
    """
    A snapshot of the EngineCore's current stats over a period of time.
    """

    # Snapshot of the scheduler stats.
    scheduler_stats: SchedulerStats = msgspec_field(
        default_factory=SchedulerStats)

    # Per request stats updates.
445
    requests_stats_updates: list[RequestStatsUpdate] = msgspec_field(
446
447
448
449
450
451
452
453
        default_factory=list)

    # Engine core's queue stats.
    engine_core_process_stats: EngineCoreProcessStats = msgspec_field(
        default_factory=EngineCoreProcessStats)

    # TODO(rickyx): Add other components' stats,
    # e.g. model runner/worker and etc.