test_functional.py 66.2 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2
3
4
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
5

6
import einops
Aarni Koskela's avatar
Aarni Koskela committed
7
import numpy as np
8
import pytest
Aarni Koskela's avatar
Aarni Koskela committed
9
from scipy.stats import norm
10
11
12
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
13
from bitsandbytes import functional as F
Aarni Koskela's avatar
Aarni Koskela committed
14
15
16
17
18
19
20
from tests.helpers import (
    BOOLEAN_TUPLES,
    TRUE_FALSE,
    describe_dtype,
    get_test_dims,
    id_formatter,
)
Tim Dettmers's avatar
Tim Dettmers committed
21

Ruff's avatar
Ruff committed
22
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
Tim Dettmers's avatar
Tim Dettmers committed
23
24
k = 20

25

Tim Dettmers's avatar
Tim Dettmers committed
26
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
27
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
28
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
29
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
30
31
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
32
            torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
33
34

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
35

36

Tim Dettmers's avatar
Tim Dettmers committed
37
38
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
39
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
40
41
42
43
44
45
46
47
48
49
50
51
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

52

53
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
54
55
56
57
58
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

59
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
60
61
62
63
64
65
66
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

67
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
71
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
72
73
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
74
75
76
77
78
79
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
80
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
84

        return self.agg[name]

    def reset(self):
85
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
86
87
        self.ends = {}
        self.agg = {}
88
89
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
90

Tim Dettmers's avatar
Tim Dettmers committed
91
92
93
def setup():
    pass

94

Tim Dettmers's avatar
Tim Dettmers committed
95
96
97
def teardown():
    pass

98

Ruff's avatar
Ruff committed
99
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
Tim Dettmers's avatar
Tim Dettmers committed
100
def test_estimate_quantiles(dtype):
101
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
102
103
104
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

105
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
106
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
107

108
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
109
110
111
112
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
113
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
114
115
116
    assert (diff > 5e-02).sum().item() == 0


Aarni Koskela's avatar
Aarni Koskela committed
117
118
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
119
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
Aarni Koskela's avatar
Aarni Koskela committed
120
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
121
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
122
123
124
    diffs = []
    reldiffs = []
    for i in range(100):
125
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
126
127
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
128
129
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
130
131
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
132
133
134
135
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
136
137
    assert abserr < 0.011
    assert relerr < 0.018
138
    assert A2.dtype == dtype
139
140

    diffs = []
141
    code = F.create_dynamic_map(signed=signed)
142
    for i in range(100):
143
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
144
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
145
        A2 = F.dequantize_blockwise(C, S)
146
147
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
148
149
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
150
151
152
        # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
153
154
155
156
157
158
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
159
    assert A2.dtype == dtype
Ruff's avatar
Ruff committed
160
161
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
162
163


Tim Dettmers's avatar
Tim Dettmers committed
164
165
def quant(x):
    max1 = torch.abs(x).max()
166
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
167
168
    return max1, x.to(torch.int8)

169

Tim Dettmers's avatar
Tim Dettmers committed
170
def dequant(c, maxC):
171
172
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
173
174

def mm_dequant(maxA, maxB, C):
175
176
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
177
178
179

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
180
181
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
182
183
    return max1, x.to(torch.int8)

184

Tim Dettmers's avatar
Tim Dettmers committed
185
def quant_multi_chunk(x, dim, chunk_size=32):
186
187
188
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
189
190
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
191
192
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
193
194
195
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
196
197
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
198
199
    return max1, x.to(torch.int8)

200

Tim Dettmers's avatar
Tim Dettmers committed
201
202
203
204
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

205

Tim Dettmers's avatar
Tim Dettmers committed
206
def mean(xx):
207
208
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
209

Aarni Koskela's avatar
Aarni Koskela committed
210
211
methods = {
    "linear": (
212
213
214
215
216
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
Aarni Koskela's avatar
Aarni Koskela committed
217
218
219
    ),
    "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
}
220
221


Aarni Koskela's avatar
Aarni Koskela committed
222
223
224
225
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
Tim Dettmers's avatar
Tim Dettmers committed
226
227
228
229
230
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Ruff's avatar
Ruff committed
231
    # print("")
Tim Dettmers's avatar
Tim Dettmers committed
232
233
    for i in range(5):
        if batched:
234
235
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
236
237
238
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
239
240
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
241
242
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
Ruff's avatar
Ruff committed
243
        torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
244
245
246
247
248
249
250
251
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
252
253
254
255
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Ruff's avatar
Ruff committed
258
259
    # print(mean(errors))
    # print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
260
261


Tim Dettmers's avatar
Tim Dettmers committed
262
263
264
265
266
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Aarni Koskela's avatar
Aarni Koskela committed
267
268
269
270
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
271
272
273
274
275
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
Ruff's avatar
Ruff committed
276
277
        shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
278
279
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
280
281
282
283
284
285
286
287
288
289
290
291
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
292

293
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
294

