test_functional.py 67.6 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
9
10
import einops
import pytest
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
11
12
from bitsandbytes import functional as F

13
14
15
torch.set_printoptions(
    precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
Tim Dettmers's avatar
Tim Dettmers committed
16
17
k = 20

18

Tim Dettmers's avatar
Tim Dettmers committed
19
20
def assert_all_approx_close(a, b, rtol, atol, count):
    idx = torch.isclose(a, b, rtol, atol)
21
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
22
    if sumval > count:
23
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
24
25
        torch.testing.assert_allclose(a, b, rtol, atol)

26

Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
        super(FFN, self).__init__()
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
class Timer(object):
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

49
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
50
51
52
53
54
55
56
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

57
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
62
63
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
64
65
66
67
68
69
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
70
            print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
Tim Dettmers's avatar
Tim Dettmers committed
71
72
73
74

        return self.agg[name]

    def reset(self):
75
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
76
77
        self.ends = {}
        self.agg = {}
78
79
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
80

Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
def setup():
    pass

84

Tim Dettmers's avatar
Tim Dettmers committed
85
86
87
def teardown():
    pass

88

89
90
91
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
92
def test_estimate_quantiles(dtype):
93
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
94
95
96
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

97
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
98
99
    torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)

100
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
101
102
103
104
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
105
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
110
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
111
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
112
113
114
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
115
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
116
117
        assert diff < 0.0075

118
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
119
120
121
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
122
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
123
124
125
126
127
128
129
130
        torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
131
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
132
133
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
134
135
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
136
137
138
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
139
140
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
141
142

    for i in range(100):
143
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
144
145
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
146
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
147
148
149
150
151
152
153
154
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


def test_dynamic_blockwise_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
155
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
156
157
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
158
159
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
160
161
162
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diffs[-1] < 0.011
163
164
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
165
166
167

    diffs = []
    for i in range(100):
168
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
169
170
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
171
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
172
173
174
        assert diff < 0.0033
        diffs.append(diff)
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
175
176
    # print(sum(diffs)/len(diffs))

Tim Dettmers's avatar
Tim Dettmers committed
177
178
179
180
181
182

def test_dynamic_blockwise_stochastic_quantization():
    diffs = []
    reldiffs = []
    rand = torch.rand(1024).cuda()
    for i in range(100):
183
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
184
185
186
187
        C1, S1 = F.quantize_blockwise(A1, rand=rand)
        C2, S2 = F.quantize_blockwise(A1)
        # a maximunm distance of quantized values of 1
        torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
188
189
190
191
192
        fraction_smaller = (C1 < C2).float().sum() / C1.numel()
        fraction_larger = (C1 > C2).float().sum() / C1.numel()
        torch.testing.assert_allclose(
            fraction_larger, fraction_smaller, atol=0.01, rtol=0
        )
Tim Dettmers's avatar
Tim Dettmers committed
193
194


195
196
197
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
198
def test_percentile_clipping(gtype):
199
200
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
201
202
    n = 4
    step = 0
203
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
204
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
205
        step += 1
206
207
208
209
210
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_allclose(clip1, clip2)
        torch.testing.assert_allclose(gnorm1, gnorm2)


Tim Dettmers's avatar
Tim Dettmers committed
226
227
def quant(x):
    max1 = torch.abs(x).max()
228
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
229
230
    return max1, x.to(torch.int8)

231

Tim Dettmers's avatar
Tim Dettmers committed
232
def dequant(c, maxC):
233
234
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
235
236

def mm_dequant(maxA, maxB, C):
237
238
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
239
240
241

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
242
243
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
244
245
    return max1, x.to(torch.int8)

246

Tim Dettmers's avatar
Tim Dettmers committed
247
def quant_multi_chunk(x, dim, chunk_size=32):
248
249
250
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
251
252
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
253
254
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
255
256
257
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
258
259
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
260
261
    return max1, x.to(torch.int8)

262

Tim Dettmers's avatar
Tim Dettmers committed
263
264
265
266
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

267

Tim Dettmers's avatar
Tim Dettmers committed
268
def mean(xx):
269
270
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
271

