test_functional.py 66.9 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
9
10
import einops
import pytest
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
11
12
from bitsandbytes import functional as F

13
14
15
torch.set_printoptions(
    precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
Tim Dettmers's avatar
Tim Dettmers committed
16
17
k = 20

18

Tim Dettmers's avatar
Tim Dettmers committed
19
20
def assert_all_approx_close(a, b, rtol, atol, count):
    idx = torch.isclose(a, b, rtol, atol)
21
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
22
    if sumval > count:
23
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
24
25
        torch.testing.assert_allclose(a, b, rtol, atol)

26

Tim Dettmers's avatar
Tim Dettmers committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
        super(FFN, self).__init__()
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

42

Tim Dettmers's avatar
Tim Dettmers committed
43
44
45
46
47
48
class Timer(object):
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

49
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
50
51
52
53
54
55
56
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

57
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
58
59
60
61
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
62
63
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
64
65
66
67
68
69
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
70
            print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
Tim Dettmers's avatar
Tim Dettmers committed
71
72
73
74

        return self.agg[name]

    def reset(self):
75
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
76
77
        self.ends = {}
        self.agg = {}
78
79
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
80

Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
def setup():
    pass

84

Tim Dettmers's avatar
Tim Dettmers committed
85
86
87
def teardown():
    pass

88
89

@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
Tim Dettmers's avatar
Tim Dettmers committed
90
def test_estimate_quantiles(dtype):
91
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
92
93
94
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

95
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
96
97
    torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)

98
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
99
100
101
102
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
103
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
104
105
106
107
108
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
109
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
110
111
112
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
113
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
114
115
        assert diff < 0.0075

116
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
117
118
119
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
120
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
121
122
123
124
125
126
127
128
        torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
129
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
130
131
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
132
133
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
134
135
136
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
137
138
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
139
140

    for i in range(100):
141
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
142
143
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
144
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
145
146
147
148
149
150
151
152
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
        assert diff < 0.004


def test_dynamic_blockwise_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
153
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
154
155
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
156
157
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
158
159
160
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diffs[-1] < 0.011
161
162
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
163
164
165

    diffs = []
    for i in range(100):
166
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
167
168
        C, S = F.quantize_blockwise(A1)
        A2 = F.dequantize_blockwise(C, S)
169
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
170
171
172
        assert diff < 0.0033
        diffs.append(diff)
        torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
173
174
    # print(sum(diffs)/len(diffs))

Tim Dettmers's avatar
Tim Dettmers committed
175
176
177
178
179
180

def test_dynamic_blockwise_stochastic_quantization():
    diffs = []
    reldiffs = []
    rand = torch.rand(1024).cuda()
    for i in range(100):
181
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
182
183
184
185
        C1, S1 = F.quantize_blockwise(A1, rand=rand)
        C2, S2 = F.quantize_blockwise(A1)
        # a maximunm distance of quantized values of 1
        torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
186
187
188
189
190
        fraction_smaller = (C1 < C2).float().sum() / C1.numel()
        fraction_larger = (C1 > C2).float().sum() / C1.numel()
        torch.testing.assert_allclose(
            fraction_larger, fraction_smaller, atol=0.01, rtol=0
        )
Tim Dettmers's avatar
Tim Dettmers committed
191
192


193
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
Tim Dettmers's avatar
Tim Dettmers committed
194
def test_percentile_clipping(gtype):
195
196
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
197
198
    n = 4
    step = 0
199
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
200
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
201
        step += 1
202
203
204
205
206
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

        torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_allclose(clip1, clip2)
        torch.testing.assert_allclose(gnorm1, gnorm2)


Tim Dettmers's avatar
Tim Dettmers committed
222
223
def quant(x):
    max1 = torch.abs(x).max()
224
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
225
226
    return max1, x.to(torch.int8)

227

Tim Dettmers's avatar
Tim Dettmers committed
228
def dequant(c, maxC):
229
230
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
231
232

def mm_dequant(maxA, maxB, C):
233
234
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
235
236
237

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
238
239
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
240
241
    return max1, x.to(torch.int8)

