"examples/backends/trtllm/mm_router_worker/README.md" did not exist on "8a098a66504647dbe0c4ab085c7c50794767e585"
metrics.py 7.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from dataclasses import dataclass, field
6
7

import numpy as np
8
import prometheus_client
9

10
from vllm.config import SpeculativeConfig
11
12
13
14
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclass
class SpecDecodingStats:
18
19
20
21
22
23
24
25
26
    """Per-step iteration decoding stats from scheduler.

    Each scheduler step, statistics on spec decoding performance are
    aggregated across requests by the scheduler and returned to the
    frontend in EngineCoreOutputs->SchedulerStats.
    """

    num_spec_tokens: int
    num_drafts: int = 0
27
28
    num_draft_tokens: int = 0
    num_accepted_tokens: int = 0
29
    num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
30

31
32
    @classmethod
    def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
33
34
35
36
        return cls(
            num_spec_tokens=num_spec_tokens,
            num_accepted_tokens_per_pos=[0] * num_spec_tokens,
        )
37

38
39
    def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
        self.num_drafts += 1
40
41
        self.num_draft_tokens += num_draft_tokens
        self.num_accepted_tokens += num_accepted_tokens
42
43
44
45
        assert num_accepted_tokens <= self.num_spec_tokens
        for i in range(num_accepted_tokens):
            self.num_accepted_tokens_per_pos[i] += 1

46

47
48
class SpecDecodingLogging:
    """Aggregate and log spec decoding metrics.
49

50
51
52
53
    LoggingStatLogger aggregates per-iteration metrics over a set
    time interval using observe() and then logs them using log()
    before resetting to zero.
    """
54
55
56
57
58

    def __init__(self):
        self.reset()

    def reset(self):
59
        self.num_drafts: list[int] = []
60
61
        self.num_draft_tokens: list[int] = []
        self.num_accepted_tokens: list[int] = []
62
        self.accepted_tokens_per_pos_lists: list[list[int]] = []
63
        self.last_log_time = time.monotonic()
64
65

    def observe(self, spec_decoding_stats: SpecDecodingStats):
66
        self.num_drafts.append(spec_decoding_stats.num_drafts)
67
        self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
68
        self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens)
69
        self.accepted_tokens_per_pos_lists.append(
70
71
            spec_decoding_stats.num_accepted_tokens_per_pos
        )
72

73
    def log(self, log_fn=logger.info):
74
75
        if not self.num_drafts:
            return
76
        num_drafts = np.sum(self.num_drafts)
77
78
        num_draft_tokens = np.sum(self.num_draft_tokens)
        num_accepted_tokens = np.sum(self.num_accepted_tokens)
79
80
81
82
83
84
85
        draft_throughput = 0
        accepted_throughput = 0

        elapsed_time = time.monotonic() - self.last_log_time
        if elapsed_time > 0:
            draft_throughput = num_draft_tokens / elapsed_time
            accepted_throughput = num_accepted_tokens / elapsed_time
86

87
88
89
90
91
        draft_acceptance_rate = (
            num_accepted_tokens / num_draft_tokens * 100
            if num_draft_tokens > 0
            else float("nan")
        )
92
93
94

        # Conventionally, mean acceptance length includes the bonus token
        mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
95
96
97
98

        pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
        acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
        rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
99

100
        log_fn(
101
            "SpecDecoding metrics: "
102
            "Mean acceptance length: %.2f, "
103
104
            "Accepted throughput: %.2f tokens/s, "
            "Drafted throughput: %.2f tokens/s, "
105
            "Accepted: %d tokens, "
106
            "Drafted: %d tokens, "
107
108
            "Per-position acceptance rate: %s, "
            "Avg Draft acceptance rate: %.1f%%",
109
            mean_acceptance_length,
110
111
            accepted_throughput,
            draft_throughput,
112
113
            num_accepted_tokens,
            num_draft_tokens,
114
            rates_str,
115
            draft_acceptance_rate,
116
        )
117
        self.reset()
118
119
120
121
122
123
124
125
126
127


class SpecDecodingProm:
    """Record spec decoding metrics in Prometheus.

    The acceptance rate can be calculated using a PromQL query:

      rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
      rate(vllm:spec_decode_num_draft_tokens_total[$interval])

128
129
    The mean acceptance length (conventionally including bonus tokens)
    can be calculated using:
130

131
      1 + (
132
      rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
133
      rate(vllm:spec_decode_num_drafts[$interval]))
134
135
136
137
138
139
140

    A per-position acceptance rate vector can be computed using

      vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
      vllm:spec_decode_num_drafts[$interval]
    """

141
142
143
144
    _counter_cls = prometheus_client.Counter

    def __init__(
        self,
145
        speculative_config: SpeculativeConfig | None,
146
        labelnames: list[str],
147
        per_engine_labelvalues: dict[int, list[str]],
148
    ):
149
150
151
152
        self.spec_decoding_enabled = speculative_config is not None
        if not self.spec_decoding_enabled:
            return

153
154
155
        counter_drafts = self._counter_cls(
            name="vllm:spec_decode_num_drafts",
            documentation="Number of spec decoding drafts.",
156
157
            labelnames=labelnames,
        )
158
        self.counter_spec_decode_num_drafts = make_per_engine(
159
160
            counter_drafts, per_engine_labelvalues
        )
161
162
163
164

        counter_draft_tokens = self._counter_cls(
            name="vllm:spec_decode_num_draft_tokens",
            documentation="Number of draft tokens.",
165
166
            labelnames=labelnames,
        )
167
        self.counter_spec_decode_num_draft_tokens = make_per_engine(
168
169
            counter_draft_tokens, per_engine_labelvalues
        )
170
171
172
173

        counter_accepted_tokens = self._counter_cls(
            name="vllm:spec_decode_num_accepted_tokens",
            documentation="Number of accepted tokens.",
174
175
            labelnames=labelnames,
        )
176
        self.counter_spec_decode_num_accepted_tokens = make_per_engine(
177
178
            counter_accepted_tokens, per_engine_labelvalues
        )
179
180

        assert speculative_config is not None
181
182
183
184
185
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if self.spec_decoding_enabled
            else 0
        )
186
        pos_labelnames = labelnames + ["position"]
187
        base_counter = self._counter_cls(
188
189
            name="vllm:spec_decode_num_accepted_tokens_per_pos",
            documentation="Accepted tokens per draft position.",
190
191
            labelnames=pos_labelnames,
        )
192
        self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
193
194
195
196
197
198
199
            int, list[prometheus_client.Counter]
        ] = {
            idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)]
            for idx, lv in per_engine_labelvalues.items()
        }

    def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0):
200
201
        if not self.spec_decoding_enabled:
            return
202
        self.counter_spec_decode_num_drafts[engine_idx].inc(
203
204
            spec_decoding_stats.num_drafts
        )
205
        self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
206
207
            spec_decoding_stats.num_draft_tokens
        )
208
        self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
209
210
            spec_decoding_stats.num_accepted_tokens
        )
211
        for pos, counter in enumerate(
212
213
            self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
        ):
214
            counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
215
216


217
218
219
def make_per_engine(
    counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]]
):
220
221
222
223
224
    """Create a counter for each label value."""
    return {
        idx: counter.labels(*labelvalues)
        for idx, labelvalues in per_engine_labelvalues.items()
    }