272
273
274
275
276
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
277
278
279
280
281
282
283
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
284
]
Tim Dettmers's avatar
Tim Dettmers committed
285
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
286
287
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
288
batched = [False, True]
289
290
291
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
292
293
    "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
    for vals in values_names
294
295
296
]


297
298
299
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
300
301
302
303
304
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
305
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
306
307
    for i in range(5):
        if batched:
308
309
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
310
311
312
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
313
314
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
315
316
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
317
318
319
        torch.testing.assert_allclose(
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
320
321
322
323
324
325
326
327
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
328
329
330
331
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    print(mean(errors))
    print(mean(relerrors))


Tim Dettmers's avatar
Tim Dettmers committed
338
339
340
341
342
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
343
n = 2
344
345
346
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
347
transpose = [(False, False), (False, True), (True, False), (True, True)]
348
349
350
351
352
353
354
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
    "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
    for vals in values
]


355
356
357
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
363
        shapeA = (
364
365
366
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
367
368
369
370
371
372
373
374
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
375
376
377
378
379
380
381
382
383
384
385
386
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
387

Tim Dettmers's avatar
Tim Dettmers committed
388
        torch.testing.assert_allclose(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
389

Tim Dettmers's avatar
Tim Dettmers committed
390
391
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
392
393
394
395
396
397
398
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
399
400
401
402
403
404
405
406
407
408
409
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_allclose(out.float(), out2)


n = 3
410
411
412
413
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
414
415
416
names = [
    "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
]
417
418


Tim Dettmers's avatar
Tim Dettmers committed
419
420
421
422
423
424
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
425
426
427
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
428
429
430
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
431
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
432
433
434
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
435
436
437
438
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_allclose(out.float(), out2)

439

Tim Dettmers's avatar
Tim Dettmers committed
440
n = 2
441
442
443
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
444
transpose = [False, True]
445
446
447
448
449
450
451
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
    "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
    for vals in values
]


452
453
454
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
455
456
457
458
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
459
460
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
461
462
463
464
465
466
467
468
469

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
470
471
472
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
473
        if transpose:
474
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
475
        else:
476
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
477
478
479
480
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
481
482
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
483
            out = out.float()
484
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
485
486
487
488
489
490

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
491
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
492
            out = F.igemm(Ac, Bc)
493
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
494
            out = out.float()
495
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
496
497
498
499
500
501
502
503
504
505

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

506
507
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
508

509
510
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
514
515

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
516
517
518
519
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
520
521
522
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

523

Tim Dettmers's avatar
Tim Dettmers committed
524
n = 2
525
526
527
528
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
529
transpose = [(False, False), (True, False), (False, True), (True, True)]
530
531
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
532
533
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
    for vals in values
534
535
536
]


Tim Dettmers's avatar
Tim Dettmers committed
537
538
539
540
541
542
543
544
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
545
546
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
547
548
549
550
551
552
553
554
555
556
557

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
558
559
560
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
561
562
563
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_allclose(out.float(), out2.float())

564

Tim Dettmers's avatar
Tim Dettmers committed
565
n = 1
566
567
568
569
570
571
572
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
573
574
575
576
577
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
578
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
579
580
581
582
583
584
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
        torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)


n = 2
585
586
587
588
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
589
dtype = [torch.int8, torch.int32]
590
591
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
592
593
transpose = [False]
dims = [2, 3]
594
595
596
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
597
598
599
600
601
602
603
604

names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
        *vals
    )
    for vals in values
]

Tim Dettmers's avatar
Tim Dettmers committed
605

606
@pytest.mark.parametrize(
607
608
609
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
610
)
611
612
613
def test_nvidia_transform(
    dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
):
614
615
616
617
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
618
619
620
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
621
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
622
    elif dims == 3:
623
624
625
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
626
627
628

    out, S = F.nvidia_transform(A, to_order=orderOut)

629
    if orderOut == "row":
Tim Dettmers's avatar
Tim Dettmers committed
630
        torch.testing.assert_allclose(A.flatten(), out.flatten())
631
    elif orderOut == "col":
Tim Dettmers's avatar
Tim Dettmers committed
632
        torch.testing.assert_allclose(A.t().flatten(), out.flatten())