242

Tim Dettmers's avatar
Tim Dettmers committed
243
def quant_multi_chunk(x, dim, chunk_size=32):
244
245
246
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
247
248
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
249
250
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
251
252
253
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
254
255
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
    return max1, x.to(torch.int8)

258

Tim Dettmers's avatar
Tim Dettmers committed
259
260
261
262
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

263

Tim Dettmers's avatar
Tim Dettmers committed
264
def mean(xx):
265
266
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
267

268
269
270
271
272
273
274
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
    (lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)
]
Tim Dettmers's avatar
Tim Dettmers committed
275
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
276
277
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
278
batched = [False, True]
279
280
281
282
283
284
285
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
    "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names
]


Tim Dettmers's avatar
Tim Dettmers committed
286
287
288
289
290
291
@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
292
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
293
294
    for i in range(5):
        if batched:
295
296
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
297
298
299
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
300
301
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
302
303
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
304
305
306
        torch.testing.assert_allclose(
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
307
308
309
310
311
312
313
314
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
315
316
317
318
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
319
320
321
322
323
324
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
    print(mean(errors))
    print(mean(relerrors))


Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
329
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
330
n = 2
331
332
333
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
334
transpose = [(False, False), (False, True), (True, False), (True, True)]
335
336
337
338
339
340
341
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
    "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
342
343
344
345
346
347
@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
348
349
350
351
352
353
354
355
356
357
        shapeA = (
            (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
358
359
360
361
362
363
364
365
366
367
368
369
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
370

Tim Dettmers's avatar
Tim Dettmers committed
371
        torch.testing.assert_allclose(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
372

Tim Dettmers's avatar
Tim Dettmers committed
373
374
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
375
376
377
378
379
380
381
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
382
383
384
385
386
387
388
389
390
391
392
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

        torch.testing.assert_allclose(out.float(), out2)


n = 3
393
394
395
396
397
398
399
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
400
401
402
403
404
405
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
406
407
408
409
410
411
412
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
        B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(
            torch.int8
        )
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
Tim Dettmers's avatar
Tim Dettmers committed
413
414
415
416
417
        iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
        out = F.igemm(A, B, out=iout)

        torch.testing.assert_allclose(out.float(), out2)

418

Tim Dettmers's avatar
Tim Dettmers committed
419
n = 2
420
421
422
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
423
transpose = [False, True]
424
425
426
427
428
429
430
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
    "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
431
432
433
434
435
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
436
437
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
438
439
440
441
442
443
444
445
446

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
447
        A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
448
        if transpose:
449
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
450
        else:
451
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
452
453
454
455
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
456
457
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
458
            out = out.float()
459
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
460
461
462
463
464
465

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
466
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
467
            out = F.igemm(Ac, Bc)
468
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
469
            out = out.float()
470
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
471
472
473
474
475
476
477
478
479
480

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

481
482
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
483

484
485
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
486
487
488
489
490

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
491
492
493
494
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
495
496
497
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

498

Tim Dettmers's avatar
Tim Dettmers committed
499
n = 2
500
501
502
503
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
504
transpose = [(False, False), (True, False), (False, True), (True, True)]
505
506
507
508
509
510
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
511
512
513
514
515
516
517
518
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
519
520
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
        torch.testing.assert_allclose(out.float(), out2.float())

536

Tim Dettmers's avatar
Tim Dettmers committed
537
n = 1
538
539
540
541
542
543
544
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
545
546
547
548
549
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
550
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
551
552
553
554
555
556
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
        torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)


n = 2
557
558
559
560
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
561
dtype = [torch.int8, torch.int32]
562
563
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
564
565
transpose = [False]
dims = [2, 3]
566
567
568
569
570
571
572
573
574
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))

names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
        *vals
    )
    for vals in values
]

Tim Dettmers's avatar
Tim Dettmers committed
575

