test_functional.py 80.6 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
import einops
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
9
import numpy as np
10
11

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
12
from bitsandbytes import functional as F
Tim Dettmers's avatar
Tim Dettmers committed
13
from scipy.stats import norm
Tim Dettmers's avatar
Tim Dettmers committed
14

15
torch.set_printoptions(
Tim Dettmers's avatar
Tim Dettmers committed
16
    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
17
)
Tim Dettmers's avatar
Tim Dettmers committed
18
19
k = 20

20

21
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
Tim Dettmers's avatar
Tim Dettmers committed
22
    idx = torch.isclose(a, b, rtol, atol)
23
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
24
    if sumval > count:
25
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
26
27
        torch.testing.assert_allclose(a, b, rtol, atol)

28

Tim Dettmers's avatar
Tim Dettmers committed
29
30
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
31
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
32
33
34
35
36
37
38
39
40
41
42
43
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

44

45
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
46
47
48
49
50
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

51
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
52
53
54
55
56
57
58
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

59
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
60
61
62
63
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
64
65
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
66
67
68
69
70
71
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
72
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
73
74
75
76

        return self.agg[name]

    def reset(self):
77
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
78
79
        self.ends = {}
        self.agg = {}
80
81
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
82

Tim Dettmers's avatar
Tim Dettmers committed
83
84
85
def setup():
    pass

86

Tim Dettmers's avatar
Tim Dettmers committed
87
88
89
def teardown():
    pass

90

91
92
93
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
94
def test_estimate_quantiles(dtype):
95
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
96
97
98
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

99
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
100
101
    torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)

102
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
103
104
105
106
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
107
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
108
109
110
111
112
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
113
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
114
115
116
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
117
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
118
119
        assert diff < 0.0075

120
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
121
122
123
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
124
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
125
126
127
128
129
130
131
132
        torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
133
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
134
135
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
136
137
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
138
139
140
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
141
142
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
143
144

    for i in range(100):
145
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
146
147
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
148
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
149
150
151
152
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


153
154
155
156

@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def test_dynamic_blockwise_quantization(nested, blocksize):
157
    #print('')
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    diffs = []
    reldiffs = []
    for i in range(100):
        A1 = torch.randn(1024, 1024, device="cuda")
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
    assert abserr < 0.011
    assert relerr < 0.018
    print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
    print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))

    diffs = []
    for i in range(100):
        A1 = torch.rand(1024, 1024, device="cuda")
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
    assert abserr < 0.0035
    assert relerr < 0.015
    print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
191

Tim Dettmers's avatar
Tim Dettmers committed
192
193
194
195
196
197

def test_dynamic_blockwise_stochastic_quantization():
    diffs = []
    reldiffs = []
    rand = torch.rand(1024).cuda()
    for i in range(100):
198
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
199
200
201
202
        C1, S1 = F.quantize_blockwise(A1, rand=rand)
        C2, S2 = F.quantize_blockwise(A1)
        # a maximunm distance of quantized values of 1
        torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
203
204
205
206
207
        fraction_smaller = (C1 < C2).float().sum() / C1.numel()
        fraction_larger = (C1 > C2).float().sum() / C1.numel()
        torch.testing.assert_allclose(
            fraction_larger, fraction_smaller, atol=0.01, rtol=0
        )
Tim Dettmers's avatar
Tim Dettmers committed
208
209


210
211
212
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
213
def test_percentile_clipping(gtype):
214
215
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
216
217
    n = 4
    step = 0
218
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
219
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
220
        step += 1
221
222
223
224
225
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_allclose(clip1, clip2)
        torch.testing.assert_allclose(gnorm1, gnorm2)


Tim Dettmers's avatar
Tim Dettmers committed
241
242
def quant(x):
    max1 = torch.abs(x).max()
243
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
244
245
    return max1, x.to(torch.int8)

246

Tim Dettmers's avatar
Tim Dettmers committed
247
def dequant(c, maxC):
248
249
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
250
251

def mm_dequant(maxA, maxB, C):
252
253
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
254
255
256

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
257
258
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
259
260
    return max1, x.to(torch.int8)

261

Tim Dettmers's avatar
Tim Dettmers committed
262
def quant_multi_chunk(x, dim, chunk_size=32):
263
264
265
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
266
267
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
268
269
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
273
274
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
275
276
    return max1, x.to(torch.int8)

