test_functional.py 82.9 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2
3
4
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
5

6
import einops
Aarni Koskela's avatar
Aarni Koskela committed
7
import numpy as np
8
import pytest
Aarni Koskela's avatar
Aarni Koskela committed
9
from scipy.stats import norm
10
11
12
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
13
from bitsandbytes import functional as F
Aarni Koskela's avatar
Aarni Koskela committed
14
15
16
17
18
19
20
from tests.helpers import (
    BOOLEAN_TUPLES,
    TRUE_FALSE,
    describe_dtype,
    get_test_dims,
    id_formatter,
)
Tim Dettmers's avatar
Tim Dettmers committed
21

22
torch.set_printoptions(
Tim Dettmers's avatar
Tim Dettmers committed
23
    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
24
)
Tim Dettmers's avatar
Tim Dettmers committed
25
26
k = 20

27

Tim Dettmers's avatar
Tim Dettmers committed
28
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
29
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
30
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
31
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
32
33
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
34
            torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
35
36

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
37

38

Tim Dettmers's avatar
Tim Dettmers committed
39
40
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
41
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
42
43
44
45
46
47
48
49
50
51
52
53
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

54

55
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
56
57
58
59
60
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

61
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
62
63
64
65
66
67
68
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

69
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
70
71
72
73
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
74
75
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
76
77
78
79
80
81
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
82
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
83
84
85
86

        return self.agg[name]

    def reset(self):
87
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
88
89
        self.ends = {}
        self.agg = {}
90
91
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
92

Tim Dettmers's avatar
Tim Dettmers committed
93
94
95
def setup():
    pass

96

Tim Dettmers's avatar
Tim Dettmers committed
97
98
99
def teardown():
    pass

100

101
102
103
@pytest.mark.parametrize(
    "dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
104
def test_estimate_quantiles(dtype):
105
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
106
107
108
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

109
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
110
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
111

112
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
113
114
115
116
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
117
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
118
119
120
121
122
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
123
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
124
125
126
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
127
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
128
129
        assert diff < 0.0075

130
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
131
132
133
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
134
        diff = torch.abs(A1 - A2).mean().item()
135
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
136
137
138
        assert diff < 0.001


139

Tim Dettmers's avatar
Tim Dettmers committed
140
141
142
143
def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
144
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
145
146
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
147
148
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
149
150
151
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
152
153
    print(sum(diffs)/len(diffs))
    print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
154
155

    for i in range(100):
156
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
157
158
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
159
        diff = torch.abs(A1 - A2).mean().item()
160
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
161
162
163
        assert diff < 0.004


164

Aarni Koskela's avatar
Aarni Koskela committed
165
166
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
167
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
Aarni Koskela's avatar
Aarni Koskela committed
168
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
169
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
170
    #print('')
171
172
173
    diffs = []
    reldiffs = []
    for i in range(100):
174
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
175
176
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
177
178
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
179
180
181
182
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
183
184
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
185
186
    assert abserr < 0.011
    assert relerr < 0.018
187
    assert A2.dtype == dtype
188
189

    diffs = []
190
    code = F.create_dynamic_map(signed=signed)
191
    for i in range(100):
192
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
193
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
194
        A2 = F.dequantize_blockwise(C, S)
195
196
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
197
198
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
199
        #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
200
201
    abserr = sum(diffs)/len(diffs)
    relerr = sum(reldiffs)/len(reldiffs)
202
203
204
205
206
207
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
208
    assert A2.dtype == dtype
209
210
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
211

Tim Dettmers's avatar
Tim Dettmers committed
212
213


214
215
216
@pytest.mark.parametrize(
    "gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
Tim Dettmers's avatar
Tim Dettmers committed
217
def test_percentile_clipping(gtype):
218
219
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
220
221
    n = 4
    step = 0
222
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
223
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
224
        step += 1
225
226
227
228
229
        g = torch.randn(n, n, dtype=gtype, device="cuda")
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(
            g, gnorm_vec2, step, percentile=percentile
        )
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
230
231
232
233
234
235
236
237
238
239

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

240
241
242
        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)
Tim Dettmers's avatar
Tim Dettmers committed
243
244


Tim Dettmers's avatar
Tim Dettmers committed
245
246
def quant(x):
    max1 = torch.abs(x).max()
247
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
248
249
    return max1, x.to(torch.int8)

250

Tim Dettmers's avatar
Tim Dettmers committed
251
def dequant(c, maxC):
252
253
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
254
255

def mm_dequant(maxA, maxB, C):
256
257
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
258
259
260

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
261
262
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
263
264
    return max1, x.to(torch.int8)

265

Tim Dettmers's avatar
Tim Dettmers committed
266
def quant_multi_chunk(x, dim, chunk_size=32):
267
268
269
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
270
271
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
272
273
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
274
275
276
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
277
278
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
279
280
    return max1, x.to(torch.int8)

281

Tim Dettmers's avatar
Tim Dettmers committed
282
283
284
285
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

286

Tim Dettmers's avatar
Tim Dettmers committed
287
def mean(xx):
288
289
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
290

Aarni Koskela's avatar
Aarni Koskela committed
291
292
methods = {
    "linear": (
293
294
295
296
297
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
Aarni Koskela's avatar
Aarni Koskela committed
298
299
300
    ),
    "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
}
301
302


Aarni Koskela's avatar
Aarni Koskela committed
303
304
305
306
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
Tim Dettmers's avatar
Tim Dettmers committed
307
308
309
310
311
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Tim Dettmers's avatar
Tim Dettmers committed
312
    #print("")
Tim Dettmers's avatar
Tim Dettmers committed
313
314
    for i in range(5):
        if batched:
315
316
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
317
318
319
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
320
321
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
322
323
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
324
        torch.testing.assert_close(
325
326
            quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
        )
Tim Dettmers's avatar
Tim Dettmers committed
327
328
329
330
331
332
333
334
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
335
336
337
338
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
339
340
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
341
342
    #print(mean(errors))
    #print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
343
344


Tim Dettmers's avatar
Tim Dettmers committed
345
346
347
348
349
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Aarni Koskela's avatar
Aarni Koskela committed
350
351
352
353
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
354
355
356
357
358
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
359
        shapeA = (
360
361
362
            (batch_dim, hidden_dim)
            if not transpose[0]
            else (hidden_dim, batch_dim)
363
364
365
366
367
368
369
370
        )
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
371
372
373
374
375
376
377
378
379
380
381
382
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
383

384
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
385

Tim Dettmers's avatar
Tim Dettmers committed
386
387
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
388
389
390
391
392
393
394
        shapeB = (
            (32 * random.randint(1, 4), hidden_dim)
            if transpose[1]
            else (hidden_dim, 32 * random.randint(1, 4))
        )
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
395
396
397
398
399
400
401
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

402
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
403
404


Aarni Koskela's avatar
Aarni Koskela committed
405
406
407
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
Tim Dettmers's avatar
Tim Dettmers committed
408
409
410
411
412
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
413
414
415
        A = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        ).to(torch.int8)
