test_functional.py 85.7 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
import einops
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
9
import numpy as np
10
11

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
12
from bitsandbytes import functional as F
Tim Dettmers's avatar
Tim Dettmers committed
13
from scipy.stats import norm
Tim Dettmers's avatar
Tim Dettmers committed
14

15
torch.set_printoptions(
Tim Dettmers's avatar
Tim Dettmers committed
16
    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
17
)
Tim Dettmers's avatar
Tim Dettmers committed
18
19
k = 20

20

Tim Dettmers's avatar
Tim Dettmers committed
21
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
Tim Dettmers's avatar
Tim Dettmers committed
22
    idx = torch.isclose(a, b, rtol, atol)
23
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
24
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
25
26
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
27
            torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
28
29

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
30

31

Tim Dettmers's avatar
Tim Dettmers committed
32
33
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
34
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
38
39
40
41
42
43
44
45
46
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

47

48
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
52
53
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

54
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
55
56
57
58
59
60
61
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

62
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
63
64
65
66
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
67
68
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
69
70
71
72
73
74
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
75
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79

        return self.agg[name]

    def reset(self):
80
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
81
82
        self.ends = {}
        self.agg = {}
83
84
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
85

Tim Dettmers's avatar
Tim Dettmers committed
86
87
88
def setup():
    pass

89

Tim Dettmers's avatar
Tim Dettmers committed
90
91
92
def teardown():
    pass

93

94
95
96
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
97
def test_estimate_quantiles(dtype):
98
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
99
100
101
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

102
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
103
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
104

105
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
110
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
111
112
113
114
115
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
116
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
117
118
119
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
120
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
121
122
        assert diff < 0.0075

123
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
124
125
126
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
127
        diff = torch.abs(A1 - A2).mean().item()
128
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
129
130
131
        assert diff < 0.001


132

Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
137
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
138
139
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
140
141
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
142
143
144
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
145
146
    print(sum(diffs)/len(diffs))
    print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
147
148

    for i in range(100):
149
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
150
151
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
152
        diff = torch.abs(A1 - A2).mean().item()
153
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
        assert diff < 0.004


157

158
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
159
160
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
161
162
@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False'])
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
163
    #print('')
164
165
166
    diffs = []
    reldiffs = []
    for i in range(100):
167
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
168
169
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
170
171
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
172
173
174
175
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
176
177
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
178
179
    assert abserr < 0.011
    assert relerr < 0.018
180
    assert A2.dtype == dtype
181
182

    diffs = []
183
    code = F.create_dynamic_map(signed=signed)
184
    for i in range(100):
185
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
186
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
187
        A2 = F.dequantize_blockwise(C, S)
188
189
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
190
191
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
192
        #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
193
194
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
195
196
197
198
199
200
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
201
    assert A2.dtype == dtype
202
203
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
204

Tim Dettmers's avatar
Tim Dettmers committed
205
206


207
208
209
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
210
def test_percentile_clipping(gtype):
211
212
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
213
214
    n = 4
    step = 0
215
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
216
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
217
        step += 1
218
219
220
221
222
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
223
224
225
226
227
228
229
230
231
232

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

233
234
235
        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)
Tim Dettmers's avatar
Tim Dettmers committed
236
237


Tim Dettmers's avatar
Tim Dettmers committed
238
239
def quant(x):
    max1 = torch.abs(x).max()
240
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
241
242
    return max1, x.to(torch.int8)

243

Tim Dettmers's avatar
Tim Dettmers committed
244
def dequant(c, maxC):
245
246
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
247
248

def mm_dequant(maxA, maxB, C):
249
250
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
251
252
253

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
254
255
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
    return max1, x.to(torch.int8)

258

Tim Dettmers's avatar
Tim Dettmers committed
259
def quant_multi_chunk(x, dim, chunk_size=32):
260
261
262
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
263
264
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
265
266
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
270
271
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
272
273
    return max1, x.to(torch.int8)

274

Tim Dettmers's avatar
Tim Dettmers committed
275
276
277
278
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

279

Tim Dettmers's avatar
Tim Dettmers committed
280
def mean(xx):
281
282
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
283

284
285
286
287
288
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
289
290
291
292
293
294
295
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
296
]
Tim Dettmers's avatar
Tim Dettmers committed
297
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
298
299
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
300
batched = [False, True]
301
302
303
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
304
    "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