277

Tim Dettmers's avatar
Tim Dettmers committed
278
279
280
281
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

282

Tim Dettmers's avatar
Tim Dettmers committed
283
def mean(xx):
284
285
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
286

287
288
289
290
291
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
292
293
294
295
296
297
298
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
299
]
Tim Dettmers's avatar
Tim Dettmers committed
300
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
301
302
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
303
batched = [False, True]
304
305
306
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
307
    "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
308
    for vals in values_names
309
310
311
]


312
313
314
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
315
316
317
318
319
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
320
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
321
322
    for i in range(5):
        if batched:
323
324
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
328
329
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
330
331
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
332
333
334
        torch.testing.assert_allclose(
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
335
336
337
338
339
340
341
342
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
343
344
345
346
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
347
348
349
350
351
352
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    print(mean(errors))
    print(mean(relerrors))


Tim Dettmers's avatar
Tim Dettmers committed
353
354
355
356
357
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
358
n = 2
359
360
361
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
362
transpose = [(False, False), (False, True), (True, False), (True, True)]
363
364
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
365
    "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
366
367
368
369
    for vals in values
]


370
371
372
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
373
374
375
376
377
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
378
        shapeA = (
379
380
381
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
382
383
384
385
386
387
388
389
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
390
391
392
393
394
395
396
397
398
399
400
401
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
402

Tim Dettmers's avatar
Tim Dettmers committed
403
        torch.testing.assert_allclose(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
404

Tim Dettmers's avatar
Tim Dettmers committed
405
406
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
407
408
409
410
411
412
413
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
414
415
416
417
418
419
420
421
422
423
424
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_allclose(out.float(), out2)


n = 3
425
426
427
428
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
429
names = [
430
    "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
431
]
432
433


Tim Dettmers's avatar
Tim Dettmers committed
434
435
436
437
438
439
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
440
441
442
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
443
444
445
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
446
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
447
448
449
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
450
451
452
453
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_allclose(out.float(), out2)

454

Tim Dettmers's avatar
Tim Dettmers committed
455
n = 2
456
457
458
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
459
transpose = [False, True]
460
461
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
462
    "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
463
464
465
466
    for vals in values
]


467
468
469
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
470
471
472
473
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
474
475
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
476
477
478
479
480
481
482
483
484

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
485
486
487
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
488
        if transpose:
489
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
490
        else:
491
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
492
493
494
495
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
496
497
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
498
            out = out.float()
499
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
500
501
502
503
504
505

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
506
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
507
            out = F.igemm(Ac, Bc)
508
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
509
            out = out.float()
510
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
514
515
516
517
518
519
520

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

521
522
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
523

524
525
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
526
527
528
529
530

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
531
532
533
534
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
535
536
537
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

538

Tim Dettmers's avatar
Tim Dettmers committed
539
n = 2
540
541
542
543
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
544
transpose = [(False, False), (True, False), (False, True), (True, True)]
545
546
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
547
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
548
    for vals in values
549
550
551
]


Tim Dettmers's avatar
Tim Dettmers committed
552
553
554
555
556
557
558
559
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
560
561
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
562
563
564
565
566
567
568
569
570
571
572

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
573
574
575
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
576
577
578
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_allclose(out.float(), out2.float())

579

Tim Dettmers's avatar
Tim Dettmers committed
580
n = 1
581
582
583
584
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
585
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
586
587


Tim Dettmers's avatar
Tim Dettmers committed
588
589
590
591
592
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
593
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
594
595
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
596
597
598
599
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Tim Dettmers's avatar
Tim Dettmers committed
600
601
602


n = 2
603
604
605
606
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
607
dtype = [torch.int8, torch.int32]
608
609
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
610
611
transpose = [False]
dims = [2, 3]
612
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
613

614
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
615

Tim Dettmers's avatar
Tim Dettmers committed
616

617
618
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
619
620
621
622
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
623
624
625
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
626
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
627
    elif dims == 3:
628
629
630
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
631
632
633

    out, S = F.nvidia_transform(A, to_order=orderOut)

634
    if orderOut == "row":