416
417
418
        B = torch.randint(
            -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
        ).to(torch.int8)
419
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
420
421
422
        iout = torch.empty(
            A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
        )
Tim Dettmers's avatar
Tim Dettmers committed
423
424
        out = F.igemm(A, B, out=iout)

425
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
426

427

Aarni Koskela's avatar
Aarni Koskela committed
428
429
430
431
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
432
433
434
435
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
436
437
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
438
439
440
441
442
443
444
445
446

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
447
448
449
        A = torch.normal(
            0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
        )
Tim Dettmers's avatar
Tim Dettmers committed
450
        if transpose:
451
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
452
        else:
453
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
454
455
456
457
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
458
459
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
460
            out = out.float()
461
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
462
463
464
465
466
467

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
468
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
469
            out = F.igemm(Ac, Bc)
470
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
471
            out = out.float()
472
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
473
474
475
476
477
478
479
480
481
482

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

483
484
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
485

486
487
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
488
489
490
491
492

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
493
494
495
496
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
497
498
499
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

500

Aarni Koskela's avatar
Aarni Koskela committed
501
502
503
504
505
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
506
507
508
509
510
511
512
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
513
514
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
515
516
517
518
519
520
521
522
523
524
525

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
526
527
528
            out2 = torch.bmm(
                A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
            )
Tim Dettmers's avatar
Tim Dettmers committed
529
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
530
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
531

532

Aarni Koskela's avatar
Aarni Koskela committed
533
534
535
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
Tim Dettmers's avatar
Tim Dettmers committed
536
537
538
539
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
540
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
541
542
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
543
544
545
546
        n = A1.numel()
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))


Aarni Koskela's avatar
Aarni Koskela committed
547
548
549
550
551
552
553
554
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
555
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
556
    if dims == 3 and orderOut != "col32":
557
        return
558
    if dtype == torch.int32 and orderOut != "col32":
559
        return
560
561
562
563
    try:
        func = F.get_transform_func(dtype, orderA, orderOut, transpose)
    except ValueError as ve:
        pytest.skip(str(ve))  # skip if not supported
Tim Dettmers's avatar
Tim Dettmers committed
564
565

    if dims == 2:
566
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
567
    elif dims == 3:
568
569
570
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
            dtype
        )
Tim Dettmers's avatar
Tim Dettmers committed
571
572
573

    out, S = F.nvidia_transform(A, to_order=orderOut)

574
    if orderOut == "row":
575
        torch.testing.assert_close(A.flatten(), out.flatten())
576
    elif orderOut == "col":
577
        torch.testing.assert_close(A.t().flatten(), out.flatten())
578
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
579
        if dims == 2:
580
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
581
        elif dims == 3:
582
583
584
585
586
            n = (
                A.shape[0]
                * A.shape[1]
                * (A.shape[2] + (32 - (A.shape[2] % 32)))
            )
Tim Dettmers's avatar
Tim Dettmers committed
587
        assert out.numel() == n
