test_functional.py 85.8 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2
3
4
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
5

6
import einops
Aarni Koskela's avatar
Aarni Koskela committed
7
import numpy as np
8
import pytest
Aarni Koskela's avatar
Aarni Koskela committed
9
from scipy.stats import norm
10
11
12
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
13
14
from bitsandbytes import functional as F

15
torch.set_printoptions(
Tim Dettmers's avatar
Tim Dettmers committed
16
    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
17
)
Tim Dettmers's avatar
Tim Dettmers committed
18
19
k = 20

20

Tim Dettmers's avatar
Tim Dettmers committed
21
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
Tim Dettmers's avatar
Tim Dettmers committed
22
    idx = torch.isclose(a, b, rtol, atol)
23
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
24
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
25
26
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
27
            torch.testing.assert_close(a, b, rtol, atol)
Tim Dettmers's avatar
Tim Dettmers committed
28
29

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
30

31

Tim Dettmers's avatar
Tim Dettmers committed
32
33
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
34
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
35
36
37
38
39
40
41
42
43
44
45
46
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

47

48
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
52
53
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

54
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
55
56
57
58
59
60
61
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

62
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
63
64
65
66
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
67
68
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
69
70
71
72
73
74
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
75
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79

        return self.agg[name]

    def reset(self):
80
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
81
82
        self.ends = {}
        self.agg = {}
83
84
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
85

Tim Dettmers's avatar
Tim Dettmers committed
86
87
88
def setup():
    pass

89

Tim Dettmers's avatar
Tim Dettmers committed
90
91
92
def teardown():
    pass

93

94
95
96
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
97
def test_estimate_quantiles(dtype):
98
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
99
100
101
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

102
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
103
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
104

105
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
109
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
110
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
111
112
113
114
115
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
116
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
117
118
119
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
120
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
121
122
        assert diff < 0.0075

123
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
124
125
126
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
127
        diff = torch.abs(A1 - A2).mean().item()
128
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
129
130
131
        assert diff < 0.001


132

Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
137
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
138
139
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
140
141
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
142
143
144
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
145
146
    print(sum(diffs)/len(diffs))
    print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
147
148

    for i in range(100):
149
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
150
151
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
152
        diff = torch.abs(A1 - A2).mean().item()
153
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
154
155
156
        assert diff < 0.004


157

158
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
159
160
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
161
162
@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False'])
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
163
    #print('')
164
165
166
    diffs = []
    reldiffs = []
    for i in range(100):
167
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
168
169
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
170
171
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
172
173
174
175
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
176
177
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
178
179
    assert abserr < 0.011
    assert relerr < 0.018
180
    assert A2.dtype == dtype
181
182

    diffs = []
183
    code = F.create_dynamic_map(signed=signed)
184
    for i in range(100):
185
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
186
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
187
        A2 = F.dequantize_blockwise(C, S)
188
189
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
190
191
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
192
        #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
193
194
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
195
196
197
198
199
200
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
201
    assert A2.dtype == dtype
202
203
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
204

Tim Dettmers's avatar
Tim Dettmers committed
205
206


207
208
209
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
210
def test_percentile_clipping(gtype):
211
212
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
213
214
    n = 4
    step = 0
215
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
216
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
217
        step += 1
218
219
220
221
222
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
223
224
225
226
227
228
229
230
231
232

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

233
234
235
        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)
Tim Dettmers's avatar
Tim Dettmers committed
236
237


Tim Dettmers's avatar
Tim Dettmers committed
238
239
def quant(x):
    max1 = torch.abs(x).max()
240
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
241
242
    return max1, x.to(torch.int8)

243

Tim Dettmers's avatar
Tim Dettmers committed
244
def dequant(c, maxC):
245
246
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
247
248

def mm_dequant(maxA, maxB, C):
249
250
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
251
252
253

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
254
255
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
256
257
    return max1, x.to(torch.int8)

258

Tim Dettmers's avatar
Tim Dettmers committed
259
def quant_multi_chunk(x, dim, chunk_size=32):
260
261
262
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
263
264
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
265
266
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
267
268
269
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
270
271
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
272
273
    return max1, x.to(torch.int8)

274

Tim Dettmers's avatar
Tim Dettmers committed
275
276
277
278
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

279

Tim Dettmers's avatar
Tim Dettmers committed
280
def mean(xx):
281
282
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
283

284
285
286
287
288
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
289
290
291
292
293
294
295
    (
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
    )
296
]
Tim Dettmers's avatar
Tim Dettmers committed
297
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
298
299
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
Tim Dettmers's avatar
Tim Dettmers committed
300
batched = [False, True]
301
302
303
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
304
    "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
305
    for vals in values_names
306
307
308
]