633
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
634
        if dims == 2:
635
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
636
        elif dims == 3:
637
638
639
640
641
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
642
        assert out.numel() == n
643
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
644
        # 32 col 8 row tiles
645
646
647
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
648
649
650
651
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
652
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
653
654
655
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
656
657
658
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
659
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
660
                col2 = col % 32
661
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
662

663
664
665
666
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
667

668
    if orderOut == "col32":
669
670
671
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
Tim Dettmers's avatar
Tim Dettmers committed
672
673
674
675
        torch.testing.assert_allclose(A, out2)


n = 1
676
677
678
679
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
680

681
682
683
684
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
685

686
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
687
ldb = [0]
688
689
690
691
692
693
694
695
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
696
697
698
699
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
700
701
702
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
703
        elif dims == 3:
704
705
706
707
708
709
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
710
711
        C1 = torch.matmul(A.float(), B.t().float())

712
713
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
714
        C2, SC = F.igemmlt(A2, B2, SA, SB)
715
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
716
717
718
        torch.testing.assert_allclose(C1, C3.float())

        # transpose
719
720
721
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
722
723
        C1 = torch.matmul(A.float(), B.float())

724
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
725
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
726
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
727
728
        torch.testing.assert_allclose(C1, C3.float())

729

Tim Dettmers's avatar
Tim Dettmers committed
730
731
732
733
734
735
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
736
737
738
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
739
740
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
    for vals in values
741
742
743
]


Tim Dettmers's avatar
Tim Dettmers committed
744
745
746
747
748
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
749
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
750
        elif dims == 3:
751
752
753
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
754
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
755
756
757
758
759
760
761
762
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
763
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
764
765
766
767
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

768
769
770
771
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
772

773
        # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
774
775

        # transpose
776
777
778
779
780
781
782
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
783
784
785
786


batch_size = 2
seqdim = 512
787
788
789
790
791
792
793
794
795
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
796
797
798
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
Tim Dettmers's avatar
Tim Dettmers committed
799
800
801
802
803


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
804
805
806
807
808
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
809

810
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
811
    ## warmup
812
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
813
    #    torch.matmul(A, w1.t())
814
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
815
816
817
818
819
820
821
822

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

823
824
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
825

826
827
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
828

829
830
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
831
832
833
834
835

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

836
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
837

838
839
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
840

841
842
843
844
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
845

846
847
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
848
    ## fc1
849
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
850
851
852
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
853
854
855
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
856
857
858
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
859
860
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
861
862
863
864
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
865
866
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
867
868
869
870
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
871
872
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
873
874
875
876
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
877
878
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
879
880
881
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

882
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
883

884
885
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
886

887
888
889
890
891
892
893
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

946
947
948
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
949
950
951


n = 2
952
953
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
954

955
956
# dim1 = [2*1024]
# dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
957

958
959
# dim1 = [4]
# dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
960
961

dims = (2,)
962
963
964
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
values = list(product(dim1, dim4, dims, formatB))
965
966
967
names = [
    "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
]
968
969


Tim Dettmers's avatar
Tim Dettmers committed
970
971
972
973
974
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB):
    inner = torch.randint(1, 128, size=(1,)).item()
    formatB = F.get_special_format_str()
    for i in range(k):
975
976
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
977
978
979
980
981
        C1 = torch.matmul(A.half(), B.t().half())

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

982
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
983
984
985
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

986
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
987
988
989
990
991
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

        count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
        n = C1.numel()
        p = 0.06
992
993
994
        assert (
            count / n < p
        ), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
995
996
997

        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
        torch.testing.assert_allclose(C5, C4)
998
        # print(C2)
Tim Dettmers's avatar
Tim Dettmers committed
999
1000
1001


n = 2
1002
1003
1004
1005
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1006
1007

dims = (2,)
1008
1009
1010
1011
1012
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1013
1014
1015
1016
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1017
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1044
1045
1046
1047
1048
1049
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

        torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
        torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
        torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)

1050
1051
1052
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1053
1054
1055
1056
1057
1058
1059

        torch.testing.assert_allclose(col_stats1, col_stats2)
        torch.testing.assert_allclose(row_stats1, row_stats2)
        assert nnz_block_ptr2 is None


