test_functional.py 85.4 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
4
5
from itertools import product

6
7
8
import einops
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
9
import numpy as np
10
11

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
12
from bitsandbytes import functional as F
Tim Dettmers's avatar
Tim Dettmers committed
13
from scipy.stats import norm
Tim Dettmers's avatar
Tim Dettmers committed
14

15
torch.set_printoptions(
Tim Dettmers's avatar
Tim Dettmers committed
16
    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
17
)
Tim Dettmers's avatar
Tim Dettmers committed
18
19
k = 20

20

Tim Dettmers's avatar
Tim Dettmers committed
21
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
Tim Dettmers's avatar
Tim Dettmers committed
22
    idx = torch.isclose(a, b, rtol, atol)
23
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
24
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
25
26
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
27
            torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
28
29

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
30

31

Tim Dettmers's avatar
Tim Dettmers committed
32
33
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
34
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
38
39
40
41
42
43
44
45
46
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

47

48
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
52
53
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

54
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
55
56
57
58
59
60
61
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

62
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
63
64
65
66
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
67
68
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
69
70
71
72
73
74
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
75
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79

        return self.agg[name]

    def reset(self):
80
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
81
82
        self.ends = {}
        self.agg = {}
83
84
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
85

Tim Dettmers's avatar
Tim Dettmers committed
86
87
88
def setup():
    pass

89

Tim Dettmers's avatar
Tim Dettmers committed
90
91
92
def teardown():
    pass

93

94
95
96
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
97
def test_estimate_quantiles(dtype):
98
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
99
100
101
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

102
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
103
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
104

105
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
110
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
111
112
113
114
115
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
116
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
117
118
119
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
120
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
121
122
        assert diff < 0.0075

123
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
124
125
126
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
127
        diff = torch.abs(A1 - A2).mean().item()
128
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
129
130
131
132
133
134
135
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
136
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
137
138
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
139
140
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
141
142
143
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
144
145
    # print(sum(diffs)/len(diffs))
    # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
146
147

    for i in range(100):
148
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
149
150
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
151
        diff = torch.abs(A1 - A2).mean().item()
152
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
153
154
155
        assert diff < 0.004


156

157
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
158
159
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
160
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
161
    #print('')
162
163
164
    diffs = []
    reldiffs = []
    for i in range(100):
165
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
166
167
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
168
169
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
170
171
172
173
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
174
175
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
176
177
    assert abserr < 0.011
    assert relerr < 0.018
178
    assert A2.dtype == dtype
179
180
181

    diffs = []
    for i in range(100):
182
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
183
184
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
185
186
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
187
188
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
189
        #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
190
191
192
193
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
    assert abserr < 0.0035
    assert relerr < 0.015
194
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
195
196
    #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
197

Tim Dettmers's avatar
Tim Dettmers committed
198
199


200
201
202
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
203
def test_percentile_clipping(gtype):
204
205
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
206
207
    n = 4
    step = 0
208
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
209
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
210
        step += 1
211
212
213
214
215
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
216
217
218
219
220
221
222
223
224
225

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

226
227
228
        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)
Tim Dettmers's avatar
Tim Dettmers committed
229
230


Tim Dettmers's avatar
Tim Dettmers committed
231
232
def quant(x):
    max1 = torch.abs(x).max()
233
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
234
235
    return max1, x.to(torch.int8)

236

Tim Dettmers's avatar
Tim Dettmers committed
237
def dequant(c, maxC):
238
239
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
240
241

def mm_dequant(maxA, maxB, C):
242
243
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
244
245
246

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
247
248
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
249
250
    return max1, x.to(torch.int8)

251

Tim Dettmers's avatar
Tim Dettmers committed
252
def quant_multi_chunk(x, dim, chunk_size=32):
253
254
255
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
258
259
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
260
261
262
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
263
264
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
265
266
    return max1, x.to(torch.int8)

267

Tim Dettmers's avatar
Tim Dettmers committed
268
269
270
271
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

272

Tim Dettmers's avatar
Tim Dettmers committed
273
def mean(xx):
274
275
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
276

277
278
279
280
281
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
282
283
284
285
286
287
288
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
289
]
Tim Dettmers's avatar
Tim Dettmers committed
290
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
291
292
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
293
batched = [False, True]
294
295
296
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
297
    "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
298
    for vals in values_names
299
300
301
]