309
310
311
@pytest.mark.parametrize(
    "dim1, dim2, quant_methods, batched", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
312
313
314
315
316
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Tim Dettmers's avatar
Tim Dettmers committed
317
    #print("")
Tim Dettmers's avatar
Tim Dettmers committed
318
319
    for i in range(5):
        if batched:
320
321
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
322
323
324
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
325
326
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
327
328
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
329
        torch.testing.assert_close(
330
331
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
337
338
339
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
340
341
342
343
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
344
345
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
346
347
    #print(mean(errors))
    #print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
348
349


Tim Dettmers's avatar
Tim Dettmers committed
350
351
352
353
354
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Tim Dettmers's avatar
Tim Dettmers committed
355
n = 2
356
357
358
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
359
transpose = [(False, False), (False, True), (True, False), (True, True)]
360
361
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
362
    "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
363
364
365
366
    for vals in values
]


367
368
369
@pytest.mark.parametrize(
    "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
370
371
372
373
374
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
375
        shapeA = (
376
377
378
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
379
380
381
382
383
384
385
386
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
387
388
389
390
391
392
393
394
395
396
397
398
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
399

400
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
401

Tim Dettmers's avatar
Tim Dettmers committed
402
403
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
404
405
406
407
408
409
410
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
411
412
413
414
415
416
417
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

418
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
419
420
421


n = 3
422
423
424
425
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
426
names = [
427
    "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
428
]
429
430


Tim Dettmers's avatar
Tim Dettmers committed
431
432
433
434
435
436
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
437
438
439
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
440
441
442
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
443
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
444
445
446
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
447
448
        out = F.igemm(A, B, out=iout)

449
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
450

451

Tim Dettmers's avatar
Tim Dettmers committed
452
n = 2
453
454
455
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
456
transpose = [False, True]
457
458
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
459
    "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
460
461
462
463
    for vals in values
]


464
465
466
@pytest.mark.parametrize(
    "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
Tim Dettmers's avatar
Tim Dettmers committed
467
468
469
470
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
471
472
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
473
474
475
476
477
478
479
480
481

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
482
483
484
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
485
        if transpose:
486
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
487
        else:
488
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
489
490
491
492
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
493
494
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
495
            out = out.float()
496
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
497
498
499
500
501
502

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
503
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
504
            out = F.igemm(Ac, Bc)
505
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
506
            out = out.float()
507
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
508
509
510
511
512
513
514
515
516
517

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

518
519
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
520

521
522
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
523
524
525
526
527

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
528
529
530
531
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
532
533
534
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

535

Tim Dettmers's avatar
Tim Dettmers committed
536
n = 2
537
538
539
540
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
541
transpose = [(False, False), (True, False), (False, True), (True, True)]
542
543
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
544
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
545
    for vals in values
546
547
548
]


Tim Dettmers's avatar
Tim Dettmers committed
549
550
551
552
553
554
555
556
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
557
558
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
559
560
561
562
563
564
565
566
567
568
569

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
570
571
572
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
573
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
574
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
575

576

Tim Dettmers's avatar
Tim Dettmers committed
577
n = 1
578
579
580
581
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
582
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
583
584


Tim Dettmers's avatar
Tim Dettmers committed
585
586
587
588
589
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
590
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
591
592
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
593
594
595
596
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Tim Dettmers's avatar
Tim Dettmers committed
597
598
599


n = 2
600
601
602
603
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
Tim Dettmers's avatar
Tim Dettmers committed
604
dtype = [torch.int8, torch.int32]
605
606
a_order = ["row"]
out_order = ["col", "row", "col32"]
Tim Dettmers's avatar
Tim Dettmers committed
607
608
transpose = [False]
dims = [2, 3]
609
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
610

611
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
612

Tim Dettmers's avatar
Tim Dettmers committed
613

614
615
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
616
617
618
619
    if dims == 3 and out_order != "col32":
        return
    if dtype == torch.int32 and out_order != "col32":
        return
620
621
622
623
    try:
        func = F.get_transform_func(dtype, orderA, orderOut, transpose)
    except ValueError as ve:
        pytest.skip(str(ve))  # skip if not supported
Tim Dettmers's avatar
Tim Dettmers committed
624
625

    if dims == 2:
626
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
627
    elif dims == 3:
628
629
630
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
631
632
633

    out, S = F.nvidia_transform(A, to_order=orderOut)

634
    if orderOut == "row":
635
        torch.testing.assert_close(A.flatten(), out.flatten())
636
    elif orderOut == "col":
637
        torch.testing.assert_close(A.t().flatten(), out.flatten())
638
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
639
        if dims == 2:
640
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
641
        elif dims == 3:
642
643
644
645
646
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
647
        assert out.numel() == n
648
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
649
        # 32 col 8 row tiles
650
651
652
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
653
654
655
656
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
657
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
658
659
660
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
661
662
663
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
664
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
665
                col2 = col % 32
666
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
667

668
669
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
670
671
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
672

673
    if orderOut == "col32":
674
675
676
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
677
        torch.testing.assert_close(A, out2)
Tim Dettmers's avatar
Tim Dettmers committed
678
679
680


n = 1
681
682
683
684
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
685

686
687
688
689
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
690

691
dims = (2, 3)
Tim Dettmers's avatar
Tim Dettmers committed
692
ldb = [0]
693
694
695
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
696
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
697
698
699
700
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
701
702
703
704
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
705
706
707
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
708
        elif dims == 3:
709
710
711
712
713
714
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
715
716
        C1 = torch.matmul(A.float(), B.t().float())

717
718
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
719
        C2, SC = F.igemmlt(A2, B2, SA, SB)
720
        C3, S = F.nvidia_transform(C2, "row", state=SC)
721
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
722
723

        # transpose
724
725
726
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
727
728
        C1 = torch.matmul(A.float(), B.float())

729
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
730
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
731
        C3, S = F.nvidia_transform(C2, "row", state=SC)
732
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
733

734

Tim Dettmers's avatar
Tim Dettmers committed
735
736
737
738
739
740
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]

dims = (2,)
741
742
743
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
744
    "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
745
    for vals in values
746
747
748
]