Tim Dettmers's avatar
Tim Dettmers committed
295
296
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
Ruff's avatar
Ruff committed
297
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
298
299
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
300
301
302
303
304
305
306
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

307
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
308
309


Aarni Koskela's avatar
Aarni Koskela committed
310
311
312
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
Tim Dettmers's avatar
Tim Dettmers committed
313
314
315
316
317
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
Ruff's avatar
Ruff committed
318
319
        A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
320
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
Ruff's avatar
Ruff committed
321
        iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
322
323
        out = F.igemm(A, B, out=iout)

324
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
325

326

Aarni Koskela's avatar
Aarni Koskela committed
327
328
329
330
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
331
332
333
334
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
335
336
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
337
338
339
340
341
342
343
344
345

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
Ruff's avatar
Ruff committed
346
        A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
347
        if transpose:
348
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
349
        else:
350
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
351
352
353
354
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
355
356
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
357
            out = out.float()
358
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
359
360
361
362
363
364

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
365
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
366
            out = F.igemm(Ac, Bc)
367
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
368
            out = out.float()
369
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
370
371
372
373
374
375
376
377
378
379

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

380
381
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
382

383
384
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
385
386
387
388
389

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
390
391
392
393
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
394
395
396
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

397

Aarni Koskela's avatar
Aarni Koskela committed
398
399
400
401
402
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
403
404
405
406
407
408
409
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
410
411
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
412
413
414
415
416
417
418
419
420
421
422

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
Ruff's avatar
Ruff committed
423
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
Tim Dettmers's avatar
Tim Dettmers committed
424
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
425
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
426

427

428
429
430
431
@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
Aarni Koskela's avatar
Aarni Koskela committed
432
433
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
434
def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb):
Tim Dettmers's avatar
Tim Dettmers committed
435
436
    for i in range(k):
        if dims == 2:
Ruff's avatar
Ruff committed
437
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
438
        elif dims == 3:
Ruff's avatar
Ruff committed
439
440
            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
441
442
        C1 = torch.matmul(A.float(), B.t().float())

443
444
        C2 = F.int8_linear_matmul(A, B)
        torch.testing.assert_close(C1, C2.float())
Tim Dettmers's avatar
Tim Dettmers committed
445

446

Aarni Koskela's avatar
Aarni Koskela committed
447
448
449
450
451
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
452
def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
Tim Dettmers's avatar
Tim Dettmers committed
453
454
    for i in range(k):
        if dims == 2:
455
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
456
        elif dims == 3:
Ruff's avatar
Ruff committed
457
            A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
458
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
459
460
461
462
463
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

464
465
466
        CA, _, statsA, _, _ = F.int8_double_quant(A)
        CB, statsB, _ = F.int8_vectorwise_quant(B)
        output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
Tim Dettmers's avatar
Tim Dettmers committed
467

468
        torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
469

Tim Dettmers's avatar
Tim Dettmers committed
470

471
472
@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
Aarni Koskela's avatar
Aarni Koskela committed
473
474
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
475
476
def test_dequant_mm(dim1, dim4, dims, has_bias):
    inner = 128
477
    bias = None
Ruff's avatar
Ruff committed
478
479
    if has_bias:
        bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
480

Tim Dettmers's avatar
Tim Dettmers committed
481
    for i in range(1):
482
483
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
484
        C1 = torch.matmul(A.half(), B.t().half())
Ruff's avatar
Ruff committed
485
486
        if has_bias:
            C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
487
488
489
490

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

491
        C2 = F.int8_linear_matmul(A1, B1)
Tim Dettmers's avatar
Tim Dettmers committed
492

493
        C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
Ruff's avatar
Ruff committed
494
495
        if has_bias:
            C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
496

497
        # TODO: is something wrong here? If so, the problem goes deeper
Aarni Koskela's avatar
Aarni Koskela committed
498
499
        # n = C1.numel()
        # p = 0.06
500
501
502
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
Aarni Koskela's avatar
Aarni Koskela committed
503
504
        # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
        # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
505

506
507
508
        C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
        C5 /= std
        torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
509
        n = C5.numel()
Aarni Koskela's avatar
Aarni Koskela committed
510
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
Tim Dettmers's avatar
Tim Dettmers committed
511

512

Aarni Koskela's avatar
Aarni Koskela committed
513
514
515
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
516
517
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
def test_colrow_absmax(dim1, dim2, dims, threshold):
Tim Dettmers's avatar
Tim Dettmers committed
518
    for i in range(k):
519
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
520

521
        assert dims == 2
522

523
524
        row_stats1, _ = torch.abs(A.float()).max(1)
        col_stats1, _ = torch.abs(A.float()).max(0)
Tim Dettmers's avatar
Tim Dettmers committed
525

526
527
528
529
530
        if threshold > 0.0:
            A_truncated = A.clone()
            A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
Tim Dettmers's avatar
Tim Dettmers committed
531

532
            row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
Tim Dettmers's avatar
Tim Dettmers committed
533