Tim Dettmers's avatar
Tim Dettmers committed
635
        torch.testing.assert_allclose(A.flatten(), out.flatten())
636
    elif orderOut == "col":
Tim Dettmers's avatar
Tim Dettmers committed
637
        torch.testing.assert_allclose(A.t().flatten(), out.flatten())
638
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
639
        if dims == 2:
640
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
641
        elif dims == 3:
642
643
644
645
646
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
647
        assert out.numel() == n
648
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
649
        # 32 col 8 row tiles
650
651
652
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
653
654
655
656
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
657
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
658
659
660
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
661
662
663
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
664
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
665
                col2 = col % 32
666
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
667

668
669
670
671
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
672

673
    if orderOut == "col32":
674
675
676
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
Tim Dettmers's avatar
Tim Dettmers committed
677
678
679
680
        torch.testing.assert_allclose(A, out2)


n = 1
681
682
683
684
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
685

686
687
688
689
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
690

691
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
692
ldb = [0]
693
694
695
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
696
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
697
698
699
700
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
701
702
703
704
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
705
706
707
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
708
        elif dims == 3:
709
710
711
712
713
714
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
715
716
        C1 = torch.matmul(A.float(), B.t().float())

717
718
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
719
        C2, SC = F.igemmlt(A2, B2, SA, SB)
720
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
721
722
723
        torch.testing.assert_allclose(C1, C3.float())

        # transpose
724
725
726
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
727
728
        C1 = torch.matmul(A.float(), B.float())

729
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
730
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
731
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
732
733
        torch.testing.assert_allclose(C1, C3.float())

734

Tim Dettmers's avatar
Tim Dettmers committed
735
736
737
738
739
740
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
741
742
743
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
744
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
745
    for vals in values
746
747
748
]


Tim Dettmers's avatar
Tim Dettmers committed
749
750
751
752
753
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
754
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
755
        elif dims == 3:
756
757
758
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
759
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
760
761
762
763
764
765
766
767
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
768
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
769
770
771
772
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

773
774
775
776
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
777

778
        # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
779
780

        # transpose
781
782
783
784
785
786
787
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
788
789
790
791


batch_size = 2
seqdim = 512
792
793
794
795
796
797
798
799
800
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
801
names = [
802
    "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
803
]
Tim Dettmers's avatar
Tim Dettmers committed
804
805
806
807
808


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
809
810
811
812
813
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
814

815
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
816
    ## warmup
817
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
818
    #    torch.matmul(A, w1.t())
819
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
820
821
822
823
824
825
826
827

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

828
829
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
830

831
832
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
833

834
835
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
836
837
838
839
840

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

841
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
842

843
844
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
845

846
847
848
849
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
850

851
852
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
853
    ## fc1
854
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
855
856
857
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
858
859
860
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
861
862
863
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
864
865
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
866
867
868
869
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
870
871
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
872
873
874
875
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
876
877
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
878
879
880
881
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
882
883
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
884
885
886
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

887
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
888

889
890
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
891

892
893
894
895
896
897
898
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

951
952
953
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
954
955
956


n = 2
957
958
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
959

960
961
#dim1 = [2*1024]
#dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
962

Tim Dettmers's avatar
Tim Dettmers committed
963
964
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
965
966

dims = (2,)
967
formatB = ["col_turing", "col_ampere"]
968
969
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
970
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
971
972


973
974
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
975
    inner = torch.randint(1, 128, size=(1,)).item()
976
977
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
978
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
979
    for i in range(1):
980
981
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
982
        C1 = torch.matmul(A.half(), B.t().half())
983
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
984
985
986
987

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

988
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
989
990
991
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

992
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
993
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
994
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
995

996
997
998
999
1000
1001
1002
        # TODO: is something wrong here? If so, the problem goes deeper
        #n = C1.numel()
        #p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
Tim Dettmers's avatar
Tim Dettmers committed
1003
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
1004

1005
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
1006
1007
1008
        #torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1)
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
Tim Dettmers's avatar
Tim Dettmers committed
1009
1010
1011


n = 2
1012
1013
1014
1015
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1016
1017

dims = (2,)
1018
1019
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
1020
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
1021
1022


Tim Dettmers's avatar
Tim Dettmers committed
1023
1024
1025
1026
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1027
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1054
1055
1056
1057
1058
1059
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

        torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
        torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
        torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)