Tim Dettmers's avatar
Tim Dettmers committed
749
750
751
752
753
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
754
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
755
        elif dims == 3:
756
757
758
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
759
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
760
761
762
763
764
765
766
767
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
768
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
769
770
771
772
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

773
774
775
776
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
777

778
        # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
779
780

        # transpose
781
782
783
784
785
786
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
787
        # torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
788
789
790
791


batch_size = 2
seqdim = 512
792
793
794
795
796
797
798
799
800
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
    (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
    (batch_size, seqdim, 5120, 3 * 5120),
    (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
]


# values = list(product(batch, seq, model, hidden))
801
names = [
802
    "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
803
]
Tim Dettmers's avatar
Tim Dettmers committed
804
805
806
807
808


@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
809
810
811
812
813
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
814

815
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
816
    ## warmup
817
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
818
    #    torch.matmul(A, w1.t())
819
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
820
821
822
823
824
825
826
827

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

828
829
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
830

831
832
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
833

834
835
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
836
837
838
839
840

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

841
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
842

843
844
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
845

846
847
848
849
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
850

851
852
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
853
    ## fc1
854
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
855
856
857
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
858
859
860
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
861
862
863
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
864
865
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
866
867
868
869
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
870
871
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
872
873
874
875
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
876
877
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
878
879
880
881
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
882
883
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
884
885
886
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

887
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
888

889
890
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
891

892
893
894
895
896
897
898
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

951
952
953
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
954
955
956


n = 2
957
958
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
959

960
961
#dim1 = [2*1024]
#dim4 = [2*1024]
Tim Dettmers's avatar
Tim Dettmers committed
962

Tim Dettmers's avatar
Tim Dettmers committed
963
964
#dim1 = [4]
#dim4 = [4]
Tim Dettmers's avatar
Tim Dettmers committed
965
966

dims = (2,)
967
formatB = ["col_turing", "col_ampere"]
968
969
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
970
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
971
972


973
974
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
975
    inner = torch.randint(1, 128, size=(1,)).item()
976
977
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
978
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
979
    for i in range(1):
980
981
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
982
        C1 = torch.matmul(A.half(), B.t().half())
983
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
984
985
986
987

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

988
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
989
990
991
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

992
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
993
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
994
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
995

996
997
998
999
1000
1001
1002
        # TODO: is something wrong here? If so, the problem goes deeper
        #n = C1.numel()
        #p = 0.06
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
        #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
Tim Dettmers's avatar
Tim Dettmers committed
1003
        #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
1004

1005
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
1006
        #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
1007
1008
        n = C5.numel()
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
Tim Dettmers's avatar
Tim Dettmers committed
1009
1010
1011


n = 2
1012
1013
1014
1015
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1016
1017

dims = (2,)
1018
1019
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
1020
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
1021
1022


Tim Dettmers's avatar
Tim Dettmers committed
1023
1024
1025
1026
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
1027
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
1054
1055
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

1056
1057
1058
        torch.testing.assert_close(col_stats1_trunc, col_stats2)
        torch.testing.assert_close(row_stats1_trunc, row_stats2)
        torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
Tim Dettmers's avatar
Tim Dettmers committed
1059

1060
1061
1062
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
1063

1064
1065
        torch.testing.assert_close(col_stats1, col_stats2)
        torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
1066
1067
1068
1069
        assert nnz_block_ptr2 is None


n = 2
1070
1071
1072
1073
1074
1075
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()

values = list(product(dim1, dim2))
1076
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1077

Tim Dettmers's avatar
Tim Dettmers committed
1078
1079
1080
1081

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
    for i in range(k):
1082
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1083
1084
1085
1086
1087
1088
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
1089
1090
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
1091
1092

        n = CAt.numel()
1093
1094
1095
1096
1097
1098
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
1099
1100

        # allow for 1:500 error due to rounding differences
1101
1102
1103
1104
1105
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1106
            assert False
1107
1108
1109
1110
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
            assert False

1113
1114
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
1115
1116
1117


n = 4
1118
1119
1120
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1121
1122

values = list(zip(dim1, dim4, inner))
1123
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1124
1125


Tim Dettmers's avatar
Tim Dettmers committed
1126
1127
1128
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1129
1130
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1131
1132
1133
1134
1135
1136
1137
1138

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

1139
1140
1141
1142
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
1143

1144
1145
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1146
1147
1148
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1149
1150
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1151
1152
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1153
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1154
1155
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1156
1157
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1158
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1159
1160
1161


n = 6
1162
1163
1164
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1165
1166

values = list(zip(dim1, dim4, inner))
1167
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1168
1169


Tim Dettmers's avatar
Tim Dettmers committed
1170
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1171
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1172
1173
1174
1175
1176
1177
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1178
1179
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1180
1181
1182
1183
1184
1185
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1186
1187
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1188
1189
1190
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1191
1192
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1193
1194
1195
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1196
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1197
1198
1199
1200
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1201
1202
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1203
1204
1205
1206
1207
1208
1209
1210

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1211
1212
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1213
1214

        C = torch.matmul(CA.float(), CB.t().float())
1215
1216
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1217

1218
1219
1220
1221
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1222

1223
1224
1225
1226
1227
1228
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1229

1230
1231
1232
1233
1234
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1235
1236
1237


dim1 = [1024, 2048]
1238
inner = [12288 * 4, 4096 * 4]
Tim Dettmers's avatar
Tim Dettmers committed
1239
1240
1241
dim4 = [12288, 4096]

values = list(zip(dim1, dim4, inner))
1242
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
1243
1244


Tim Dettmers's avatar
Tim Dettmers committed
1245
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
1246
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1247
1248
1249
1250
def test_row_scale_bench(dim1, dim4, inner):
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1251
1252
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1263
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1264
1265

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1266
1267
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1268
1269
1270
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1271
1272
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1273
1274
1275
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1276
1277
1278
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1279
    torch.cuda.synchronize()
1280
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1281
1282
1283
1284
1285
1286
1287
1288

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1289
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1290
1291
1292


n = 2
1293
1294
1295
1296
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
Tim Dettmers's avatar
Tim Dettmers committed
1297
1298
1299

dim3 = [0]
dtype = [torch.int8]
1300
1301
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
Tim Dettmers's avatar
Tim Dettmers committed
1302
1303
transpose = [False, True]
dims = [2]
1304
1305
1306
values = list(
    product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
1307
names = [
1308
    "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
1309
1310
1311
1312
1313
1314
1315
        *vals
    )
    for vals in values
]


@pytest.mark.parametrize(
1316
1317
1318
    "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
    values,
    ids=names,
1319
)
Tim Dettmers's avatar
Tim Dettmers committed
1320
1321
1322
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1323
1324
1325
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1326
        elif dims == 3:
1327
1328
1329
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1341
1342
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1343

1344
        torch.testing.assert_close(out1, out2)
Tim Dettmers's avatar
Tim Dettmers committed
1345

1346

Tim Dettmers's avatar
Tim Dettmers committed
1347
n = 2
1348
1349
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
Tim Dettmers's avatar
Tim Dettmers committed
1350
1351
1352
1353
dim1 = [1]
dim2 = [33]

dtype = [torch.int8]
1354
1355
1356
1357
1358
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
1359
    "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
1360
1361
1362
1363
    for vals in values
]


