test_functional.py 66.4 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
9
10
import einops
import pytest
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
11
12
from bitsandbytes import functional as F

13
14
15
torch.set_printoptions(
    precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
Tim Dettmers's avatar
Tim Dettmers committed
16
17
k = 20

18

19
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
Tim Dettmers's avatar
Tim Dettmers committed
20
    idx = torch.isclose(a, b, rtol, atol)
21
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
22
    if sumval > count:
23
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
24
25
        torch.testing.assert_allclose(a, b, rtol, atol)

26

Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
        super(FFN, self).__init__()
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
class Timer(object):
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

49
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
50
51
52
53
54
55
56
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

57
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
62
63
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
64
65
66
67
68
69
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
70
            print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
Tim Dettmers's avatar
Tim Dettmers committed
71
72
73
74

        return self.agg[name]

    def reset(self):
75
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
76
77
        self.ends = {}
        self.agg = {}
78
79
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
80

Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
def setup():
    pass

84

Tim Dettmers's avatar
Tim Dettmers committed
85
86
87
def teardown():
    pass

88

89
90
91
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
92
def test_estimate_quantiles(dtype):
93
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
94
95
96
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

97
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
98
99
    torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)

100
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
101
102
103
104
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
105
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
110
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
111
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
112
113
114
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
115
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
116
117
        assert diff < 0.0075

118
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
119
120
121
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
122
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
123
124
125
126
127
128
129
130
        torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
131
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
132
133
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
134
135
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
136
137
138
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
139
140
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
141
142

    for i in range(100):
143
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
144
145
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
146
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
147
148
149
150
151
152
153
154
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


def test_dynamic_blockwise_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
155
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
156
157
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
158
159
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
160
161
162
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diffs[-1] < 0.011
163
164
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
165
166
167

    diffs = []
    for i in range(100):
168
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
169
170
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
171
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
        assert diff < 0.0033
        diffs.append(diff)
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
175
176
    # print(sum(diffs)/len(diffs))

Tim Dettmers's avatar
Tim Dettmers committed
177
178
179
180
181
182

def test_dynamic_blockwise_stochastic_quantization():
    diffs = []
    reldiffs = []
    rand = torch.rand(1024).cuda()
    for i in range(100):
183
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
184
185
186
187
        C1, S1 = F.quantize_blockwise(A1, rand=rand)
        C2, S2 = F.quantize_blockwise(A1)
        # a maximunm distance of quantized values of 1
        torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
188
189
190
191
192
        fraction_smaller = (C1 < C2).float().sum() / C1.numel()
        fraction_larger = (C1 > C2).float().sum() / C1.numel()
        torch.testing.assert_allclose(
            fraction_larger, fraction_smaller, atol=0.01, rtol=0
        )
Tim Dettmers's avatar
Tim Dettmers committed
193
194


195
196
197
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
198
def test_percentile_clipping(gtype):
199
200
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
201
202
    n = 4
    step = 0
203
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
204
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
205
        step += 1
206
207
208
209
210
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_allclose(clip1, clip2)
        torch.testing.assert_allclose(gnorm1, gnorm2)


Tim Dettmers's avatar
Tim Dettmers committed
226
227
def quant(x):
    max1 = torch.abs(x).max()
228
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
229
230
    return max1, x.to(torch.int8)

231

Tim Dettmers's avatar
Tim Dettmers committed
232
def dequant(c, maxC):
233
234
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
235
236

def mm_dequant(maxA, maxB, C):
237
238
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
239
240
241

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
242
243
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
244
245
    return max1, x.to(torch.int8)

246

Tim Dettmers's avatar
Tim Dettmers committed
247
def quant_multi_chunk(x, dim, chunk_size=32):
248
249
250
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
251
252
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
253
254
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
255
256
257
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
258
259
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
260
261
    return max1, x.to(torch.int8)

262

Tim Dettmers's avatar
Tim Dettmers committed
263
264
265
266
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

267

Tim Dettmers's avatar
Tim Dettmers committed
268
def mean(xx):
269
270
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
271

272
273
274
275
276
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
277
278
279
280
281
282
283
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
284
]
Tim Dettmers's avatar
Tim Dettmers committed
285
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
286
287
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
288
batched = [False, True]
289
290
291
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
292
293
    "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
    for vals in values_names