534
535
536
537
538
539
540
            nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
            nnz_block_ptr1 = torch.zeros(
                nnz_rows1_counts.shape[0] + 1,
                dtype=nnz_rows1_counts.dtype,
                device=nnz_rows1_counts.device,
            )
            nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
Tim Dettmers's avatar
Tim Dettmers committed
541

542
543
544
545
546
547
548
549
            torch.testing.assert_close(col_stats1_trunc, col_stats2)
            torch.testing.assert_close(row_stats1_trunc, row_stats2)
            # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
        else:
            row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
            assert nnz_block_ptr2 is None
            torch.testing.assert_close(col_stats1, col_stats2)
            torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
550

551
552
553
554

@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(dim1, dim2):
Tim Dettmers's avatar
Tim Dettmers committed
555
    for i in range(k):
556
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
557
558
559
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

560
        CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
Tim Dettmers's avatar
Tim Dettmers committed
561
562

        # max difference is 1 due to rounding differences
563
564
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
565
566

        n = CAt.numel()
Ruff's avatar
Ruff committed
567
568
        num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
569
570

        # allow for 1:500 error due to rounding differences
571
572
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
Ruff's avatar
Ruff committed
573
            print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
Tim Dettmers's avatar
Tim Dettmers committed
574
            assert False
575
        if num_not_close_rows > (min_error * n):
Ruff's avatar
Ruff committed
576
            print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
Tim Dettmers's avatar
Tim Dettmers committed
577
578
            assert False

579
580
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
581
582


Aarni Koskela's avatar
Aarni Koskela committed
583
584
585
586
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
Ruff's avatar
Ruff committed
587
        for (dim1, dim4, inner) in zip(
588
589
590
            (1, 8, 2048, 4096),
            (2, 128, 2048, 4096),
            (4, 256, 512, 4096),
Aarni Koskela's avatar
Aarni Koskela committed
591
        )
Ruff's avatar
Ruff committed
592
    ),
Aarni Koskela's avatar
Aarni Koskela committed
593
)
594
def test_integrated_int8_linear_matmul(dim1, dim4, inner):
Tim Dettmers's avatar
Tim Dettmers committed
595
    for i in range(k):
596
597
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
598
599
600

        out1 = torch.matmul(A.half(), B.t().half())

601
602
        C1a, stats1a, _ = F.int8_vectorwise_quant(A)
        C2a, stats2a, _ = F.int8_vectorwise_quant(B)
Tim Dettmers's avatar
Tim Dettmers committed
603
604
605
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

606
607
608
609
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
610

611
        out2 = F.int8_linear_matmul(A1, B1)
Tim Dettmers's avatar
Tim Dettmers committed
612

613
        C2 = F.int8_linear_matmul(A1, B1)
Tim Dettmers's avatar
Tim Dettmers committed
614

615
        out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
Tim Dettmers's avatar
Tim Dettmers committed
616

617
618
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
619
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
620
621


Aarni Koskela's avatar
Aarni Koskela committed
622
623
624
625
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
Ruff's avatar
Ruff committed
626
        for (dim1, dim4, inner) in zip(
Aarni Koskela's avatar
Aarni Koskela committed
627
628
629
630
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
        )
Ruff's avatar
Ruff committed
631
    ),
Aarni Koskela's avatar
Aarni Koskela committed
632
)
633
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
634
635
636
637
638
639
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
640
641
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
642
643
644
645
646
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

647
        C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
648
649
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
650
651
652
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

653
654
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
655
656
657
        outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
        # C3, S = F.nvidia_transform(outC32, "row", state=SC)
        C3 = outC32
Tim Dettmers's avatar
Tim Dettmers committed
658
659
660
661
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
662
663
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
664
665
666
667
668

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
669
670
        outC32 = F.int8_linear_matmul(A2, B2)
        out2 = F.int8_mm_dequant(outC32, stats1a, stats2a)
Tim Dettmers's avatar
Tim Dettmers committed
671

672
673
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
674
675

        C = torch.matmul(CA.float(), CB.t().float())
676
677
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
678

679
680
681
682
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
683

684
685
686
687
688
689
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
690

691
692
693
694
695
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
696
697


698
699
700
701
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_double_quant(dim1, dim2):
    threshold = 2.00
Tim Dettmers's avatar
Tim Dettmers committed
702
    for i in range(k):
703
        A = torch.randn(dim1, dim2, device="cuda").half()
704

705
706
        idx = torch.abs(A) >= threshold
        CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
Tim Dettmers's avatar
Tim Dettmers committed
707

708
709
710
711
        if outlier_cols is not None:
            A1 = A * idx
            A2 = torch.zeros_like(A) + A1
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
712

713
714
715
            A[:, outlier_cols] = 0
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
            torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
Tim Dettmers's avatar
Tim Dettmers committed
716
717


718
719
720
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
def test_coo_int8_vectorwise_quant(dim1, dim2):
Tim Dettmers's avatar
Tim Dettmers committed
721
722
    threshold = 3.00
    for i in range(k):
723
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
724

725
        idx = torch.abs(A) >= threshold
726
        CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