305
    for vals in values_names
306
307
308
]


309
310
311
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
312
313
314
315
316
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Tim Dettmers's avatar
Tim Dettmers committed
317
    #print("")
Tim Dettmers's avatar
Tim Dettmers committed
318
319
    for i in range(5):
        if batched:
320
321
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
325
326
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
327
328
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
329
        torch.testing.assert_close(
330
331
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
338
339
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
340
341
342
343
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
344
345
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
346
347
    #print(mean(errors))
    #print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
348
349


Tim Dettmers's avatar
Tim Dettmers committed
350
351
352
353
354
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
355
n = 2
356
357
358
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
359
transpose = [(False, False), (False, True), (True, False), (True, True)]
360
361
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
362
    "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
363
364
365
366
    for vals in values
]


367
368
369
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
370
371
372
373
374
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
375
        shapeA = (
376
377
378
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
379
380
381
382
383
384
385
386
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
387
388
389
390
391
392
393
394
395
396
397
398
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
399

400
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
401

Tim Dettmers's avatar
Tim Dettmers committed
402
403
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
404
405
406
407
408
409
410
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
411
412
413
414
415
416
417
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

418
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
419
420
421


n = 3
422
423
424
425
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
426
names = [
427
    "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
428
]
429
430


Tim Dettmers's avatar
Tim Dettmers committed
431
432
433
434
435
436
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
437
438
439
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
440
441
442
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
443
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
444
445
446
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
447
448
        out = F.igemm(A, B, out=iout)

449
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
450

451

Tim Dettmers's avatar
Tim Dettmers committed
452
n = 2
453
454
455
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
456
transpose = [False, True]
457
458
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
459
    "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
460
461
462
463
    for vals in values
]


464
465
466
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
467
468
469
470
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
471
472
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
473
474
475
476
477
478
479
480
481

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
482
483
484
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
485
        if transpose:
486
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
487
        else:
488
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
489
490
491
492
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
493
494
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
495
            out = out.float()
496
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
497
498
499
500
501
502

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
503
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
504
            out = F.igemm(Ac, Bc)
505
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
506
            out = out.float()
507
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
508
509
510
511
512
513
514
515
516
517

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

518
519
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
520

521
522
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
523
524
525
526
527

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
528
529
530
531
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
532
533
534
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

535

Tim Dettmers's avatar
Tim Dettmers committed
536
n = 2
537
538
539
540
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
541
transpose = [(False, False), (True, False), (False, True), (True, True)]
542
543
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
544
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
545
    for vals in values
546
547
548
]


Tim Dettmers's avatar
Tim Dettmers committed
549
550
551
552
553
554
555
556
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
557
558
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
559
560
561
562
563
564
565
566
567
568
569

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
570
571
572
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
573
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
574
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
575

576

Tim Dettmers's avatar
Tim Dettmers committed
577
n = 1
578
579
580
581
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
582
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
583
584


Tim Dettmers's avatar
Tim Dettmers committed
585
586
587
588
589
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
590
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
591
592
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
593
594
595
596
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Tim Dettmers's avatar
Tim Dettmers committed
597
598
599


n = 2
600
601
602
603
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
604
dtype = [torch.int8, torch.int32]
605
606
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
607
608
transpose = [False]
dims = [2, 3]
609
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
610

611
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
612

Tim Dettmers's avatar
Tim Dettmers committed
613

614
615
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
616
617
618
619
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
620
621
622
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
623
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
624
    elif dims == 3:
625
626
627
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
628
629
630

    out, S = F.nvidia_transform(A, to_order=orderOut)

631
    if orderOut == "row":
632
        torch.testing.assert_close(A.flatten(), out.flatten())
633
    elif orderOut == "col":
634
        torch.testing.assert_close(A.t().flatten(), out.flatten())
635
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
636
        if dims == 2:
637
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
638
        elif dims == 3:
639
640
641
642
643
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
644
        assert out.numel() == n
645
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
646
        # 32 col 8 row tiles
647
648
649
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
650
651
652
653
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
654
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
655
656
657
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
658
659
660
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
661
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
662
                col2 = col % 32
663
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
664

665
666
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
667
668
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
669

670
    if orderOut == "col32":
671
672
673
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
674
        torch.testing.assert_close(A, out2)