576
577
578
@pytest.mark.parametrize(
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
579
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
580
581
582
583
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
584
585
586
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
587
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
588
    elif dims == 3:
589
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
590
591
592

    out, S = F.nvidia_transform(A, to_order=orderOut)

593
    if orderOut == "row":
Tim Dettmers's avatar
Tim Dettmers committed
594
        torch.testing.assert_allclose(A.flatten(), out.flatten())
595
    elif orderOut == "col":
Tim Dettmers's avatar
Tim Dettmers committed
596
        torch.testing.assert_allclose(A.t().flatten(), out.flatten())
597
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
598
        if dims == 2:
599
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
600
        elif dims == 3:
601
            n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
602
        assert out.numel() == n
603
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
604
        # 32 col 8 row tiles
605
606
607
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
608
609
610
611
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
612
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
613
614
615
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
616
617
                rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
618
                col2 = col % 32
619
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
620

621
622
623
624
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
                # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
625

626
627
    if orderOut == "col32":
        out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
Tim Dettmers's avatar
Tim Dettmers committed
628
629
630
631
        torch.testing.assert_allclose(A, out2)


n = 1
632
633
634
635
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
636

637
638
639
640
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
641

642
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
643
ldb = [0]
644
645
646
647
648
649
650
651
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
652
653
654
655
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
656
657
658
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
659
        elif dims == 3:
660
661
662
663
            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
                torch.int8
            )
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
664
665
        C1 = torch.matmul(A.float(), B.t().float())

666
667
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
668
        C2, SC = F.igemmlt(A2, B2, SA, SB)
669
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
670
671
672
        torch.testing.assert_allclose(C1, C3.float())

        # transpose
673
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
674
675
        C1 = torch.matmul(A.float(), B.float())

676
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
677
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
678
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
679
680
        torch.testing.assert_allclose(C1, C3.float())

681

Tim Dettmers's avatar
Tim Dettmers committed
682
683
684
685
686
687
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
688
689
690
691
692
693
694
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
695
696
697
698
699
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
700
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
701
        elif dims == 3:
702
703
            A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
704
705
706
707
708
709
710
711
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
712
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
713
714
715
716
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

717
718
719
720
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
721

722
        # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
723
724

        # transpose
725
726
727
728
729
730
731
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
732
733
734
735


batch_size = 2
seqdim = 512
736
737
738
739
740
741
742
743
744
745
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
746
747
748
749
750


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
751
752
753
754
755
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
756

757
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
758
    ## warmup
759
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
760
    #    torch.matmul(A, w1.t())
761
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
762
763
764
765
766
767
768
769

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

770
771
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
772

773
774
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
775

776
777
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
778
779
780
781
782

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

783
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
784

785
786
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
787

788
789
790
791
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
792

793
794
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
795
    ## fc1
796
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
797
798
799
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
800
801
802
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
803
804
805
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
806
807
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
808
809
810
811
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
812
813
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
814
815
816
817
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
818
819
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
820
821
822
823
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
824
825
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
826
827
828
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

829
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
830

831
832
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
833

834
835
836
837
838
839
840
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

893
894
895
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
896
897
898


n = 2
899
900
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
901

902
903
# dim1 = [2*1024]
# dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
904

905
906
# dim1 = [4]
# dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
907
908

dims = (2,)
909
910
911
912
913
914
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
values = list(product(dim1, dim4, dims, formatB))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
915
916
917
918
919
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB):
    inner = torch.randint(1, 128, size=(1,)).item()
    formatB = F.get_special_format_str()
    for i in range(k):
920
921
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
922
923
924
925
926
        C1 = torch.matmul(A.half(), B.t().half())

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

927
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
928
929
930
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

931
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
932
933
934
935
936
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

        count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
        n = C1.numel()
        p = 0.06
937
938
939
        assert (
            count / n < p
        ), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
940
941
942

        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
        torch.testing.assert_allclose(C5, C4)
943
        # print(C2)
Tim Dettmers's avatar
Tim Dettmers committed
944
945
946


n = 2
947
948
949
950
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
951
952

dims = (2,)
953
954
955
956
957
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
958
959
960
961
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
962
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
963
964
965
966
967
968
969
970
971
972
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

        torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
        torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
        torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)

        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)

        torch.testing.assert_allclose(col_stats1, col_stats2)
        torch.testing.assert_allclose(row_stats1, row_stats2)
        assert nnz_block_ptr2 is None