Tim Dettmers's avatar
Tim Dettmers committed
727

728
        if outlier_cols is not None:
729
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
730
            A[:, outlier_cols] = 0
Ruff's avatar
Ruff committed
731
            torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
732

Tim Dettmers's avatar
Tim Dettmers committed
733

Aarni Koskela's avatar
Aarni Koskela committed
734
735
736
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
Tim Dettmers's avatar
Tim Dettmers committed
737
738
739
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
740
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
741
742
743
744
745
746
747
748
749
750
751
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
Ruff's avatar
Ruff committed
752
        cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
753
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
754
755
756
757
758
759
760
761
762
763
764

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


Aarni Koskela's avatar
Aarni Koskela committed
765
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
766
767
def test_spmm_bench():
    batch = 2
768
769
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
770
    seq = 1024
771
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
772
773
774
    dim2 = model
    dim3 = hidden
    threshold = 4
775
776
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
777
    for i in range(10):
778
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
779
780
781
782

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
783
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
784
    torch.cuda.synchronize()
785
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
786
787
788

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
789
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
790
791
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
792
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
Tim Dettmers's avatar
Tim Dettmers committed
793
794

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
795
796
797
798
799
800
801
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
802
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
803
    print(tsp, t8)
804
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
805
806


807
808
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
809
810
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
811
    for _ in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
812
813
814
815
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

816
817
        Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
        CA, statsA, _ = F.int8_vectorwise_quant(A)
Tim Dettmers's avatar
Tim Dettmers committed
818

819
820
        out1_32 = F.int8_linear_matmul(CA, Cw1)
        out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
Tim Dettmers's avatar
Tim Dettmers committed
821

822
823
        # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
        CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
Tim Dettmers's avatar
Tim Dettmers committed
824

825
826
        out1_32 = F.int8_linear_matmul(CA, Cw1)
        out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
Tim Dettmers's avatar
Tim Dettmers committed
827
828
829
830

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
831
832
        # idx = torch.unique(coo_tensor._indices()[1]).long()
        # out4 = torch.matmul(A, w1.t())
Tim Dettmers's avatar
Tim Dettmers committed
833
834
        out5 = out3 + out4

835
836
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
837
838
839
840
        assert err2 < err1


def test_matmuls():
841
842
843
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
844
    c2 = bnb.matmul(a, b)
845
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
846

847
848
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
849
850
    assert err1 < 0.2
    assert err2 < 0.2
851
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
852
853


Aarni Koskela's avatar
Aarni Koskela committed
854
855
856
857
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
Tim Dettmers's avatar
Tim Dettmers committed
858
859
860
861
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
862
863
864
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
865
    if dtype == torch.float16:
866
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
867
868
        torch.nn.init.xavier_uniform_(B)
    else:
869
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
870
        torch.nn.init.xavier_uniform_(B)
871
872
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
873

874
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
875
876
877
878
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
879
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
880
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
881
882
883
884
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
885
886
887
888
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
889
    n = out1.numel()
890
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
891
892
893
    std = out1.std()
    out1 /= std
    out2 /= std
Ruff's avatar
Ruff committed
894
    assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
895
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
896
897
898

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

899
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
900

901
902
903
904
905
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
906
907
908
909
910
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

911
912
913
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
914
915
916
917
918
919
920
921

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
922
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
923
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
924
925
926
927
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

928
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
929
    idx = A2 != 0
930
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
931
932
933
934
935
936
937
938
939


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
940
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
941
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
942
943
944
945
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

946
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
947
    # torch uses row-major -> use transpose to transfer to col-major
948
    idx = A2.t() != 0
949
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
950
951


Aarni Koskela's avatar
Aarni Koskela committed
952
953
954
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
Tim Dettmers's avatar
Tim Dettmers committed
955
956
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
957
958
959
960
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
961
962
963
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

964
    CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
Tim Dettmers's avatar
Tim Dettmers committed
965
966
967
968
969
970
971
972
973

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
974
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
975
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
976
977
978
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
979
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
980
981
982
983
984
985

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

986
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
987

988
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
989
    n = out1.numel()
990
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
991
992
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

993
994
995
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
996
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
997
998
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
999
1000
1001
1002

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1003
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1004
    torch.cuda.synchronize()
1005
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1006
1007
1008
1009

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1010
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1011
    torch.cuda.synchronize()
1012
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1013
1014
1015
1016

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1017
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1018
    torch.cuda.synchronize()
1019
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1020
1021
1022

    torch.cuda.synchronize()
    t0 = time.time()
1023
1024
    for i in range(100):
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1025
    torch.cuda.synchronize()
1026
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1027
1028
1029

    torch.cuda.synchronize()
    t0 = time.time()
1030
1031
1032
1033
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1034
    torch.cuda.synchronize()
1035
    print("sparse+ matmul", time.time() - t0)
1036

1037
1038
    torch.cuda.synchronize()
    t0 = time.time()
1039
1040
1041
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
1042
    torch.cuda.synchronize()
1043
    print("partial matmul", time.time() - t0)