Tim Dettmers's avatar
Tim Dettmers committed
675
676
677


n = 1
678
679
680
681
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
682

683
684
685
686
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
687

688
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
689
ldb = [0]
690
691
692
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
693
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
694
695
696
697
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
698
699
700
701
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
702
703
704
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
705
        elif dims == 3:
706
707
708
709
710
711
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
712
713
        C1 = torch.matmul(A.float(), B.t().float())

714
715
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
716
        C2, SC = F.igemmlt(A2, B2, SA, SB)
717
        C3, S = F.nvidia_transform(C2, "row", state=SC)
718
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
719
720

        # transpose
721
722
723
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
724
725
        C1 = torch.matmul(A.float(), B.float())

726
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
727
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
728
        C3, S = F.nvidia_transform(C2, "row", state=SC)
729
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
730

731

Tim Dettmers's avatar
Tim Dettmers committed
732
733
734
735
736
737
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
738
739
740
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
741
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
742
    for vals in values
743
744
745
]


Tim Dettmers's avatar
Tim Dettmers committed
746
747
748
749
750
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
751
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
752
        elif dims == 3:
753
754
755
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
756
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
757
758
759
760
761
762
763
764
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
765
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
766
767
768
769
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

770
771
772
773
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
774

775
        # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
776
777

        # transpose
778
779
780
781
782
783
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
784
        # torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
785
786
787
788


batch_size = 2
seqdim = 512
789
790
791
792
793
794
795
796
797
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
798
names = [
799
    "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
800
]
Tim Dettmers's avatar
Tim Dettmers committed
801
802
803
804
805


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
806
807
808
809
810
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
811

812
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
813
    ## warmup
814
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
815
    #    torch.matmul(A, w1.t())
816
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
817
818
819
820
821
822
823
824

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

825
826
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
827

828
829
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
830

831
832
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
833
834
835
836
837

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

838
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
839

840
841
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
842

843
844
845
846
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
847

848
849
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
850
    ## fc1
851
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
852
853
854
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
855
856
857
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
858
859
860
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
861
862
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
863
864
865
866
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
867
868
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
869
870
871
872
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
873
874
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
875
876
877
878
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
879
880
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
881
882
883
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

884
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
885

886
887
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
888

889
890
891
892
893
894
895
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

948
949
950
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
951
952
953


n = 2
954
955
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
956

957
958
#dim1 = [2*1024]
#dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
959

Tim Dettmers's avatar
Tim Dettmers committed
960
961
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
962
963

dims = (2,)
964
formatB = ["col_turing", "col_ampere"]
965
966
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
967
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
968
969


970
971
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
972
    inner = torch.randint(1, 128, size=(1,)).item()
973
974
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
975
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
976
    for i in range(1):
977
978
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
979
        C1 = torch.matmul(A.half(), B.t().half())
980
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
981
982
983
984

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

985
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
986
987
988
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

989
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
990
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
991
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
992

993
994
995
996
997
998
999
        # TODO: is something wrong here? If so, the problem goes deeper
        #n = C1.numel()
        #p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
Tim Dettmers's avatar
Tim Dettmers committed
1000
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
1001

1002
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
1003
        #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
1004
1005
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
Tim Dettmers's avatar
Tim Dettmers committed
1006
1007
1008


n = 2
1009
1010
1011
1012
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1013
1014

dims = (2,)
1015
1016
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
1017
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
1018
1019


Tim Dettmers's avatar
Tim Dettmers committed
1020
1021
1022
1023
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1024
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1051
1052
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

1053
1054
1055
        torch.testing.assert_close(col_stats1_trunc, col_stats2)
        torch.testing.assert_close(row_stats1_trunc, row_stats2)
        torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
Tim Dettmers's avatar
Tim Dettmers committed
1056

1057
1058
1059
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1060

1061
1062
        torch.testing.assert_close(col_stats1, col_stats2)
        torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
1063
1064
1065
1066
        assert nnz_block_ptr2 is None


n = 2
1067
1068
1069
1070
1071
1072
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
1073
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1074

Tim Dettmers's avatar
Tim Dettmers committed
1075
1076
1077
1078

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1079
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1080
1081
1082
1083
1084
1085
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
1086
1087
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
1088
1089

        n = CAt.numel()
