metrics.py 7.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
6
from dataclasses import dataclass, field
from typing import Optional
7
8

import numpy as np
9
import prometheus_client
10

11
from vllm.config import SpeculativeConfig
12
13
14
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclass
class SpecDecodingStats:
19
20
21
22
23
24
25
26
27
    """Per-step iteration decoding stats from scheduler.

    Each scheduler step, statistics on spec decoding performance are
    aggregated across requests by the scheduler and returned to the
    frontend in EngineCoreOutputs->SchedulerStats.
    """

    num_spec_tokens: int
    num_drafts: int = 0
28
29
    num_draft_tokens: int = 0
    num_accepted_tokens: int = 0
30
    num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
31

32
33
34
35
    @classmethod
    def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
        return cls(num_spec_tokens=num_spec_tokens,
                   num_accepted_tokens_per_pos=[0] * num_spec_tokens)
36

37
38
    def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
        self.num_drafts += 1
39
40
        self.num_draft_tokens += num_draft_tokens
        self.num_accepted_tokens += num_accepted_tokens
41
42
43
44
        assert num_accepted_tokens <= self.num_spec_tokens
        for i in range(num_accepted_tokens):
            self.num_accepted_tokens_per_pos[i] += 1

45

46
47
class SpecDecodingLogging:
    """Aggregate and log spec decoding metrics.
48

49
50
51
52
    LoggingStatLogger aggregates per-iteration metrics over a set
    time interval using observe() and then logs them using log()
    before resetting to zero.
    """
53
54
55
56
57

    def __init__(self):
        self.reset()

    def reset(self):
58
        self.num_drafts: list[int] = []
59
60
        self.num_draft_tokens: list[int] = []
        self.num_accepted_tokens: list[int] = []
61
        self.accepted_tokens_per_pos_lists: list[list[int]] = []
62
        self.last_log_time = time.monotonic()
63
64

    def observe(self, spec_decoding_stats: SpecDecodingStats):
65
        self.num_drafts.append(spec_decoding_stats.num_drafts)
66
67
68
        self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
        self.num_accepted_tokens.append(
            spec_decoding_stats.num_accepted_tokens)
69
70
        self.accepted_tokens_per_pos_lists.append(
            spec_decoding_stats.num_accepted_tokens_per_pos)
71

72
    def log(self, log_fn=logger.info):
73
74
        if not self.num_drafts:
            return
75
        num_drafts = np.sum(self.num_drafts)
76
77
        num_draft_tokens = np.sum(self.num_draft_tokens)
        num_accepted_tokens = np.sum(self.num_accepted_tokens)
78
79
80
81
82
83
84
        draft_throughput = 0
        accepted_throughput = 0

        elapsed_time = time.monotonic() - self.last_log_time
        if elapsed_time > 0:
            draft_throughput = num_draft_tokens / elapsed_time
            accepted_throughput = num_accepted_tokens / elapsed_time
85

86
87
        draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
                                 100 if num_draft_tokens > 0 else float("nan"))
88
89
90

        # Conventionally, mean acceptance length includes the bonus token
        mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
91
92
93
94

        pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
        acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
        rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
95

96
        log_fn(
97
            "SpecDecoding metrics: "
98
            "Mean acceptance length: %.2f, "
99
100
            "Accepted throughput: %.2f tokens/s, "
            "Drafted throughput: %.2f tokens/s, "
101
            "Accepted: %d tokens, "
102
            "Drafted: %d tokens, "
103
104
            "Per-position acceptance rate: %s, "
            "Avg Draft acceptance rate: %.1f%%",
105
            mean_acceptance_length,
106
107
            accepted_throughput,
            draft_throughput,
108
109
            num_accepted_tokens,
            num_draft_tokens,
110
            rates_str,
111
            draft_acceptance_rate,
112
        )
113
        self.reset()
114
115
116
117
118
119
120
121
122
123


class SpecDecodingProm:
    """Record spec decoding metrics in Prometheus.

    The acceptance rate can be calculated using a PromQL query:

      rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
      rate(vllm:spec_decode_num_draft_tokens_total[$interval])

124
125
    The mean acceptance length (conventionally including bonus tokens)
    can be calculated using:
126

127
      1 + (
128
      rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
129
      rate(vllm:spec_decode_num_drafts[$interval]))
130
131
132
133
134
135
136

    A per-position acceptance rate vector can be computed using

      vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
      vllm:spec_decode_num_drafts[$interval]
    """

137
138
139
140
141
142
    _counter_cls = prometheus_client.Counter

    def __init__(
        self,
        speculative_config: Optional[SpeculativeConfig],
        labelnames: list[str],
143
        per_engine_labelvalues: dict[int, list[str]],
144
    ):
145
146
147
148
        self.spec_decoding_enabled = speculative_config is not None
        if not self.spec_decoding_enabled:
            return

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        counter_drafts = self._counter_cls(
            name="vllm:spec_decode_num_drafts",
            documentation="Number of spec decoding drafts.",
            labelnames=labelnames)
        self.counter_spec_decode_num_drafts = make_per_engine(
            counter_drafts, per_engine_labelvalues)

        counter_draft_tokens = self._counter_cls(
            name="vllm:spec_decode_num_draft_tokens",
            documentation="Number of draft tokens.",
            labelnames=labelnames)
        self.counter_spec_decode_num_draft_tokens = make_per_engine(
            counter_draft_tokens, per_engine_labelvalues)

        counter_accepted_tokens = self._counter_cls(
            name="vllm:spec_decode_num_accepted_tokens",
            documentation="Number of accepted tokens.",
            labelnames=labelnames)
        self.counter_spec_decode_num_accepted_tokens = make_per_engine(
            counter_accepted_tokens, per_engine_labelvalues)
169
170
171
172
173

        assert speculative_config is not None
        num_spec_tokens = (speculative_config.num_speculative_tokens
                           if self.spec_decoding_enabled else 0)
        pos_labelnames = labelnames + ["position"]
174
        base_counter = self._counter_cls(
175
176
            name="vllm:spec_decode_num_accepted_tokens_per_pos",
            documentation="Accepted tokens per draft position.",
177
178
            labelnames=pos_labelnames,
        )
179
180
181
182
183
184
185
186
187
188
189
190
        self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
            int, list[prometheus_client.Counter]] = {
                idx: [
                    base_counter.labels(*lv, str(pos))
                    for pos in range(num_spec_tokens)
                ]
                for idx, lv in per_engine_labelvalues.items()
            }

    def observe(self,
                spec_decoding_stats: SpecDecodingStats,
                engine_idx: int = 0):
191
192
        if not self.spec_decoding_enabled:
            return
193
194
195
        self.counter_spec_decode_num_drafts[engine_idx].inc(
            spec_decoding_stats.num_drafts)
        self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
196
            spec_decoding_stats.num_draft_tokens)
197
        self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
198
199
            spec_decoding_stats.num_accepted_tokens)
        for pos, counter in enumerate(
200
201
                self.
                counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]):
202
            counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
203
204
205
206
207
208
209
210
211


def make_per_engine(counter: prometheus_client.Counter,
                    per_engine_labelvalues: dict[int, list[str]]):
    """Create a counter for each label value."""
    return {
        idx: counter.labels(*labelvalues)
        for idx, labelvalues in per_engine_labelvalues.items()
    }