benchmark_aqlm.py 8.97 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
import os
import sys
from typing import Optional

import torch
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
    dequantize_weight, generic_dequantize_gemm, get_int_dtype,
    optimized_dequantize_gemm)
laibao's avatar
laibao committed
12
from vllm.utils import FlexibleArgumentParser
zhuwenwen's avatar
zhuwenwen committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def torch_mult(
        input: torch.Tensor,  #  [..., in_features]
        weights: torch.Tensor,
        scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
    output = F.linear(input, weights)
    return output


def dequant_out_scale(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
    output_partition_sizes: torch.IntTensor,
    bias: Optional[torch.Tensor],
) -> torch.Tensor:

    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

    if bias is None:
        output = F.linear(input, weights, bias)
        orig_shape = output.shape
        flattened_output = output.view(-1, output.size(-1))
        f_scales = scales.view(-1, scales.shape[0])
        b_scales = f_scales.expand(flattened_output.shape[0], -1)
        flattened_output *= b_scales
        return flattened_output.view(orig_shape)
    else:
        b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
            -1, weights.shape[1])
        weights *= b_scales
        return F.linear(input, weights, bias)


def dequant_weight_scale(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
    output_partition_sizes: torch.IntTensor,
    bias: Optional[torch.Tensor],
) -> torch.Tensor:

    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

    b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
        -1, weights.shape[1])
    weights *= b_scales
    return F.linear(input, weights, bias)


def dequant_no_scale(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
    output_partition_sizes: torch.IntTensor,
    bias: Optional[torch.Tensor],
) -> torch.Tensor:

    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

    return F.linear(input, weights, bias)


# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
laibao's avatar
laibao committed
89
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
zhuwenwen's avatar
zhuwenwen committed
90

laibao's avatar
laibao committed
91
    n = int(parts.sum().item())
zhuwenwen's avatar
zhuwenwen committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    device = torch.device('cuda:0')

    code_range = (1 << bits) // 2
    ingroups = 8

    codes = torch.randint(-code_range,
                          code_range,
                          size=(n, k // ingroups, nbooks),
                          dtype=get_int_dtype(bits),
                          device=device)

    codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
                            dtype=torch.float16,
                            device=device)

    count = 0
    for index in range(16):
        for i in range(8):
            for book in range(nbooks):
                codebooks[book, index, 0, i] = count * (10**book)
            count += 1

    print("codes shape", codes.shape)

    for i in range(16):
        for book in range(nbooks):
            codes[0, i, book] = i
            codes[0, -i, book] = i

    weights = dequantize_weight(codes, codebooks, None)
    weights2 = ops.aqlm_dequant(codes, codebooks, parts)

    print("weights shape:", weights.shape)
    print("weights2 shape:", weights2.shape)

    print("weights are:", weights)
    print("weights2 are:", weights2)

    print("first 128 weights are", weights[0, 0:128].to(torch.int32))
    print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))

    print("last 128 weights are", weights[0, -128:])
    print("last 128 weights2 are:", weights2[0, -128:])


def main():

laibao's avatar
laibao committed
140
    parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
zhuwenwen's avatar
zhuwenwen committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    # Add arguments
    parser.add_argument("--nbooks",
                        type=int,
                        default=1,
                        help="Number of codebooks (default: 1)")
    parser.add_argument("--bits",
                        type=int,
                        default=16,
                        help="Number of bits per code element (default: 16)")
    parser.add_argument(
        "--test",
        type=bool,
        default=False,
        help="Run the decompression/dequant tester rather than benchmarking "
        "(default: False)")

    # Parse the arguments
    args = parser.parse_args()

    # Extract values
    nbooks = args.nbooks
    bits = args.bits

    if args.test:
        dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
        return

    # Otherwise, benchmark.
    methods = [
        ops.aqlm_gemm,
        dequant_out_scale,
        generic_dequantize_gemm,
        optimized_dequantize_gemm,
        dequant_weight_scale,
        torch_mult,
        dequant_no_scale,
    ]

    filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
    print(f"writing benchmarks to file {filename}")
    with open(filename, "w") as f:
        sys.stdout = f

        print('m | k | n | n parts', end='')
        for method in methods:
            print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
        print('')

        # These are reasonable prefill sizes.
        ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
                         (4096, (11008, 11008)), (11008, (4096, )))

        # reasonable ranges for m.
        for m in [
                1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
                128, 256, 512, 1024, 1536, 2048, 3072, 4096
        ]:
            print(f'{m}', file=sys.__stdout__)
            for ksp in ksandpartions:
                run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
                         methods)

        sys.stdout = sys.__stdout__


laibao's avatar
laibao committed
207
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
zhuwenwen's avatar
zhuwenwen committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
             methods):

    # I didn't see visible improvements from increasing these, but feel free :)
    num_warmup_trials = 1
    num_trials = 1

    num_calls = 100

    # warmup.
    for method in methods:
        for _ in range(num_warmup_trials):
            run_timing(
                num_calls=num_calls,
                m=m,
                k=k,
                parts=parts,
                nbooks=nbooks,
                bits=bits,
                method=method,
            )

    n = parts.sum().item()
    print(f'{m} | {k} | {n} | {parts.tolist()}', end='')

    for method in methods:
        best_time_us = 1e20
        for _ in range(num_trials):
            kernel_dur_ms = run_timing(
                num_calls=num_calls,
                m=m,
                k=k,
                parts=parts,
                nbooks=nbooks,
                bits=bits,
                method=method,
            )

            kernel_dur_us = 1000 * kernel_dur_ms

            if kernel_dur_us < best_time_us:
                best_time_us = kernel_dur_us

        print(f' | {kernel_dur_us:.0f}', end='')

    print('')


laibao's avatar
laibao committed
255
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
256
257
               nbooks: int, bits: int, method) -> float:

laibao's avatar
laibao committed
258
    n = int(parts.sum().item())
zhuwenwen's avatar
zhuwenwen committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    device = torch.device('cuda:0')

    input = torch.randn((1, m, k), dtype=torch.float16, device=device)

    code_range = (1 << bits) // 2
    ingroups = 8

    codes = torch.randint(-code_range,
                          code_range,
                          size=(n, k // ingroups, nbooks),
                          dtype=get_int_dtype(bits),
                          device=device)

    codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
                            dtype=torch.float16,
                            device=device)

    scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)

    # for comparison to just a pytorch mult.
    weights = torch.randn((n, k), dtype=torch.float16, device=device)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()

    if method is torch_mult:
        for i in range(num_calls):
            torch_mult(input, weights, scales)
    else:
        for i in range(num_calls):
            method(input, codes, codebooks, scales, parts, None)

    end_event.record()
    end_event.synchronize()

    dur_ms = start_event.elapsed_time(end_event) / num_calls
    return dur_ms


if __name__ == "__main__":
    sys.exit(main())