n = 2
1060
1061
1062
1063
1064
1065
1066
1067
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1068
1069
1070
1071

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1072
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
1083
1084
1085
1086
1087
1088
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1089
1090

        # allow for 1:500 error due to rounding differences
1091
1092
1093
1094
1095
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1096
            assert False
1097
1098
1099
1100
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1101
1102
1103
1104
1105
1106
1107
            assert False

        torch.testing.assert_allclose(Srow.flatten(), statsA)
        torch.testing.assert_allclose(Scol.flatten(), statsAt)


n = 4
1108
1109
1110
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
1113
1114
1115
1116

dim1 = [6]
dim4 = [4]
inner = [8]

values = list(zip(dim1, dim4, inner))
1117
1118
1119
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1120
1121
1122
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1123
1124
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_allclose(maxA.flatten(), stats1a)
        torch.testing.assert_allclose(maxB.flatten(), stats2a)
        torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)

1138
1139
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1140
1141
1142
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1143
1144
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1145
1146
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1147
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1148
1149
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1150
1151
1152
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
        assert err2 <= err1 * 1.01
Tim Dettmers's avatar
Tim Dettmers committed
1153
1154
1155


n = 6
1156
1157
1158
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1159
1160

values = list(zip(dim1, dim4, inner))
1161
1162
1163
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1164
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1165
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1166
1167
1168
1169
1170
1171
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1172
1173
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1174
1175
1176
1177
1178
1179
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1180
1181
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1182
1183
1184
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1185
1186
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1187
1188
1189
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1190
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1191
1192
1193
1194
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1195
1196
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1197
1198
1199
1200
1201
1202
1203
1204

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1205
1206
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1207
1208

        C = torch.matmul(CA.float(), CB.t().float())
1209
1210
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1211

1212
1213
1214
1215
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1216

1217
1218
1219
1220
1221
1222
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1223

1224
1225
1226
1227
1228
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1229
1230
1231


dim1 = [1024, 2048]
1232
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1233
1234
1235
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1236
1237
1238
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1239
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1240
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1241
1242
1243
1244
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1245
1246
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1257
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1258
1259

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1260
1261
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1262
1263
1264
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1265
1266
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1267
1268
1269
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1270
1271
1272
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1273
    torch.cuda.synchronize()
1274
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1275
1276
1277
1278
1279
1280
1281
1282

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1283
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1284
1285
1286


n = 2
1287
1288
1289
1290
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1291
1292
1293

dim3 = [0]
dtype = [torch.int8]
1294
1295
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1296
1297
transpose = [False, True]
dims = [2]
1298
1299
1300
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1301
1302
1303
1304
1305
1306
1307
1308
1309
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1310
1311
1312
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1313
)
Tim Dettmers's avatar
Tim Dettmers committed
1314
1315
1316
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1317
1318
1319
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1320
        elif dims == 3:
1321
1322
1323
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1335
1336
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1337
1338
1339

        torch.testing.assert_allclose(out1, out2)

1340

Tim Dettmers's avatar
Tim Dettmers committed
1341
n = 2
1342
1343
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1344
1345
1346
1347
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
    for vals in values
]