1060
1061
1062
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1063
1064
1065
1066
1067
1068
1069

        torch.testing.assert_allclose(col_stats1, col_stats2)
        torch.testing.assert_allclose(row_stats1, row_stats2)
        assert nnz_block_ptr2 is None


n = 2
1070
1071
1072
1073
1074
1075
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
1076
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1077

Tim Dettmers's avatar
Tim Dettmers committed
1078
1079
1080
1081

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1082
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
1093
1094
1095
1096
1097
1098
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1099
1100

        # allow for 1:500 error due to rounding differences
1101
1102
1103
1104
1105
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1106
            assert False
1107
1108
1109
1110
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
1113
1114
1115
1116
1117
            assert False

        torch.testing.assert_allclose(Srow.flatten(), statsA)
        torch.testing.assert_allclose(Scol.flatten(), statsAt)


n = 4
1118
1119
1120
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1121
1122

values = list(zip(dim1, dim4, inner))
1123
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1124
1125


Tim Dettmers's avatar
Tim Dettmers committed
1126
1127
1128
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1129
1130
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_allclose(maxA.flatten(), stats1a)
        torch.testing.assert_allclose(maxB.flatten(), stats2a)
        torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)

1144
1145
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1146
1147
1148
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1149
1150
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1151
1152
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1153
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1154
1155
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1156
1157
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1158
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1159
1160
1161


n = 6
1162
1163
1164
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1165
1166

values = list(zip(dim1, dim4, inner))
1167
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1168
1169


Tim Dettmers's avatar
Tim Dettmers committed
1170
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1171
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1172
1173
1174
1175
1176
1177
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1178
1179
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1180
1181
1182
1183
1184
1185
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1186
1187
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1188
1189
1190
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1191
1192
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1193
1194
1195
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1196
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1197
1198
1199
1200
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1201
1202
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1203
1204
1205
1206
1207
1208
1209
1210

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1211
1212
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1213
1214

        C = torch.matmul(CA.float(), CB.t().float())
1215
1216
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1217

1218
1219
1220
1221
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1222

1223
1224
1225
1226
1227
1228
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1229

1230
1231
1232
1233
1234
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1235
1236
1237


dim1 = [1024, 2048]
1238
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1239
1240
1241
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1242
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1243
1244


Tim Dettmers's avatar
Tim Dettmers committed
1245
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1246
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1247
1248
1249
1250
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1251
1252
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1263
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1264
1265

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1266
1267
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1268
1269
1270
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1271
1272
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1273
1274
1275
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1276
1277
1278
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1279
    torch.cuda.synchronize()
1280
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1281
1282
1283
1284
1285
1286
1287
1288

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1289
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1290
1291
1292


n = 2
1293
1294
1295
1296
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1297
1298
1299

dim3 = [0]
dtype = [torch.int8]
1300
1301
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1302
1303
transpose = [False, True]
dims = [2]
1304
1305
1306
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1307
names = [
1308
    "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
1309
1310
1311
1312
1313
1314
1315
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1316
1317
1318
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1319
)
Tim Dettmers's avatar
Tim Dettmers committed
1320
1321
1322
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1323
1324
1325
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1326
        elif dims == 3:
1327
1328
1329
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1341
1342
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1343
1344
1345

        torch.testing.assert_allclose(out1, out2)

1346

Tim Dettmers's avatar
Tim Dettmers committed
1347
n = 2
1348
1349
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1350
1351
1352
1353
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1354
1355
1356
1357
1358
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
1359
    "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
1360
1361
1362
1363
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1364
1365
def test_overflow():
    formatB = F.get_special_format_str()
1366
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1367
    for i in range(2):
1368
1369
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1370

1371
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1372
1373
1374
1375
1376
1377
1378
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1379
1380
1381
1382
1383
1384
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
1385
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1386

Tim Dettmers's avatar
Tim Dettmers committed
1387
1388
1389
1390
1391

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1392
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1393

1394
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1395
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1396
1397
1398
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1399
1400

        if coo_tensor is not None:
1401
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1402
            A2 = torch.zeros_like(A)
1403
1404
1405
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
Tim Dettmers's avatar
Tim Dettmers committed
1406
1407
            torch.testing.assert_allclose(A1, A2)