588
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
589
        # 32 col 8 row tiles
590
591
592
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
            A.shape[1] + (32 - (A.shape[1] % 32))
        )
Tim Dettmers's avatar
Tim Dettmers committed
593
594
595
596
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
597
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
598
599
600
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
601
602
603
                rowtile = (
                    (row // 8) + (1 if row % 8 != 0 else 0)
                ) * total_coltile
604
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
605
                col2 = col % 32
606
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
607

608
609
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
610
611
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
612

613
    if orderOut == "col32":
614
615
616
        out2, S = F.nvidia_transform(
            out, from_order=orderOut, to_order="row", state=S
        )
617
        torch.testing.assert_close(A, out2)
Tim Dettmers's avatar
Tim Dettmers committed
618
619


Aarni Koskela's avatar
Aarni Koskela committed
620
621
622
623
624
625
@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
Tim Dettmers's avatar
Tim Dettmers committed
626
627
628
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
629
630
631
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
Tim Dettmers's avatar
Tim Dettmers committed
632
        elif dims == 3:
633
634
635
636
637
638
            A = torch.randint(
                -128, 127, size=(dim1, dim2, dim3), device="cuda"
            ).to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
639
640
        C1 = torch.matmul(A.float(), B.t().float())

641
642
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
643
        C2, SC = F.igemmlt(A2, B2, SA, SB)
644
        C3, S = F.nvidia_transform(C2, "row", state=SC)
645
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
646
647

        # transpose
648
649
650
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
            torch.int8
        )
Tim Dettmers's avatar
Tim Dettmers committed
651
652
        C1 = torch.matmul(A.float(), B.float())

653
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
654
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
655
        C3, S = F.nvidia_transform(C2, "row", state=SC)
656
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
657

658

Aarni Koskela's avatar
Aarni Koskela committed
659
660
661
662
663
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
Tim Dettmers's avatar
Tim Dettmers committed
664
665
666
667
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
668
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
669
        elif dims == 3:
670
671
672
            A = torch.normal(
                0, 0.5, size=(dim1, dim2, dim3), device="cuda"
            ).half()
673
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
674
675
676
677
678
679
680
681
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
682
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
683
684
685
686
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

687
688
689
690
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
691

692
        # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
693
694

        # transpose
695
696
697
698
699
700
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
701
        # torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
702

Aarni Koskela's avatar
Aarni Koskela committed
703
704
705
706
707
708
709
710
711
@pytest.mark.parametrize(
    ("batch", "seq", "model", "hidden"),
    [
        pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
        pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
        pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
    ],
)
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
712
713
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
714
715
716
717
718
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
719

720
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
721
    ## warmup
722
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
723
    #    torch.matmul(A, w1.t())
724
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
725
726
727
728
729
730
731
732

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):

733
734
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
735

736
737
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
738

739
740
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
741
742
743
744
745

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

746
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
747

748
749
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
750

751
752
753
754
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
755

756
757
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
758
    ## fc1
759
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
760
761
762
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
763
764
765
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
766
767
768
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
769
770
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
771
772
773
774
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
775
776
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
777
778
779
780
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
781
782
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
783
784
785
786
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
787
788
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
789
790
791
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

792
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
793

794
795
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
796

797
798
799
800
801
802
803
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

856
857
858
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
859
860


Aarni Koskela's avatar
Aarni Koskela committed
861
862
863
864
865
@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
866
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
867
    inner = torch.randint(1, 128, size=(1,)).item()
868
869
    bias = None
    if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
870
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
871
    for i in range(1):
872
873
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
874
        C1 = torch.matmul(A.half(), B.t().half())
875
        if has_bias: C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
876
877
878
879

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

880
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
881
882
883
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

884
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
885
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
886
        if has_bias: C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
887

888
        # TODO: is something wrong here? If so, the problem goes deeper
Aarni Koskela's avatar
Aarni Koskela committed
889
890
        # n = C1.numel()
        # p = 0.06
891
892
893
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
Aarni Koskela's avatar
Aarni Koskela committed
894
895
        # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
        # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
896

897
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
Aarni Koskela's avatar
Aarni Koskela committed
898
        # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
899
        n = C5.numel()
Aarni Koskela's avatar
Aarni Koskela committed
900
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
Tim Dettmers's avatar
Tim Dettmers committed
901

902

Aarni Koskela's avatar
Aarni Koskela committed
903
904
905
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
Tim Dettmers's avatar
Tim Dettmers committed
906
907
908
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
909
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
910
911
912
913
914
915
916
917
918
919
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=threshold
        )

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
936
937
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

938
939
940
        torch.testing.assert_close(col_stats1_trunc, col_stats2)
        torch.testing.assert_close(row_stats1_trunc, row_stats2)
        torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
Tim Dettmers's avatar
Tim Dettmers committed
941

942
943
944
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
            A, threshold=0.0
        )
Tim Dettmers's avatar
Tim Dettmers committed
945