Tim Dettmers's avatar
Tim Dettmers committed
1364
1365
def test_overflow():
    formatB = F.get_special_format_str()
1366
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1367
    for i in range(2):
1368
1369
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1370

1371
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1372
1373
1374
1375
1376
1377
1378
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


n = 2
1379
1380
1381
1382
1383
1384
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]

values = list(product(dim1, dim2))
1385
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1386

Tim Dettmers's avatar
Tim Dettmers committed
1387
1388
1389
1390
1391

@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1392
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1393

1394
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1395
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1396
1397
1398
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1399
1400

        if coo_tensor is not None:
1401
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1402
            A2 = torch.zeros_like(A)
1403
1404
1405
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
1406
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1407

1408
1409
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1410
            torch.testing.assert_close(
1411
1412
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1413

Tim Dettmers's avatar
Tim Dettmers committed
1414
1415

n = 2
1416
1417
1418
1419
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
Tim Dettmers's avatar
Tim Dettmers committed
1420
transposed_B = [False, True]
1421
values = list(product(dim1, dim2, transposed_B))
1422
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
1423
1424


Tim Dettmers's avatar
Tim Dettmers committed
1425
1426
1427
1428
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1429
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1441
1442
1443
1444
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


def test_spmm_bench():
    batch = 2
1458
1459
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1460
    seq = 1024
1461
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1462
1463
1464
    dim2 = model
    dim3 = hidden
    threshold = 4
1465
1466
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1467
    for i in range(10):
1468
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1469
1470
1471
1472

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1473
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1474
    torch.cuda.synchronize()
1475
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1476
1477
1478

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1479
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1480
1481
    rows, cols = torch.where(idx)
    values = A[idx]
1482
1483
1484
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1485
1486

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1487
1488
1489
1490
1491
1492
1493
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1494
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1495
    print(tsp, t8)
1496
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1497
1498
1499


n = 2
1500
1501
1502
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
1503
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
1504
1505


Tim Dettmers's avatar
Tim Dettmers committed
1506
1507
1508
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1509
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1510
1511
1512
1513
1514
1515
1516
1517
1518
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1519
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1520
1521
1522
1523

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1524
1525
1526
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1527
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1528
1529
1530
1531
1532
1533
1534
1535
1536

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1537
1538
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1539
1540
1541
1542
        assert err2 < err1


def test_matmuls():
1543
1544
1545
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1546
    c2 = bnb.matmul(a, b)
1547
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1548

1549
1550
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1551
1552
    assert err1 < 0.2
    assert err2 < 0.2
1553
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1554
1555
1556


n = 2
1557
1558
1559
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
Tim Dettmers's avatar
Tim Dettmers committed
1560
dim2 = [12288]
1561
1562
1563
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
Tim Dettmers's avatar
Tim Dettmers committed
1564
dtype = [torch.float16]
1565
1566
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
1567
names = [
1568
    "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
1569
]
1570
1571


Tim Dettmers's avatar
Tim Dettmers committed
1572
1573
1574
1575
1576
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1577
1578
1579
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1580
    if dtype == torch.float16:
1581
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1582
1583
        torch.nn.init.xavier_uniform_(B)
    else:
1584
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1585
        torch.nn.init.xavier_uniform_(B)
1586
1587
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1588

1589
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1590
1591
1592
1593
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1594
1595
1596
1597
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1598
1599
1600
1601
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1602
1603
1604
1605
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1606
    n = out1.numel()
1607
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1608
1609
1610
    std = out1.std()
    out1 /= std
    out2 /= std
1611
1612
1613
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1614
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1615
1616
1617

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1618
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1619

1620
1621
1622
1623
1624
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1625
1626
1627
1628
1629
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1630
1631
1632
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1633
1634
1635
1636
1637
1638
1639
1640

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1641
1642
1643
1644
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1645
1646
1647
1648
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1649
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
1650
    idx = A2 != 0
1651
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1652
1653
1654
1655
1656
1657
1658
1659
1660


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1661
1662
1663
1664
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1665
1666
1667
1668
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1669
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1670
    # torch uses row-major -> use transpose to transfer to col-major
1671
    idx = A2.t() != 0
1672
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1673
1674
1675


n = 2
1676
1677
1678
1679
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
Tim Dettmers's avatar
Tim Dettmers committed
1680
dim2 = [2048]
1681
1682
# dim1 = [2]
# dim2 = [2]
Tim Dettmers's avatar
Tim Dettmers committed
1683
dtype = [torch.int8]
1684
values = list(product(dim1, dim2, dtype))
1685
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
1686
1687


Tim Dettmers's avatar
Tim Dettmers committed
1688
1689
1690
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1691
1692
1693
1694
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1708
1709
1710
1711
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1712
1713
1714
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1715
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1716
1717
1718
1719
1720
1721

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

1722
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1723

1724
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1725
    n = out1.numel()
1726
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1727
1728
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1729
1730
1731
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1732
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1733
1734
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1735
1736
1737
1738

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1739
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1740
    torch.cuda.synchronize()
1741
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1742
1743
1744
1745

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1746
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1747
    torch.cuda.synchronize()
1748
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1749
1750
1751
1752

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1753
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1754
    torch.cuda.synchronize()
1755
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1756
1757
1758
1759

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1760
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1761
    torch.cuda.synchronize()
1762
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1763
1764
1765
1766
1767
1768

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1769
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1770
    torch.cuda.synchronize()
1771
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1772
1773
1774
1775
1776
1777
1778

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1779
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1780
1781
1782
1783
1784
1785

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1786
1787
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1788

1789
batch_size = 1
1790
seqdim = 1
Tim Dettmers's avatar
Tim Dettmers committed
1791
values = []
Tim Dettmers's avatar
Tim Dettmers committed
1792
#values.append((batch_size, seqdim, 768, 4 * 768))
1793
1794
1795
1796
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
1797
1798
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
1799
1800
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
1801
#values.append((batch_size, seqdim, 5140, 4*5140))
1802
#values.append((batch_size, seqdim, 12288, 4*12288))
1803
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
1806
    iters = 1000
Tim Dettmers's avatar
Tim Dettmers committed
1807
1808
    formatB = F.get_special_format_str()

1809
1810
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1811
1812
    torch.nn.init.xavier_uniform_(B)

1813
    B_fp4, state = F.quantize_fp4(B)
1814
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1815

1816
    B_nf4, state_nf4 = F.quantize_nf4(B)
1817
    B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
1818

Tim Dettmers's avatar
Tim Dettmers committed
1819
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
1820
1821
1822
1823
1824
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

Tim Dettmers's avatar
Tim Dettmers committed
1825
1826
    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
    #linearMixedBit.eval()
Tim Dettmers's avatar
Tim Dettmers committed
1827

1828
1829
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
1830
    bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1831

Tim Dettmers's avatar
Tim Dettmers committed
1832
    # warmup
1833
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1834
1835
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1836
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1837
1838
1839

    torch.cuda.synchronize()
    t0 = time.time()
1840
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1841
1842
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1843
1844
    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1845
1846
1847
1848
1849
1850
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
    #torch.cuda.synchronize()
    #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1851

1852
1853
1854
1855
1856
1857
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
    #torch.cuda.synchronize()
    #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
1858

1859
1860
1861
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1862
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1863
1864
1865
    torch.cuda.synchronize()
    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1866
1867
1868
1869
1870
1871
1872
1873
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
    torch.cuda.synchronize()
    print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )


Tim Dettmers's avatar
Tim Dettmers committed
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B, threshold=6.0)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    #C32A, SA = F.transform(CA, "col32")
    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    #CxB, SB = F.transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #torch.cuda.synchronize()
    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1)
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    #torch.cuda.synchronize()
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
    #torch.cuda.synchronize()
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

1927
1928
1929
1930
1931
1932
1933
    #linear8bit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1934

1935
1936
1937
1938
1939
1940
1941
    #linearMixedBit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linearMixedBit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957

    #linear8bit_train(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train_thresh(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1958
1959
1960
1961
1962
1963

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1964
1965
1966
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1967
        minx = x.min()
1968
1969
1970
1971
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1972
        return x, qx, zpx
1973

Tim Dettmers's avatar
Tim Dettmers committed
1974
1975
1976
    batch = 2
    seq = 512
    model = 1024
1977
1978
1979
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1980
1981
1982

    C0 = torch.matmul(A, B)

1983
1984
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1985
1986
1987
1988
1989
1990
1991
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1992
1993
1994
1995
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1996
1997

    ca, cqa, cza = quant_zp(A)
1998
1999
    #print(ca.min(), ca.max())
    #print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
2000
2001
2002

    zp = 1
    scale = 2.0
2003
2004
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
2005
2006
2007
2008
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
2009
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
2010
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
2011