1090
1091
1092
1093
1094
1095
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1096
1097

        # allow for 1:500 error due to rounding differences
1098
1099
1100
1101
1102
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1103
            assert False
1104
1105
1106
1107
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1108
1109
            assert False

1110
1111
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
1112
1113
1114


n = 4
1115
1116
1117
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1118
1119

values = list(zip(dim1, dim4, inner))
1120
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1121
1122


Tim Dettmers's avatar
Tim Dettmers committed
1123
1124
1125
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1126
1127
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1128
1129
1130
1131
1132
1133
1134
1135

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

1136
1137
1138
1139
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
1140

1141
1142
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1143
1144
1145
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1146
1147
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1148
1149
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1150
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1151
1152
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1153
1154
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1155
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1156
1157
1158


n = 6
1159
1160
1161
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1162
1163

values = list(zip(dim1, dim4, inner))
1164
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1165
1166


Tim Dettmers's avatar
Tim Dettmers committed
1167
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1168
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1169
1170
1171
1172
1173
1174
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1175
1176
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1177
1178
1179
1180
1181
1182
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1183
1184
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1185
1186
1187
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1188
1189
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1190
1191
1192
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1193
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1194
1195
1196
1197
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1198
1199
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1200
1201
1202
1203
1204
1205
1206
1207

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1208
1209
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1210
1211

        C = torch.matmul(CA.float(), CB.t().float())
1212
1213
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1214

1215
1216
1217
1218
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1219

1220
1221
1222
1223
1224
1225
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1226

1227
1228
1229
1230
1231
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1232
1233
1234


dim1 = [1024, 2048]
1235
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1236
1237
1238
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1239
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1240
1241


Tim Dettmers's avatar
Tim Dettmers committed
1242
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1243
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1244
1245
1246
1247
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1248
1249
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1260
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1261
1262

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1263
1264
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1265
1266
1267
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1268
1269
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1270
1271
1272
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1273
1274
1275
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1276
    torch.cuda.synchronize()
1277
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1278
1279
1280
1281
1282
1283
1284
1285

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1286
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1287
1288
1289


n = 2
1290
1291
1292
1293
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1294
1295
1296

dim3 = [0]
dtype = [torch.int8]
1297
1298
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1299
1300
transpose = [False, True]
dims = [2]
1301
1302
1303
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1304
names = [
1305
    "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
1306
1307
1308
1309
1310
1311
1312
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1313
1314
1315
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1316
)
Tim Dettmers's avatar
Tim Dettmers committed
1317
1318
1319
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1320
1321
1322
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1323
        elif dims == 3:
1324
1325
1326
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1338
1339
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1340

1341
        torch.testing.assert_close(out1, out2)
Tim Dettmers's avatar
Tim Dettmers committed
1342

1343

Tim Dettmers's avatar
Tim Dettmers committed
1344
n = 2
1345
1346
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1347
1348
1349
1350
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1351
1352
1353
1354
1355
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
1356
    "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
1357
1358
1359
1360
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1361
1362
def test_overflow():
    formatB = F.get_special_format_str()
1363
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1364
    for i in range(2):
1365
1366
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1367

1368
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1369
1370
1371
1372
1373
1374
1375
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1376
1377
1378
1379
1380
1381
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
1382
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1383

Tim Dettmers's avatar
Tim Dettmers committed
1384
1385
1386
1387
1388

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1389
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1390

1391
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1392
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1393
1394
1395
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1396
1397

        if coo_tensor is not None:
1398
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1399
            A2 = torch.zeros_like(A)
1400
1401
1402
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
1403
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1404