946
947
        torch.testing.assert_close(col_stats1, col_stats2)
        torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
948
949
950
        assert nnz_block_ptr2 is None


Aarni Koskela's avatar
Aarni Koskela committed
951
952
@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
953
954
def test_double_quant(dim1, dim2):
    for i in range(k):
955
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
956
957
958
959
960
961
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
962
963
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
964
965

        n = CAt.numel()
966
967
968
969
970
971
        num_not_close_rows = (
            (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        )
        num_not_close_cols = (
            (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
        )
Tim Dettmers's avatar
Tim Dettmers committed
972
973

        # allow for 1:500 error due to rounding differences
974
975
976
977
978
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
979
            assert False
980
981
982
983
        if num_not_close_rows > (min_error * n):
            print(
                f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
            )
Tim Dettmers's avatar
Tim Dettmers committed
984
985
            assert False

986
987
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
988
989


Aarni Koskela's avatar
Aarni Koskela committed
990
991
992
993
994
995
996
997
998
999
1000
1001
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
        for (dim1, dim4, inner)
        in zip(
            get_test_dims(1, 4 * 1024, n=4),
            get_test_dims(1, 4 * 1024, n=4),
            get_test_dims(1, 4 * 1024, n=4),
        )
    )
)
Tim Dettmers's avatar
Tim Dettmers committed
1002
1003
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
1004
1005
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1006
1007
1008
1009
1010
1011
1012
1013

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

1014
1015
1016
1017
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
1018

1019
1020
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1021
1022
1023
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1024
1025
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
1026
1027
        C2, SC = F.igemmlt(A2, B2, SA, SB)

1028
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1029
1030
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

1031
1032
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
1033
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
1034
1035


Aarni Koskela's avatar
Aarni Koskela committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
        for (dim1, dim4, inner)
        in zip(
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
        )
    )
)
1048
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
1049
1050
1051
1052
1053
1054
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
1055
1056
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1057
1058
1059
1060
1061
1062
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1063
1064
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1065
1066
1067
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1068
1069
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
1070
1071
1072
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
1073
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1074
1075
1076
1077
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1078
1079
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1080
1081
1082
1083
1084
1085
1086
1087

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1088
1089
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1090
1091

        C = torch.matmul(CA.float(), CB.t().float())
1092
1093
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1094

1095
1096
1097
1098
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1099

1100
1101
1102
1103
1104
1105
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1106

1107
1108
1109
1110
1111
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1112
1113


Aarni Koskela's avatar
Aarni Koskela committed
1114
1115
1116
1117
1118
1119
1120
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    [
        pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
        pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
    ],
)
1121
@pytest.mark.skip("Row scale has some bugs for ampere")
Aarni Koskela's avatar
Aarni Koskela committed
1122
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1123
def test_row_scale_bench(dim1, dim4, inner):
Aarni Koskela's avatar
Aarni Koskela committed
1124
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
1125
1126
1127
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1128
1129
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1140
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1141
1142

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1143
1144
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1145
1146
1147
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1148
1149
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1150
1151
1152
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1153
1154
1155
        outC32, SC = F.igemmlt(
            A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
        )
Tim Dettmers's avatar
Tim Dettmers committed
1156
    torch.cuda.synchronize()
1157
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1158
1159
1160
1161
1162
1163
1164
1165

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1166
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1167
1168


Aarni Koskela's avatar
Aarni Koskela committed
1169
1170
1171
1172
1173
1174
1175
1176
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
1177
1178
1179
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
1180
1181
1182
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
                dtype
            )
Tim Dettmers's avatar
Tim Dettmers committed
1183
        elif dims == 3:
1184
1185
1186
            A = torch.randint(
                10, 99, size=(dim1, dim2, dim3), device="cuda"
            ).to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1198
1199
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1200

1201
        torch.testing.assert_close(out1, out2)
Tim Dettmers's avatar
Tim Dettmers committed
1202

1203

Tim Dettmers's avatar
Tim Dettmers committed
1204
1205
def test_overflow():
    formatB = F.get_special_format_str()
1206
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1207
    for i in range(2):
1208
1209
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1210

1211
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1212
1213
1214
1215
1216
1217
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


Aarni Koskela's avatar
Aarni Koskela committed
1218
1219
@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
1220
1221
1222
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1223
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1224

1225
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1226
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1227
1228
1229
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
1230
1231

        if coo_tensor is not None:
1232
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1233
            A2 = torch.zeros_like(A)
1234
1235
1236
            A2[
                coo_tensor.rowidx.long(), coo_tensor.colidx.long()
            ] = coo_tensor.values
1237
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1238

1239
1240
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
1241
            torch.testing.assert_close(
1242
1243
                A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
            )
1244

Tim Dettmers's avatar
Tim Dettmers committed
1245

Aarni Koskela's avatar
Aarni Koskela committed
1246
1247
1248
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
Tim Dettmers's avatar
Tim Dettmers committed
1249
1250
1251
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1252
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
1264
1265
1266
1267
        cooA = F.COOSparseTensor(
            A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
        )
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


