test_functional.py 68.9 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
9
10
import einops
import pytest
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
11
12
from bitsandbytes import functional as F

13
14
15
torch.set_printoptions(
    precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
Tim Dettmers's avatar
Tim Dettmers committed
16
17
k = 20

18

Tim Dettmers's avatar
Tim Dettmers committed
19
20
def assert_all_approx_close(a, b, rtol, atol, count):
    idx = torch.isclose(a, b, rtol, atol)
21
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
22
    if sumval > count:
23
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
24
25
        torch.testing.assert_allclose(a, b, rtol, atol)

26

Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
        super(FFN, self).__init__()
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
class Timer(object):
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

49
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
50
51
52
53
54
55
56
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

57
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
62
63
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
64
65
66
67
68
69
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
70
            print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
Tim Dettmers's avatar
Tim Dettmers committed
71
72
73
74

        return self.agg[name]

    def reset(self):
75
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
76
77
        self.ends = {}
        self.agg = {}
78
79
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
80

Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
def setup():
    pass

84

Tim Dettmers's avatar
Tim Dettmers committed
85
86
87
def teardown():
    pass

88

89
90
91
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
92
def test_estimate_quantiles(dtype):
93
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
94
95
96
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

97
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
98
99
    torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)

100
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
101
102
103
104
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
105
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
110
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
111
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
112
113
114
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
115
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
116
117
        assert diff < 0.0075

118
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
119
120
121
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
122
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
123
124
125
126
127
128
129
130
        torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
131
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
132
133
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
134
135
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
136
137
138
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
139
140
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
141
142

    for i in range(100):
143
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
144
145
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
146
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
147
148
149
150
151
152
153
154
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


def test_dynamic_blockwise_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
155
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
156
157
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
158
159
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
160
161
162
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diffs[-1] < 0.011
163
164
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
165
166
167

    diffs = []
    for i in range(100):
168
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
169
170
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
171
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
        assert diff < 0.0033
        diffs.append(diff)
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
175
176
    # print(sum(diffs)/len(diffs))

Tim Dettmers's avatar
Tim Dettmers committed
177
178
179
180
181
182

def test_dynamic_blockwise_stochastic_quantization():
    diffs = []
    reldiffs = []
    rand = torch.rand(1024).cuda()
    for i in range(100):
183
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
184
185
186
187
        C1, S1 = F.quantize_blockwise(A1, rand=rand)
        C2, S2 = F.quantize_blockwise(A1)
        # a maximunm distance of quantized values of 1
        torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
188
189
190
191
192
        fraction_smaller = (C1 < C2).float().sum() / C1.numel()
        fraction_larger = (C1 > C2).float().sum() / C1.numel()
        torch.testing.assert_allclose(
            fraction_larger, fraction_smaller, atol=0.01, rtol=0
        )
Tim Dettmers's avatar
Tim Dettmers committed
193
194


195
196
197
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
198
def test_percentile_clipping(gtype):
199
200
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
201
202
    n = 4
    step = 0
203
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
204
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
205
        step += 1
206
207
208
209
210
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_allclose(clip1, clip2)
        torch.testing.assert_allclose(gnorm1, gnorm2)


Tim Dettmers's avatar
Tim Dettmers committed
226
227
def quant(x):
    max1 = torch.abs(x).max()
228
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
229
230
    return max1, x.to(torch.int8)

231

Tim Dettmers's avatar
Tim Dettmers committed
232
def dequant(c, maxC):
233
234
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
235
236

def mm_dequant(maxA, maxB, C):
237
238
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
239
240
241

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
242
243
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
244
245
    return max1, x.to(torch.int8)

246

Tim Dettmers's avatar
Tim Dettmers committed
247
def quant_multi_chunk(x, dim, chunk_size=32):
248
249
250
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
251
252
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
253
254
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
255
256
257
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
258
259
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
260
261
    return max1, x.to(torch.int8)

262

Tim Dettmers's avatar
Tim Dettmers committed
263
264
265
266
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

267

Tim Dettmers's avatar
Tim Dettmers committed
268
def mean(xx):
269
270
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
271

272
273
274
275
276
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
277
278
279
280
281
282
283
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
284
]
Tim Dettmers's avatar
Tim Dettmers committed
285
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
286
287
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
288
batched = [False, True]
289
290
291
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
292
293
    "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
    for vals in values_names