1044

1045
1046
    torch.cuda.synchronize()
    t0 = time.time()
1047
1048
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
1049
    torch.cuda.synchronize()
1050
    print("partial matmul", time.time() - t0)
Ruff's avatar
Ruff committed
1051

Tim Dettmers's avatar
Tim Dettmers committed
1052
1053
1054
1055
1056
1057

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1058
1059
1060
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1061
        minx = x.min()
1062
1063
1064
1065
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1066
        return x, qx, zpx
1067

Tim Dettmers's avatar
Tim Dettmers committed
1068
1069
1070
    batch = 2
    seq = 512
    model = 1024
1071
1072
1073
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1074
1075
1076

    C0 = torch.matmul(A, B)

1077
1078
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1079
1080
1081
1082
1083
1084
1085
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1086
1087
1088
1089
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1090
1091

    ca, cqa, cza = quant_zp(A)
Ruff's avatar
Ruff committed
1092
1093
    # print(ca.min(), ca.max())
    # print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1094
1095
1096

    zp = 1
    scale = 2.0
1097
1098
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1099
1100
1101
1102
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1103
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1104
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1105

Tim Dettmers's avatar
Tim Dettmers committed
1106
1107
1108
1109
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
1110
1111
1112
1113
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1114

Tim Dettmers's avatar
Tim Dettmers committed
1115
1116
1117
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
1118
1119
1120
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1121

Ruff's avatar
Ruff committed
1122
    # print("")
1123
    # print(C0.flatten()[:10])
Ruff's avatar
Ruff committed
1124
1125
1126
1127
1128
1129
    # print(C1.flatten()[:10])
    # print(C2.flatten()[:10])
    # print(C3.flatten()[:10])
    # print(C5.flatten()[:10])
    # print(C6.flatten()[:10])
    # print(C7.flatten()[:10])
1130
1131
1132
1133
1134
1135
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1136
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
1137
1138


1139
@pytest.mark.deprecated
1140
def test_extract_outliers():
1141
    for i in range(k):
1142
        shapeA = (4096, 4096 * 4)
1143
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
1144
1145
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
1146
        outliers1 = A[:, idx.long()]
1147

1148
        CA, SA = F.transform(A, "col_turing")
1149

1150
        outliers2 = F.extract_outliers(CA, SA, idx)
1151

1152
1153
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1154

1155
        torch.testing.assert_close(outliers1, outliers2)
1156

1157
        CA, SA = F.transform(A, "col_ampere")
1158
1159
1160
1161
1162

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1163

1164
        torch.testing.assert_close(outliers1, outliers2)
1165
1166
1167
1168
1169
1170
1171


def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
Ruff's avatar
Ruff committed
1172
    for hidden in [128]:  # , 14336]:
1173
1174
        for blocksize in [4096, 16384]:
            for i in range(2):
Ruff's avatar
Ruff committed
1175
                A1 = torch.randn(batch, seq, hidden, device="cpu")
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
1187
1188
1189
1190


def test_fp8_quant():
    for e_bits in range(1, 7):
Ruff's avatar
Ruff committed
1191
        p_bits = 7 - e_bits
Tim Dettmers's avatar
Tim Dettmers committed
1192
1193
1194
1195
1196
1197
1198
1199
1200
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1201
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1202
1203
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1204
1205
1206
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1207
1208
1209
1210
1211
1212
1213
1214

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1215
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1216
1217
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1218
1219
1220
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1221
1222
1223
1224
1225
1226
1227
1228

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1229
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1230
1231
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1232
1233
1234
            # assert diff < 0.0075
        # print(3, sum(abserr)/len(abserr))
        # print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1235

1236
1237

def test_few_bit_quant():
Ruff's avatar
Ruff committed
1238
    # print('')
1239
    for bits in range(2, 9):
Ruff's avatar
Ruff committed
1240
1241
        # print('='*30, bits, '='*30)
        for method in ["linear", "fp8", "dynamic", "quantile"]:
Tim Dettmers's avatar
Tim Dettmers committed
1242
1243
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
1244
            code = None
Ruff's avatar
Ruff committed
1245
            if method == "linear":
1246
                code = F.create_linear_map(True, total_bits=bits).cuda()
Ruff's avatar
Ruff committed
1247
1248
1249
            elif method == "fp8":
                ebits = math.ceil(bits / 2)
                pbits = bits - ebits - 1