1408
1409
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1410
1411
1412
            torch.testing.assert_allclose(
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1413

Tim Dettmers's avatar
Tim Dettmers committed
1414
1415

n = 2
1416
1417
1418
1419
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1420
transposed_B = [False, True]
1421
values = list(product(dim1, dim2, transposed_B))
1422
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
1423
1424


Tim Dettmers's avatar
Tim Dettmers committed
1425
1426
1427
1428
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1429
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1441
1442
1443
1444
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1458
1459
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1460
    seq = 1024
1461
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1462
1463
1464
    dim2 = model
    dim3 = hidden
    threshold = 4
1465
1466
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1467
    for i in range(10):
1468
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1469
1470
1471
1472

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1473
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1474
    torch.cuda.synchronize()
1475
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1476
1477
1478

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1479
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1480
1481
    rows, cols = torch.where(idx)
    values = A[idx]
1482
1483
1484
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1485
1486

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1487
1488
1489
1490
1491
1492
1493
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1494
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1495
    print(tsp, t8)
1496
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1497
1498
1499


n = 2
1500
1501
1502
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
1503
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1504
1505


Tim Dettmers's avatar
Tim Dettmers committed
1506
1507
1508
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1509
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1510
1511
1512
1513
1514
1515
1516
1517
1518
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1519
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1520
1521
1522
1523

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1524
1525
1526
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1527
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1528
1529
1530
1531
1532
1533
1534
1535
1536

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1537
1538
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1539
1540
1541
1542
        assert err2 < err1


def test_matmuls():
1543
1544
1545
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1546
    c2 = bnb.matmul(a, b)
1547
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1548

1549
1550
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1551
1552
    assert err1 < 0.2
    assert err2 < 0.2
1553
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1554
1555
1556


n = 2
1557
1558
1559
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1560
dim2 = [12288]
1561
1562
1563
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1564
dtype = [torch.float16]
1565
1566
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1567
names = [
1568
    "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
1569
]
1570
1571


Tim Dettmers's avatar
Tim Dettmers committed
1572
1573
1574
1575
1576
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1577
1578
1579
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1580
    if dtype == torch.float16:
1581
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1582
1583
        torch.nn.init.xavier_uniform_(B)
    else:
1584
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1585
        torch.nn.init.xavier_uniform_(B)
1586
1587
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1588

1589
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1590
1591
1592
1593
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1594
1595
1596
1597
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1598
1599
1600
1601
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1602
1603
1604
1605
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1606
    n = out1.numel()
1607
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1608
1609
1610
    std = out1.std()
    out1 /= std
    out2 /= std
1611
1612
1613
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1614
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1615
1616
1617

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1618
    # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1619

1620
1621
1622
1623
1624
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1625
1626
1627
1628
1629
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1630
1631
1632
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1633
1634
1635
1636
1637
1638
1639
1640

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1641
1642
1643
1644
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1645
1646
1647
1648
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1649
1650
    torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
    idx = A2 != 0
Tim Dettmers's avatar
Tim Dettmers committed
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
    torch.testing.assert_allclose(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1661
1662
1663
1664
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1665
1666
1667
1668
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1669
    torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1670
    # torch uses row-major -> use transpose to transfer to col-major
1671
    idx = A2.t() != 0
Tim Dettmers's avatar
Tim Dettmers committed
1672
1673
1674
1675
    torch.testing.assert_allclose(A2.t()[idx], cscA.values)


n = 2
1676
1677
1678
1679
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1680
dim2 = [2048]
1681
1682
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1683
dtype = [torch.int8]
1684
values = list(product(dim1, dim2, dtype))
1685
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
1686
1687


Tim Dettmers's avatar
Tim Dettmers committed
1688
1689
1690
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1691
1692
1693
1694
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1708
1709
1710
1711
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1712
1713
1714
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1715
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1716
1717
1718
1719
1720
1721
1722
1723

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)

1724
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1725
    n = out1.numel()
1726
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1727
1728
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1729
1730
1731
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1732
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1733
1734
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1735
1736
1737
1738

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1739
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1740
    torch.cuda.synchronize()
1741
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1742
1743
1744
1745

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1746
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1747
    torch.cuda.synchronize()
1748
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1749
1750
1751
1752

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1753
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1754
    torch.cuda.synchronize()
1755
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1756
1757
1758
1759

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1760
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1761
    torch.cuda.synchronize()
1762
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1763
1764
1765
1766
1767
1768

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1769
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1770
    torch.cuda.synchronize()
1771
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1772
1773
1774
1775
1776
1777
1778

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1779
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1780
1781
1782
1783
1784
1785

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1786
1787
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1788

1789
1790
batch_size = 2
seqdim = 2048
Tim Dettmers's avatar
Tim Dettmers committed
1791
values = []
1792
values.append((batch_size, seqdim, 768, 4 * 768))
Tim Dettmers's avatar
Tim Dettmers committed
1793
1794
1795
1796
1797
1798
1799
values.append((batch_size, seqdim, 1024, 4*1024))
values.append((batch_size, seqdim, 1536, 4*1536))
values.append((batch_size, seqdim, 2048, 4*2048))
values.append((batch_size, seqdim, 2560, 4*2560))
values.append((batch_size, seqdim, 4096, 4*4096))
values.append((batch_size, seqdim, 5140, 4*5140))
values.append((batch_size, seqdim, 12288, 4*12288))
1800
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
1801
1802
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1803
    iters = 1
Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
    formatB = F.get_special_format_str()

1806
1807
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1808
1809
    torch.nn.init.xavier_uniform_(B)

1810
    B_fp4, state = F.quantize_fp4(B)
1811
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1812

1813
1814
    B_nf4, state_nf4= F.quantize_nf4(B)

Tim Dettmers's avatar
Tim Dettmers committed
1815
1816
1817
1818
1819
1820
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

1821
    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half())