294
295
296
]


297
298
299
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
300
301
302
303
304
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
305
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
306
307
    for i in range(5):
        if batched:
308
309
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
310
311
312
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
313
314
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
315
316
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
317
318
319
        torch.testing.assert_allclose(
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
320
321
322
323
324
325
326
327
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
328
329
330
331
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    print(mean(errors))
    print(mean(relerrors))


Tim Dettmers's avatar
Tim Dettmers committed
338
339
340
341
342
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
343
n = 2
344
345
346
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
347
transpose = [(False, False), (False, True), (True, False), (True, True)]
348
349
350
351
352
353
354
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
    "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
    for vals in values
]


355
356
357
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
363
        shapeA = (
364
365
366
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
367
368
369
370
371
372
373
374
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
375
376
377
378
379
380
381
382
383
384
385
386
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
387

Tim Dettmers's avatar
Tim Dettmers committed
388
        torch.testing.assert_allclose(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
389

Tim Dettmers's avatar
Tim Dettmers committed
390
391
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
392
393
394
395
396
397
398
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
399
400
401
402
403
404
405
406
407
408
409
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_allclose(out.float(), out2)


n = 3
410
411
412
413
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
414
415
416
names = [
    "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
]
417
418


Tim Dettmers's avatar
Tim Dettmers committed
419
420
421
422
423
424
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
425
426
427
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
428
429
430
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
431
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
432
433
434
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
435
436
437
438
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_allclose(out.float(), out2)

439

Tim Dettmers's avatar
Tim Dettmers committed
440
n = 2
441
442
443
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
444
transpose = [False, True]
445
446
447
448
449
450
451
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
    "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
    for vals in values
]


452
453
454
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
455
456
457
458
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
459
460
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
461
462
463
464
465
466
467
468
469

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
470
471
472
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
473
        if transpose:
474
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
475
        else:
476
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
477
478
479
480
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
481
482
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
483
            out = out.float()
484
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
485
486
487
488
489
490

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
491
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
492
            out = F.igemm(Ac, Bc)
493
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
494
            out = out.float()
495
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
496
497
498
499
500
501
502
503
504
505

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

506
507
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
508

509
510
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
514
515

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
516
517
518
519
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
520
521
522
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

523

Tim Dettmers's avatar
Tim Dettmers committed
524
n = 2
525
526
527
528
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
529
transpose = [(False, False), (True, False), (False, True), (True, True)]
530
531
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
532
533
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
    for vals in values
534
535
536
]


Tim Dettmers's avatar
Tim Dettmers committed
537
538
539
540
541
542
543
544
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
545
546
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
547
548
549
550
551
552
553
554
555
556
557

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
558
559
560
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
561
562
563
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_allclose(out.float(), out2.float())

564

Tim Dettmers's avatar
Tim Dettmers committed
565
n = 1
566
567
568
569
570
571
572
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
573
574
575
576
577
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
578
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
579
580
581
582
583
584
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
        torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)


n = 2
585
586
587
588
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
589
dtype = [torch.int8, torch.int32]
590
591
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
592
593
transpose = [False]
dims = [2, 3]
594
595
596
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
597
598
599
600
601
602
603
604

names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
        *vals
    )
    for vals in values
]

Tim Dettmers's avatar
Tim Dettmers committed
605

606
@pytest.mark.parametrize(
607
608
609
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
610
)
611
612
613
def test_nvidia_transform(
    dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
):
614
615
616
617
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
618
619
620
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
621
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
622
    elif dims == 3:
623
624
625
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
626
627
628

    out, S = F.nvidia_transform(A, to_order=orderOut)

629
    if orderOut == "row":
Tim Dettmers's avatar
Tim Dettmers committed
630
        torch.testing.assert_allclose(A.flatten(), out.flatten())
631
    elif orderOut == "col":
Tim Dettmers's avatar
Tim Dettmers committed
632
        torch.testing.assert_allclose(A.t().flatten(), out.flatten())
633
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
634
        if dims == 2:
635
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
636
        elif dims == 3:
637
638
639
640
641
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
642
        assert out.numel() == n