Aarni Koskela's avatar
Aarni Koskela committed
1279
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1280
1281
def test_spmm_bench():
    batch = 2
1282
1283
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1284
    seq = 1024
1285
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1286
1287
1288
    dim2 = model
    dim3 = hidden
    threshold = 4
1289
1290
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1291
    for i in range(10):
1292
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1293
1294
1295
1296

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1297
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1298
    torch.cuda.synchronize()
1299
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1300
1301
1302

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1303
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1304
1305
    rows, cols = torch.where(idx)
    values = A[idx]
1306
1307
1308
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
Tim Dettmers's avatar
Tim Dettmers committed
1309
1310

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1311
1312
1313
1314
1315
1316
1317
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1318
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1319
    print(tsp, t8)
1320
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1321
1322


Aarni Koskela's avatar
Aarni Koskela committed
1323
1324
@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
1325
1326
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1327
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1328
1329
1330
1331
1332
1333
1334
1335
1336
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1337
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1338
1339
1340
1341

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

1342
1343
1344
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
            A, threshold=threshold
        )
1345
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1346
1347
1348
1349
1350
1351
1352
1353
1354

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1355
1356
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1357
1358
1359
1360
        assert err2 < err1


def test_matmuls():
1361
1362
1363
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1364
    c2 = bnb.matmul(a, b)
1365
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1366

1367
1368
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1369
1370
    assert err1 < 0.2
    assert err2 < 0.2
1371
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1372
1373


Aarni Koskela's avatar
Aarni Koskela committed
1374
1375
1376
1377
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
Tim Dettmers's avatar
Tim Dettmers committed
1378
1379
1380
1381
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1382
1383
1384
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1385
    if dtype == torch.float16:
1386
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1387
1388
        torch.nn.init.xavier_uniform_(B)
    else:
1389
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1390
        torch.nn.init.xavier_uniform_(B)
1391
1392
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1393

1394
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1395
1396
1397
1398
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1399
1400
1401
1402
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1403
1404
1405
1406
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1407
1408
1409
1410
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1411
    n = out1.numel()
1412
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1413
1414
1415
    std = out1.std()
    out1 /= std
    out2 /= std
1416
1417
1418
    assert_all_approx_close(
        out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
    )
1419
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1420
1421
1422

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1423
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1424

1425
1426
1427
1428
1429
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1430
1431
1432
1433
1434
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1435
1436
1437
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1438
1439
1440
1441
1442
1443
1444
1445

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1446
1447
1448
1449
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1450
1451
1452
1453
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1454
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
1455
    idx = A2 != 0
1456
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1457
1458
1459
1460
1461
1462
1463
1464
1465


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1466
1467
1468
1469
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1470
1471
1472
1473
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1474
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1475
    # torch uses row-major -> use transpose to transfer to col-major
1476
    idx = A2.t() != 0
1477
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1478
1479


Aarni Koskela's avatar
Aarni Koskela committed
1480
1481
1482
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
Tim Dettmers's avatar
Tim Dettmers committed
1483
1484
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1485
1486
1487
1488
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
1502
1503
1504
1505
    cooA = F.COOSparseTensor(
        A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
    )
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1506
1507
1508
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1509
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1510
1511
1512
1513
1514
1515

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

1516
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1517

1518
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1519
    n = out1.numel()
1520
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1521
1522
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1523
1524
1525
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1526
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1527
1528
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1529
1530
1531
1532

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1533
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1534
    torch.cuda.synchronize()
1535
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1536
1537
1538
1539

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1540
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1541
    torch.cuda.synchronize()
1542
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1543
1544
1545
1546

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1547
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1548
    torch.cuda.synchronize()
1549
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1550
1551
1552
1553

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1554
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1555
    torch.cuda.synchronize()
1556
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1557
1558
1559
1560
1561
1562

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1563
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1564
    torch.cuda.synchronize()
1565
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1566
1567
1568
1569
1570
1571
1572

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1573
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1574
1575
1576
1577
1578
1579

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1580
1581
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1582

Aarni Koskela's avatar
Aarni Koskela committed
1583
1584
1585
1586
1587
@pytest.mark.parametrize(
    ("batch", "seq", "model", "hidden"),
    [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")],
)
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1588
def test_bench_matmul(batch, seq, model, hidden):
1589
    iters = 1000
Tim Dettmers's avatar
Tim Dettmers committed
1590
1591
    formatB = F.get_special_format_str()

1592
1593
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1594
1595
    torch.nn.init.xavier_uniform_(B)

1596
    B_fp4, state = F.quantize_fp4(B)
1597
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1598

1599
    B_nf4, state_nf4 = F.quantize_nf4(B)
1600
    B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
1601

Tim Dettmers's avatar
Tim Dettmers committed
1602
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
1603
1604
1605
1606
1607
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

Tim Dettmers's avatar
Tim Dettmers committed
1608
1609
    linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
    #linearMixedBit.eval()
Tim Dettmers's avatar
Tim Dettmers committed
1610

1611
1612
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
1613
    bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1614

Tim Dettmers's avatar
Tim Dettmers committed
1615
    # warmup
1616
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1617
1618
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1619
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1620
1621
1622

    torch.cuda.synchronize()
    t0 = time.time()
1623
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1624
1625
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1626
1627
    print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1628
1629
1630
1631
1632
1633
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
    #torch.cuda.synchronize()
    #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1634

1635
1636
1637
1638
1639
1640
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
    #torch.cuda.synchronize()
    #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
1641

1642
1643
1644
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1645
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1646
1647
1648
    torch.cuda.synchronize()
    print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )

1649
1650
1651
1652
1653
1654
1655
1656
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
    torch.cuda.synchronize()
    print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )


Tim Dettmers's avatar
Tim Dettmers committed
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    bnb.matmul(A, B, threshold=6.0)
    #torch.cuda.synchronize()
    #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    #C32A, SA = F.transform(CA, "col32")
    #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    #CxB, SB = F.transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #torch.cuda.synchronize()
    #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1)
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
    #torch.cuda.synchronize()
    #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    #CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
    #torch.cuda.synchronize()
    #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

1710
1711
1712
1713
1714
1715
1716
    #linear8bit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1717

1718
1719
1720
1721
1722
1723
1724
    #linearMixedBit(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linearMixedBit(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740

    #linear8bit_train(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    #linear8bit_train_thresh(A)
    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    linear8bit_train(A)
    #torch.cuda.synchronize()
    #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1741
1742
1743
1744
1745
1746

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1747
1748
1749
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1750
        minx = x.min()
1751
1752
1753
1754
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1755
        return x, qx, zpx
1756

Tim Dettmers's avatar
Tim Dettmers committed
1757
1758
1759
    batch = 2
    seq = 512
    model = 1024
1760
1761
1762
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1763
1764
1765

    C0 = torch.matmul(A, B)

1766
1767
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1768
1769
1770
1771
1772
1773
1774
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1775
1776
1777
1778
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1779
1780

    ca, cqa, cza = quant_zp(A)
1781
1782
    #print(ca.min(), ca.max())
    #print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1783
1784
1785

    zp = 1
    scale = 2.0
1786
1787
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1788
1789
1790
1791
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1792
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1793
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1794

Tim Dettmers's avatar
Tim Dettmers committed
1795
1796
1797
1798
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
1799
1800
1801
1802
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1803

Tim Dettmers's avatar
Tim Dettmers committed
1804
1805
1806
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
1807
1808
1809
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1810

1811
    #print("")
1812
    # print(C0.flatten()[:10])
1813
1814
1815
1816
1817
1818
    #print(C1.flatten()[:10])
    #print(C2.flatten()[:10])
    #print(C3.flatten()[:10])
    #print(C5.flatten()[:10])
    #print(C6.flatten()[:10])
    #print(C7.flatten()[:10])
1819
1820
1821
1822
1823
1824
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1825
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
1826
1827


1828
def test_extract_outliers():
1829
    for i in range(k):
1830
        shapeA = (4096, 4096 * 4)
1831
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
1832
1833
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
1834
        outliers1 = A[:, idx.long()]
1835

1836
        CA, SA = F.transform(A, "col_turing")
1837

1838
        outliers2 = F.extract_outliers(CA, SA, idx)
1839

1840
1841
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1842

1843
        torch.testing.assert_close(outliers1, outliers2)
1844

1845
        CA, SA = F.transform(A, "col_ampere")
1846
1847
1848
1849
1850

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1851

1852
        torch.testing.assert_close(outliers1, outliers2)
1853
1854
1855
1856
1857
1858
1859
1860



def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
1861
    for hidden in [128]:#, 14336]:
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
        for blocksize in [4096, 16384]:
            for i in range(2):
                A1 = torch.randn(batch, seq, hidden, device='cpu')
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894



def test_fp8_quant():
    for e_bits in range(1, 7):
        p_bits = 7-e_bits
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
1895
1896
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
1909
1910
        #print(sum(abserr)/len(abserr))
        #print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
            reldiff = diff/torch.abs(A1+1e-8)
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
            #assert diff < 0.0075
1923
1924
        #print(3, sum(abserr)/len(abserr))
        #print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1925

1926
1927
1928

def test_few_bit_quant():

1929
    #print('')
1930
    for bits in range(2, 9):
1931
        #print('='*30, bits, '='*30)
Tim Dettmers's avatar
Tim Dettmers committed
1932
1933
1934
        for method in ['linear', 'fp8', 'dynamic', 'quantile']:
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
1935
1936
            code = None
            if method == 'linear':
