layernorm_rms_benchmarks.py 5.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pickle as pkl
import time
6
from collections.abc import Callable, Iterable
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from dataclasses import dataclass
from itertools import product

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm


@dataclass
class bench_params_t:
    num_tokens: int
    hidden_size: int
    add_residual: bool
    dtype: torch.dtype

    def description(self):
27
28
29
30
31
32
        return (
            f"N {self.num_tokens} "
            f"x D {self.hidden_size} "
            f"x R {self.add_residual} "
            f"x DT {self.dtype}"
        )
33
34


35
def get_bench_params() -> list[bench_params_t]:
36
37
38
39
40
41
42
    ## Test Fixtures
    NUM_TOKENS = [2**x for x in range(11)]
    HIDDEN_SIZES = list(range(1024, 8129, 1024))
    ADD_RESIDUAL = [True, False]
    DTYPES = [torch.bfloat16, torch.float]

    combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
43
44
45
    bench_params = list(
        map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
    )
46
47
48
49
    return bench_params


# Reference impls
50
51
52
def unfused_int8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
53
    residual: torch.Tensor | None,
54
55
    quant_dtype: torch.dtype,
):
56
57
58
59
60
61
62
63
64
65
66
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _, _ = ops.scaled_int8_quant(torch_out)


67
68
69
def unfused_fp8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
70
    residual: torch.Tensor | None,
71
72
    quant_dtype: torch.dtype,
):
73
74
75
76
77
78
79
80
81
82
83
84
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _ = ops.scaled_fp8_quant(torch_out)


def fused_impl(
85
86
    rms_norm_layer: RMSNorm,  # this stores the weights
    x: torch.Tensor,
87
    residual: torch.Tensor | None,
88
89
90
91
92
    quant_dtype: torch.dtype,
):
    out, _ = ops.rms_norm_dynamic_per_token_quant(
        x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
    )
93
94
95


# Bench functions
96
97
98
99
100
101
102
103
104
105
def bench_fn(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: torch.Tensor,
    quant_dtype: torch.dtype,
    label: str,
    sub_label: str,
    fn: Callable,
    description: str,
) -> TMeasurement:
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    min_run_time = 1

    globals = {
        "rms_norm_layer": rms_norm_layer,
        "x": x,
        "residual": residual,
        "quant_dtype": quant_dtype,
        "fn": fn,
    }
    return TBenchmark.Timer(
        stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


124
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
125
126
127
128
129
130
    # Make inputs
    layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
    # Make weights
    layer.weight.data.normal_(mean=1.0, std=0.1)
    # Make inputs
    scale = 1 / params.hidden_size
131
132
133
134
135
136
137
138
139
    x = (
        torch.randn(
            params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
        )
        * scale
    )
    residual = (
        (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
    )
140
141
142
143
144

    timers = []

    # unfused int8 impl.
    timers.append(
145
146
147
148
149
150
151
152
153
154
155
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
            label,
            sub_label,
            unfused_int8_impl,
            "unfused_int8_impl",
        )
    )
156
157
158

    # unfused fp8 impl.
    timers.append(
159
160
161
162
163
164
165
166
167
168
169
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            label,
            sub_label,
            unfused_fp8_impl,
            "unfused_fp8_impl",
        )
    )
170
171
172

    # fused int8 impl.
    timers.append(
173
174
175
176
177
178
179
180
181
182
183
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
            label,
            sub_label,
            fused_impl,
            "fused_int8_impl",
        )
    )
184
185
186

    # fused fp8 impl.
    timers.append(
187
188
189
190
191
192
193
194
195
196
197
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            label,
            sub_label,
            fused_impl,
            "fused_fp8_impl",
        )
    )
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    print_timers(timers)

    return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def main():
212
    torch.set_default_device("cuda")
213
214
215
216
    bench_params = get_bench_params()

    timers = []
    for bp in tqdm(bench_params):
217
        timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
218
219
220
221
222
223
224
225
    print_timers(timers)

    # pickle all the results
    timestamp = int(time.time())
    with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
        pkl.dump(timers, f)


226
if __name__ == "__main__":
227
    main()