643
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
644
        # 32 col 8 row tiles
645
646
647
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
648
649
650
651
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
652
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
653
654
655
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
656
657
658
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
659
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
660
                col2 = col % 32
661
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
662

663
664
665
666
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
667

668
    if orderOut == "col32":
669
670
671
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
Tim Dettmers's avatar
Tim Dettmers committed
672
673
674
675
        torch.testing.assert_allclose(A, out2)


n = 1
676
677
678
679
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
680

681
682
683
684
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
685

686
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
687
ldb = [0]
688
689
690
691
692
693
694
695
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
696
697
698
699
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
700
701
702
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
703
        elif dims == 3:
704
705
706
707
708
709
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
710
711
        C1 = torch.matmul(A.float(), B.t().float())

712
713
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
714
        C2, SC = F.igemmlt(A2, B2, SA, SB)
715
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
716
717
718
        torch.testing.assert_allclose(C1, C3.float())

        # transpose
719
720
721
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
722
723
        C1 = torch.matmul(A.float(), B.float())

724
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
725
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
726
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
727
728
        torch.testing.assert_allclose(C1, C3.float())

729

Tim Dettmers's avatar
Tim Dettmers committed
730
731
732
733
734
735
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
736
737
738
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
739
740
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
    for vals in values
741
742
743
]


Tim Dettmers's avatar
Tim Dettmers committed
744
745
746
747
748
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
749
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
750
        elif dims == 3:
751
752
753
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
754
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
755
756
757
758
759
760
761
762
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
763
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
764
765
766
767
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

768
769
770
771
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
772

773
        # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
774
775

        # transpose
776
777
778
779
780
781
782
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
783
784
785
786


batch_size = 2
seqdim = 512
787
788
789
790
791
792
793
794
795
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
796
797
798
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
Tim Dettmers's avatar
Tim Dettmers committed
799
800
801
802
803


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
804
805
806
807
808
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
809

810
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
811
    ## warmup
812
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
813
    #    torch.matmul(A, w1.t())
814
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
815
816
817
818
819
820
821
822

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

823
824
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
825

826
827
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
828

829
830
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
831
832
833
834
835

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

836
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
837

838
839
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
840

841
842
843
844
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
845

846
847
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
848
    ## fc1
849
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
850
851
852
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
853
854
855
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
856
857
858
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
859
860
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
861
862
863
864
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
865
866
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
867
868
869
870
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
871
872
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
873
874
875
876
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
877
878
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
879
880
881
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

882
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
883

884
885
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
886

887
888
889
890
891
892
893
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

946
947
948
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
949
950
951


n = 2
952
953
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
954

955
956
# dim1 = [2*1024]
# dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
957

Tim Dettmers's avatar
Tim Dettmers committed
958
959
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
960
961

dims = (2,)
962
963
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
964
965
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
966
names = [
967
    "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values
968
]
969
970


971
972
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
973
    inner = torch.randint(1, 128, size=(1,)).item()
974
975
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
976
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
977
    for i in range(1):
978
979
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
980
        C1 = torch.matmul(A.half(), B.t().half())
981
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
982
983
984
985

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

986
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
987
988
989
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

990
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
991
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
992
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
993
994
995
996

        count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
        n = C1.numel()
        p = 0.06
Tim Dettmers's avatar
Tim Dettmers committed
997
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
998

999
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
Tim Dettmers's avatar
Tim Dettmers committed
1000
1001
1002
1003
        torch.testing.assert_allclose(C5, C4)


n = 2
1004
1005
1006
1007
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1008
1009

dims = (2,)
1010
1011
1012
1013
1014
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1015
1016
1017
1018
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1019
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1046
1047
1048
1049
1050
1051
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

        torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
        torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
        torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)

1052
1053
1054
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1055
1056
1057
1058
1059
1060
1061

        torch.testing.assert_allclose(col_stats1, col_stats2)
        torch.testing.assert_allclose(row_stats1, row_stats2)
        assert nnz_block_ptr2 is None


n = 2
1062
1063
1064
1065
1066
1067
1068
1069
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1070
1071
1072
1073

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1074
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
1085
1086
1087
1088
1089
1090
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1091
1092

        # allow for 1:500 error due to rounding differences