n = 2
1003
1004
1005
1006
1007
1008
1009
1010
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1011
1012
1013
1014

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1015
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
        torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)

        n = CAt.numel()
1026
1027
        num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
1028
1029

        # allow for 1:500 error due to rounding differences
1030
1031
1032
1033
1034
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1035
            assert False
1036
1037
1038
1039
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1040
1041
1042
1043
1044
1045
1046
            assert False

        torch.testing.assert_allclose(Srow.flatten(), statsA)
        torch.testing.assert_allclose(Scol.flatten(), statsAt)


n = 4
1047
1048
1049
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1050
1051
1052
1053
1054
1055

dim1 = [6]
dim4 = [4]
inner = [8]

values = list(zip(dim1, dim4, inner))
1056
1057
1058
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1059
1060
1061
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1062
1063
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

        torch.testing.assert_allclose(maxA.flatten(), stats1a)
        torch.testing.assert_allclose(maxB.flatten(), stats2a)
        torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)

1077
1078
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1079
1080
1081
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1082
1083
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1084
1085
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1086
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1087
1088
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1089
1090
1091
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
        assert err2 <= err1 * 1.01
Tim Dettmers's avatar
Tim Dettmers committed
1092
1093
1094


n = 6
1095
1096
1097
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1098
1099

values = list(zip(dim1, dim4, inner))
1100
1101
1102
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1103
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1104
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1105
1106
1107
1108
1109
1110
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1111
1112
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1113
1114
1115
1116
1117
1118
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1119
1120
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1121
1122
1123
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1124
1125
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
Tim Dettmers's avatar
Tim Dettmers committed
1126
        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
1127
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1128
1129
1130
1131
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1132
1133
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1134
1135
1136
1137
1138
1139
1140
1141

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1142
1143
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1144
1145

        C = torch.matmul(CA.float(), CB.t().float())
1146
1147
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1148

1149
1150
1151
1152
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1153

1154
1155
1156
1157
1158
1159
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1160

1161
1162
1163
1164
1165
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1166
1167
1168


dim1 = [1024, 2048]
1169
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1170
1171
1172
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1173
1174
1175
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1176
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1177
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1178
1179
1180
1181
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1182
1183
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1194
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1195
1196

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1197
1198
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1199
1200
1201
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1202
1203
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1204
1205
1206
1207
1208
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
    torch.cuda.synchronize()
1209
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1210
1211
1212
1213
1214
1215
1216
1217

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1218
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1219
1220
1221


n = 2
1222
1223
1224
1225
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1226
1227
1228

dim3 = [0]
dtype = [torch.int8]
1229
1230
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1231
1232
transpose = [False, True]
dims = [2]
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = [
    "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
1245
1246
1247
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1248
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1249
        elif dims == 3:
1250
            A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1262
1263
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1264
1265
1266

        torch.testing.assert_allclose(out1, out2)

1267

Tim Dettmers's avatar
Tim Dettmers committed
1268
n = 2
1269
1270
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1271
1272
1273
1274
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
    "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1285
1286
1287
@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
    for i in range(1):
1288
        A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1289
1290

        out2, S2 = F.transform(A, to_order=orderA)
1291
        A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
Tim Dettmers's avatar
Tim Dettmers committed
1292
1293
1294
        assert A2.shape[0] == A.shape[0]
        assert A2.shape[1] == A.shape[1]

1295
        print("")
Tim Dettmers's avatar
Tim Dettmers committed
1296
1297
1298
1299
        print(A)
        print(out2)
        print(A2)

1300
        # torch.testing.assert_allclose(A, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1301
1302
1303
1304


def test_overflow():
    formatB = F.get_special_format_str()
1305
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1306
    for i in range(2):
1307
1308
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1309

1310
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1311
1312
1313
1314
1315
1316
1317
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1318
1319
1320
1321
1322
1323
1324
1325
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]