Tim Dettmers's avatar
Tim Dettmers committed
1250
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Ruff's avatar
Ruff committed
1251
1252
1253
1254
            elif method == "dynamic":
                code = F.create_dynamic_map(True, bits - 0, bits).cuda()
            elif method == "quantile":
                values = torch.randn(2048, 2048, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1255
1256
1257
1258
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
Ruff's avatar
Ruff committed
1259
1260
            assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
            # print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
1261
1262
            assert code.numel() == 256
            for i in range(10):
Ruff's avatar
Ruff committed
1263
                values = torch.randn(1, 32, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1264
                values /= values.abs().max()
Ruff's avatar
Ruff committed
1265
                # values[values.abs() < 1e-6] += 1e-5
Tim Dettmers's avatar
Tim Dettmers committed
1266
1267
1268
1269

                q1 = []
                v1 = []
                for v in values[0]:
Ruff's avatar
Ruff committed
1270
                    idx = torch.abs(v - code).argmin()
Tim Dettmers's avatar
Tim Dettmers committed
1271
1272
1273
1274
1275
1276
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
1277
1278
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
1279
1280

                idx = torch.isclose(q1.int(), q2.int())
Ruff's avatar
Ruff committed
1281
                err2 = torch.abs(v2 - values)
Tim Dettmers's avatar
Tim Dettmers committed
1282
                abserrs.append(err2.mean().item())
Ruff's avatar
Ruff committed
1283
                relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1284
1285
                if idx.sum():
                    # some weird cases
Ruff's avatar
Ruff committed
1286
1287
                    err1 = torch.abs(v1 - values).mean()
                    # assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
1288
1289

                else:
1290
                    torch.testing.assert_close(q1, q2)
Ruff's avatar
Ruff committed
1291
1292
            # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
    # assert False
Tim Dettmers's avatar
Tim Dettmers committed
1293
1294
1295
1296


def test_kbit_quantile_estimation():
    for i in range(100):
Ruff's avatar
Ruff committed
1297
        data = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1298
        for bits in range(2, 9):
Ruff's avatar
Ruff committed
1299
            p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
Tim Dettmers's avatar
Tim Dettmers committed
1300
1301
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
Ruff's avatar
Ruff committed
1302
            err = torch.abs(val1 - val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1303
1304
1305
            assert err < 0.038

    for i in range(100):
Ruff's avatar
Ruff committed
1306
        data = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1307
        for bits in range(2, 4):
Ruff's avatar
Ruff committed
1308
1309
1310
            total_values = 2**bits - 1
            p = np.linspace(0, 1, 2 * total_values + 1)
            idx = np.arange(1, 2 * total_values + 1, 2)
Tim Dettmers's avatar
Tim Dettmers committed
1311
            p = p[idx]
Ruff's avatar
Ruff committed
1312
1313
            offset = 1 / (2 * total_values)
            p = np.linspace(offset, 1 - offset, total_values)
Tim Dettmers's avatar
Tim Dettmers committed
1314
            val1 = torch.Tensor(norm.ppf(p)).cuda()
Ruff's avatar
Ruff committed
1315
1316
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
            err = torch.abs(val1 - val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1317
            assert err < 0.035
1318
1319


Aarni Koskela's avatar
Aarni Koskela committed
1320
@pytest.mark.benchmark
1321
def test_bench_dequantization():
Ruff's avatar
Ruff committed
1322
1323
    a = torch.rand(1024, 1024, device="cuda").half()
    code = F.create_fp8_map(True, 3, 0, 4).cuda()
1324
1325
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
1326

Ruff's avatar
Ruff committed
1327
1328
    max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
    # print(max_theoretical_mu)
1329
1330
1331
1332

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1333
        qa, SA = F.quantize_blockwise(a)
1334
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1335
    # print((time.time()-t0)/1e6)
1336
1337


Aarni Koskela's avatar
Aarni Koskela committed
1338
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
1339
1340
1341
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
1342
1343
1344
1345
1346
1347
1348
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
Ruff's avatar
Ruff committed
1349
        idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
1350
        sign = -1.0 if sign else 1.0
Ruff's avatar
Ruff committed
1351
        exp = e1 * 2 + e2 * 1
1352
1353
        if exp == 0:
            # sub-normal
Ruff's avatar
Ruff committed
1354
1355
1356
1357
            if p1 == 0:
                result = 0
            else:
                result = sign * 0.0625
1358
1359
        else:
            # normal
Ruff's avatar
Ruff committed
1360
            exp = 2 ** (-exp + bias + 1)
1361
            frac = 1.5 if p1 else 1.0
Ruff's avatar
Ruff committed
1362
            result = sign * exp * frac
1363
1364
        code[idx] = result

Ruff's avatar
Ruff committed
1365
    A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
1366
1367
    qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
    A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
1368
1369

    err = (A1 - A2).abs().float()
Ruff's avatar
Ruff committed
1370
    relerr = (err / (A1.abs().float() + 1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1371
    idx = err > 1.0
1372
1373
    err = err.mean()

1374
    assert A2.dtype == dtype
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392

    # With larger block sizes, we can expect this to blow up.
    # At blocksize>=1024, don't even bother looking at relerr.
    if blocksize <= 64:
        assert err.item() < 0.1
        assert relerr.item() < 0.28
    elif blocksize <= 256:
        assert err.item() < 0.11
        assert relerr.item() < 0.30
    elif blocksize <= 512:
        assert err.item() < 0.12
        assert relerr.item() < 0.31
    elif quant_type == "fp4":
        # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
        assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
    else:
        # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
        assert err.item() < math.log2(blocksize) * 8e-2
1393
1394


Ruff's avatar
Ruff committed
1395
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
Tim Dettmers's avatar
Tim Dettmers committed
1396
def test_4bit_compressed_stats(quant_type):
1397
1398
1399
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
1400
        for i in range(10):
Ruff's avatar
Ruff committed
1401
            A1 = torch.randn(1024, 1024, device="cuda").half()
1402
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
Ruff's avatar
Ruff committed
1403
            q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
1404
1405
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
1406
1407

            err = (A1 - A2).abs().float()
Ruff's avatar
Ruff committed
1408
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
1409
1410
            err = err.mean()

1411
1412
            errs1.append(err.item())

1413
1414
1415
1416
            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
Ruff's avatar
Ruff committed
1417
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
1418
1419
            err = err.mean()

1420
            errs2.append(err.item())
1421
1422
1423
1424

            assert err.item() < 0.11
            assert relerr.item() < 0.28

Ruff's avatar
Ruff committed
1425
1426
        # print(sum(errs1)/len(errs1), blocksize, quant_type)
        # print(sum(errs2)/len(errs2), blocksize, quant_type)
1427
1428


Ruff's avatar
Ruff committed
1429
1430
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
Aarni Koskela's avatar
Aarni Koskela committed
1431
@pytest.mark.benchmark
1432
def test_bench_4bit_dequant(quant_type):
1433
    blocksize = 256
Ruff's avatar
Ruff committed
1434
    a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
1435
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
1436

Ruff's avatar
Ruff committed
1437
1438
1439
1440
1441
1442
1443
    input_size = a.numel() / 2
    output_size = a.numel() * 2
    num_bytes = input_size + output_size
    GB = num_bytes / 1e9
    max_theoretical_s = GB / 768
    # print(max_theoretical_s*1e6)
    b = torch.randn(128, 1024 * 12, device="cuda").half()
1444

Tim Dettmers's avatar
Tim Dettmers committed
1445
    iters = 100
1446
1447
1448
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1449
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
Ruff's avatar
Ruff committed
1450
        # b.copy_(a)
1451
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1452
    # print((time.time()-t0)/iters*1e6)
1453

Ruff's avatar
Ruff committed
1454
1455
1456
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
1457
    #    torch.matmul(b, a.t())
Ruff's avatar
Ruff committed
1458
1459
    # torch.cuda.synchronize()
    # print((time.time()-t0)/iters*1e6)
1460
1461
1462
1463


def test_normal_map_tree():
    code = F.create_normal_map()
Ruff's avatar
Ruff committed
1464
    values = code[:8].tolist() + code[-8:].tolist()
1465
    num_pivots = 1
Ruff's avatar
Ruff committed
1466
1467
1468
1469
    # print(values)
    while num_pivots < 16:
        idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
        # print(idx)
1470
1471
1472
        num_pivots *= 2
        pivots = []
        for i in idx:
Ruff's avatar
Ruff committed
1473
1474
            pivots.append((values[i - 1] + values[i]) / 2)
        # print(pivots)
1475

Tim Dettmers's avatar
Tim Dettmers committed
1476

Aarni Koskela's avatar
Aarni Koskela committed
1477
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
Ruff's avatar
Ruff committed
1478
1479
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
Aarni Koskela's avatar
Aarni Koskela committed
1480
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
Ruff's avatar
Ruff committed
1481
1482
1483
1484
1485
@pytest.mark.parametrize(
    "quant_storage",
    [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
    ids=describe_dtype,
)
1486
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
1487
    for dim in [128, 256, 512, 1024]:
Ruff's avatar
Ruff committed
1488
1489
        # for dim in [4*1024]:
        # for dim in [1*16]:
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

1500
        for i in range(100):
Ruff's avatar
Ruff committed
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
            if kind == "fc1":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "fc2":
                A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn_packed":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)

            qB, state = F.quantize_4bit(
                B,
                quant_type=storage_type,
                compress_statistics=double_quant,
                quant_storage=quant_storage,
            )
1520
            C3 = torch.matmul(A, B.t())
1521
            C2 = F.gemv_4bit(A, qB.t(), state=state)
1522
1523
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
1524

Ruff's avatar
Ruff committed
1525
1526
1527
            err1 = (C1 - C2).abs().float()
            err2 = (C3 - C2).abs().float()
            err3 = (C3 - C1).abs().float()
1528

Ruff's avatar
Ruff committed
1529
1530
1531
            mag1 = torch.abs(C1).float() + 1e-5
            mag2 = torch.abs(C3).float() + 1e-5
            mag3 = torch.abs(C3).float() + 1e-5
1532

Ruff's avatar
Ruff committed
1533
1534
1535
            relerr1 = err1 / mag1
            relerr2 = err2 / mag2
            relerr3 = err3 / mag3
1536

1537
1538
1539
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
1540

1541
1542
1543
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1544

1545
1546
1547
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
1548

1549
1550
1551
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
1552

Ruff's avatar
Ruff committed
1553
            c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
Tim Dettmers's avatar
Tim Dettmers committed
1554

1555
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
Ruff's avatar
Ruff committed
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
        err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
        err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
        relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
        relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
        relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
        maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
        maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
        maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
        absratio = err2 / err3
        relratio = relerr2 / relerr3
        maxratio = relerr2 / relerr3
1568
1569
1570

        # for debugging if the tests fails
        #
Ruff's avatar
Ruff committed
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        # print('='*80)
        # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
        # print(C1.flatten()[-20:])
        # print(C2.flatten()[-20:])
        # print(f'inference vs training abs: {err1}')
        # print(f'inference vs training rel: {relerr1}')
        # print(f'inference vs training max: {maxerr1}')
        # print(f'inference vs training vs torch err ratio abs: {absratio}')
        # print(f'inference vs training vs torch err ratio rel: {relratio}')
        # print(f'inference vs training vs torch err ratio max: {maxratio}')
1581
        if dtype == torch.float16:
1582
1583
1584
1585
1586
1587
1588
1589
1590
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
1591
        elif dtype == torch.float32:
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
1603
        elif dtype == torch.bfloat16:
1604
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
1605
                assert err1 < 6e-4
1606
1607
1608
1609
1610
1611
1612
1613
1614
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
1615

Ruff's avatar
Ruff committed
1616

1617
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1618
def test_managed():
Ruff's avatar
Ruff committed
1619
    n = 32 * 10
Tim Dettmers's avatar
Tim Dettmers committed
1620
1621
1622
1623
1624
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
Ruff's avatar
Ruff committed
1625
1626
    assert A.page_deviceid == 0
    assert B.page_deviceid == 0
Tim Dettmers's avatar
Tim Dettmers committed
1627
1628
1629
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
Ruff's avatar
Ruff committed
1630
1631
1632
1633
    assert (A == 17).sum().item() == n * n
    assert (B == 17).sum().item() == n * n
    C = A * B.float()
    assert (C == 289).sum().item() == n * n
Tim Dettmers's avatar
Tim Dettmers committed
1634
1635
1636
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
Ruff's avatar
Ruff committed
1637
1638
1639
1640
    assert (A == 17 * (2**3)).sum().item() == n * n


@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
Aarni Koskela's avatar
Aarni Koskela committed
1641
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
Ruff's avatar
Ruff committed
1642
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
1643
1644
1645
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
Aarni Koskela's avatar
Aarni Koskela committed
1646
    dims = get_test_dims(0, 8192, n=dims)
Ruff's avatar
Ruff committed
1647
1648
    dims = [dim + (64 - (dim % 64)) for dim in dims]
    # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
1649
    for dim in dims:
Ruff's avatar
Ruff committed
1650
1651
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
        B = torch.eye(dim, dtype=dtype, device="cuda")
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
Ruff's avatar
Ruff committed
1662
1663
        # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844


@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
@pytest.mark.deprecated
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
        A = torch.randn(size=(dim2, dim3), device="cuda")
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))


@pytest.mark.deprecated
def test_quantile_quantization():
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda")
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
        diff = torch.abs(A1 - A2).mean().item()
        assert diff < 0.0075

        A1 = torch.rand(1024, 1024, device="cuda")
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
        diff = torch.abs(A1 - A2).mean().item()
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


@pytest.mark.deprecated
def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda")
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
    print(sum(diffs) / len(diffs))
    print(sum(reldiffs) / len(reldiffs))

    for i in range(100):
        A1 = torch.rand(1024, 1024, device="cuda")
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
        diff = torch.abs(A1 - A2).mean().item()
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
@pytest.mark.deprecated
def test_percentile_clipping(gtype):
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
    n = 4
    step = 0
    percentile = 5
    for i in range(k):
        step += 1
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)


@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.deprecated
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
        elif dims == 3:
            A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
        # print(out1)
        # print(out2)

        torch.testing.assert_close(out1, out2)


@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
@pytest.mark.deprecated
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    if dims == 3 and orderOut != "col32":
        return
    if dtype == torch.int32 and orderOut != "col32":
        return
    try:
        func = F.get_transform_func(dtype, orderA, orderOut, transpose)
    except ValueError as ve:
        pytest.skip(str(ve))  # skip if not supported

    if dims == 2:
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
    elif dims == 3:
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)

    out, S = F.nvidia_transform(A, to_order=orderOut)

    if orderOut == "row":
        torch.testing.assert_close(A.flatten(), out.flatten())
    elif orderOut == "col":
        torch.testing.assert_close(A.t().flatten(), out.flatten())
    elif orderOut == "col32":
        if dims == 2:
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
        elif dims == 3:
            n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
        assert out.numel() == n
    elif orderOut == "col_turing":
        # 32 col 8 row tiles
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
                i = row * A.shape[1]
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
                rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
                offset = 32 * 8 * (rowtile + coltile)
                col2 = col % 32
                row2 = (row % 8) * 32

                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])

    if orderOut == "col32":
        out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
        torch.testing.assert_close(A, out2)