Tim Dettmers's avatar
Tim Dettmers committed
1822
1823
    linearMixedBit.eval()

1824
1825
1826
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()

Tim Dettmers's avatar
Tim Dettmers committed
1827
    # warmup
1828
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1829
1830
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1831
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1832
1833
1834

    torch.cuda.synchronize()
    t0 = time.time()
1835
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1836
1837
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1838
1839
1840
1841
1842
    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1843
        bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
1844
1845
    torch.cuda.synchronize()
    print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1846

1847
1848
1849
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1850
        bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
1851
1852
1853
    torch.cuda.synchronize()
    print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1854
1855
1856
1857
1858
1859
1860
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
    torch.cuda.synchronize()
    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

Tim Dettmers's avatar
Tim Dettmers committed
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B, threshold=6.0)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    #C32A, SA = F.transform(CA, "col32")
    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    #CxB, SB = F.transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #torch.cuda.synchronize()
    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1)
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    #torch.cuda.synchronize()
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
    #torch.cuda.synchronize()
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linearMixedBit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linearMixedBit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train_thresh(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1945
1946
1947
1948
1949
1950

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1951
1952
1953
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1954
        minx = x.min()
1955
1956
1957
1958
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1959
        return x, qx, zpx
1960

Tim Dettmers's avatar
Tim Dettmers committed
1961
1962
1963
    batch = 2
    seq = 512
    model = 1024
1964
1965
1966
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1967
1968
1969

    C0 = torch.matmul(A, B)

1970
1971
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1972
1973
1974
1975
1976
1977
1978
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1979
1980
1981
1982
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1983
1984
1985

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
1986
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1987
1988
1989

    zp = 1
    scale = 2.0
1990
1991
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1992
1993
1994
1995
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1996
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1997
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1998

Tim Dettmers's avatar
Tim Dettmers committed
1999
2000
2001
2002
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2003
2004
2005
2006
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2007

Tim Dettmers's avatar
Tim Dettmers committed
2008
2009
2010
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2011
2012
2013
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2014

2015
2016
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
2017
2018
2019
2020
2021
2022
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2023
2024
2025
2026
2027
2028
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2029
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2030
2031


2032
def test_extract_outliers():
2033
    for i in range(k):
2034
        shapeA = (4096, 4096 * 4)
2035
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2036
2037
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2038
        outliers1 = A[:, idx.long()]
2039

2040
        CA, SA = F.transform(A, "col_turing")
2041

2042
        outliers2 = F.extract_outliers(CA, SA, idx)
2043

2044
2045
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2046

2047
2048
        torch.testing.assert_allclose(outliers1, outliers2)

2049
        CA, SA = F.transform(A, "col_ampere")
2050
2051
2052
2053
2054

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2055

2056
        torch.testing.assert_allclose(outliers1, outliers2)
2057
2058
2059
2060
2061
2062
2063
2064



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2065
    for hidden in [128]:#, 14336]:
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098



def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7-e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2099
2100
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2113
2114
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2127
2128
        #print(3, sum(abserr)/len(abserr))
        #print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2129

2130
2131
2132

def test_few_bit_quant():

2133
    #print('')
2134
    for bits in range(2, 9):
2135
        #print('='*30, bits, '='*30)
Tim Dettmers's avatar
Tim Dettmers committed
2136
2137
2138
        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
2139
2140
            code = None
            if method == 'linear':
2141
                code = F.create_linear_map(True, total_bits=bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2142
2143
2144
2145
            elif method == 'fp8':
                ebits = math.ceil(bits/2)
                pbits = bits-ebits-1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2146
2147
2148
2149
            elif method == 'dynamic':
                code = F.create_dynamic_map(True, bits-0, bits).cuda()
            elif method == 'quantile':
                values = torch.randn(2048, 2048, device='cuda')
Tim Dettmers's avatar
Tim Dettmers committed
2150
2151
2152
2153
2154
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
2155
            #print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
            assert code.numel() == 256
            for i in range(10):

                values = torch.randn(1, 32, device='cuda')
                values /= values.abs().max()
                #values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v-code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
2173
2174
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
2175
2176

                idx = torch.isclose(q1.int(), q2.int())
Tim Dettmers's avatar
Tim Dettmers committed
2177
2178
2179
                err2 = torch.abs(v2-values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2/(1e-10+values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2180
2181
2182
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1-values).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2183
                    #assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
2184
2185
2186

                else:
                    torch.testing.assert_allclose(q1, q2)
2187
            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
Tim Dettmers's avatar
Tim Dettmers committed
2188
    #assert False
Tim Dettmers's avatar
Tim Dettmers committed
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 4):
            total_values = 2**bits-1
            p = np.linspace(0, 1, 2*total_values+1)
            idx = np.arange(1, 2*total_values+1, 2)
            p = p[idx]
            offset = 1/(2*total_values)
            p = np.linspace(offset, 1-offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2213
            assert err < 0.035
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225


def test_bench_dequantization():
    a = torch.rand(1024, 1024, device='cuda').half()
    qa, SA = F.quantize_blockwise(a)

    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
    #print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
2226
        qa, SA = F.quantize_blockwise(a)
2227
2228
2229
    torch.cuda.synchronize()
    #print((time.time()-t0)/1e6)

2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259


def test_fp4_quant():
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign*8 + e1*4 + e2*2 + p1*1
        sign = -1.0 if sign else 1.0
        exp = e1*2 + e2*1
        if exp == 0:
            # sub-normal
            if p1 == 0: result = 0
            else: result = sign*0.0625
        else:
            # normal
            exp = 2**(-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign*exp*frac
        code[idx] = result

    A1 = torch.randn(1024, 1024, device='cuda').half()
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
    relerr = (err/A1.abs().float()).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2260
    idx = err > 1.0
2261
2262
    err = err.mean()

Tim Dettmers's avatar
Tim Dettmers committed
2263

Tim Dettmers's avatar
Tim Dettmers committed
2264
2265
    assert err.item() < 0.1
    assert relerr.item() < 0.28
2266
2267


Tim Dettmers's avatar
Tim Dettmers committed
2268
2269
2270
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
2271
2272
2273
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
2274
        for i in range(10):
2275
            A1 = torch.randn(1024, 1024, device='cuda').half()
2276
2277
2278
2279
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
2280
2281
2282
2283
2284
2285


            err = (A1 - A2).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2286
2287
            errs1.append(err.item())

2288
2289
2290
2291
2292
2293
2294
2295

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2296
            errs2.append(err.item())
2297
2298
2299
2300

            assert err.item() < 0.11
            assert relerr.item() < 0.28

2301
2302
        #print(sum(errs1)/len(errs1), blocksize, quant_type)
        #print(sum(errs2)/len(errs2), blocksize, quant_type)
2303
2304
2305
2306




Tim Dettmers's avatar
Tim Dettmers committed
2307
2308
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
2309
def test_bench_4bit_dequant(quant_type):
2310
2311
    blocksize = 256
    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
2312
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2313
2314
2315
2316
2317
2318

    input_size = a.numel()/2
    output_size = a.numel()*2
    num_bytes = input_size+output_size
    GB = num_bytes/1e9
    max_theoretical_s =  GB/768
2319
    #print(max_theoretical_s*1e6)
2320
2321
    b = torch.randn(128, 1024*12, device='cuda').half()

2322
    iters = 5
2323
2324
2325
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2326
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
2327
2328
        #b.copy_(a)
    torch.cuda.synchronize()
2329
2330
2331
2332
2333
2334
2335
2336
    #print((time.time()-t0)/iters*1e6)

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    torch.matmul(b, a.t())
    #torch.cuda.synchronize()
    #print((time.time()-t0)/iters*1e6)
2337
2338
2339
2340
2341
2342
2343



def test_normal_map_tree():
    code = F.create_normal_map()
    values =code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
Tim Dettmers's avatar
Tim Dettmers committed
2344
    print(values)
2345
2346
2347
2348
2349
2350
2351
2352
2353
    while num_pivots <16:
        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
        print(idx)
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i-1]+values[i])/2)
        print(pivots)

Tim Dettmers's avatar
Tim Dettmers committed
2354

Tim Dettmers's avatar
Tim Dettmers committed
2355
2356
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
2357
def test_cutlass3_gemm(dtype):
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
    for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
        errs = []
        relerrs = []
        max_err = 0
        max_relerr = 0
        for i in range(100):
            #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
            #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
            #A = torch.rand(1, 4096, dtype=dtype, device='cuda')
            #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
            A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
            B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)

            #print('')
            #print(A)
            #print(B.t())
            #A[:, :-3] = 0
            #B[:, :-3] = 0


            C1 = torch.matmul(A, B.t())
            C2 = F.cutlass3_gemm(A, B.t())

            # tensor cores are non-deterministic
            # so we need to analyze errors around the mean
            # to test our implementation
            err = torch.abs(C1-C2)
            mag = torch.abs(C1)+1e-8
            relerr = err/mag
            max_err = max(err.max(), max_err)
            max_relerr = max(relerr.max(), max_relerr)
            err = err.mean().item()
            relerr = relerr.mean().item()

            errs.append(err)
            relerrs.append(relerr)

            #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
            #    print('')
            #    print(i, err, mag.item(), relerr.item())
            #    print(A.flatten()[-6:])
            #    print(B.flatten()[-6:])
            #    out = A.flatten()[-6:]*B.flatten()[-6:]
            #    print(out)
            #    print(out[:-1].sum())
            #    print('='*80)
            #    print(C1.flatten()[-6:])
            #    print(C2.flatten()[-6:])
            #    #assert False, 'ERROR'

            c = int(C1.numel()*0.00125*(dim/256))+1
            assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
        print('')
        print(dim, sum(errs)/len(errs)/math.sqrt(dim))
        print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
        print(dim, (max_err.item(), max_relerr.item()))
Tim Dettmers's avatar
Tim Dettmers committed
2414

Tim Dettmers's avatar
Tim Dettmers committed
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_gemm_4bit(dtype):
    for i in range(1):
        #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
        #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
        #torch.random.manual_seed(17)
        A = torch.rand(1, 4096, dtype=dtype, device='cuda')
        B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')

        #print('')
        #print(A)
        #print(B)

        qB, state = F.quantize_nf4(B)
        F.dequantize_nf4(qB, state)


        C1 = torch.matmul(A, B.t())
        #C1 = bnb.matmul_4bit(A, qB.t(), state)
        C2 = F.cutlass3_gemm(A, qB.t(), state=state)
        #print(C1)
        #print(C2)

        #torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)

Tim Dettmers's avatar
Tim Dettmers committed
2441

Tim Dettmers's avatar
Tim Dettmers committed
2442
2443
2444
2445
2446
def test_pipeline_func():
    a = torch.rand(2, 4).cuda()
    out = F.pipeline_test(a, 2)
    print(a)
    print(out)