1937
                code = F.create_linear_map(True, total_bits=bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
1938
1939
1940
1941
            elif method == 'fp8':
                ebits = math.ceil(bits/2)
                pbits = bits-ebits-1
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Tim Dettmers's avatar
Tim Dettmers committed
1942
1943
1944
1945
            elif method == 'dynamic':
                code = F.create_dynamic_map(True, bits-0, bits).cuda()
            elif method == 'quantile':
                values = torch.randn(2048, 2048, device='cuda')
Tim Dettmers's avatar
Tim Dettmers committed
1946
1947
1948
1949
1950
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
            assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
1951
            #print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
            assert code.numel() == 256
            for i in range(10):

                values = torch.randn(1, 32, device='cuda')
                values /= values.abs().max()
                #values[values.abs() < 1e-6] += 1e-5

                q1 = []
                v1 = []
                for v in values[0]:
                    idx = torch.abs(v-code).argmin()
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
1969
1970
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
1971
1972

                idx = torch.isclose(q1.int(), q2.int())
Tim Dettmers's avatar
Tim Dettmers committed
1973
1974
1975
                err2 = torch.abs(v2-values)
                abserrs.append(err2.mean().item())
                relerrs.append((err2/(1e-10+values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1976
1977
1978
                if idx.sum():
                    # some weird cases
                    err1 = torch.abs(v1-values).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1979
                    #assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
1980
1981

                else:
1982
                    torch.testing.assert_close(q1, q2)
1983
            #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
Tim Dettmers's avatar
Tim Dettmers committed
1984
    #assert False
Tim Dettmers's avatar
Tim Dettmers committed
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994


def test_kbit_quantile_estimation():
    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 9):
            p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
            assert err < 0.038

    for i in range(100):
        data = torch.randn(1024, 1024, device='cuda')
        for bits in range(2, 4):
            total_values = 2**bits-1
            p = np.linspace(0, 1, 2*total_values+1)
            idx = np.arange(1, 2*total_values+1, 2)
            p = p[idx]
            offset = 1/(2*total_values)
            p = np.linspace(offset, 1-offset, total_values)
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
            err = torch.abs(val1-val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2009
            assert err < 0.035
2010
2011


Aarni Koskela's avatar
Aarni Koskela committed
2012
@pytest.mark.benchmark
2013
2014
def test_bench_dequantization():
    a = torch.rand(1024, 1024, device='cuda').half()
2015
2016
2017
    code =F.create_fp8_map(True, 3, 0, 4).cuda()
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
2018
2019
2020
2021
2022
2023
2024

    max_theoretical_mu =  1024*1024*2/1024**3/672*1000*1000
    #print(max_theoretical_mu)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
2025
        qa, SA = F.quantize_blockwise(a)
2026
2027
2028
    torch.cuda.synchronize()
    #print((time.time()-t0)/1e6)

2029
2030


Aarni Koskela's avatar
Aarni Koskela committed
2031
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
2032
def test_fp4_quant(dtype):
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
        idx = sign*8 + e1*4 + e2*2 + p1*1
        sign = -1.0 if sign else 1.0
        exp = e1*2 + e2*1
        if exp == 0:
            # sub-normal
            if p1 == 0: result = 0
            else: result = sign*0.0625
        else:
            # normal
            exp = 2**(-exp + bias + 1)
            frac = 1.5 if p1 else 1.0
            result = sign*exp*frac
        code[idx] = result

2054
    A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
2055
2056
2057
2058
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
2059
    relerr = (err/(A1.abs().float()+1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
2060
    idx = err > 1.0
2061
2062
    err = err.mean()

2063
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
2064
2065
    assert err.item() < 0.1
    assert relerr.item() < 0.28
2066
2067


Tim Dettmers's avatar
Tim Dettmers committed
2068
2069
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
2070
2071
2072
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
2073
        for i in range(10):
2074
            A1 = torch.randn(1024, 1024, device='cuda').half()
2075
2076
2077
2078
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
            q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
2079
2080
2081
2082
2083
2084


            err = (A1 - A2).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2085
2086
            errs1.append(err.item())

2087
2088
2089
2090
2091
2092
2093
2094

            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
            relerr = (err/(A1.abs().float()+1e-15)).mean()
            err = err.mean()

2095
            errs2.append(err.item())
2096
2097
2098
2099

            assert err.item() < 0.11
            assert relerr.item() < 0.28

2100
2101
        #print(sum(errs1)/len(errs1), blocksize, quant_type)
        #print(sum(errs2)/len(errs2), blocksize, quant_type)
2102
2103
2104
2105




Tim Dettmers's avatar
Tim Dettmers committed
2106
2107
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
Aarni Koskela's avatar
Aarni Koskela committed
2108
@pytest.mark.benchmark
2109
def test_bench_4bit_dequant(quant_type):
2110
2111
    blocksize = 256
    a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
2112
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2113
2114
2115
2116
2117
2118

    input_size = a.numel()/2
    output_size = a.numel()*2
    num_bytes = input_size+output_size
    GB = num_bytes/1e9
    max_theoretical_s =  GB/768
2119
    #print(max_theoretical_s*1e6)
2120
2121
    b = torch.randn(128, 1024*12, device='cuda').half()

Tim Dettmers's avatar
Tim Dettmers committed
2122
    iters = 100
2123
2124
2125
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2126
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
2127
2128
        #b.copy_(a)
    torch.cuda.synchronize()
2129
2130
2131
2132
2133
2134
2135
2136
    #print((time.time()-t0)/iters*1e6)

    #torch.cuda.synchronize()
    #t0 = time.time()
    #for i in range(iters):
    #    torch.matmul(b, a.t())
    #torch.cuda.synchronize()
    #print((time.time()-t0)/iters*1e6)
2137
2138
2139
2140
2141
2142
2143



def test_normal_map_tree():
    code = F.create_normal_map()
    values =code[:8].tolist() + code[-8:].tolist()
    num_pivots = 1
2144
    #print(values)
2145
2146
    while num_pivots <16:
        idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
2147
        #print(idx)
2148
2149
2150
2151
        num_pivots *= 2
        pivots = []
        for i in idx:
            pivots.append((values[i-1]+values[i])/2)
2152
        #print(pivots)
2153

Tim Dettmers's avatar
Tim Dettmers committed
2154

Aarni Koskela's avatar
Aarni Koskela committed
2155
2156
2157
2158
2159
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
2160
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
2161
    for dim in [128, 256, 512, 1024]:
2162
    #for dim in [4*1024]:
Tim Dettmers's avatar
Tim Dettmers committed
2163
    #for dim in [1*16]:
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

2174

2175
        for i in range(100):
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
            if kind == 'fc1':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'fc2':
                A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
            elif kind == 'attn_packed':
                A = torch.randn(1, dim, dtype=dtype, device='cuda')
                B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
Tim Dettmers's avatar
Tim Dettmers committed
2188

2189
            qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
2190
            C3 = torch.matmul(A, B.t())
2191
            C2 = F.gemv_4bit(A, qB.t(), state=state)
2192
2193
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
2194

2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
            err1 = (C1-C2).abs().float()
            err2 = (C3-C2).abs().float()
            err3 = (C3-C1).abs().float()

            mag1 = torch.abs(C1).float()+1e-5
            mag2 = torch.abs(C3).float()+1e-5
            mag3 = torch.abs(C3).float()+1e-5

            relerr1 = err1/mag1
            relerr2 = err2/mag2
            relerr3 = err3/mag3
2206

2207
2208
2209
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
2210

2211
2212
2213
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2214

2215
2216
2217
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
2218

2219
2220
2221
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
2222
2223

            c = int(C1.numel()*0.0014*(dim/256))+1
Tim Dettmers's avatar
Tim Dettmers committed
2224

2225
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
        err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
        err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
        err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
        relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
        relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
        relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
        maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
        maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
        maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
        absratio = err2/err3
        relratio = relerr2/relerr3
        maxratio = relerr2/relerr3

        # for debugging if the tests fails
        #
        #print('='*80)
        #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
2243
2244
2245
2246
2247
        #print(C1.flatten()[-20:])
        #print(C2.flatten()[-20:])
        #print(f'inference vs training abs: {err1}')
        #print(f'inference vs training rel: {relerr1}')
        #print(f'inference vs training max: {maxerr1}')
2248
2249
2250
        #print(f'inference vs training vs torch err ratio abs: {absratio}')
        #print(f'inference vs training vs torch err ratio rel: {relratio}')
        #print(f'inference vs training vs torch err ratio max: {maxratio}')
2251
        if dtype == torch.float16:
2252
2253
2254
2255
2256
2257
2258
2259
2260
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2261
        elif dtype == torch.float32:
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2273
        elif dtype == torch.bfloat16:
2274
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
2275
                assert err1 < 6e-4
2276
2277
2278
2279
2280
2281
2282
2283
2284
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
2285

2286
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
def test_managed():
    n = 32*10
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
    assert A.page_deviceid==0
    assert B.page_deviceid==0
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
    assert (A==17).sum().item() == n*n
    assert (B==17).sum().item() == n*n
    C = A*B.float()
    assert (C==289).sum().item() == n*n
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
    assert (A==17*(2**3)).sum().item() == n*n
   # F.prefetch_tensor(A)
   # F.prefetch_tensor(B)


   # F.fill(B2, 17.0)
   # F._mul(A, B2)

   # F.prefetch_tensor(A, to_cpu=True)
   # F.prefetch_tensor(B, to_cpu=True)
   # F.prefetch_tensor(B2, to_cpu=True)
   # torch.cuda.synchronize()

   # assert (A==17).sum().item() == n*n

2321
   # torch.testing.assert_close(A, torch.ones(A.shape)*289)
2322
2323
2324


@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
Aarni Koskela's avatar
Aarni Koskela committed
2325
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
2326
2327
2328
2329
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
Aarni Koskela's avatar
Aarni Koskela committed
2330
    dims = get_test_dims(0, 8192, n=dims)
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
    dims = [dim + (64-(dim % 64)) for dim in dims]
    #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
    for dim in dims:
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda')
        B = torch.eye(dim, dtype=dtype, device='cuda')

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
        #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)