302
303
304
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
305
306
307
308
309
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Tim Dettmers's avatar
Tim Dettmers committed
310
    #print("")
Tim Dettmers's avatar
Tim Dettmers committed
311
312
    for i in range(5):
        if batched:
313
314
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
315
316
317
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
318
319
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
320
321
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
322
        torch.testing.assert_close(
323
324
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
325
326
327
328
329
330
331
332
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
333
334
335
336
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
337
338
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
339
340
    #print(mean(errors))
    #print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
341
342


Tim Dettmers's avatar
Tim Dettmers committed
343
344
345
346
347
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
348
n = 2
349
350
351
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
352
transpose = [(False, False), (False, True), (True, False), (True, True)]
353
354
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
355
    "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
356
357
358
359
    for vals in values
]


360
361
362
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
363
364
365
366
367
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
368
        shapeA = (
369
370
371
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
372
373
374
375
376
377
378
379
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
380
381
382
383
384
385
386
387
388
389
390
391
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
392

393
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
394

Tim Dettmers's avatar
Tim Dettmers committed
395
396
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
397
398
399
400
401
402
403
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
404
405
406
407
408
409
410
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

411
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
412
413
414


n = 3
415
416
417
418
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
419
names = [
420
    "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
421
]
422
423


Tim Dettmers's avatar
Tim Dettmers committed
424
425
426
427
428
429
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
430
431
432
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
433
434
435
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
436
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
437
438
439
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
440
441
        out = F.igemm(A, B, out=iout)

442
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
443

444

Tim Dettmers's avatar
Tim Dettmers committed
445
n = 2
446
447
448
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
449
transpose = [False, True]
450
451
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
452
    "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
453
454
455
456
    for vals in values
]


457
458
459
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
460
461
462
463
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
464
465
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
466
467
468
469
470
471
472
473
474

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
475
476
477
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
478
        if transpose:
479
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
480
        else:
481
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
482
483
484
485
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
486
487
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
488
            out = out.float()
489
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
490
491
492
493
494
495

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
496
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
497
            out = F.igemm(Ac, Bc)
498
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
499
            out = out.float()
500
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
501
502
503
504
505
506
507
508
509
510

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

511
512
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
513

514
515
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
516
517
518
519
520

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
521
522
523
524
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
525
526
527
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

528

Tim Dettmers's avatar
Tim Dettmers committed
529
n = 2
530
531
532
533
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
534
transpose = [(False, False), (True, False), (False, True), (True, True)]
535
536
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
537
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
538
    for vals in values
539
540
541
]


Tim Dettmers's avatar
Tim Dettmers committed
542
543
544
545
546
547
548
549
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
550
551
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
552
553
554
555
556
557
558
559
560
561
562

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
563
564
565
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
566
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
567
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
568

569

Tim Dettmers's avatar
Tim Dettmers committed
570
n = 1
571
572
573
574
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
575
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
576
577


Tim Dettmers's avatar
Tim Dettmers committed
578
579
580
581
582
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
583
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
584
585
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
586
587
588
589
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Tim Dettmers's avatar
Tim Dettmers committed
590
591
592


n = 2
593
594
595
596
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
597
dtype = [torch.int8, torch.int32]
598
599
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
600
601
transpose = [False]
dims = [2, 3]
602
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
603

604
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
605

Tim Dettmers's avatar
Tim Dettmers committed
606

607
608
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
609
610
611
612
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
Tim Dettmers's avatar
Tim Dettmers committed
613
614
615
    func = F.get_transform_func(dtype, orderA, orderOut, transpose)

    if dims == 2:
616
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
617
    elif dims == 3:
618
619
620
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
621
622
623

    out, S = F.nvidia_transform(A, to_order=orderOut)

624
    if orderOut == "row":
625
        torch.testing.assert_close(A.flatten(), out.flatten())
626
    elif orderOut == "col":
627
        torch.testing.assert_close(A.t().flatten(), out.flatten())
628
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
629
        if dims == 2:
630
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
631
        elif dims == 3:
632
633
634
635
636
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
637
        assert out.numel() == n
638
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
639
        # 32 col 8 row tiles
640
641
642
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
643
644
645
646
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
647
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
648
649
650
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
651
652
653
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
654
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
655
                col2 = col % 32
656
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
657

658
659
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
660
661
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
662

663
    if orderOut == "col32":