Tim Dettmers's avatar
Tim Dettmers committed
1326
1327
1328
1329
1330

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1331
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1332

1333
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1334
1335
1336
1337
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)

        if coo_tensor is not None:
1338
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1339
1340
1341
1342
            A2 = torch.zeros_like(A)
            A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
            torch.testing.assert_allclose(A1, A2)

1343
1344
1345
1346
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
            torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)

Tim Dettmers's avatar
Tim Dettmers committed
1347
1348

n = 2
1349
1350
1351
1352
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1353
transposed_B = [False, True]
1354
1355
1356
1357
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1358
1359
1360
1361
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1362
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1374
1375
1376
1377
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1391
1392
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1393
    seq = 1024
1394
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1395
1396
1397
    dim2 = model
    dim3 = hidden
    threshold = 4
1398
1399
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1400
    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1401
1402
1403
1404
1405
1406
1407
        C1 = bnb.matmul(A, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = bnb.matmul(A, B)
    torch.cuda.synchronize()
1408
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1409
1410
1411

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1412
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1413
1414
    rows, cols = torch.where(idx)
    values = A[idx]
1415
1416
1417
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1418
1419

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1420
1421
1422
1423
1424
1425
1426
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1427
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1428
    print(tsp, t8)
1429
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1430
1431
1432


n = 2
1433
1434
1435
1436
1437
1438
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1439
1440
1441
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1442
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1443
1444
1445
1446
1447
1448
1449
1450
1451
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1452
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1453
1454
1455
1456
1457

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
1458
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1459
1460
1461
1462
1463
1464
1465
1466
1467

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1468
1469
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
        assert err2 < err1


def test_matmuls():
    a = torch.randn(256, 256).half().cuda()
    b = torch.randn(256, 256).half().cuda()
    c1 = torch.matmul(a, b)
    c2 = bnb.matmul(a, b)
    c3 = bnb.matmul(a, b)

1480
1481
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1482
1483
1484
1485
1486
    assert err1 < 0.2
    assert err2 < 0.2


n = 2
1487
1488
1489
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1490
dim2 = [12288]
1491
1492
1493
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1494
dtype = [torch.float16]
1495
1496
1497
1498
1499
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1500
1501
1502
1503
1504
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1505
1506
1507
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1508
    if dtype == torch.float16:
1509
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1510
1511
        torch.nn.init.xavier_uniform_(B)
    else:
1512
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1513
        torch.nn.init.xavier_uniform_(B)
1514
1515
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1516

1517
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1518
1519
1520
1521
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1522
1523
1524
1525
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1526
1527
1528
1529
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1530
1531
1532
1533
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1534
    n = out1.numel()
1535
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1536
1537
1538
1539
    std = out1.std()
    out1 /= std
    out2 /= std
    assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
1540
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1541
1542
1543

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1544
    # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1545

1546
1547
1548
1549
1550
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1551
1552
1553
1554
1555
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1556
1557
1558
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1559
1560

def test_layout():
1561
1562
1563
    a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
    a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
    a2, s2 = F.transform(a1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1564
1565
    print(a2.shape)

1566
    print(a1.flatten()[8 * 64 : 8 * 64 + 32])
Tim Dettmers's avatar
Tim Dettmers committed
1567
    for i in range(4):
1568
        print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
Tim Dettmers's avatar
Tim Dettmers committed
1569
1570
1571
1572
1573
1574
1575
1576
1577


def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1578
1579
1580
1581
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1582
1583
1584
1585
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1586
1587
    torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
    idx = A2 != 0
Tim Dettmers's avatar
Tim Dettmers committed
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
    torch.testing.assert_allclose(A2[idx], csrA.values)


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1598
1599
1600
1601
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1602
1603
1604
1605
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1606
    torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1607
    # torch uses row-major -> use transpose to transfer to col-major
1608
    idx = A2.t() != 0
Tim Dettmers's avatar
Tim Dettmers committed
1609
1610
1611
1612
    torch.testing.assert_allclose(A2.t()[idx], cscA.values)


n = 2
1613
1614
1615
1616
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1617
dim2 = [2048]
1618
1619
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1620
dtype = [torch.int8]
1621
1622
1623
1624
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1625
1626
1627
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1628
1629
1630
1631
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1645
1646
1647
1648
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1649
1650
1651
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1652
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1653
1654
1655
1656
1657
1658
1659
1660

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

    torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)

1661
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1662
    n = out1.numel()
1663
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1664
1665
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1666
1667
1668
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1669
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1670
1671
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1672
1673
1674
1675

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1676
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1677
    torch.cuda.synchronize()
1678
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1679
1680
1681
1682

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1683
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1684
    torch.cuda.synchronize()
1685
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1686
1687
1688
1689

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1690
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1691
    torch.cuda.synchronize()
1692
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1693
1694
1695
1696

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1697
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1698
    torch.cuda.synchronize()
1699
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1700
1701
1702
1703
1704
1705

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1706
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1707
    torch.cuda.synchronize()
1708
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1709
1710
1711
1712
1713
1714
1715

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1716
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1717
1718
1719
1720
1721
1722

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1723
1724
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1725
1726
1727
1728

batch_size = 1
seqdim = 2048
values = []
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
1740
1741
1742
1743
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
    formatB = F.get_special_format_str()

1744
1745
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1746
1747
1748
1749
1750
1751
1752
1753
    torch.nn.init.xavier_uniform_(B)

    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

1754
1755
1756
    linearMixedBit = (
        bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
1757
1758
1759
1760
1761
1762
    linearMixedBit.eval()

    # warmup
    for i in range(100):
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1763
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1764
1765
1766
1767
1768
1769

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1770
1771
1772
    print(
        f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1773
1774
1775
1776
1777
1778

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        bnb.matmul(A, B)
    torch.cuda.synchronize()
1779
1780
1781
    print(
        f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1782
1783

    CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
1784
    C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1785
1786
1787
1788
1789
1790
1791
    CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    CxB, SB = F.transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    torch.cuda.synchronize()
1792
1793
1794
    print(
        f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1795
1796
1797
1798
1799
1800
1801
1802

    BA, statsB = F.vectorwise_quant(B, dim=1)
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        A2 = A.view(-1, A.shape[-1]).contiguous()
        CA, statsA = F.vectorwise_quant(A2, dim=1)
1803
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1804
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1805
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
Tim Dettmers's avatar
Tim Dettmers committed
1806
1807
        F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    torch.cuda.synchronize()
1808
1809
1810
    print(
        f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1811

1812
    BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1813
1814
1815
1816
1817
    CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        A2 = A.view(-1, A.shape[-1]).contiguous()
1818
1819
        CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
        C32A, SA = F.nvidia_transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1820
        out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
1821
1822
        Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
        out = Cout * statsB * statsA * (1.0 / (127 * 127))
Tim Dettmers's avatar
Tim Dettmers committed
1823
    torch.cuda.synchronize()
1824
1825
1826
    print(
        f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1827
1828
1829
1830
1831
1832
1833

    linear8bit(A)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        linear8bit(A)
    torch.cuda.synchronize()
1834
1835
1836
    print(
        f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1837
1838
1839
1840
1841
1842
1843

    linearMixedBit(A)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        linearMixedBit(A)
    torch.cuda.synchronize()
1844
1845
1846
    print(
        f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
    )
Tim Dettmers's avatar
Tim Dettmers committed
1847
1848
1849
1850
1851
1852


def test_zeropoint():
    def min_max(x):
        maxA = torch.amax(x, dim=1, keepdim=True)
        minA = torch.amin(x, dim=1, keepdim=True)
1853
1854
1855
1856
1857
        midpoint = (maxA - minA) / 2.0
        dyna = 252 / (maxA - minA)
        # dyna *= 0.98
        x = dyna * x
        x = x - torch.round((dyna * (minA + midpoint)))
Tim Dettmers's avatar
Tim Dettmers committed
1858
        return x.to(torch.int8), minA, midpoint, dyna
1859

Tim Dettmers's avatar
Tim Dettmers committed
1860
1861
1862
    batch = 2
    seq = 2
    model = 4
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
    hidden = 2 * model
    # batch = 4
    # seq = 2048
    # model = 1024
    # hidden = 8*model
    A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
    B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())

    # A[0] = 0
    # B[:, 0] = 0
    # A = A*(A>0)
    # A[0, 0] = 0
    # A[0, 0] = 6.0
Tim Dettmers's avatar
Tim Dettmers committed
1876
1877

    Ac, minA, midpoint, dyna = min_max(A)
1878
1879
1880
    # print(Ac[0, 0], 'zero')
    # print(Ac, Ac.min(), Ac.max())
    Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1881
    out = F.igemm(Ac, Bc)
1882
1883
    out2 = torch.matmul(A, B)
    offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1884
    out = out.float()
1885
1886
1887
    # print(out.shape, maxB.shape, scale.shape, offset.shape)
    norm1 = maxB / 127
    C4 = (out / dyna) * norm1 + offset
Tim Dettmers's avatar
Tim Dettmers committed
1888
1889
1890
1891
1892
1893
1894

    B1 = torch.nn.Parameter(B.clone())
    B2 = torch.nn.Parameter(B.clone())
    B3 = torch.nn.Parameter(B.clone())
    B4 = torch.nn.Parameter(B.clone())

    C1 = torch.matmul(A, B1)
1895
1896
1897
    C2 = bnb.matmul_cublas(A, B2, None, "linear")
    C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
    C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
Tim Dettmers's avatar
Tim Dettmers committed
1898

1899
1900
1901
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1902
    print(err1, err2, err3)
1903
    # assert err1 > err2
Tim Dettmers's avatar
Tim Dettmers committed
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919

    loss1 = C1.mean()
    loss2 = C2.mean()
    loss3 = C3.mean()
    loss4 = C4.mean()

    loss1.backward()
    loss2.backward()
    loss3.backward()
    loss4.backward()

    print(B.grad)
    print(B1.grad)
    print(B2.grad)
    print(B3.grad)
    print(B4.grad)
1920
1921
1922
    err1 = torch.abs(B1.grad - B2.grad).mean().item()
    err2 = torch.abs(B1.grad - B3.grad).mean().item()
    err3 = torch.abs(B1.grad - B4.grad).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1923
1924
1925
1926
1927
1928
1929
1930
    print(err1, err2, err3)


def test_zp():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1931
1932
1933
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1934
        minx = x.min()
1935
1936
1937
1938
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1939
        return x, qx, zpx
1940

Tim Dettmers's avatar
Tim Dettmers committed
1941
1942
1943
    batch = 2
    seq = 512
    model = 1024
1944
1945
1946
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1947
1948
1949

    C0 = torch.matmul(A, B)

1950
1951
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1952
1953
1954
1955
1956
1957
1958
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1959
1960
1961
1962
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1963
1964
1965

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
1966
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1967
1968
1969

    zp = 1
    scale = 2.0
1970
1971
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1972
1973
1974
1975
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1976
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1977
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1978

Tim Dettmers's avatar
Tim Dettmers committed
1979
1980
1981
1982
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
1983
1984
1985
1986
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1987

Tim Dettmers's avatar
Tim Dettmers committed
1988
1989
1990
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
1991
1992
1993
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1994

1995
1996
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
1997
1998
1999
2000
2001
2002
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2003
2004
2005
2006
2007
2008
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2009
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2010
2011


2012
def test_extract_outliers():
2013
    for i in range(k):
2014
        shapeA = (4096, 4096 * 4)
2015
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2016
2017
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2018
        outliers1 = A[:, idx.long()]
2019

2020
        CA, SA = F.transform(A, "col_turing")
2021

2022
        outliers2 = F.extract_outliers(CA, SA, idx)
2023

2024
2025
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2026

2027
2028
        torch.testing.assert_allclose(outliers1, outliers2)

2029
        CA, SA = F.transform(A, "col_ampere")
2030
2031
2032
2033
2034

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2035

2036
        torch.testing.assert_allclose(outliers1, outliers2)