294
295
296
]


297
298
299
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
300
301
302
303
304
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
305
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
306
307
    for i in range(5):
        if batched:
308
309
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
310
311
312
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
313
314
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
315
316
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
317
318
319
        torch.testing.assert_allclose(
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
320
321
322
323
324
325
326
327
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
328
329
330
331
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    print(mean(errors))
    print(mean(relerrors))


Tim Dettmers's avatar
Tim Dettmers committed
338
339
340
341
342
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
343
n = 2
344
345
346
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
347
transpose = [(False, False), (False, True), (True, False), (True, True)]
348
349
350
351
352
353
354
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
    "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
    for vals in values
]


355
356
357
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
363
        shapeA = (
364
365
366
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
367
368
369
370
371
372
373
374
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
375
376
377
378
379
380
381
382
383
384
385
386
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
387

Tim Dettmers's avatar
Tim Dettmers committed
388
        torch.testing.assert_allclose(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
389

Tim Dettmers's avatar
Tim Dettmers committed
390
391
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
392
393
394
395
396
397
398
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
399
400
401
402
403
404
405
406
407
408
409
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_allclose(out.float(), out2)


n = 3
410
411
412
413
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
414
415
416
names = [
    "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
]
417
418


Tim Dettmers's avatar
Tim Dettmers committed
419
420
421
422
423
424
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
425
426
427
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
428
429
430
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
431
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
432
433
434
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
435
436
437
438
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_allclose(out.float(), out2)

439

Tim Dettmers's avatar
Tim Dettmers committed
440
n = 2
441
442
443
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
444
transpose = [False, True]
445
446
447
448
449
450
451
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
    "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
    for vals in values
]


452
453
454
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
455
456
457
458
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
459
460
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
461
462
463
464
465
466
467
468
469

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
470
471
472
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
473
        if transpose:
474
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
475
        else:
476
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
477
478
479
480
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
481
482
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
483
            out = out.float()
484
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
485
486
487
488
489
490

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
491
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
492
            out = F.igemm(Ac, Bc)
493
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
494
            out = out.float()
495
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
496
497
498
499
500
501
502
503
504
505

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

506
507
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
508

509
510
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
514
515

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
516
517
518
519
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
520
521
522
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

523

Tim Dettmers's avatar
Tim Dettmers committed
524
n = 2
525
526
527
528
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
529
transpose = [(False, False), (True, False), (False, True), (True, True)]
530
531
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
532
533
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
    for vals in values
534
535
536
]


Tim Dettmers's avatar
Tim Dettmers committed
537
538
539
540
541
542
543
544
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
545
546
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
547
548
549
550
551
552
553
554
555
556
557

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
558
559
560
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
561
562
563
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_allclose(out.float(), out2.float())

564

Tim Dettmers's avatar
Tim Dettmers committed
565
n = 1
566
567
568
569
570
571
572
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
573
574
575
576
577
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
578
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
579
580
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
581
582
583
584
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Tim Dettmers's avatar
Tim Dettmers committed
585
586
587


n = 2
588
589
590
591
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
592
dtype = [torch.int8, torch.int32]
593
594
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
595
596
transpose = [False]
dims = [2, 3]
597
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
598

599
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values]
600

Tim Dettmers's avatar
Tim Dettmers committed
601

602
603
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
604
605
606
607
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
608
609
610
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
611
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
612
    elif dims == 3:
613
614
615
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
616
617
618

    out, S = F.nvidia_transform(A, to_order=orderOut)

619
    if orderOut == "row":
Tim Dettmers's avatar
Tim Dettmers committed
620
        torch.testing.assert_allclose(A.flatten(), out.flatten())
621
    elif orderOut == "col":
Tim Dettmers's avatar
Tim Dettmers committed
622
        torch.testing.assert_allclose(A.t().flatten(), out.flatten())
623
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
624
        if dims == 2:
625
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
626
        elif dims == 3:
627
628
629
630
631
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
632
        assert out.numel() == n
633
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
634
        # 32 col 8 row tiles
635
636
637
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
638
639
640
641
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
642
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
643
644
645
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
646
647
648
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
649
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
650
                col2 = col % 32