664
665
666
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
667
        torch.testing.assert_close(A, out2)
Tim Dettmers's avatar
Tim Dettmers committed
668
669
670


n = 1
671
672
673
674
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
675

676
677
678
679
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
680

681
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
682
ldb = [0]
683
684
685
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
686
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
687
688
689
690
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
691
692
693
694
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
695
696
697
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
698
        elif dims == 3:
699
700
701
702
703
704
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
705
706
        C1 = torch.matmul(A.float(), B.t().float())

707
708
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
709
        C2, SC = F.igemmlt(A2, B2, SA, SB)
710
        C3, S = F.nvidia_transform(C2, "row", state=SC)
711
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
712
713

        # transpose
714
715
716
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
717
718
        C1 = torch.matmul(A.float(), B.float())

719
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
720
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
721
        C3, S = F.nvidia_transform(C2, "row", state=SC)
722
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
723

724

Tim Dettmers's avatar
Tim Dettmers committed
725
726
727
728
729
730
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
731
732
733
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
734
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
735
    for vals in values
736
737
738
]


Tim Dettmers's avatar
Tim Dettmers committed
739
740
741
742
743
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
744
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
745
        elif dims == 3:
746
747
748
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
749
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
750
751
752
753
754
755
756
757
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
758
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
759
760
761
762
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

763
764
765
766
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
767

768
        # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
769
770

        # transpose
771
772
773
774
775
776
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
777
        # torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
778
779
780
781


batch_size = 2
seqdim = 512
782
783
784
785
786
787
788
789
790
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
791
names = [
792
    "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
793
]
Tim Dettmers's avatar
Tim Dettmers committed
794
795
796
797
798


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
799
800
801
802
803
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
804

805
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
806
    ## warmup
807
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
808
    #    torch.matmul(A, w1.t())
809
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
810
811
812
813
814
815
816
817

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

818
819
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
820

821
822
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
823

824
825
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
826
827
828
829
830

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

831
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
832

833
834
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
835

836
837
838
839
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
840

841
842
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
843
    ## fc1
844
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
845
846
847
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
848
849
850
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
851
852
853
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
854
855
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
856
857
858
859
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
860
861
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
862
863
864
865
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
866
867
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
868
869
870
871
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
872
873
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
874
875
876
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

877
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
878

879
880
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
881

882
883
884
885
886
887
888
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

941
942
943
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
944
945
946


n = 2
947
948
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
949

950
951
#dim1 = [2*1024]
#dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
952

Tim Dettmers's avatar
Tim Dettmers committed
953
954
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
955
956

dims = (2,)
957
formatB = ["col_turing", "col_ampere"]
958
959
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
960
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
961
962


963
964
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
965
    inner = torch.randint(1, 128, size=(1,)).item()
966
967
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
968
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
969
    for i in range(1):
970
971
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
972
        C1 = torch.matmul(A.half(), B.t().half())
973
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
974
975
976
977

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

978
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
979
980
981
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

982
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
983
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
984
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
985

986
987
988
989
990
991
992
        # TODO: is something wrong here? If so, the problem goes deeper
        #n = C1.numel()
        #p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
Tim Dettmers's avatar
Tim Dettmers committed
993
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
994

995
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
996
        #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
997
998
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
Tim Dettmers's avatar
Tim Dettmers committed
999
1000
1001


n = 2
1002
1003
1004
1005
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1006
1007

dims = (2,)
1008
1009
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
1010
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
1011
1012


Tim Dettmers's avatar
Tim Dettmers committed
1013
1014
1015
1016
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1017
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1044
1045
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

1046
1047
1048
        torch.testing.assert_close(col_stats1_trunc, col_stats2)
        torch.testing.assert_close(row_stats1_trunc, row_stats2)
        torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
Tim Dettmers's avatar
Tim Dettmers committed
1049

1050
1051
1052
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1053

1054
1055
        torch.testing.assert_close(col_stats1, col_stats2)
        torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
1056
1057
1058
1059
        assert nnz_block_ptr2 is None


n = 2
1060
1061
1062
1063
1064
1065
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
1066
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1067

Tim Dettmers's avatar
Tim Dettmers committed
1068
1069
1070
1071

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1072
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1073
1074
1075
1076
1077
1078
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
1079
1080
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
1081
1082

        n = CAt.numel()
1083
1084
1085
1086
1087
1088
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1089
1090

        # allow for 1:500 error due to rounding differences
