"tests/vscode:/vscode.git/clone" did not exist on "2669a0d7b518371bb1d950425bd64a320010733f"
layernorm_rms_benchmarks.py 7.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pickle as pkl
import time
6
from collections.abc import Callable, Iterable
7
8
9
10
11
12
13
14
15
16
from dataclasses import dataclass
from itertools import product

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm
17
18
19
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    per_token_group_quant_fp8,
)
20
21
22
23
24
25
26
27


@dataclass
class bench_params_t:
    num_tokens: int
    hidden_size: int
    add_residual: bool
    dtype: torch.dtype
28
    group_size: list[int]
29
30

    def description(self):
31
32
33
34
35
        return (
            f"N {self.num_tokens} "
            f"x D {self.hidden_size} "
            f"x R {self.add_residual} "
            f"x DT {self.dtype}"
36
            f"x GS {self.group_size}"
37
        )
38
39


40
def get_bench_params() -> list[bench_params_t]:
41
42
43
44
45
    ## Test Fixtures
    NUM_TOKENS = [2**x for x in range(11)]
    HIDDEN_SIZES = list(range(1024, 8129, 1024))
    ADD_RESIDUAL = [True, False]
    DTYPES = [torch.bfloat16, torch.float]
46
    GROUP_SIZES = [[1, 64], [1, 128]]
47

48
    combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
49
    bench_params = list(
50
        map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
51
    )
52
53
54
55
    return bench_params


# Reference impls
56
57
58
def unfused_int8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
59
    residual: torch.Tensor | None,
60
    quant_dtype: torch.dtype,
61
    group_size: list[int],
62
):
63
64
65
66
67
68
69
70
71
72
73
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _, _ = ops.scaled_int8_quant(torch_out)


74
75
76
def unfused_fp8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
77
    residual: torch.Tensor | None,
78
    quant_dtype: torch.dtype,
79
    group_size: list[int],
80
):
81
82
83
84
85
86
87
88
89
90
91
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _ = ops.scaled_fp8_quant(torch_out)


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def unfused_groupwise_fp8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: torch.Tensor | None,
    quant_dtype: torch.dtype,
    group_size: list[int],
):
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _ = per_token_group_quant_fp8(
        torch_out, group_size=group_size[1], use_ue8m0=False
    )


112
def fused_impl(
113
114
    rms_norm_layer: RMSNorm,  # this stores the weights
    x: torch.Tensor,
115
    residual: torch.Tensor | None,
116
    quant_dtype: torch.dtype,
117
    group_size: list[int],
118
119
120
121
):
    out, _ = ops.rms_norm_dynamic_per_token_quant(
        x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
    )
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def fused_groupwise_impl(
    rms_norm_layer: RMSNorm,  # this stores the weights
    x: torch.Tensor,
    residual: torch.Tensor | None,
    quant_dtype: torch.dtype,
    group_size: list[int],
):
    out, _ = ops.rms_norm_per_block_quant(
        x,
        rms_norm_layer.weight,
        1e-6,
        quant_dtype,
        group_size,
        residual=residual,
        is_scale_transposed=True,
    )


142
# Bench functions
143
144
145
146
147
def bench_fn(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: torch.Tensor,
    quant_dtype: torch.dtype,
148
    group_size: list[int],
149
150
151
152
153
    label: str,
    sub_label: str,
    fn: Callable,
    description: str,
) -> TMeasurement:
154
155
156
157
158
159
160
    min_run_time = 1

    globals = {
        "rms_norm_layer": rms_norm_layer,
        "x": x,
        "residual": residual,
        "quant_dtype": quant_dtype,
161
        "group_size": group_size,
162
163
164
        "fn": fn,
    }
    return TBenchmark.Timer(
165
        stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
166
167
168
169
170
171
172
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


173
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
174
175
176
177
178
179
    # Make inputs
    layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
    # Make weights
    layer.weight.data.normal_(mean=1.0, std=0.1)
    # Make inputs
    scale = 1 / params.hidden_size
180
181
182
183
184
185
186
187
188
    x = (
        torch.randn(
            params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
        )
        * scale
    )
    residual = (
        (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
    )
189
190
191
192
193

    timers = []

    # unfused int8 impl.
    timers.append(
194
195
196
197
198
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
199
            params.group_size,
200
201
202
203
204
205
            label,
            sub_label,
            unfused_int8_impl,
            "unfused_int8_impl",
        )
    )
206
207
208

    # unfused fp8 impl.
    timers.append(
209
210
211
212
213
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
214
            params.group_size,
215
216
217
218
219
220
            label,
            sub_label,
            unfused_fp8_impl,
            "unfused_fp8_impl",
        )
    )
221
222
223

    # fused int8 impl.
    timers.append(
224
225
226
227
228
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
229
            params.group_size,
230
231
232
233
234
235
            label,
            sub_label,
            fused_impl,
            "fused_int8_impl",
        )
    )
236
237
238

    # fused fp8 impl.
    timers.append(
239
240
241
242
243
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
244
            params.group_size,
245
246
247
248
249
250
            label,
            sub_label,
            fused_impl,
            "fused_fp8_impl",
        )
    )
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    # unfused groupwise fp8 impl.
    timers.append(
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            params.group_size,
            label,
            sub_label,
            unfused_groupwise_fp8_impl,
            "unfused_groupwise_fp8_impl",
        )
    )

    # fused groupwise fp8 impl.
    timers.append(
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            params.group_size,
            label,
            sub_label,
            fused_groupwise_impl,
            "fused_groupwise_fp8_impl",
        )
    )

282
283
284
285
286
287
288
289
290
291
292
293
294
    print_timers(timers)

    return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def main():
295
    torch.set_default_device("cuda")
296
297
298
299
    bench_params = get_bench_params()

    timers = []
    for bp in tqdm(bench_params):
300
        timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
301
302
303
304
305
306
307
308
    print_timers(timers)

    # pickle all the results
    timestamp = int(time.time())
    with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
        pkl.dump(timers, f)


309
if __name__ == "__main__":
310
    main()