651
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
652

653
654
655
656
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
657

658
    if orderOut == "col32":
659
660
661
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
Tim Dettmers's avatar
Tim Dettmers committed
662
663
664
665
        torch.testing.assert_allclose(A, out2)


n = 1
666
667
668
669
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
670

671
672
673
674
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
675

676
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
677
ldb = [0]
678
679
680
681
682
683
684
685
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
686
687
688
689
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
690
691
692
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
693
        elif dims == 3:
694
695
696
697
698
699
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
700
701
        C1 = torch.matmul(A.float(), B.t().float())

702
703
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
704
        C2, SC = F.igemmlt(A2, B2, SA, SB)
705
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
706
707
708
        torch.testing.assert_allclose(C1, C3.float())

        # transpose
709
710
711
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
712
713
        C1 = torch.matmul(A.float(), B.float())

714
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
715
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
716
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
717
718
        torch.testing.assert_allclose(C1, C3.float())

719

Tim Dettmers's avatar
Tim Dettmers committed
720
721
722
723
724
725
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
726
727
728
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
729
730
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
    for vals in values
731
732
733
]


Tim Dettmers's avatar
Tim Dettmers committed
734
735
736
737
738
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
739
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
740
        elif dims == 3:
741
742
743
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
744
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
745
746
747
748
749
750
751
752
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
753
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
754
755
756
757
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

758
759
760
761
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
762

763
        # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
764
765

        # transpose
766
767
768
769
770
771
772
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
773
774
775
776


batch_size = 2
seqdim = 512
777
778
779
780
781
782
783
784
785
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
786
787
788
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
Tim Dettmers's avatar
Tim Dettmers committed
789
790
791
792
793


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
794
795
796
797
798
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
799

800
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
801
    ## warmup
802
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
803
    #    torch.matmul(A, w1.t())
804
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
805
806
807
808
809
810
811
812

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

813
814
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
815

816
817
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
818

819
820
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
821
822
823
824
825

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

826
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
827

828
829
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
830

831
832
833
834
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
835

836
837
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
838
    ## fc1
839
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
840
841
842
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
843
844
845
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
846
847
848
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
849
850
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
851
852
853
854
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
855
856
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
857
858
859
860
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
861
862
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
863
864
865
866
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
867
868
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
869
870
871
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

872
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
873

874
875
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
876

877
878
879
880
881
882
883
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

936
937
938
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
939
940
941


n = 2
942
943
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
944

945
946
#dim1 = [2*1024]
#dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
947

Tim Dettmers's avatar
Tim Dettmers committed
948
949
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
950
951

dims = (2,)
952
formatB = ["col_turing", "col_ampere"]
953
954
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
955
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values]
956
957


958
959
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
960
    inner = torch.randint(1, 128, size=(1,)).item()
961
962
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
963
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
964
    for i in range(1):
965
966
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
967
        C1 = torch.matmul(A.half(), B.t().half())
968
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
969
970
971
972

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

973
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
974
975
976
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

977
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
978
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
979
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
980

981
982
983
984
985
986
987
        # TODO: is something wrong here? If so, the problem goes deeper
        #n = C1.numel()
        #p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
Tim Dettmers's avatar
Tim Dettmers committed
988
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
989

990
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
991
992
993
        #torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1)
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
Tim Dettmers's avatar
Tim Dettmers committed
994
995
996


n = 2
997
998
999
1000
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1001
1002

dims = (2,)
1003
1004
1005
1006
1007
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1008
1009
1010
1011
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1012
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1039
1040
1041
1042
1043
1044
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

        torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
        torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
        torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)

1045
1046
1047
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1048
1049
1050
1051
1052
1053
1054

        torch.testing.assert_allclose(col_stats1, col_stats2)
        torch.testing.assert_allclose(row_stats1, row_stats2)
        assert nnz_block_ptr2 is None


n = 2
1055
1056
1057
1058
1059
1060
1061
1062
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1063
1064
1065
1066

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1067
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
1078
1079
1080
1081
1082
1083
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1084
1085

        # allow for 1:500 error due to rounding differences
1086
1087
1088
1089
1090
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1091
            assert False