1091
1092
1093
1094
1095
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1096
            assert False
1097
1098
1099
1100
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1101
1102
            assert False

1103
1104
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
1105
1106
1107


n = 4
1108
1109
1110
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112

values = list(zip(dim1, dim4, inner))
1113
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1114
1115


Tim Dettmers's avatar
Tim Dettmers committed
1116
1117
1118
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1119
1120
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1121
1122
1123
1124
1125
1126
1127
1128

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

1129
1130
1131
1132
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
1133

1134
1135
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1136
1137
1138
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1139
1140
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1141
1142
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1143
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1144
1145
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1146
1147
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1148
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1149
1150
1151


n = 6
1152
1153
1154
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1155
1156

values = list(zip(dim1, dim4, inner))
1157
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1158
1159


Tim Dettmers's avatar
Tim Dettmers committed
1160
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1161
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1162
1163
1164
1165
1166
1167
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1168
1169
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1170
1171
1172
1173
1174
1175
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1176
1177
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1178
1179
1180
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1181
1182
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1183
1184
1185
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1186
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1187
1188
1189
1190
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1191
1192
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1193
1194
1195
1196
1197
1198
1199
1200

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1201
1202
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1203
1204

        C = torch.matmul(CA.float(), CB.t().float())
1205
1206
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1207

1208
1209
1210
1211
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1212

1213
1214
1215
1216
1217
1218
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1219

1220
1221
1222
1223
1224
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1225
1226
1227


dim1 = [1024, 2048]
1228
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1229
1230
1231
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1232
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1233
1234


Tim Dettmers's avatar
Tim Dettmers committed
1235
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1236
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1237
1238
1239
1240
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1241
1242
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1253
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1254
1255

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1256
1257
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1258
1259
1260
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1261
1262
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1263
1264
1265
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1266
1267
1268
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1269
    torch.cuda.synchronize()
1270
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1271
1272
1273
1274
1275
1276
1277
1278

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1279
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1280
1281
1282


n = 2
1283
1284
1285
1286
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1287
1288
1289

dim3 = [0]
dtype = [torch.int8]
1290
1291
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1292
1293
transpose = [False, True]
dims = [2]
1294
1295
1296
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1297
names = [
1298
    "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
1299
1300
1301
1302
1303
1304
1305
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1306
1307
1308
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1309
)
Tim Dettmers's avatar
Tim Dettmers committed
1310
1311
1312
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1313
1314
1315
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1316
        elif dims == 3:
1317
1318
1319
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1331
1332
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1333

1334
        torch.testing.assert_close(out1, out2)
Tim Dettmers's avatar
Tim Dettmers committed
1335

1336

Tim Dettmers's avatar
Tim Dettmers committed
1337
n = 2
1338
1339
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1340
1341
1342
1343
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1344
1345
1346
1347
1348
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
1349
    "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
1350
1351
1352
1353
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1354
1355
def test_overflow():
    formatB = F.get_special_format_str()
1356
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1357
    for i in range(2):
1358
1359
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1360

1361
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1362
1363
1364
1365
1366
1367
1368
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1369
1370
1371
1372
1373
1374
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
1375
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1376

Tim Dettmers's avatar
Tim Dettmers committed
1377
1378
1379
1380
1381

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1382
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1383

1384
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1385
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1386
1387
1388
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1389
1390

        if coo_tensor is not None:
1391
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1392
            A2 = torch.zeros_like(A)
1393
1394
1395
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
1396
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1397