1093
1094
1095
1096
1097
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1098
            assert False
1099
1100
1101
1102
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1103
1104
1105
1106
1107
1108
1109
            assert False

        torch.testing.assert_allclose(Srow.flatten(), statsA)
        torch.testing.assert_allclose(Scol.flatten(), statsAt)


n = 4
1110
1111
1112
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1113
1114
1115
1116
1117
1118

dim1 = [6]
dim4 = [4]
inner = [8]

values = list(zip(dim1, dim4, inner))
1119
1120
1121
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1122
1123
1124
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1125
1126
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_allclose(maxA.flatten(), stats1a)
        torch.testing.assert_allclose(maxB.flatten(), stats2a)
        torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)

1140
1141
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1142
1143
1144
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1145
1146
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1147
1148
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1149
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1150
1151
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1152
1153
1154
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
        assert err2 <= err1 * 1.01
Tim Dettmers's avatar
Tim Dettmers committed
1155
1156
1157


n = 6
1158
1159
1160
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1161
1162

values = list(zip(dim1, dim4, inner))
1163
1164
1165
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1166
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1167
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1168
1169
1170
1171
1172
1173
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1174
1175
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1176
1177
1178
1179
1180
1181
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1182
1183
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1184
1185
1186
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1187
1188
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1189
1190
1191
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1192
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1193
1194
1195
1196
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1197
1198
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1199
1200
1201
1202
1203
1204
1205
1206

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1207
1208
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1209
1210

        C = torch.matmul(CA.float(), CB.t().float())
1211
1212
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1213

1214
1215
1216
1217
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1218

1219
1220
1221
1222
1223
1224
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1225

1226
1227
1228
1229
1230
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1231
1232
1233


dim1 = [1024, 2048]
1234
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1235
1236
1237
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1238
1239
1240
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1241
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1242
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1243
1244
1245
1246
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1247
1248
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1259
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1260
1261

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1262
1263
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1264
1265
1266
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1267
1268
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1269
1270
1271
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1272
1273
1274
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1275
    torch.cuda.synchronize()
1276
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1277
1278
1279
1280
1281
1282
1283
1284

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1285
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1286
1287
1288


n = 2
1289
1290
1291
1292
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1293
1294
1295

dim3 = [0]
dtype = [torch.int8]
1296
1297
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1298
1299
transpose = [False, True]
dims = [2]
1300
1301
1302
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1303
1304
1305
1306
1307
1308
1309
1310
1311
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1312
1313
1314
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1315
)
Tim Dettmers's avatar
Tim Dettmers committed
1316
1317
1318
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1319
1320
1321
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1322
        elif dims == 3:
1323
1324
1325
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1337
1338
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1339
1340
1341

        torch.testing.assert_allclose(out1, out2)

1342

Tim Dettmers's avatar
Tim Dettmers committed
1343
n = 2
1344
1345
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1346
1347
1348
1349
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
    for vals in values
]