1405
1406
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1407
            torch.testing.assert_close(
1408
1409
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1410

Tim Dettmers's avatar
Tim Dettmers committed
1411
1412

n = 2
1413
1414
1415
1416
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1417
transposed_B = [False, True]
1418
values = list(product(dim1, dim2, transposed_B))
1419
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
1420
1421


Tim Dettmers's avatar
Tim Dettmers committed
1422
1423
1424
1425
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1426
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1438
1439
1440
1441
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1455
1456
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1457
    seq = 1024
1458
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1459
1460
1461
    dim2 = model
    dim3 = hidden
    threshold = 4
1462
1463
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1464
    for i in range(10):
1465
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1466
1467
1468
1469

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1470
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1471
    torch.cuda.synchronize()
1472
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1473
1474
1475

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1476
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1477
1478
    rows, cols = torch.where(idx)
    values = A[idx]
1479
1480
1481
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1482
1483

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1484
1485
1486
1487
1488
1489
1490
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1491
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1492
    print(tsp, t8)
1493
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1494
1495
1496


n = 2
1497
1498
1499
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
1500
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1501
1502


Tim Dettmers's avatar
Tim Dettmers committed
1503
1504
1505
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1506
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1507
1508
1509
1510
1511
1512
1513
1514
1515
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1516
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1517
1518
1519
1520

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1521
1522
1523
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1524
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1525
1526
1527
1528
1529
1530
1531
1532
1533

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1534
1535
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1536
1537
1538
1539
        assert err2 < err1


def test_matmuls():
1540
1541
1542
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1543
    c2 = bnb.matmul(a, b)
1544
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1545

1546
1547
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1548
1549
    assert err1 < 0.2
    assert err2 < 0.2
1550
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1551
1552
1553


n = 2
1554
1555
1556
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1557
dim2 = [12288]
1558
1559
1560
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1561
dtype = [torch.float16]
1562
1563
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1564
names = [
1565
    "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
1566
]
1567
1568


Tim Dettmers's avatar
Tim Dettmers committed
1569
1570
1571
1572
1573
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1574
1575
1576
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1577
    if dtype == torch.float16:
1578
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1579
1580
        torch.nn.init.xavier_uniform_(B)
    else:
1581
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1582
        torch.nn.init.xavier_uniform_(B)
1583
1584
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1585

1586
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1587
1588
1589
1590
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1591
1592
1593
1594
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1595
1596
1597
1598
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1599
1600
1601
1602
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1603
    n = out1.numel()
1604
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1605
1606
1607
    std = out1.std()
    out1 /= std
    out2 /= std
1608
1609
1610
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1611
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1612
1613
1614

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1615
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1616

1617
1618
1619
1620
1621
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1622
1623
1624
1625
1626
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1627
1628
1629
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1630
1631
1632
1633
1634
1635
1636
1637

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1638
1639
1640
1641
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1642
1643
1644
1645
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1646
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
1647
    idx = A2 != 0
1648
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1649
1650
1651
1652
1653
1654
1655
1656
1657


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1658
1659
1660
1661
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1662
1663
1664
1665
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1666
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1667
    # torch uses row-major -> use transpose to transfer to col-major
1668
    idx = A2.t() != 0
1669
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1670
1671
1672


n = 2
1673
1674
1675
1676
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1677
dim2 = [2048]
1678
1679
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1680
dtype = [torch.int8]
1681
values = list(product(dim1, dim2, dtype))
1682
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
1683
1684


Tim Dettmers's avatar
Tim Dettmers committed
1685
1686
1687
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1688
1689
1690
1691
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1705
1706
1707
1708
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1709
1710
1711
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1712
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1713
1714
1715
1716
1717
1718

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

1719
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1720

1721
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1722
    n = out1.numel()
1723
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1724
1725
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1726
1727
1728
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1729
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1730
1731
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1732
1733
1734
1735

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1736
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1737
    torch.cuda.synchronize()
1738
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1739
1740
1741
1742

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1743
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1744
    torch.cuda.synchronize()
1745
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1746
1747
1748
1749

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1750
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1751
    torch.cuda.synchronize()
1752
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1753
1754
1755
1756

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1757
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1758
    torch.cuda.synchronize()
1759
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1760
1761
1762
1763
1764
1765

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1766
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1767
    torch.cuda.synchronize()
1768
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1769
1770
1771
1772
1773
1774
1775

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1776
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1777
1778
1779
1780
1781
1782

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1783
1784
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1785

1786
batch_size = 1
1787
seqdim = 1
Tim Dettmers's avatar
Tim Dettmers committed
1788
values = []
Tim Dettmers's avatar
Tim Dettmers committed
1789
#values.append((batch_size, seqdim, 768, 4 * 768))
1790
1791
1792
1793
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
1794
1795
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
1796
1797
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
1798
#values.append((batch_size, seqdim, 5140, 4*5140))
1799
#values.append((batch_size, seqdim, 12288, 4*12288))
1800
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
1801
1802
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1803
    iters = 1000
Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
    formatB = F.get_special_format_str()

1806
1807
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1808
1809
    torch.nn.init.xavier_uniform_(B)

1810
    B_fp4, state = F.quantize_fp4(B)
1811
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1812

1813
    B_nf4, state_nf4 = F.quantize_nf4(B)
1814
    B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
1815

Tim Dettmers's avatar
Tim Dettmers committed
1816
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
1817
1818
1819
1820
1821
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

Tim Dettmers's avatar
Tim Dettmers committed
1822
1823
    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
    #linearMixedBit.eval()
Tim Dettmers's avatar
Tim Dettmers committed
1824

1825
1826
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
1827
    bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1828

Tim Dettmers's avatar
Tim Dettmers committed
1829
    # warmup
1830
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1831
1832
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1833
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1834
1835
1836

    torch.cuda.synchronize()
    t0 = time.time()
1837
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1838
1839
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1840
1841
    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1842
1843
1844
1845
1846
1847
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
    #torch.cuda.synchronize()
    #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1848

1849
1850
1851
1852
1853
1854
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
    #torch.cuda.synchronize()
    #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
1855

1856
1857
1858
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1859
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1860
1861
1862
    torch.cuda.synchronize()
    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1863
1864
1865
1866
1867
1868
1869
1870
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
    torch.cuda.synchronize()
    print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )


Tim Dettmers's avatar
Tim Dettmers committed
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B, threshold=6.0)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    #C32A, SA = F.transform(CA, "col32")
    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    #CxB, SB = F.transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #torch.cuda.synchronize()
    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1)
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    #torch.cuda.synchronize()
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
    #torch.cuda.synchronize()
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

1924
1925
1926
1927
1928
1929
1930
    #linear8bit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1931

1932
1933
1934
1935
1936
1937
1938
    #linearMixedBit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linearMixedBit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954

    #linear8bit_train(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train_thresh(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1955
1956
1957
1958
1959
1960

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1961
1962
1963
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1964
        minx = x.min()
1965
1966
1967
1968
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1969
        return x, qx, zpx
1970

Tim Dettmers's avatar
Tim Dettmers committed
1971
1972
1973
    batch = 2
    seq = 512
    model = 1024
1974
1975
1976
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1977
1978
1979

    C0 = torch.matmul(A, B)

1980
1981
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1982
1983
1984
1985
1986
1987
1988
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1989
1990
1991
1992
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1993
1994
1995

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
1996
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1997
1998
1999

    zp = 1
    scale = 2.0
2000
2001
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2002
2003
2004
2005
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
2006
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
2007
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
2008

Tim Dettmers's avatar
Tim Dettmers committed
2009
2010
2011
2012
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2013
2014
2015
2016
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2017

Tim Dettmers's avatar
Tim Dettmers committed
2018
2019
2020
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2021
2022
2023
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2024

2025
2026
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
2027
2028
2029
2030
2031
2032
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2033
2034
2035
2036
2037
2038
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2039
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2040
2041


2042
def test_extract_outliers():
2043
    for i in range(k):
2044
        shapeA = (4096, 4096 * 4)
2045
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2046
2047
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2048
        outliers1 = A[:, idx.long()]
2049

2050
        CA, SA = F.transform(A, "col_turing")
2051

2052
        outliers2 = F.extract_outliers(CA, SA, idx)
2053

2054
2055
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2056

2057
        torch.testing.assert_close(outliers1, outliers2)
2058

2059
        CA, SA = F.transform(A, "col_ampere")
2060
2061
2062
2063
2064

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2065

2066
        torch.testing.assert_close(outliers1, outliers2)
2067
2068
2069
2070
2071
2072
2073
2074



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2075
    for hidden in [128]:#, 14336]:
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108



def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7-e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2109
2110
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2123
2124
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2137
2138
        #print(3, sum(abserr)/len(abserr))
        #print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2139

2140
2141
2142

def test_few_bit_quant():

2143
    #print('')
2144
    for bits in range(2, 9):
2145
        #print('='*30, bits, '='*30)
Tim Dettmers's avatar
Tim Dettmers committed
2146
2147
2148
        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
2149
2150
            code = None
            if method == 'linear':