1398
1399
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1400
            torch.testing.assert_close(
1401
1402
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1403

Tim Dettmers's avatar
Tim Dettmers committed
1404
1405

n = 2
1406
1407
1408
1409
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1410
transposed_B = [False, True]
1411
values = list(product(dim1, dim2, transposed_B))
1412
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
1413
1414


Tim Dettmers's avatar
Tim Dettmers committed
1415
1416
1417
1418
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1419
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1431
1432
1433
1434
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1448
1449
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1450
    seq = 1024
1451
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1452
1453
1454
    dim2 = model
    dim3 = hidden
    threshold = 4
1455
1456
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1457
    for i in range(10):
1458
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1459
1460
1461
1462

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1463
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1464
    torch.cuda.synchronize()
1465
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1466
1467
1468

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1469
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1470
1471
    rows, cols = torch.where(idx)
    values = A[idx]
1472
1473
1474
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1475
1476

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1477
1478
1479
1480
1481
1482
1483
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1484
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1485
    print(tsp, t8)
1486
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1487
1488
1489


n = 2
1490
1491
1492
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
1493
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1494
1495


Tim Dettmers's avatar
Tim Dettmers committed
1496
1497
1498
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1499
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1500
1501
1502
1503
1504
1505
1506
1507
1508
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1509
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1510
1511
1512
1513

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1514
1515
1516
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1517
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1518
1519
1520
1521
1522
1523
1524
1525
1526

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1527
1528
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1529
1530
1531
1532
        assert err2 < err1


def test_matmuls():
1533
1534
1535
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1536
    c2 = bnb.matmul(a, b)
1537
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1538

1539
1540
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1541
1542
    assert err1 < 0.2
    assert err2 < 0.2
1543
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1544
1545
1546


n = 2
1547
1548
1549
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1550
dim2 = [12288]
1551
1552
1553
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1554
dtype = [torch.float16]
1555
1556
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1557
names = [
1558
    "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
1559
]
1560
1561


Tim Dettmers's avatar
Tim Dettmers committed
1562
1563
1564
1565
1566
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1567
1568
1569
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1570
    if dtype == torch.float16:
1571
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1572
1573
        torch.nn.init.xavier_uniform_(B)
    else:
1574
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1575
        torch.nn.init.xavier_uniform_(B)
1576
1577
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1578

1579
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1580
1581
1582
1583
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1584
1585
1586
1587
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1588
1589
1590
1591
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1592
1593
1594
1595
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1596
    n = out1.numel()
1597
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1598
1599
1600
    std = out1.std()
    out1 /= std
    out2 /= std
1601
1602
1603
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1604
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1605
1606
1607

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1608
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1609

1610
1611
1612
1613
1614
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1615
1616
1617
1618
1619
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1620
1621
1622
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1623
1624
1625
1626
1627
1628
1629
1630

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1631
1632
1633
1634
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1635
1636
1637
1638
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1639
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
1640
    idx = A2 != 0
1641
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1642
1643
1644
1645
1646
1647
1648
1649
1650


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1651
1652
1653
1654
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1655
1656
1657
1658
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1659
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1660
    # torch uses row-major -> use transpose to transfer to col-major
1661
    idx = A2.t() != 0
1662
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1663
1664
1665


n = 2
1666
1667
1668
1669
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1670
dim2 = [2048]
1671
1672
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1673
dtype = [torch.int8]
1674
values = list(product(dim1, dim2, dtype))
1675
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
1676
1677


Tim Dettmers's avatar
Tim Dettmers committed
1678
1679
1680
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1681
1682
1683
1684
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1698
1699
1700
1701
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1702
1703
1704
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1705
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1706
1707
1708
1709
1710
1711

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

1712
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1713

1714
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1715
    n = out1.numel()
1716
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1717
1718
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1719
1720
1721
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1722
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1723
1724
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1725
1726
1727
1728

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1729
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1730
    torch.cuda.synchronize()
1731
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1732
1733
1734
1735

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1736
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1737
    torch.cuda.synchronize()
1738
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1739
1740
1741
1742

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1743
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1744
    torch.cuda.synchronize()
1745
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1746
1747
1748
1749

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1750
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1751
    torch.cuda.synchronize()
1752
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1753
1754
1755
1756
1757
1758

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1759
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1760
    torch.cuda.synchronize()
1761
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1762
1763
1764
1765
1766
1767
1768

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1769
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1770
1771
1772
1773
1774
1775

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1776
1777
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1778

1779
batch_size = 1
1780
seqdim = 1
Tim Dettmers's avatar
Tim Dettmers committed
1781
values = []
Tim Dettmers's avatar
Tim Dettmers committed
1782
#values.append((batch_size, seqdim, 768, 4 * 768))
1783
1784
1785
1786
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
1787
1788
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
1789
1790
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
1791
#values.append((batch_size, seqdim, 5140, 4*5140))
1792
#values.append((batch_size, seqdim, 12288, 4*12288))
1793
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
1794
1795
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1796
    iters = 1000
Tim Dettmers's avatar
Tim Dettmers committed
1797
1798
    formatB = F.get_special_format_str()

1799
1800
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1801
1802
    torch.nn.init.xavier_uniform_(B)

1803
    B_fp4, state = F.quantize_fp4(B)
1804
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1805

1806
    B_nf4, state_nf4 = F.quantize_nf4(B)
1807
    B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
1808

Tim Dettmers's avatar
Tim Dettmers committed
1809
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
1810
1811
1812
1813
1814
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

Tim Dettmers's avatar
Tim Dettmers committed
1815
1816
    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
    #linearMixedBit.eval()
Tim Dettmers's avatar
Tim Dettmers committed
1817

1818
1819
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
1820
    bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1821

Tim Dettmers's avatar
Tim Dettmers committed
1822
    # warmup
1823
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1824
1825
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1826
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1827
1828
1829

    torch.cuda.synchronize()
    t0 = time.time()
1830
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1831
1832
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1833
1834
    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1835
1836
1837
1838
1839
1840
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
    #torch.cuda.synchronize()
    #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1841

1842
1843
1844
1845
1846
1847
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
    #torch.cuda.synchronize()
    #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
1848

1849
1850
1851
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1852
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1853
1854
1855
    torch.cuda.synchronize()
    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1856
1857
1858
1859
1860
1861
1862
1863
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
    torch.cuda.synchronize()
    print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )


Tim Dettmers's avatar
Tim Dettmers committed
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B, threshold=6.0)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    #C32A, SA = F.transform(CA, "col32")
    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    #CxB, SB = F.transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #torch.cuda.synchronize()
    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1)
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    #torch.cuda.synchronize()
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
    #torch.cuda.synchronize()
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

1917
1918
1919
1920
1921
1922
1923
    #linear8bit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1924

1925
1926
1927
1928
1929
1930
1931
    #linearMixedBit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linearMixedBit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947

    #linear8bit_train(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train_thresh(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1948
1949
1950
1951
1952
1953

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1954
1955
1956
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1957
        minx = x.min()
1958
1959
1960
1961
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1962
        return x, qx, zpx
1963

Tim Dettmers's avatar
Tim Dettmers committed
1964
1965
1966
    batch = 2
    seq = 512
    model = 1024
1967
1968
1969
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1970
1971
1972

    C0 = torch.matmul(A, B)

1973
1974
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1975
1976
1977
1978
1979
1980
1981
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1982
1983
1984
1985
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1986
1987
1988

    ca, cqa, cza = quant_zp(A)
    print(ca.min(), ca.max())
1989
    print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1990
1991
1992

    zp = 1
    scale = 2.0
1993
1994
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1995
1996
1997
1998
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1999
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
2000
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
2001

Tim Dettmers's avatar
Tim Dettmers committed
2002
2003
2004
2005
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2006
2007
2008
2009
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2010

Tim Dettmers's avatar
Tim Dettmers committed
2011
2012
2013
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2014
2015
2016
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2017

2018
2019
    print("")
    # print(C0.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
2020
2021
2022
2023
2024
2025
    print(C1.flatten()[:10])
    print(C2.flatten()[:10])
    print(C3.flatten()[:10])
    print(C5.flatten()[:10])
    print(C6.flatten()[:10])
    print(C7.flatten()[:10])
2026
2027
2028
2029
2030
2031
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2032
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2033
2034


2035
def test_extract_outliers():
2036
    for i in range(k):
2037
        shapeA = (4096, 4096 * 4)
2038
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2039
2040
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2041
        outliers1 = A[:, idx.long()]
2042

2043
        CA, SA = F.transform(A, "col_turing")
2044

2045
        outliers2 = F.extract_outliers(CA, SA, idx)
2046

2047
2048
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2049

2050
        torch.testing.assert_close(outliers1, outliers2)
2051

2052
        CA, SA = F.transform(A, "col_ampere")
2053
2054
2055
2056
2057

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2058

2059
        torch.testing.assert_close(outliers1, outliers2)
2060
2061
2062
2063
2064
2065
2066
2067



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2068
    for hidden in [128]:#, 14336]:
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101



def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7-e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2102
2103
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2116
2117
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2130
2131
        #print(3, sum(abserr)/len(abserr))
        #print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2132

2133
2134
2135

def test_few_bit_quant():

2136
    #print('')
2137
    for bits in range(2, 9):
2138
        #print('='*30, bits, '='*30)
Tim Dettmers's avatar
Tim Dettmers committed
2139
2140
2141
        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
2142
2143
            code = None
            if method == 'linear':