1092
1093
1094
1095
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1096
1097
1098
1099
1100
1101
1102
            assert False

        torch.testing.assert_allclose(Srow.flatten(), statsA)
        torch.testing.assert_allclose(Scol.flatten(), statsAt)


n = 4
1103
1104
1105
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1106
1107

values = list(zip(dim1, dim4, inner))
1108
1109
1110
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
1113
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1114
1115
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_allclose(maxA.flatten(), stats1a)
        torch.testing.assert_allclose(maxB.flatten(), stats2a)
        torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)

1129
1130
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1131
1132
1133
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1134
1135
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1136
1137
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1138
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1139
1140
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1141
1142
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1143
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1144
1145
1146


n = 6
1147
1148
1149
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1150
1151

values = list(zip(dim1, dim4, inner))
1152
1153
1154
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1155
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1156
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1157
1158
1159
1160
1161
1162
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1163
1164
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1165
1166
1167
1168
1169
1170
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1171
1172
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1173
1174
1175
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1176
1177
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1178
1179
1180
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1181
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1182
1183
1184
1185
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1186
1187
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1188
1189
1190
1191
1192
1193
1194
1195

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1196
1197
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1198
1199

        C = torch.matmul(CA.float(), CB.t().float())
1200
1201
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1202

1203
1204
1205
1206
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1207

1208
1209
1210
1211
1212
1213
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1214

1215
1216
1217
1218
1219
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1220
1221
1222


dim1 = [1024, 2048]
1223
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1224
1225
1226
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1227
1228
1229
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1230
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1231
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1232
1233
1234
1235
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1236
1237
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1248
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1249
1250

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1251
1252
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1253
1254
1255
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1256
1257
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1258
1259
1260
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1261
1262
1263
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1264
    torch.cuda.synchronize()
1265
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1266
1267
1268
1269
1270
1271
1272
1273

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1274
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1275
1276
1277


n = 2
1278
1279
1280
1281
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1282
1283
1284

dim3 = [0]
dtype = [torch.int8]
1285
1286
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1287
1288
transpose = [False, True]
dims = [2]
1289
1290
1291
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1292
1293
1294
1295
1296
1297
1298
1299
1300
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1301
1302
1303
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1304
)
Tim Dettmers's avatar
Tim Dettmers committed
1305
1306
1307
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1308
1309
1310
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1311
        elif dims == 3:
1312
1313
1314
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1326
1327
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1328
1329
1330

        torch.testing.assert_allclose(out1, out2)

1331

Tim Dettmers's avatar
Tim Dettmers committed
1332
n = 2
1333
1334
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1335
1336
1337
1338
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1349
1350
def test_overflow():
    formatB = F.get_special_format_str()
1351
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1352
    for i in range(2):
1353
1354
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1355

1356
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1357
1358
1359
1360
1361
1362
1363
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1364
1365
1366
1367
1368
1369
1370
1371
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1372
1373
1374
1375
1376

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1377
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1378

1379
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1380
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1381
1382
1383
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1384
1385

        if coo_tensor is not None:
1386
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1387
            A2 = torch.zeros_like(A)
1388
1389
1390
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
Tim Dettmers's avatar
Tim Dettmers committed
1391
1392
            torch.testing.assert_allclose(A1, A2)

1393
1394
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1395
1396
1397
            torch.testing.assert_allclose(
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1398

Tim Dettmers's avatar
Tim Dettmers committed
1399
1400

n = 2
1401
1402
1403
1404
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1405
transposed_B = [False, True]
1406
1407
1408
1409
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1410
1411
1412
1413
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1414
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1426
1427
1428
1429
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1443
1444
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1445
    seq = 1024
1446
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1447
1448
1449
    dim2 = model
    dim3 = hidden
    threshold = 4
1450
1451
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1452
    for i in range(10):
1453
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1454
1455
1456
1457

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1458
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1459
    torch.cuda.synchronize()
1460
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1461
1462
1463

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1464
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1465
1466
    rows, cols = torch.where(idx)
    values = A[idx]
1467
1468
1469
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1470
1471

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1472
1473
1474
1475
1476
1477
1478
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1479
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1480
    print(tsp, t8)
1481
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1482
1483
1484


n = 2
1485
1486
1487
1488
1489
1490
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1491
1492
1493
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1494
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1495
1496
1497
1498
1499
1500
1501
1502
1503
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1504
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1505
1506
1507
1508

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1509
1510
1511
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1512
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1513
1514
1515
1516
1517
1518
1519
1520
1521

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1522
1523
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1524
1525
1526
1527
        assert err2 < err1


def test_matmuls():
1528
1529
1530
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1531
    c2 = bnb.matmul(a, b)
1532
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1533

1534
1535
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1536
1537
    assert err1 < 0.2
    assert err2 < 0.2
1538
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1539
1540
1541


n = 2
1542
1543
1544
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1545
dim2 = [12288]
1546
1547
1548
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1549
dtype = [torch.float16]
1550
1551
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1552
1553
1554
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
]
1555
1556