1360
1361
1362
@pytest.mark.parametrize(
    "dim1, dim2, dtype, orderA, orderOut", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
1363
1364
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
    for i in range(1):
1365
        A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1366
1367

        out2, S2 = F.transform(A, to_order=orderA)
1368
        A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
Tim Dettmers's avatar
Tim Dettmers committed
1369
1370
1371
        assert A2.shape[0] == A.shape[0]
        assert A2.shape[1] == A.shape[1]

1372
        print("")
Tim Dettmers's avatar
Tim Dettmers committed
1373
1374
1375
1376
        print(A)
        print(out2)
        print(A2)

1377
        # torch.testing.assert_allclose(A, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1378
1379
1380
1381


def test_overflow():
    formatB = F.get_special_format_str()
1382
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1383
    for i in range(2):
1384
1385
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1386

1387
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1388
1389
1390
1391
1392
1393
1394
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1395
1396
1397
1398
1399
1400
1401
1402
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1403
1404
1405
1406
1407

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1408
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1409

1410
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1411
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1412
1413
1414
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1415
1416

        if coo_tensor is not None:
1417
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1418
            A2 = torch.zeros_like(A)
1419
1420
1421
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
Tim Dettmers's avatar
Tim Dettmers committed
1422
1423
            torch.testing.assert_allclose(A1, A2)

1424
1425
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1426
1427
1428
            torch.testing.assert_allclose(
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1429

Tim Dettmers's avatar
Tim Dettmers committed
1430
1431

n = 2
1432
1433
1434
1435
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1436
transposed_B = [False, True]
1437
1438
1439
1440
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1441
1442
1443
1444
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1445
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1457
1458
1459
1460
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1474
1475
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1476
    seq = 1024
1477
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1478
1479
1480
    dim2 = model
    dim3 = hidden
    threshold = 4
1481
1482
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1483
    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1484
1485
1486
1487
1488
1489
1490
        C1 = bnb.matmul(A, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = bnb.matmul(A, B)
    torch.cuda.synchronize()
1491
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1492
1493
1494

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1495
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1496
1497
    rows, cols = torch.where(idx)
    values = A[idx]
1498
1499
1500
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1501
1502

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1503
1504
1505
1506
1507
1508
1509
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1510
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1511
    print(tsp, t8)
1512
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1513
1514
1515


n = 2
1516
1517
1518
1519
1520
1521
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1522
1523
1524
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1525
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1526
1527
1528
1529
1530
1531
1532
1533
1534
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1535
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1536
1537
1538
1539

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1540
1541
1542
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1543
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1544
1545
1546
1547
1548
1549
1550
1551
1552

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1553
1554
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
        assert err2 < err1


def test_matmuls():
    a = torch.randn(256, 256).half().cuda()
    b = torch.randn(256, 256).half().cuda()
    c1 = torch.matmul(a, b)
    c2 = bnb.matmul(a, b)
    c3 = bnb.matmul(a, b)

1565
1566
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1567
1568
1569
1570
1571
    assert err1 < 0.2
    assert err2 < 0.2


n = 2
1572
1573
1574
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1575
dim2 = [12288]
1576
1577
1578
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1579
dtype = [torch.float16]
1580
1581
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1582
1583
1584
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
]
1585
1586


Tim Dettmers's avatar
Tim Dettmers committed
1587
1588
1589
1590
1591
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1592
1593
1594
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1595
    if dtype == torch.float16:
1596
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1597
1598
        torch.nn.init.xavier_uniform_(B)
    else:
1599
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1600
        torch.nn.init.xavier_uniform_(B)
1601
1602
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1603

1604
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1605
1606
1607
1608
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1609
1610
1611
1612
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1613
1614
1615
1616
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1617
1618
1619
1620
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1621
    n = out1.numel()
1622
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1623
1624
1625
    std = out1.std()
    out1 /= std
    out2 /= std
1626
1627
1628
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1629
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1630
1631
1632

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1633
    # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1634

1635
1636
1637
1638
1639
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1640
1641
1642
1643
1644
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1645
1646
1647
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1648
1649

def test_layout():
1650
1651
1652
    a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
    a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
    a2, s2 = F.transform(a1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1653
1654
    print(a2.shape)

1655
    print(a1.flatten()[8 * 64 : 8 * 64 + 32])
Tim Dettmers's avatar
Tim Dettmers committed
1656
    for i in range(4):
1657
        print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
Tim Dettmers's avatar
Tim Dettmers committed
1658
1659
1660
1661
1662
1663
1664
1665
1666


def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1667
1668
1669
1670
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1671
1672
1673
1674
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1675
1676
    torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
    idx = A2 != 0
Tim Dettmers's avatar
Tim Dettmers committed
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
    torch.testing.assert_allclose(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1687
1688
1689
1690
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1691
1692
1693
1694
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1695
    torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1696
    # torch uses row-major -> use transpose to transfer to col-major
1697
    idx = A2.t() != 0
Tim Dettmers's avatar
Tim Dettmers committed
1698
1699
1700
1701
    torch.testing.assert_allclose(A2.t()[idx], cscA.values)


n = 2
1702
1703
1704
1705
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1706
dim2 = [2048]
1707
1708
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1709
dtype = [torch.int8]
1710
1711
1712
1713
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1714
1715
1716
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1717
1718
1719
1720
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1734
1735
1736
1737
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1738
1739
1740
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1741
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1742
1743
1744
1745
1746
1747
1748
1749

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)

1750
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1751
    n = out1.numel()
1752
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1753
1754
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1755
1756
1757
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1758
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1759
1760
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1761
1762
1763
1764

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1765
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1766
    torch.cuda.synchronize()
1767
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1768
1769
1770
1771

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1772
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1773
    torch.cuda.synchronize()
1774
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1775
1776
1777
1778

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1779
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1780
    torch.cuda.synchronize()
1781
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1782
1783
1784
1785

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1786
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1787
    torch.cuda.synchronize()
1788
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1789
1790
1791
1792
1793
1794

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1795
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1796
    torch.cuda.synchronize()
1797
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1798
1799
1800
1801
1802
1803
1804

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1805
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1806
1807
1808
1809
1810
1811

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1812
1813
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1814
1815

batch_size = 1
1816
seqdim = 1
Tim Dettmers's avatar
Tim Dettmers committed
1817
values = []
1818
values.append((batch_size, seqdim, 768, 4 * 768))
1819
1820
1821
1822
1823
1824
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
1825
#values.append((batch_size, seqdim, 12288, 4*12288))
1826
1827
1828
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
1829
1830


Tim Dettmers's avatar
Tim Dettmers committed
1831
1832
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1833
    iters = 128
Tim Dettmers's avatar
Tim Dettmers committed
1834
1835
    formatB = F.get_special_format_str()

1836
1837
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1838
1839
1840
1841
1842
1843
1844
1845
    torch.nn.init.xavier_uniform_(B)

    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

1846
1847
1848
    linearMixedBit = (
        bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
1849
1850
1851
    linearMixedBit.eval()

    # warmup
1852
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1853
1854
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1855
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1856
1857
1858

    torch.cuda.synchronize()
    t0 = time.time()
1859
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1860
1861
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1862
    print(
1863
        f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
1864
    )
Tim Dettmers's avatar
Tim Dettmers committed
1865
1866
1867

    torch.cuda.synchronize()
    t0 = time.time()
1868
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1869
1870
        bnb.matmul(A, B)
    torch.cuda.synchronize()
1871
1872
1873
1874
1875
1876
1877
1878
    print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul(A, B, threshold=6.0)
    torch.cuda.synchronize()
    print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1879
1880

    CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
1881
    C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1882
1883
1884
1885
    CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    CxB, SB = F.transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1886
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1887
1888
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    torch.cuda.synchronize()
1889
    print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1890
1891
1892
1893
1894

    BA, statsB = F.vectorwise_quant(B, dim=1)
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1895
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1896
1897
        A2 = A.view(-1, A.shape[-1]).contiguous()
        CA, statsA = F.vectorwise_quant(A2, dim=1)
1898
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1899
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1900
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
Tim Dettmers's avatar
Tim Dettmers committed
1901
1902
        F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    torch.cuda.synchronize()
1903
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1904

1905
    BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1906
1907
1908
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1909
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1910
        A2 = A.view(-1, A.shape[-1]).contiguous()
1911
1912
        CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1913
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1914
1915
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
        out = Cout * statsB * statsA * (1.0 / (127 * 127))
Tim Dettmers's avatar
Tim Dettmers committed
1916
    torch.cuda.synchronize()
1917
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1918
1919
1920
1921

    linear8bit(A)
    torch.cuda.synchronize()
    t0 = time.time()
1922
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1923
1924
        linear8bit(A)
    torch.cuda.synchronize()
1925
1926
1927
    print(
        f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1928
1929
1930
1931

    linearMixedBit(A)
    torch.cuda.synchronize()
    t0 = time.time()
1932
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1933
1934
        linearMixedBit(A)
    torch.cuda.synchronize()
1935
1936
1937
    print(
        f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1938
1939
1940
1941
1942
1943


def test_zeropoint():
    def min_max(x):
        maxA = torch.amax(x, dim=1, keepdim=True)
        minA = torch.amin(x, dim=1, keepdim=True)
1944
1945
1946
1947
1948
        midpoint = (maxA - minA) / 2.0
        dyna = 252 / (maxA - minA)
        # dyna *= 0.98
        x = dyna * x
        x = x - torch.round((dyna * (minA + midpoint)))
Tim Dettmers's avatar
Tim Dettmers committed
1949
        return x.to(torch.int8), minA, midpoint, dyna
1950

Tim Dettmers's avatar
Tim Dettmers committed
1951
1952
1953
    batch = 2
    seq = 2
    model = 4
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
    hidden = 2 * model
    # batch = 4
    # seq = 2048
    # model = 1024
    # hidden = 8*model
    A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
    B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())

    # A[0] = 0
    # B[:, 0] = 0
    # A = A*(A>0)
    # A[0, 0] = 0
    # A[0, 0] = 6.0
Tim Dettmers's avatar
Tim Dettmers committed
1967
1968

    Ac, minA, midpoint, dyna = min_max(A)
1969
1970
1971
    # print(Ac[0, 0], 'zero')
    # print(Ac, Ac.min(), Ac.max())
    Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1972
    out = F.igemm(Ac, Bc)
1973
1974
    out2 = torch.matmul(A, B)
    offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1975
    out = out.float()
1976
1977
1978
    # print(out.shape, maxB.shape, scale.shape, offset.shape)
    norm1 = maxB / 127
    C4 = (out / dyna) * norm1 + offset
Tim Dettmers's avatar
Tim Dettmers committed
1979
1980
1981
1982
1983
1984
1985

    B1 = torch.nn.Parameter(B.clone())
    B2 = torch.nn.Parameter(B.clone())
    B3 = torch.nn.Parameter(B.clone())
    B4 = torch.nn.Parameter(B.clone())

    C1 = torch.matmul(A, B1)
1986
1987
1988
    C2 = bnb.matmul_cublas(A, B2, None, "linear")
    C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
    C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
Tim Dettmers's avatar
Tim Dettmers committed
1989

1990
1991
1992
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1993
    print(err1, err2, err3)
1994
    # assert err1 > err2
Tim Dettmers's avatar
Tim Dettmers committed
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010

    loss1 = C1.mean()
    loss2 = C2.mean()
    loss3 = C3.mean()
    loss4 = C4.mean()

    loss1.backward()
    loss2.backward()
    loss3.backward()
    loss4.backward()

    print(B.grad)
    print(B1.grad)
    print(B2.grad)
    print(B3.grad)
    print(B4.grad)
2011
2012
2013
    err1 = torch.abs(B1.grad - B2.grad).mean().item()
    err2 = torch.abs(B1.grad - B3.grad).mean().item()
    err3 = torch.abs(B1.grad - B4.grad).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2014
2015
2016
2017
2018
2019
2020
2021
    print(err1, err2, err3)


def test_zp():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
2022
2023
2024
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2025
        minx = x.min()
2026
2027
2028
2029
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2030
        return x, qx, zpx
2031

Tim Dettmers's avatar
Tim Dettmers committed
2032
2033
2034
    batch = 2
    seq = 512
    model = 1024
2035
2036
2037
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
2038
2039
2040

    C0 = torch.matmul(A, B)

2041
2042
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
2043
2044
2045
2046
2047
2048
2049
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
2050
2051
2052
2053
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2054
2055
2056

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
2057
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
2058
2059
2060

    zp = 1
    scale = 2.0
2061
2062
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2063
2064
2065
2066
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
2067
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
2068
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
2069

Tim Dettmers's avatar
Tim Dettmers committed
2070
2071
2072
2073
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2074
2075
2076
2077
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2078

Tim Dettmers's avatar
Tim Dettmers committed
2079
2080
2081
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2082
2083
2084
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2085

2086
2087
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
2088
2089
2090
2091
2092
2093
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2094
2095
2096
2097
2098
2099
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2100
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2101
2102


2103
def test_extract_outliers():
2104
    for i in range(k):
2105
        shapeA = (4096, 4096 * 4)
2106
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2107
2108
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2109
        outliers1 = A[:, idx.long()]
2110

2111
        CA, SA = F.transform(A, "col_turing")
2112

2113
        outliers2 = F.extract_outliers(CA, SA, idx)
2114

2115
2116
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2117

2118
2119
        torch.testing.assert_allclose(outliers1, outliers2)

2120
        CA, SA = F.transform(A, "col_ampere")
2121
2122
2123
2124
2125

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2126

2127
        torch.testing.assert_allclose(outliers1, outliers2)
2128
2129
2130
2131
2132
2133
2134
2135



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
    for hidden in [128, 14336]:
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))