2144
                code = F.create_linear_map(True, total_bits=bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2145
2146
2147
2148
            elif method == 'fp8':
                ebits = math.ceil(bits/2)
                pbits = bits-ebits-1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2149
2150
2151
2152
            elif method == 'dynamic':
                code = F.create_dynamic_map(True, bits-0, bits).cuda()
            elif method == 'quantile':
                values = torch.randn(2048, 2048, device='cuda')
Tim Dettmers's avatar
Tim Dettmers committed
2153
2154
2155
2156
2157
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
2158
            #print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
            assert code.numel() == 256
            for i in range(10):

                values = torch.randn(1, 32, device='cuda')
                values /= values.abs().max()
                #values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v-code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
2176
2177
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
2178
2179

                idx = torch.isclose(q1.int(), q2.int())
Tim Dettmers's avatar
Tim Dettmers committed
2180
2181
2182
                err2 = torch.abs(v2-values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2/(1e-10+values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2183
2184
2185
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1-values).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2186
                    #assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
2187
2188

                else:
2189
                    torch.testing.assert_close(q1, q2)
2190
            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
Tim Dettmers's avatar
Tim Dettmers committed
2191
    #assert False
Tim Dettmers's avatar
Tim Dettmers committed
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 4):
            total_values = 2**bits-1
            p = np.linspace(0, 1, 2*total_values+1)
            idx = np.arange(1, 2*total_values+1, 2)
            p = p[idx]
            offset = 1/(2*total_values)
            p = np.linspace(offset, 1-offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2216
            assert err < 0.035
2217
2218
2219
2220


def test_bench_dequantization():
    a = torch.rand(1024, 1024, device='cuda').half()
2221
2222
2223
    code =F.create_fp8_map(True, 3, 0, 4).cuda()
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
2224
2225
2226
2227
2228
2229
2230

    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
    #print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
2231
        qa, SA = F.quantize_blockwise(a)
2232
2233
2234
    torch.cuda.synchronize()
    #print((time.time()-t0)/1e6)

2235
2236


2237
2238
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_fp4_quant(dtype):
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign*8 + e1*4 + e2*2 + p1*1
        sign = -1.0 if sign else 1.0
        exp = e1*2 + e2*1
        if exp == 0:
            # sub-normal
            if p1 == 0: result = 0
            else: result = sign*0.0625
        else:
            # normal
            exp = 2**(-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign*exp*frac
        code[idx] = result

2260
    A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
2261
2262
2263
2264
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
2265
    relerr = (err/(A1.abs().float()+1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2266
    idx = err > 1.0
2267
2268
    err = err.mean()

2269
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
2270
2271
    assert err.item() < 0.1
    assert relerr.item() < 0.28
2272
2273


Tim Dettmers's avatar
Tim Dettmers committed
2274
2275
2276
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
2277
2278
2279
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
2280
        for i in range(10):
2281
            A1 = torch.randn(1024, 1024, device='cuda').half()
2282
2283
2284
2285
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
2286
2287
2288
2289
2290
2291


            err = (A1 - A2).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2292
2293
            errs1.append(err.item())

2294
2295
2296
2297
2298
2299
2300
2301

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2302
            errs2.append(err.item())
2303
2304
2305
2306

            assert err.item() < 0.11
            assert relerr.item() < 0.28

2307
2308
        #print(sum(errs1)/len(errs1), blocksize, quant_type)
        #print(sum(errs2)/len(errs2), blocksize, quant_type)
2309
2310
2311
2312




Tim Dettmers's avatar
Tim Dettmers committed
2313
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
Tim Dettmers's avatar
Tim Dettmers committed
2314
2315
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
2316
def test_bench_4bit_dequant(quant_type):
2317
2318
    blocksize = 256
    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
2319
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2320
2321
2322
2323
2324
2325

    input_size = a.numel()/2
    output_size = a.numel()*2
    num_bytes = input_size+output_size
    GB = num_bytes/1e9
    max_theoretical_s =  GB/768
2326
    #print(max_theoretical_s*1e6)
2327
2328
    b = torch.randn(128, 1024*12, device='cuda').half()

Tim Dettmers's avatar
Tim Dettmers committed
2329
    iters = 100
2330
2331
2332
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2333
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
2334
2335
        #b.copy_(a)
    torch.cuda.synchronize()
2336
2337
2338
2339
2340
2341
2342
2343
    #print((time.time()-t0)/iters*1e6)

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    torch.matmul(b, a.t())
    #torch.cuda.synchronize()
    #print((time.time()-t0)/iters*1e6)
2344
2345
2346
2347
2348
2349
2350



def test_normal_map_tree():
    code = F.create_normal_map()
    values =code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
Tim Dettmers's avatar
Tim Dettmers committed
2351
    print(values)
2352
2353
2354
2355
2356
2357
2358
2359
2360
    while num_pivots <16:
        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
        print(idx)
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i-1]+values[i])/2)
        print(pivots)

Tim Dettmers's avatar
Tim Dettmers committed
2361

2362
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
2363
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
2364
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
2365
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2366
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
2367
    for dim in [128, 256, 512, 1024]:
2368
    #for dim in [4*1024]:
Tim Dettmers's avatar
Tim Dettmers committed
2369
    #for dim in [1*16]:
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

2380

2381
        for i in range(100):
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
            if kind == 'fc1':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'fc2':
                A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn_packed':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
Tim Dettmers's avatar
Tim Dettmers committed
2394

2395
            qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
2396
            C3 = torch.matmul(A, B.t())
2397
            C2 = F.gemv_4bit(A, qB.t(), state=state)
2398
2399
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
2400

2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
            err1 = (C1-C2).abs().float()
            err2 = (C3-C2).abs().float()
            err3 = (C3-C1).abs().float()

            mag1 = torch.abs(C1).float()+1e-5
            mag2 = torch.abs(C3).float()+1e-5
            mag3 = torch.abs(C3).float()+1e-5

            relerr1 = err1/mag1
            relerr2 = err2/mag2
            relerr3 = err3/mag3
2412

2413
2414
2415
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
2416

2417
2418
2419
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2420

2421
2422
2423
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
2424

2425
2426
2427
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
2428
2429

            c = int(C1.numel()*0.0014*(dim/256))+1
Tim Dettmers's avatar
Tim Dettmers committed
2430

2431
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
        err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
        err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
        err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
        relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
        relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
        relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
        maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
        maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
        maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
        absratio = err2/err3
        relratio = relerr2/relerr3
        maxratio = relerr2/relerr3

        # for debugging if the tests fails
        #
        #print('='*80)
        #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
Tim Dettmers's avatar
Tim Dettmers committed
2449
2450
2451
2452
2453
        print(C1.flatten()[-20:])
        print(C2.flatten()[-20:])
        print(f'inference vs training abs: {err1}')
        print(f'inference vs training rel: {relerr1}')
        print(f'inference vs training max: {maxerr1}')
2454
2455
2456
        #print(f'inference vs training vs torch err ratio abs: {absratio}')
        #print(f'inference vs training vs torch err ratio rel: {relratio}')
        #print(f'inference vs training vs torch err ratio max: {maxratio}')
2457
        if dtype == torch.float16:
2458
2459
2460
2461
2462
2463
2464
2465
2466
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2467
        elif dtype == torch.float32:
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2479
        elif dtype == torch.bfloat16:
2480
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
2481
                assert err1 < 6e-4
2482
2483
2484
2485
2486
2487
2488
2489
2490
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
2491

2492
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
def test_managed():
    n = 32*10
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
    assert A.page_deviceid==0
    assert B.page_deviceid==0
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
    assert (A==17).sum().item() == n*n
    assert (B==17).sum().item() == n*n
    C = A*B.float()
    assert (C==289).sum().item() == n*n
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
    assert (A==17*(2**3)).sum().item() == n*n
   # F.prefetch_tensor(A)
   # F.prefetch_tensor(B)


   # F.fill(B2, 17.0)
   # F._mul(A, B2)

   # F.prefetch_tensor(A, to_cpu=True)
   # F.prefetch_tensor(B, to_cpu=True)
   # F.prefetch_tensor(B2, to_cpu=True)
   # torch.cuda.synchronize()

   # assert (A==17).sum().item() == n*n

2527
   # torch.testing.assert_close(A, torch.ones(A.shape)*289)
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555


@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
    dims = torch.randint(0, 8192, size=(dims,)).tolist()
    dims = [dim + (64-(dim % 64)) for dim in dims]
    #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
    for dim in dims:
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
        B = torch.eye(dim, dtype=dtype, device='cuda')

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
        #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)