Tim Dettmers's avatar
Tim Dettmers committed
2012
2013
2014
2015
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
2016
2017
2018
2019
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2020

Tim Dettmers's avatar
Tim Dettmers committed
2021
2022
2023
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
2024
2025
2026
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
2027

2028
    #print("")
2029
    # print(C0.flatten()[:10])
2030
2031
2032
2033
2034
2035
    #print(C1.flatten()[:10])
    #print(C2.flatten()[:10])
    #print(C3.flatten()[:10])
    #print(C5.flatten()[:10])
    #print(C6.flatten()[:10])
    #print(C7.flatten()[:10])
2036
2037
2038
2039
2040
2041
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
2042
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
2043
2044


2045
def test_extract_outliers():
2046
    for i in range(k):
2047
        shapeA = (4096, 4096 * 4)
2048
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
2049
2050
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
2051
        outliers1 = A[:, idx.long()]
2052

2053
        CA, SA = F.transform(A, "col_turing")
2054

2055
        outliers2 = F.extract_outliers(CA, SA, idx)
2056

2057
2058
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2059

2060
        torch.testing.assert_close(outliers1, outliers2)
2061

2062
        CA, SA = F.transform(A, "col_ampere")
2063
2064
2065
2066
2067

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
2068

2069
        torch.testing.assert_close(outliers1, outliers2)