1358
1359
1360
@pytest.mark.parametrize(
    "dim1, dim2, dtype, orderA, orderOut", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
1361
1362
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
    for i in range(1):
1363
        A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1364
1365

        out2, S2 = F.transform(A, to_order=orderA)
1366
        A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
Tim Dettmers's avatar
Tim Dettmers committed
1367
1368
1369
        assert A2.shape[0] == A.shape[0]
        assert A2.shape[1] == A.shape[1]

1370
        print("")
Tim Dettmers's avatar
Tim Dettmers committed
1371
1372
1373
1374
        print(A)
        print(out2)
        print(A2)

1375
        # torch.testing.assert_allclose(A, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1376
1377
1378
1379


def test_overflow():
    formatB = F.get_special_format_str()
1380
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1381
    for i in range(2):
1382
1383
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1384

1385
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1386
1387
1388
1389
1390
1391
1392
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1393
1394
1395
1396
1397
1398
1399
1400
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1401
1402
1403
1404
1405

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1406
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1407

1408
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1409
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1410
1411
1412
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1413
1414

        if coo_tensor is not None:
1415
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1416
            A2 = torch.zeros_like(A)
1417
1418
1419
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
Tim Dettmers's avatar
Tim Dettmers committed
1420
1421
            torch.testing.assert_allclose(A1, A2)

1422
1423
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1424
1425
1426
            torch.testing.assert_allclose(
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1427

Tim Dettmers's avatar
Tim Dettmers committed
1428
1429

n = 2
1430
1431
1432
1433
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1434
transposed_B = [False, True]
1435
1436
1437
1438
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1439
1440
1441
1442
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1443
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1455
1456
1457
1458
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1472
1473
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1474
    seq = 1024
1475
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1476
1477
1478
    dim2 = model
    dim3 = hidden
    threshold = 4
1479
1480
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1481
    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1482
1483
1484
1485
1486
1487
1488
        C1 = bnb.matmul(A, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = bnb.matmul(A, B)
    torch.cuda.synchronize()
1489
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1490
1491
1492

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1493
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1494
1495
    rows, cols = torch.where(idx)
    values = A[idx]
1496
1497
1498
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1499
1500

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1501
1502
1503
1504
1505
1506
1507
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1508
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1509
    print(tsp, t8)
1510
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1511
1512
1513


n = 2
1514
1515
1516
1517
1518
1519
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1520
1521
1522
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1523
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1524
1525
1526
1527
1528
1529
1530
1531
1532
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1533
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1534
1535
1536
1537

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1538
1539
1540
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1541
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1542
1543
1544
1545
1546
1547
1548
1549
1550

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1551
1552
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
        assert err2 < err1


def test_matmuls():
    a = torch.randn(256, 256).half().cuda()
    b = torch.randn(256, 256).half().cuda()
    c1 = torch.matmul(a, b)
    c2 = bnb.matmul(a, b)
    c3 = bnb.matmul(a, b)

1563
1564
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1565
1566
1567
1568
1569
    assert err1 < 0.2
    assert err2 < 0.2


n = 2
1570
1571
1572
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1573
dim2 = [12288]
1574
1575
1576
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1577
dtype = [torch.float16]
1578
1579
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1580
1581
1582
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
]
1583
1584


Tim Dettmers's avatar
Tim Dettmers committed
1585
1586
1587
1588
1589
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1590
1591
1592
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1593
    if dtype == torch.float16:
1594
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1595
1596
        torch.nn.init.xavier_uniform_(B)
    else:
1597
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1598
        torch.nn.init.xavier_uniform_(B)
1599
1600
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1601

1602
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1603
1604
1605
1606
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1607
1608
1609
1610
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1611
1612
1613
1614
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1615
1616
1617
1618
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1619
    n = out1.numel()
1620
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1621
1622
1623
    std = out1.std()
    out1 /= std
    out2 /= std
1624
1625
1626
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1627
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1628
1629
1630

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1631
    # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1632

1633
1634
1635
1636
1637
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1638
1639
1640
1641
1642
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1643
1644
1645
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1646
1647

def test_layout():
1648
1649
1650
    a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
    a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
    a2, s2 = F.transform(a1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1651
1652
    print(a2.shape)

1653
    print(a1.flatten()[8 * 64 : 8 * 64 + 32])
Tim Dettmers's avatar
Tim Dettmers committed
1654
    for i in range(4):
1655
        print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
Tim Dettmers's avatar
Tim Dettmers committed
1656
1657
1658
1659
1660
1661
1662
1663
1664


def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1665
1666
1667
1668
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1669
1670
1671
1672
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1673
1674
    torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
    idx = A2 != 0
Tim Dettmers's avatar
Tim Dettmers committed
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
    torch.testing.assert_allclose(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1685
1686
1687
1688
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1689
1690
1691
1692
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1693
    torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1694
    # torch uses row-major -> use transpose to transfer to col-major
1695
    idx = A2.t() != 0
Tim Dettmers's avatar
Tim Dettmers committed
1696
1697
1698
1699
    torch.testing.assert_allclose(A2.t()[idx], cscA.values)


n = 2
1700
1701
1702
1703
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1704
dim2 = [2048]
1705
1706
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1707
dtype = [torch.int8]
1708
1709
1710
1711
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1712
1713
1714
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1715
1716
1717
1718
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1732
1733
1734
1735
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1736
1737
1738
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1739
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1740
1741
1742
1743
1744
1745
1746
1747

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)

1748
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1749
    n = out1.numel()
1750
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1751
1752
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1753
1754
1755
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1756
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1757
1758
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1759
1760
1761
1762

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1763
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1764
    torch.cuda.synchronize()
1765
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1766
1767
1768
1769

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1770
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1771
    torch.cuda.synchronize()
1772
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1773
1774
1775
1776

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1777
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1778
    torch.cuda.synchronize()
1779
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1780
1781
1782
1783

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1784
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1785
    torch.cuda.synchronize()
1786
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1787
1788
1789
1790
1791
1792

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1793
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1794
    torch.cuda.synchronize()
1795
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1796
1797
1798
1799
1800
1801
1802

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1803
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
1806
1807
1808
1809

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1810
1811
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1812
1813
1814
1815

batch_size = 1
seqdim = 2048
values = []
1816
1817
1818
1819
1820
1821
1822
1823
values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
1824
1825
1826
names = [
    "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
1827
1828


Tim Dettmers's avatar
Tim Dettmers committed
1829
1830
1831
1832
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
    formatB = F.get_special_format_str()

1833
1834
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1835
1836
1837
1838
1839
1840
1841
1842
    torch.nn.init.xavier_uniform_(B)

    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

1843
1844
1845
    linearMixedBit = (
        bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
1846
1847
1848
1849
1850
1851
    linearMixedBit.eval()

    # warmup
    for i in range(100):
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1852
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1853
1854
1855
1856
1857
1858

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1859
1860
1861
    print(
        f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1862
1863
1864
1865
1866
1867

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        bnb.matmul(A, B)
    torch.cuda.synchronize()
1868
1869
1870
    print(
        f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1871
1872

    CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
1873
    C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1874
1875
1876
1877
1878
1879
1880
    CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    CxB, SB = F.transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    torch.cuda.synchronize()
1881
1882
1883
    print(
        f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1884
1885
1886
1887
1888
1889
1890
1891

    BA, statsB = F.vectorwise_quant(B, dim=1)
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        A2 = A.view(-1, A.shape[-1]).contiguous()
        CA, statsA = F.vectorwise_quant(A2, dim=1)
1892
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1893
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1894
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
Tim Dettmers's avatar
Tim Dettmers committed
1895
1896
        F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    torch.cuda.synchronize()
1897
1898
1899
    print(
        f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1900

1901
    BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1902
1903
1904
1905
1906
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        A2 = A.view(-1, A.shape[-1]).contiguous()
1907
1908
        CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1909
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1910
1911
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
        out = Cout * statsB * statsA * (1.0 / (127 * 127))
Tim Dettmers's avatar
Tim Dettmers committed
1912
    torch.cuda.synchronize()
1913
1914
1915
    print(
        f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1916
1917
1918
1919
1920
1921
1922

    linear8bit(A)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        linear8bit(A)
    torch.cuda.synchronize()
1923
1924
1925
    print(
        f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1926
1927
1928
1929
1930
1931
1932

    linearMixedBit(A)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        linearMixedBit(A)
    torch.cuda.synchronize()
1933
1934
1935
    print(
        f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1936
1937
1938
1939
1940
1941


def test_zeropoint():
    def min_max(x):
        maxA = torch.amax(x, dim=1, keepdim=True)
        minA = torch.amin(x, dim=1, keepdim=True)
1942
1943
1944
1945
1946
        midpoint = (maxA - minA) / 2.0
        dyna = 252 / (maxA - minA)
        # dyna *= 0.98
        x = dyna * x
        x = x - torch.round((dyna * (minA + midpoint)))
Tim Dettmers's avatar
Tim Dettmers committed
1947
        return x.to(torch.int8), minA, midpoint, dyna
1948

Tim Dettmers's avatar
Tim Dettmers committed
1949
1950
1951
    batch = 2
    seq = 2
    model = 4
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
    hidden = 2 * model
    # batch = 4
    # seq = 2048
    # model = 1024
    # hidden = 8*model
    A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
    B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())

    # A[0] = 0
    # B[:, 0] = 0
    # A = A*(A>0)
    # A[0, 0] = 0
    # A[0, 0] = 6.0
Tim Dettmers's avatar
Tim Dettmers committed
1965
1966

    Ac, minA, midpoint, dyna = min_max(A)
1967
1968
1969
    # print(Ac[0, 0], 'zero')
    # print(Ac, Ac.min(), Ac.max())
    Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1970
    out = F.igemm(Ac, Bc)
1971
1972
    out2 = torch.matmul(A, B)
    offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1973
    out = out.float()
1974
1975
1976
    # print(out.shape, maxB.shape, scale.shape, offset.shape)
    norm1 = maxB / 127
    C4 = (out / dyna) * norm1 + offset
Tim Dettmers's avatar
Tim Dettmers committed
1977
1978
1979
1980
1981
1982
1983

    B1 = torch.nn.Parameter(B.clone())
    B2 = torch.nn.Parameter(B.clone())
    B3 = torch.nn.Parameter(B.clone())
    B4 = torch.nn.Parameter(B.clone())

    C1 = torch.matmul(A, B1)
1984
1985
1986
    C2 = bnb.matmul_cublas(A, B2, None, "linear")
    C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
    C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
Tim Dettmers's avatar
Tim Dettmers committed
1987

1988
1989
1990
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1991
    print(err1, err2, err3)
1992
    # assert err1 > err2
Tim Dettmers's avatar
Tim Dettmers committed
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008

    loss1 = C1.mean()
    loss2 = C2.mean()
    loss3 = C3.mean()
    loss4 = C4.mean()

    loss1.backward()
    loss2.backward()
    loss3.backward()
    loss4.backward()

    print(B.grad)
    print(B1.grad)
    print(B2.grad)
    print(B3.grad)
    print(B4.grad)
2009
2010
2011
    err1 = torch.abs(B1.grad - B2.grad).mean().item()
    err2 = torch.abs(B1.grad - B3.grad).mean().item()
    err3 = torch.abs(B1.grad - B4.grad).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2012
2013
2014
2015
2016
2017
2018
2019
    print(err1, err2, err3)


def test_zp():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
2020
2021
2022
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
2023
        minx = x.min()
2024
2025
2026
2027
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
2028
        return x, qx, zpx
2029

Tim Dettmers's avatar
Tim Dettmers committed
2030
2031
2032
    batch = 2
    seq = 512
    model = 1024
2033
2034
2035
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
2036
2037
2038

    C0 = torch.matmul(A, B)

2039
2040
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
2041
2042
2043
2044
2045
2046
2047
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
2048
2049
2050
2051
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2052
2053
2054

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
2055
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
2056
2057
2058

    zp = 1
    scale = 2.0
2059
2060
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2061
2062
2063
2064
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
2065
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
2066
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
2067

Tim Dettmers's avatar
Tim Dettmers committed
2068
2069
2070
2071
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2072
2073
2074
2075
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2076

Tim Dettmers's avatar
Tim Dettmers committed
2077
2078
2079
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2080
2081
2082
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2083

2084
2085
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
2086
2087
2088
2089
2090
2091
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2092
2093
2094
2095
2096
2097
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2098
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2099
2100


2101
def test_extract_outliers():
2102
    for i in range(k):
2103
        shapeA = (4096, 4096 * 4)
2104
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2105
2106
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2107
        outliers1 = A[:, idx.long()]
2108

2109
        CA, SA = F.transform(A, "col_turing")
2110

2111
        outliers2 = F.extract_outliers(CA, SA, idx)
2112

2113
2114
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2115

2116
2117
        torch.testing.assert_allclose(outliers1, outliers2)

2118
        CA, SA = F.transform(A, "col_ampere")
2119
2120
2121
2122
2123

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2124

2125
        torch.testing.assert_allclose(outliers1, outliers2)