"vllm/vscode:/vscode.git/clone" did not exist on "ffb08379d8870a1a81ba82b72797f196838d0c86"
layernorm_rms_benchmarks.py 5.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pickle as pkl
import time
6
from collections.abc import Iterable
7
8
from dataclasses import dataclass
from itertools import product
9
from typing import Callable, Optional
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm


@dataclass
class bench_params_t:
    num_tokens: int
    hidden_size: int
    add_residual: bool
    dtype: torch.dtype

    def description(self):
28
29
30
31
32
33
        return (
            f"N {self.num_tokens} "
            f"x D {self.hidden_size} "
            f"x R {self.add_residual} "
            f"x DT {self.dtype}"
        )
34
35


36
def get_bench_params() -> list[bench_params_t]:
37
38
39
40
41
42
43
    ## Test Fixtures
    NUM_TOKENS = [2**x for x in range(11)]
    HIDDEN_SIZES = list(range(1024, 8129, 1024))
    ADD_RESIDUAL = [True, False]
    DTYPES = [torch.bfloat16, torch.float]

    combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
44
45
46
    bench_params = list(
        map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
    )
47
48
49
50
    return bench_params


# Reference impls
51
52
53
54
55
56
def unfused_int8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: Optional[torch.Tensor],
    quant_dtype: torch.dtype,
):
57
58
59
60
61
62
63
64
65
66
67
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _, _ = ops.scaled_int8_quant(torch_out)


68
69
70
71
72
73
def unfused_fp8_impl(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: Optional[torch.Tensor],
    quant_dtype: torch.dtype,
):
74
75
76
77
78
79
80
81
82
83
84
85
    # Norm
    torch_out = None
    if residual is None:
        torch_out = rms_norm_layer.forward_cuda(x, residual)
    else:
        torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

    # Quant
    torch_out, _ = ops.scaled_fp8_quant(torch_out)


def fused_impl(
86
87
88
89
90
91
92
93
    rms_norm_layer: RMSNorm,  # this stores the weights
    x: torch.Tensor,
    residual: Optional[torch.Tensor],
    quant_dtype: torch.dtype,
):
    out, _ = ops.rms_norm_dynamic_per_token_quant(
        x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
    )
94
95
96


# Bench functions
97
98
99
100
101
102
103
104
105
106
def bench_fn(
    rms_norm_layer: RMSNorm,
    x: torch.Tensor,
    residual: torch.Tensor,
    quant_dtype: torch.dtype,
    label: str,
    sub_label: str,
    fn: Callable,
    description: str,
) -> TMeasurement:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    min_run_time = 1

    globals = {
        "rms_norm_layer": rms_norm_layer,
        "x": x,
        "residual": residual,
        "quant_dtype": quant_dtype,
        "fn": fn,
    }
    return TBenchmark.Timer(
        stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
        globals=globals,
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


125
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
126
127
128
129
130
131
    # Make inputs
    layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
    # Make weights
    layer.weight.data.normal_(mean=1.0, std=0.1)
    # Make inputs
    scale = 1 / params.hidden_size
132
133
134
135
136
137
138
139
140
    x = (
        torch.randn(
            params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
        )
        * scale
    )
    residual = (
        (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
    )
141
142
143
144
145

    timers = []

    # unfused int8 impl.
    timers.append(
146
147
148
149
150
151
152
153
154
155
156
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
            label,
            sub_label,
            unfused_int8_impl,
            "unfused_int8_impl",
        )
    )
157
158
159

    # unfused fp8 impl.
    timers.append(
160
161
162
163
164
165
166
167
168
169
170
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            label,
            sub_label,
            unfused_fp8_impl,
            "unfused_fp8_impl",
        )
    )
171
172
173

    # fused int8 impl.
    timers.append(
174
175
176
177
178
179
180
181
182
183
184
        bench_fn(
            layer,
            x,
            residual,
            torch.int8,
            label,
            sub_label,
            fused_impl,
            "fused_int8_impl",
        )
    )
185
186
187

    # fused fp8 impl.
    timers.append(
188
189
190
191
192
193
194
195
196
197
198
        bench_fn(
            layer,
            x,
            residual,
            torch.float8_e4m3fn,
            label,
            sub_label,
            fused_impl,
            "fused_fp8_impl",
        )
    )
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    print_timers(timers)

    return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def main():
213
    torch.set_default_device("cuda")
214
215
216
217
    bench_params = get_bench_params()

    timers = []
    for bp in tqdm(bench_params):
218
        timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
219
220
221
222
223
224
225
226
    print_timers(timers)

    # pickle all the results
    timestamp = int(time.time())
    with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
        pkl.dump(timers, f)


227
if __name__ == "__main__":
228
    main()