2070
2071
2072
2073
2074
2075
2076
2077



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
2078
    for hidden in [128]:#, 14336]:
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111



def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7-e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2112
2113
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2126
2127
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
2140
2141
        #print(3, sum(abserr)/len(abserr))
        #print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
2142

2143
2144
2145

def test_few_bit_quant():

2146
    #print('')
2147
    for bits in range(2, 9):
2148
        #print('='*30, bits, '='*30)
Tim Dettmers's avatar
Tim Dettmers committed
2149
2150
2151
        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
2152
2153
            code = None
            if method == 'linear':
2154
                code = F.create_linear_map(True, total_bits=bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2155
2156
2157
2158
            elif method == 'fp8':
                ebits = math.ceil(bits/2)
                pbits = bits-ebits-1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
2159
2160
2161
2162
            elif method == 'dynamic':
                code = F.create_dynamic_map(True, bits-0, bits).cuda()
            elif method == 'quantile':
                values = torch.randn(2048, 2048, device='cuda')
Tim Dettmers's avatar
Tim Dettmers committed
2163
2164
2165
2166
2167
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
2168
            #print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
            assert code.numel() == 256
            for i in range(10):

                values = torch.randn(1, 32, device='cuda')
                values /= values.abs().max()
                #values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v-code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
2186
2187
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
2188
2189

                idx = torch.isclose(q1.int(), q2.int())
Tim Dettmers's avatar
Tim Dettmers committed
2190
2191
2192
                err2 = torch.abs(v2-values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2/(1e-10+values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2193
2194
2195
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1-values).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2196
                    #assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
2197
2198

                else:
2199
                    torch.testing.assert_close(q1, q2)
2200
            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
Tim Dettmers's avatar
Tim Dettmers committed
2201
    #assert False
Tim Dettmers's avatar
Tim Dettmers committed
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 4):
            total_values = 2**bits-1
            p = np.linspace(0, 1, 2*total_values+1)
            idx = np.arange(1, 2*total_values+1, 2)
            p = p[idx]
            offset = 1/(2*total_values)
            p = np.linspace(offset, 1-offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2226
            assert err < 0.035
2227
2228
2229
2230


def test_bench_dequantization():
    a = torch.rand(1024, 1024, device='cuda').half()
2231
2232
2233
    code =F.create_fp8_map(True, 3, 0, 4).cuda()
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
2234
2235
2236
2237
2238
2239
2240

    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
    #print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
2241
        qa, SA = F.quantize_blockwise(a)
2242
2243
2244
    torch.cuda.synchronize()
    #print((time.time()-t0)/1e6)

2245
2246


2247
2248
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_fp4_quant(dtype):
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign*8 + e1*4 + e2*2 + p1*1
        sign = -1.0 if sign else 1.0
        exp = e1*2 + e2*1
        if exp == 0:
            # sub-normal
            if p1 == 0: result = 0
            else: result = sign*0.0625
        else:
            # normal
            exp = 2**(-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign*exp*frac
        code[idx] = result

2270
    A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
2271
2272
2273
2274
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
2275
    relerr = (err/(A1.abs().float()+1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2276
    idx = err > 1.0
2277
2278
    err = err.mean()

2279
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
2280
2281
    assert err.item() < 0.1
    assert relerr.item() < 0.28
2282
2283


Tim Dettmers's avatar
Tim Dettmers committed
2284
2285
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
2286
2287
2288
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
2289
        for i in range(10):
2290
            A1 = torch.randn(1024, 1024, device='cuda').half()
2291
2292
2293
2294
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
2295
2296
2297
2298
2299
2300


            err = (A1 - A2).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2301
2302
            errs1.append(err.item())

2303
2304
2305
2306
2307
2308
2309
2310

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2311
            errs2.append(err.item())
2312
2313
2314
2315

            assert err.item() < 0.11
            assert relerr.item() < 0.28

2316
2317
        #print(sum(errs1)/len(errs1), blocksize, quant_type)
        #print(sum(errs2)/len(errs2), blocksize, quant_type)
2318
2319
2320
2321




Tim Dettmers's avatar
Tim Dettmers committed
2322
2323
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
2324
def test_bench_4bit_dequant(quant_type):
2325
2326
    blocksize = 256
    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
2327
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2328
2329
2330
2331
2332
2333

    input_size = a.numel()/2
    output_size = a.numel()*2
    num_bytes = input_size+output_size
    GB = num_bytes/1e9
    max_theoretical_s =  GB/768
2334
    #print(max_theoretical_s*1e6)
2335
2336
    b = torch.randn(128, 1024*12, device='cuda').half()

Tim Dettmers's avatar
Tim Dettmers committed
2337
    iters = 100
2338
2339
2340
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2341
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
2342
2343
        #b.copy_(a)
    torch.cuda.synchronize()
2344
2345
2346
2347
2348
2349
2350
2351
    #print((time.time()-t0)/iters*1e6)

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    torch.matmul(b, a.t())
    #torch.cuda.synchronize()
    #print((time.time()-t0)/iters*1e6)
2352
2353
2354
2355
2356
2357
2358



def test_normal_map_tree():
    code = F.create_normal_map()
    values =code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
2359
    #print(values)
2360
2361
    while num_pivots <16:
        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
2362
        #print(idx)
2363
2364
2365
2366
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i-1]+values[i])/2)
2367
        #print(pivots)
2368

Tim Dettmers's avatar
Tim Dettmers committed
2369

2370
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
2371
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
2372
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
2373
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2374
2375
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
2376
    for dim in [128, 256, 512, 1024]:
2377
    #for dim in [4*1024]:
Tim Dettmers's avatar
Tim Dettmers committed
2378
    #for dim in [1*16]:
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

2389

2390
        for i in range(100):
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
            if kind == 'fc1':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'fc2':
                A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn_packed':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
Tim Dettmers's avatar
Tim Dettmers committed
2403

2404
            qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
2405
            C3 = torch.matmul(A, B.t())
2406
            C2 = F.gemv_4bit(A, qB.t(), state=state)
2407
2408
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
2409

2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
            err1 = (C1-C2).abs().float()
            err2 = (C3-C2).abs().float()
            err3 = (C3-C1).abs().float()

            mag1 = torch.abs(C1).float()+1e-5
            mag2 = torch.abs(C3).float()+1e-5
            mag3 = torch.abs(C3).float()+1e-5

            relerr1 = err1/mag1
            relerr2 = err2/mag2
            relerr3 = err3/mag3
2421

2422
2423
2424
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
2425

2426
2427
2428
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2429

2430
2431
2432
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
2433

2434
2435
2436
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
2437
2438

            c = int(C1.numel()*0.0014*(dim/256))+1
Tim Dettmers's avatar
Tim Dettmers committed
2439

2440
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
        err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
        err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
        err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
        relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
        relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
        relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
        maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
        maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
        maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
        absratio = err2/err3
        relratio = relerr2/relerr3
        maxratio = relerr2/relerr3

        # for debugging if the tests fails
        #
        #print('='*80)
        #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
2458
2459
2460
2461
2462
        #print(C1.flatten()[-20:])
        #print(C2.flatten()[-20:])
        #print(f'inference vs training abs: {err1}')
        #print(f'inference vs training rel: {relerr1}')
        #print(f'inference vs training max: {maxerr1}')
2463
2464
2465
        #print(f'inference vs training vs torch err ratio abs: {absratio}')
        #print(f'inference vs training vs torch err ratio rel: {relratio}')
        #print(f'inference vs training vs torch err ratio max: {maxratio}')
2466
        if dtype == torch.float16:
2467
2468
2469
2470
2471
2472
2473
2474
2475
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2476
        elif dtype == torch.float32:
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2488
        elif dtype == torch.bfloat16:
2489
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
2490
                assert err1 < 6e-4
2491
2492
2493
2494
2495
2496
2497
2498
2499
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
2500

2501
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
def test_managed():
    n = 32*10
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
    assert A.page_deviceid==0
    assert B.page_deviceid==0
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
    assert (A==17).sum().item() == n*n
    assert (B==17).sum().item() == n*n
    C = A*B.float()
    assert (C==289).sum().item() == n*n
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
    assert (A==17*(2**3)).sum().item() == n*n
   # F.prefetch_tensor(A)
   # F.prefetch_tensor(B)


   # F.fill(B2, 17.0)
   # F._mul(A, B2)

   # F.prefetch_tensor(A, to_cpu=True)
   # F.prefetch_tensor(B, to_cpu=True)
   # F.prefetch_tensor(B2, to_cpu=True)
   # torch.cuda.synchronize()

   # assert (A==17).sum().item() == n*n

2536
   # torch.testing.assert_close(A, torch.ones(A.shape)*289)
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564


@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
    dims = torch.randint(0, 8192, size=(dims,)).tolist()
    dims = [dim + (64-(dim % 64)) for dim in dims]
    #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
    for dim in dims:
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
        B = torch.eye(dim, dtype=dtype, device='cuda')

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
        #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)