Tim Dettmers's avatar
Tim Dettmers committed
1557
1558
1559
1560
1561
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1562
1563
1564
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1565
    if dtype == torch.float16:
1566
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1567
1568
        torch.nn.init.xavier_uniform_(B)
    else:
1569
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1570
        torch.nn.init.xavier_uniform_(B)
1571
1572
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1573

1574
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1575
1576
1577
1578
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1579
1580
1581
1582
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1583
1584
1585
1586
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1587
1588
1589
1590
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1591
    n = out1.numel()
1592
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1593
1594
1595
    std = out1.std()
    out1 /= std
    out2 /= std
1596
1597
1598
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1599
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1600
1601
1602

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1603
    # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1604

1605
1606
1607
1608
1609
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1610
1611
1612
1613
1614
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1615
1616
1617
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1618
1619

def test_layout():
1620
1621
1622
    a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
    a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
    a2, s2 = F.transform(a1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1623
1624
    print(a2.shape)

1625
    print(a1.flatten()[8 * 64 : 8 * 64 + 32])
Tim Dettmers's avatar
Tim Dettmers committed
1626
    for i in range(4):
1627
        print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
Tim Dettmers's avatar
Tim Dettmers committed
1628
1629
1630
1631
1632
1633
1634
1635
1636


def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1637
1638
1639
1640
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1641
1642
1643
1644
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1645
1646
    torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
    idx = A2 != 0
Tim Dettmers's avatar
Tim Dettmers committed
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
    torch.testing.assert_allclose(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1657
1658
1659
1660
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1661
1662
1663
1664
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1665
    torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1666
    # torch uses row-major -> use transpose to transfer to col-major
1667
    idx = A2.t() != 0
Tim Dettmers's avatar
Tim Dettmers committed
1668
1669
1670
1671
    torch.testing.assert_allclose(A2.t()[idx], cscA.values)


n = 2
1672
1673
1674
1675
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1676
dim2 = [2048]
1677
1678
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1679
dtype = [torch.int8]
1680
1681
1682
1683
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1684
1685
1686
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1687
1688
1689
1690
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1704
1705
1706
1707
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1708
1709
1710
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1711
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1712
1713
1714
1715
1716
1717
1718
1719

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)

1720
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1721
    n = out1.numel()
1722
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1723
1724
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1725
1726
1727
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1728
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1729
1730
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1731
1732
1733
1734

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1735
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1736
    torch.cuda.synchronize()
1737
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1738
1739
1740
1741

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1742
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1743
    torch.cuda.synchronize()
1744
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1745
1746
1747
1748

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1749
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1750
    torch.cuda.synchronize()
1751
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1752
1753
1754
1755

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1756
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1757
    torch.cuda.synchronize()
1758
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1759
1760
1761
1762
1763
1764

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1765
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1766
    torch.cuda.synchronize()
1767
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1768
1769
1770
1771
1772
1773
1774

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1775
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1776
1777
1778
1779
1780
1781

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1782
1783
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1784
1785

batch_size = 1
1786
seqdim = 1
Tim Dettmers's avatar
Tim Dettmers committed
1787
values = []
1788
values.append((batch_size, seqdim, 768, 4 * 768))
1789
1790
1791
1792
1793
1794
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
1795
#values.append((batch_size, seqdim, 12288, 4*12288))
1796
1797
1798
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
1799
1800


Tim Dettmers's avatar
Tim Dettmers committed
1801
1802
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1803
    iters = 128
Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
    formatB = F.get_special_format_str()

1806
1807
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1808
1809
1810
1811
1812
1813
1814
1815
    torch.nn.init.xavier_uniform_(B)

    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

1816
1817
1818
    linearMixedBit = (
        bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
1819
1820
1821
    linearMixedBit.eval()

    # warmup
1822
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1823
1824
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1825
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1826
1827
1828

    torch.cuda.synchronize()
    t0 = time.time()
1829
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1830
1831
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1832
    print(
1833
        f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
1834
    )
Tim Dettmers's avatar
Tim Dettmers committed
1835
1836
1837

    torch.cuda.synchronize()
    t0 = time.time()
1838
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1839
1840
        bnb.matmul(A, B)
    torch.cuda.synchronize()
1841
1842
1843
1844
1845
1846
1847
1848
    print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul(A, B, threshold=6.0)
    torch.cuda.synchronize()
    print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1849
1850

    CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
1851
    C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1852
1853
1854
1855
    CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    CxB, SB = F.transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1856
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1857
1858
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    torch.cuda.synchronize()
1859
    print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1860
1861
1862
1863
1864

    BA, statsB = F.vectorwise_quant(B, dim=1)
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1865
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1866
1867
        A2 = A.view(-1, A.shape[-1]).contiguous()
        CA, statsA = F.vectorwise_quant(A2, dim=1)
1868
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1869
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1870
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
Tim Dettmers's avatar
Tim Dettmers committed
1871
1872
        F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    torch.cuda.synchronize()
1873
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1874

1875
    BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1876
1877
1878
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
1879
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1880
        A2 = A.view(-1, A.shape[-1]).contiguous()
1881
1882
        CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1883
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1884
1885
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
        out = Cout * statsB * statsA * (1.0 / (127 * 127))
Tim Dettmers's avatar
Tim Dettmers committed
1886
    torch.cuda.synchronize()
1887
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1888
1889
1890
1891

    linear8bit(A)
    torch.cuda.synchronize()
    t0 = time.time()
1892
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1893
1894
        linear8bit(A)
    torch.cuda.synchronize()
1895
1896
1897
    print(
        f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1898
1899
1900
1901

    linearMixedBit(A)
    torch.cuda.synchronize()
    t0 = time.time()
1902
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1903
1904
        linearMixedBit(A)
    torch.cuda.synchronize()
1905
1906
1907
    print(
        f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1908
1909
1910
1911
1912
1913

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1914
1915
1916
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1917
        minx = x.min()
1918
1919
1920
1921
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1922
        return x, qx, zpx
1923

Tim Dettmers's avatar
Tim Dettmers committed
1924
1925
1926
    batch = 2
    seq = 512
    model = 1024
1927
1928
1929
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1930
1931
1932

    C0 = torch.matmul(A, B)

1933
1934
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1935
1936
1937
1938
1939
1940
1941
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1942
1943
1944
1945
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1946
1947
1948

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
1949
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1950
1951
1952

    zp = 1
    scale = 2.0
1953
1954
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1955
1956
1957
1958
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1959
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1960
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1961

Tim Dettmers's avatar
Tim Dettmers committed
1962
1963
1964
1965
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
1966
1967
1968
1969
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1970

Tim Dettmers's avatar
Tim Dettmers committed
1971
1972
1973
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
1974
1975
1976
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1977

1978
1979
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
1980
1981
1982
1983
1984
1985
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
1986
1987
1988
1989
1990
1991
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1992
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
1993
1994


1995
def test_extract_outliers():
1996
    for i in range(k):
1997
        shapeA = (4096, 4096 * 4)
1998
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
1999
2000
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2001
        outliers1 = A[:, idx.long()]
2002

2003
        CA, SA = F.transform(A, "col_turing")
2004

2005
        outliers2 = F.extract_outliers(CA, SA, idx)
2006

2007
2008
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2009

2010
2011
        torch.testing.assert_allclose(outliers1, outliers2)

2012
        CA, SA = F.transform(A, "col_ampere")
2013
2014
2015
2016
2017

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2018

2019
        torch.testing.assert_allclose(outliers1, outliers2)
2020
2021
2022
2023
2024
2025
2026
2027



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2028
    for hidden in [128]:#, 14336]:
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))