2151
                code = F.create_linear_map(True, total_bits=bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2152
2153
2154
2155
            elif method == 'fp8':
                ebits = math.ceil(bits/2)
                pbits = bits-ebits-1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2156
2157
2158
2159
            elif method == 'dynamic':
                code = F.create_dynamic_map(True, bits-0, bits).cuda()
            elif method == 'quantile':
                values = torch.randn(2048, 2048, device='cuda')
Tim Dettmers's avatar
Tim Dettmers committed
2160
2161
2162
2163
2164
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
2165
            #print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
            assert code.numel() == 256
            for i in range(10):

                values = torch.randn(1, 32, device='cuda')
                values /= values.abs().max()
                #values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v-code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
2183
2184
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
2185
2186

                idx = torch.isclose(q1.int(), q2.int())
Tim Dettmers's avatar
Tim Dettmers committed
2187
2188
2189
                err2 = torch.abs(v2-values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2/(1e-10+values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2190
2191
2192
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1-values).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2193
                    #assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
2194
2195

                else:
2196
                    torch.testing.assert_close(q1, q2)
2197
            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
Tim Dettmers's avatar
Tim Dettmers committed
2198
    #assert False
Tim Dettmers's avatar
Tim Dettmers committed
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 4):
            total_values = 2**bits-1
            p = np.linspace(0, 1, 2*total_values+1)
            idx = np.arange(1, 2*total_values+1, 2)
            p = p[idx]
            offset = 1/(2*total_values)
            p = np.linspace(offset, 1-offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2223
            assert err < 0.035
2224
2225
2226
2227


def test_bench_dequantization():
    a = torch.rand(1024, 1024, device='cuda').half()
2228
2229
2230
    code =F.create_fp8_map(True, 3, 0, 4).cuda()
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
2231
2232
2233
2234
2235
2236
2237

    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
    #print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
2238
        qa, SA = F.quantize_blockwise(a)
2239
2240
2241
    torch.cuda.synchronize()
    #print((time.time()-t0)/1e6)

2242
2243


2244
2245
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_fp4_quant(dtype):
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign*8 + e1*4 + e2*2 + p1*1
        sign = -1.0 if sign else 1.0
        exp = e1*2 + e2*1
        if exp == 0:
            # sub-normal
            if p1 == 0: result = 0
            else: result = sign*0.0625
        else:
            # normal
            exp = 2**(-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign*exp*frac
        code[idx] = result

2267
    A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
2268
2269
2270
2271
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
2272
    relerr = (err/(A1.abs().float()+1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2273
    idx = err > 1.0
2274
2275
    err = err.mean()

2276
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
2277
2278
    assert err.item() < 0.1
    assert relerr.item() < 0.28
2279
2280


Tim Dettmers's avatar
Tim Dettmers committed
2281
2282
2283
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
2284
2285
2286
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
2287
        for i in range(10):
2288
            A1 = torch.randn(1024, 1024, device='cuda').half()
2289
2290
2291
2292
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
2293
2294
2295
2296
2297
2298


            err = (A1 - A2).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2299
2300
            errs1.append(err.item())

2301
2302
2303
2304
2305
2306
2307
2308

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2309
            errs2.append(err.item())
2310
2311
2312
2313

            assert err.item() < 0.11
            assert relerr.item() < 0.28

2314
2315
        #print(sum(errs1)/len(errs1), blocksize, quant_type)
        #print(sum(errs2)/len(errs2), blocksize, quant_type)
2316
2317
2318
2319




Tim Dettmers's avatar
Tim Dettmers committed
2320
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
Tim Dettmers's avatar
Tim Dettmers committed
2321
2322
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
2323
def test_bench_4bit_dequant(quant_type):
2324
2325
    blocksize = 256
    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
2326
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2327
2328
2329
2330
2331
2332

    input_size = a.numel()/2
    output_size = a.numel()*2
    num_bytes = input_size+output_size
    GB = num_bytes/1e9
    max_theoretical_s =  GB/768
2333
    #print(max_theoretical_s*1e6)
2334
2335
    b = torch.randn(128, 1024*12, device='cuda').half()

Tim Dettmers's avatar
Tim Dettmers committed
2336
    iters = 100
2337
2338
2339
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2340
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
2341
2342
        #b.copy_(a)
    torch.cuda.synchronize()
2343
2344
2345
2346
2347
2348
2349
2350
    #print((time.time()-t0)/iters*1e6)

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    torch.matmul(b, a.t())
    #torch.cuda.synchronize()
    #print((time.time()-t0)/iters*1e6)
2351
2352
2353
2354
2355
2356
2357



def test_normal_map_tree():
    code = F.create_normal_map()
    values =code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
Tim Dettmers's avatar
Tim Dettmers committed
2358
    print(values)
2359
2360
2361
2362
2363
2364
2365
2366
2367
    while num_pivots <16:
        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
        print(idx)
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i-1]+values[i])/2)
        print(pivots)

Tim Dettmers's avatar
Tim Dettmers committed
2368

2369
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
2370
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
2371
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
2372
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2373
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
2374
    for dim in [128, 256, 512, 1024]:
2375
    #for dim in [4*1024]:
Tim Dettmers's avatar
Tim Dettmers committed
2376
    #for dim in [1*16]:
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

2387

2388
        for i in range(100):
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
            if kind == 'fc1':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'fc2':
                A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn_packed':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
Tim Dettmers's avatar
Tim Dettmers committed
2401

2402
            qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
2403
            C3 = torch.matmul(A, B.t())
2404
            C2 = F.gemv_4bit(A, qB.t(), quant_state=state)
2405
2406
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
2407

2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
            err1 = (C1-C2).abs().float()
            err2 = (C3-C2).abs().float()
            err3 = (C3-C1).abs().float()

            mag1 = torch.abs(C1).float()+1e-5
            mag2 = torch.abs(C3).float()+1e-5
            mag3 = torch.abs(C3).float()+1e-5

            relerr1 = err1/mag1
            relerr2 = err2/mag2
            relerr3 = err3/mag3
2419

2420
2421
2422
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
2423

2424
2425
2426
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2427

2428
2429
2430
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
2431

2432
2433
2434
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
2435
2436

            c = int(C1.numel()*0.0014*(dim/256))+1
Tim Dettmers's avatar
Tim Dettmers committed
2437

2438
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
        err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
        err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
        err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
        relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
        relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
        relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
        maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
        maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
        maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
        absratio = err2/err3
        relratio = relerr2/relerr3
        maxratio = relerr2/relerr3

        # for debugging if the tests fails
        #
        #print('='*80)
        #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
Tim Dettmers's avatar
Tim Dettmers committed
2456
2457
2458
2459
2460
        print(C1.flatten()[-20:])
        print(C2.flatten()[-20:])
        print(f'inference vs training abs: {err1}')
        print(f'inference vs training rel: {relerr1}')
        print(f'inference vs training max: {maxerr1}')
2461
2462
2463
        #print(f'inference vs training vs torch err ratio abs: {absratio}')
        #print(f'inference vs training vs torch err ratio rel: {relratio}')
        #print(f'inference vs training vs torch err ratio max: {maxratio}')
2464
        if dtype == torch.float16:
2465
2466
2467
2468
2469
2470
2471
2472
2473
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2474
        elif dtype == torch.float32:
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2486
        elif dtype == torch.bfloat16:
2487
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
2488
                assert err1 < 6e-4
2489
2490
2491
2492
2493
2494
2495
2496
2497
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
2498

2499
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
def test_managed():
    n = 32*10
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
    assert A.page_deviceid==0
    assert B.page_deviceid==0
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
    assert (A==17).sum().item() == n*n
    assert (B==17).sum().item() == n*n
    C = A*B.float()
    assert (C==289).sum().item() == n*n
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
    assert (A==17*(2**3)).sum().item() == n*n
   # F.prefetch_tensor(A)
   # F.prefetch_tensor(B)


   # F.fill(B2, 17.0)
   # F._mul(A, B2)

   # F.prefetch_tensor(A, to_cpu=True)
   # F.prefetch_tensor(B, to_cpu=True)
   # F.prefetch_tensor(B2, to_cpu=True)
   # torch.cuda.synchronize()

   # assert (A==17).sum().item() == n*n

2534
   # torch.testing.assert_close(A, torch.ones(A.shape)*289)
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562


@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
    dims = torch.randint(0, 8192, size=(dims,)).tolist()
    dims = [dim + (64-(dim % 64)) for dim in dims]
    #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
    for dim in dims:
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
        B = torch.eye(dim, dtype=dtype, device='cuda')

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
        #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)