test_functional.py 82.3 KB
Newer Older
Aarni Koskela's avatar
Aarni Koskela committed
1
from itertools import product
Tim Dettmers's avatar
Tim Dettmers committed
2
3
4
import math
import random
import time
Tim Dettmers's avatar
Tim Dettmers committed
5

6
import einops
Aarni Koskela's avatar
Aarni Koskela committed
7
import numpy as np
8
import pytest
Aarni Koskela's avatar
Aarni Koskela committed
9
from scipy.stats import norm
10
11
12
import torch

import bitsandbytes as bnb
Tim Dettmers's avatar
Tim Dettmers committed
13
from bitsandbytes import functional as F
Aarni Koskela's avatar
Aarni Koskela committed
14
15
16
17
18
19
20
from tests.helpers import (
    BOOLEAN_TUPLES,
    TRUE_FALSE,
    describe_dtype,
    get_test_dims,
    id_formatter,
)
Tim Dettmers's avatar
Tim Dettmers committed
21

Ruff's avatar
Ruff committed
22
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
Tim Dettmers's avatar
Tim Dettmers committed
23
24
k = 20

25

Tim Dettmers's avatar
Tim Dettmers committed
26
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
27
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
28
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
29
    if sumval > count:
Tim Dettmers's avatar
Tim Dettmers committed
30
31
        if throw:
            print(f"Too many values not close: assert {sumval} < {count}")
32
            torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
Tim Dettmers's avatar
Tim Dettmers committed
33
34

    return sumval
Tim Dettmers's avatar
Tim Dettmers committed
35

36

Tim Dettmers's avatar
Tim Dettmers committed
37
38
class FFN(torch.nn.Module):
    def __init__(self, input_features, hidden_size, bias=True):
39
        super().__init__()
Tim Dettmers's avatar
Tim Dettmers committed
40
41
42
43
44
45
46
47
48
49
50
51
        self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)

        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.fc1.weight)
            torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

52

53
class Timer:
Tim Dettmers's avatar
Tim Dettmers committed
54
55
56
57
58
    def __init__(self):
        self.starts = {}
        self.ends = {}
        self.agg = {}

59
    def tick(self, name="default"):
Tim Dettmers's avatar
Tim Dettmers committed
60
61
62
63
64
65
66
        if name not in self.starts:
            self.starts[name] = torch.cuda.Event(enable_timing=True)
            self.ends[name] = torch.cuda.Event(enable_timing=True)
            self.starts[name].record()
        else:
            ms = self.tock(name, evict=True, print_ms=False)

67
    def tock(self, name="default", evict=True, print_ms=True):
Tim Dettmers's avatar
Tim Dettmers committed
68
69
70
71
        if name in self.ends:
            self.ends[name].record()
            torch.cuda.synchronize()
            ms = self.starts[name].elapsed_time(self.ends[name])
72
73
            if name not in self.agg:
                self.agg[name] = 0.0
Tim Dettmers's avatar
Tim Dettmers committed
74
75
76
77
78
79
            self.agg[name] += ms
            if evict:
                self.starts.pop(name)
                self.ends.pop(name)

        if print_ms and name in self.agg:
80
            print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
Tim Dettmers's avatar
Tim Dettmers committed
81
82
83
84

        return self.agg[name]

    def reset(self):
85
        self.starts = {}
Tim Dettmers's avatar
Tim Dettmers committed
86
87
        self.ends = {}
        self.agg = {}
88
89
        print("Resetting benchmark data")

Tim Dettmers's avatar
Tim Dettmers committed
90

Tim Dettmers's avatar
Tim Dettmers committed
91
92
93
def setup():
    pass

94

Tim Dettmers's avatar
Tim Dettmers committed
95
96
97
def teardown():
    pass

98

Ruff's avatar
Ruff committed
99
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
Tim Dettmers's avatar
Tim Dettmers committed
100
def test_estimate_quantiles(dtype):
101
    A = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
102
103
104
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

105
    percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
106
    torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
Tim Dettmers's avatar
Tim Dettmers committed
107

108
    A = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
109
110
111
112
    A = A.to(dtype)
    code = F.estimate_quantiles(A)

    quantiles = torch.quantile(A.float(), percs)
113
    diff = torch.abs(code - quantiles)
Tim Dettmers's avatar
Tim Dettmers committed
114
115
116
117
118
    assert (diff > 5e-02).sum().item() == 0


def test_quantile_quantization():
    for i in range(100):
119
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
120
121
122
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
123
        diff = torch.abs(A1 - A2).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
124
125
        assert diff < 0.0075

126
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
127
128
129
        code = F.estimate_quantiles(A1)
        C = F.quantize_no_absmax(A1, code)
        A2 = F.dequantize_no_absmax(C, code)
130
        diff = torch.abs(A1 - A2).mean().item()
131
        torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
132
133
134
135
136
137
138
        assert diff < 0.001


def test_dynamic_quantization():
    diffs = []
    reldiffs = []
    for i in range(100):
139
        A1 = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
140
141
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
142
143
        diff = torch.abs(A1 - A2)
        reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
144
145
146
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
        assert diff.mean().item() < 0.0135
Ruff's avatar
Ruff committed
147
148
    print(sum(diffs) / len(diffs))
    print(sum(reldiffs) / len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
149
150

    for i in range(100):
151
        A1 = torch.rand(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
152
153
        C, S = F.quantize(A1)
        A2 = F.dequantize(C, S)
154
        diff = torch.abs(A1 - A2).mean().item()
155
        torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
156
157
158
        assert diff < 0.004


Aarni Koskela's avatar
Aarni Koskela committed
159
160
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
161
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
Aarni Koskela's avatar
Aarni Koskela committed
162
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
163
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
Ruff's avatar
Ruff committed
164
    # print('')
165
166
167
    diffs = []
    reldiffs = []
    for i in range(100):
168
        A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
169
170
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
        A2 = F.dequantize_blockwise(C, S)
171
172
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
173
174
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
175
176
177
178
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
    # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
179
180
    assert abserr < 0.011
    assert relerr < 0.018
181
    assert A2.dtype == dtype
182
183

    diffs = []
184
    code = F.create_dynamic_map(signed=signed)
185
    for i in range(100):
186
        A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
187
        C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
188
        A2 = F.dequantize_blockwise(C, S)
189
190
        diff = torch.abs(A1 - A2).float()
        reldiff = diff / torch.abs(A1.float() + 1e-8)
191
192
        diffs.append(diff.mean().item())
        reldiffs.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
193
194
195
        # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
    abserr = sum(diffs) / len(diffs)
    relerr = sum(reldiffs) / len(reldiffs)
196
197
198
199
200
201
    if signed:
        assert abserr < 0.0035
        assert relerr < 0.015
    else:
        assert abserr < 0.00175
        assert relerr < 0.012
202
    assert A2.dtype == dtype
Ruff's avatar
Ruff committed
203
204
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
    # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
205
206


Ruff's avatar
Ruff committed
207
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
Tim Dettmers's avatar
Tim Dettmers committed
208
def test_percentile_clipping(gtype):
209
210
    gnorm_vec1 = torch.zeros(100, device="cuda")
    gnorm_vec2 = torch.zeros(100, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
211
212
    n = 4
    step = 0
213
    percentile = 5
Tim Dettmers's avatar
Tim Dettmers committed
214
    for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
215
        step += 1
216
        g = torch.randn(n, n, dtype=gtype, device="cuda")
Ruff's avatar
Ruff committed
217
        gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
218
        assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
Tim Dettmers's avatar
Tim Dettmers committed
219
220
221
222
223
224
225
226
227
228

        gnorm2 = torch.norm(g.float())
        if step == 1:
            gnorm_vec1[:] = gnorm2
        else:
            gnorm_vec1[step % 100] = gnorm2

        vals, idx = torch.sort(gnorm_vec1)
        clip1 = vals[percentile]

229
230
231
        torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
        torch.testing.assert_close(clip1, clip2)
        torch.testing.assert_close(gnorm1, gnorm2)
Tim Dettmers's avatar
Tim Dettmers committed
232
233


Tim Dettmers's avatar
Tim Dettmers committed
234
235
def quant(x):
    max1 = torch.abs(x).max()
236
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
237
238
    return max1, x.to(torch.int8)

239

Tim Dettmers's avatar
Tim Dettmers committed
240
def dequant(c, maxC):
241
242
    return c.float() * (maxC / 127)

Tim Dettmers's avatar
Tim Dettmers committed
243
244

def mm_dequant(maxA, maxB, C):
245
246
    return C.float() * (maxA / 127) * (maxB / 127)

Tim Dettmers's avatar
Tim Dettmers committed
247
248
249

def quant_multi(x, dim):
    max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
250
251
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
252
253
    return max1, x.to(torch.int8)

254

Tim Dettmers's avatar
Tim Dettmers committed
255
def quant_multi_chunk(x, dim, chunk_size=32):
256
257
258
    if dim == 1:
        x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
        max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
Tim Dettmers's avatar
Tim Dettmers committed
259
260
        max1 = torch.tile(max1, (1, 1, x.shape[1]))
        max1 = max1.view(x.shape)
261
262
    elif dim == 0:
        x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
Tim Dettmers's avatar
Tim Dettmers committed
263
264
265
        max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
        max1 = torch.tile(max1, (x.shape[0], 1, 1))
        max1 = max1.view(x.shape)
266
267
    max1[max1 == 0] = 1.0
    x = torch.round(x / max1 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
268
269
    return max1, x.to(torch.int8)

270

Tim Dettmers's avatar
Tim Dettmers committed
271
272
273
274
def quant_minmax(A):
    minA = A.min()
    maxA = A.max()

275

Tim Dettmers's avatar
Tim Dettmers committed
276
def mean(xx):
277
278
    return sum(xx) / float(len(xx))

Tim Dettmers's avatar
Tim Dettmers committed
279

Aarni Koskela's avatar
Aarni Koskela committed
280
281
methods = {
    "linear": (
282
283
284
285
286
        lambda x, dim: quant(x),
        lambda x, dim: quant(x),
        dequant,
        dequant,
        mm_dequant,
Aarni Koskela's avatar
Aarni Koskela committed
287
288
289
    ),
    "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
}
290
291


Aarni Koskela's avatar
Aarni Koskela committed
292
293
294
295
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
Tim Dettmers's avatar
Tim Dettmers committed
296
297
298
299
300
def test_approx_igemm(dim1, dim2, quant_methods, batched):
    dim1 = dim1 - (dim1 % 32)
    dim2 = dim2 - (dim2 % 32)
    errors = []
    relerrors = []
Ruff's avatar
Ruff committed
301
    # print("")
Tim Dettmers's avatar
Tim Dettmers committed
302
303
    for i in range(5):
        if batched:
304
305
            A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
            B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
306
307
308
            maxA, Ac = quant_methods[0](A, 2)
            maxB, Bc = quant_methods[1](B, 1)
        else:
309
310
            A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
            B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
311
312
            maxA, Ac = quant_methods[0](A, 1)
            maxB, Bc = quant_methods[1](B, 0)
Ruff's avatar
Ruff committed
313
        torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
314
315
316
317
318
319
320
321
        if batched:
            out2 = torch.bmm(A, B)
            C = torch.bmm(Ac.float(), Bc.float())
        else:
            out2 = torch.mm(A, B)
            C = F.igemm(Ac, Bc)
        out = quant_methods[4](maxA, maxB, C)
        std = out2.std()
322
323
324
325
        out /= std
        out2 /= std
        err = torch.abs(out - out2)
        relerr = err / torch.abs(out2)
Tim Dettmers's avatar
Tim Dettmers committed
326
327
        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())
Ruff's avatar
Ruff committed
328
329
    # print(mean(errors))
    # print(mean(relerrors))
Tim Dettmers's avatar
Tim Dettmers committed
330
331


Tim Dettmers's avatar
Tim Dettmers committed
332
333
334
335
336
def test_stable_embedding():
    layer = bnb.nn.StableEmbedding(1024, 1024)
    layer.reset_parameters()


Aarni Koskela's avatar
Aarni Koskela committed
337
338
339
340
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
341
342
343
344
345
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 16)
    seq_dim = seq_dim - (seq_dim % 16)
    for i in range(k):
Ruff's avatar
Ruff committed
346
347
        shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
348
349
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
350
351
352
353
354
355
356
357
358
359
360
361
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())
        elif transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.t().float(), B.float())
            out = F.igemm(A.t(), B)
        elif transpose[0] and transpose[1]:
            out2 = torch.matmul(A.t().float(), B.t().float())
            out = F.igemm(A.t(), B.t())
Tim Dettmers's avatar
Tim Dettmers committed
362

363
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
364

Tim Dettmers's avatar
Tim Dettmers committed
365
366
    for i in range(k):
        shapeA = (batch_dim, seq_dim, hidden_dim)
Ruff's avatar
Ruff committed
367
        shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
368
369
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
370
371
372
373
374
375
376
        if not transpose[0] and not transpose[1]:
            out2 = torch.matmul(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.matmul(A.float(), B.t().float())
            out = F.igemm(A, B.t())

377
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
378
379


Aarni Koskela's avatar
Aarni Koskela committed
380
381
382
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
Tim Dettmers's avatar
Tim Dettmers committed
383
384
385
386
387
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
    seq_dim = seq_dim - (seq_dim % 32)
    hidden_dim = hidden_dim - (hidden_dim % 32)
    batch_dim = batch_dim - (batch_dim % 2)
    for i in range(25):
Ruff's avatar
Ruff committed
388
389
        A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
390
        out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
Ruff's avatar
Ruff committed
391
        iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
Tim Dettmers's avatar
Tim Dettmers committed
392
393
        out = F.igemm(A, B, out=iout)

394
        torch.testing.assert_close(out.float(), out2)
Tim Dettmers's avatar
Tim Dettmers committed
395

396

Aarni Koskela's avatar
Aarni Koskela committed
397
398
399
400
@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
401
402
403
404
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
    def min_max(x):
        maxA = torch.amax(x, dim=2, keepdim=True)
        minA = torch.amin(x, dim=2, keepdim=True)
405
406
        scale = (maxA - minA) / 2.0
        return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
Tim Dettmers's avatar
Tim Dettmers committed
407
408
409
410
411
412
413
414
415

    seq_dim = seq_dim - (seq_dim % 16)
    hidden_dim = hidden_dim - (hidden_dim % 16)
    batch_dim = batch_dim - (batch_dim % 2)
    errs = []
    relerrs = []
    errs2 = []
    relerrs2 = []
    for i in range(k):
Ruff's avatar
Ruff committed
416
        A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
417
        if transpose:
418
            B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
419
        else:
420
            B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
421
422
423
424
        Ac, minA, scale = min_max(A)
        if transpose:
            maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
            out = F.igemm(Ac, Bc.t())
425
426
            out2 = torch.matmul(A, B.t())
            offset = B.t().sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
427
            out = out.float()
428
            out = (out * maxB.t() * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
429
430
431
432
433
434

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc.t())
            out3 = mm_dequant(maxA, maxB.t(), out3)
        else:
            maxB, Bc = quant_multi(B, dim=0)
435
            offset = B.sum(0) * (minA + scale)
Tim Dettmers's avatar
Tim Dettmers committed
436
            out = F.igemm(Ac, Bc)
437
            out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
438
            out = out.float()
439
            out = (out * maxB * scale / (127 * 127)) + offset
Tim Dettmers's avatar
Tim Dettmers committed
440
441
442
443
444
445
446
447
448
449

            maxA, Ac = quant_multi(A, dim=2)
            out3 = F.igemm(Ac, Bc)
            out3 = mm_dequant(maxA, maxB, out3)

        std = out2.std()
        out2 /= std
        out /= std
        out3 /= std

450
451
        err = torch.abs(out - out2)
        relerr = err / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
452

453
454
        err2 = torch.abs(out3 - out2)
        relerr2 = err2 / (torch.abs(out2) + 1e-7)
Tim Dettmers's avatar
Tim Dettmers committed
455
456
457
458
459

        errs.append(err.mean().item())
        relerrs.append(relerr.mean().item())
        errs2.append(err2.mean().item())
        relerrs2.append(relerr2.mean().item())
460
461
462
463
    # print(mean(errs))
    # print(mean(relerrs))
    # print(mean(errs2))
    # print(mean(relerrs2))
Tim Dettmers's avatar
Tim Dettmers committed
464
465
466
    assert mean(errs) < 0.015
    assert mean(relerrs) < 0.3

467

Aarni Koskela's avatar
Aarni Koskela committed
468
469
470
471
472
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
473
474
475
476
477
478
479
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    dim4 = dim4 - (dim4 % 16)
    for i in range(k):
        shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
        shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
480
481
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
482
483
484
485
486
487
488
489
490
491
492

        if not transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.float(), B.float())
            out = F.igemm(A, B)
        elif not transpose[0] and transpose[1]:
            out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
            out = F.igemm(A, B.permute([0, 2, 1]))
        elif transpose[0] and not transpose[1]:
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
            out = F.igemm(A.permute([0, 2, 1]), B)
        elif transpose[0] and transpose[1]:
Ruff's avatar
Ruff committed
493
            out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
Tim Dettmers's avatar
Tim Dettmers committed
494
            out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
495
        torch.testing.assert_close(out.float(), out2.float())
Tim Dettmers's avatar
Tim Dettmers committed
496

497

Aarni Koskela's avatar
Aarni Koskela committed
498
499
500
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
Tim Dettmers's avatar
Tim Dettmers committed
501
502
503
504
def test_vector_quant(dim1, dim2, dim3):
    dim2 = dim2 - (dim2 % 16)
    dim3 = dim3 - (dim3 % 16)
    for i in range(k):
505
        A = torch.randn(size=(dim2, dim3), device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
506
507
        qA, SA = F.vectorwise_quant(A, dim=0)
        A1 = F.vectorwise_dequant(qA, SA)
508
        n = A1.numel()
Ruff's avatar
Ruff committed
509
        assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
510
511


Aarni Koskela's avatar
Aarni Koskela committed
512
513
514
515
516
517
518
519
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
520
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
521
    if dims == 3 and orderOut != "col32":
522
        return
523
    if dtype == torch.int32 and orderOut != "col32":
524
        return
525
526
527
528
    try:
        func = F.get_transform_func(dtype, orderA, orderOut, transpose)
    except ValueError as ve:
        pytest.skip(str(ve))  # skip if not supported
Tim Dettmers's avatar
Tim Dettmers committed
529
530

    if dims == 2:
531
        A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
532
    elif dims == 3:
Ruff's avatar
Ruff committed
533
        A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
534
535
536

    out, S = F.nvidia_transform(A, to_order=orderOut)

537
    if orderOut == "row":
538
        torch.testing.assert_close(A.flatten(), out.flatten())
539
    elif orderOut == "col":
540
        torch.testing.assert_close(A.t().flatten(), out.flatten())
541
    elif orderOut == "col32":
Tim Dettmers's avatar
Tim Dettmers committed
542
        if dims == 2:
543
            n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
544
        elif dims == 3:
Ruff's avatar
Ruff committed
545
            n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
546
        assert out.numel() == n
547
    elif orderOut == "col_turing":
Tim Dettmers's avatar
Tim Dettmers committed
548
        # 32 col 8 row tiles
Ruff's avatar
Ruff committed
549
        n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
Tim Dettmers's avatar
Tim Dettmers committed
550
551
552
553
        assert out.numel() == n
        total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
        for row in range(A.shape[0]):
            for col in range(A.shape[1]):
554
                i = row * A.shape[1]
Tim Dettmers's avatar
Tim Dettmers committed
555
556
557
                j = col

                coltile = (col // 32) + (1 if col % 32 != 0 else 0)
Ruff's avatar
Ruff committed
558
                rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
559
                offset = 32 * 8 * (rowtile + coltile)
Tim Dettmers's avatar
Tim Dettmers committed
560
                col2 = col % 32
561
                row2 = (row % 8) * 32
Tim Dettmers's avatar
Tim Dettmers committed
562

563
564
                assert A.flatten()[i + j] == A[row, col]
                # assert A.flatten()[i+j] == out.flatten()[row2+col2]
565
566
                # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
                # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
Tim Dettmers's avatar
Tim Dettmers committed
567

568
    if orderOut == "col32":
Ruff's avatar
Ruff committed
569
        out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
570
        torch.testing.assert_close(A, out2)
Tim Dettmers's avatar
Tim Dettmers committed
571
572


Aarni Koskela's avatar
Aarni Koskela committed
573
574
575
576
577
578
@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
Tim Dettmers's avatar
Tim Dettmers committed
579
580
581
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
    for i in range(k):
        if dims == 2:
Ruff's avatar
Ruff committed
582
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
583
        elif dims == 3:
Ruff's avatar
Ruff committed
584
585
            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
586
587
        C1 = torch.matmul(A.float(), B.t().float())

588
589
        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
590
        C2, SC = F.igemmlt(A2, B2, SA, SB)
591
        C3, S = F.nvidia_transform(C2, "row", state=SC)
592
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
593
594

        # transpose
Ruff's avatar
Ruff committed
595
        B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
596
597
        C1 = torch.matmul(A.float(), B.float())

598
        B2t, SBt = F.transform(B, "col_turing", transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
599
        C2, SC = F.igemmlt(A2, B2t, SA, SBt)
600
        C3, S = F.nvidia_transform(C2, "row", state=SC)
601
        torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
602

603

Aarni Koskela's avatar
Aarni Koskela committed
604
605
606
607
608
@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
Tim Dettmers's avatar
Tim Dettmers committed
609
610
611
612
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
    formatB = F.get_special_format_str()
    for i in range(k):
        if dims == 2:
613
            A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
614
        elif dims == 3:
Ruff's avatar
Ruff committed
615
            A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
616
        B = torch.randn((dim4, dim3), device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
617
618
619
620
621
622
623
624
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())
        C2 = bnb.matmul(A, B.t())

        A = A.view(-1, A.shape[-1])

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
        CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
625
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
626
627
628
629
        CxB, SB = F.transform(CB, to_order=formatB)
        out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
        output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)

630
631
632
633
        # print('')
        # print(output.flatten()[:10])
        # print(C1.flatten()[:10])
        # print(C2.flatten()[:10])
Tim Dettmers's avatar
Tim Dettmers committed
634

635
        # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
Tim Dettmers's avatar
Tim Dettmers committed
636
637

        # transpose
638
639
640
641
642
643
        # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
        # C1 = torch.matmul(A.float(), B.float())

        # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
        # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
        # C3, S = F.transform(C2, 'row', state=SC)
644
        # torch.testing.assert_close(C1, C3.float())
Tim Dettmers's avatar
Tim Dettmers committed
645

Ruff's avatar
Ruff committed
646

Aarni Koskela's avatar
Aarni Koskela committed
647
648
649
650
651
652
653
654
655
@pytest.mark.parametrize(
    ("batch", "seq", "model", "hidden"),
    [
        pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
        pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
        pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
    ],
)
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
656
657
def test_bench_8bit_training(batch, seq, model, hidden):
    formatB = F.get_special_format_str()
658
659
660
661
662
    A = torch.randn(batch, seq, model, device="cuda").half()
    grad = torch.randn(batch, seq, model, device="cuda").half()
    w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
    w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
663

664
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
665
    ## warmup
666
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
667
    #    torch.matmul(A, w1.t())
668
    # torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
669
670
671
672
673
674
675

    dtype = torch.int8
    A = A.view(-1, A.shape[-1]).contiguous()
    grad = grad.view(-1, grad.shape[-1]).contiguous()
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
676
677
        out1 = torch.matmul(A, w1.t())  # fc1
        # out2 = torch.matmul(out1, w2.t())# fc2
Tim Dettmers's avatar
Tim Dettmers committed
678

679
680
        # d1 = torch.matmul(grad, w2) # delta1
        # d2 = torch.matmul(d1, w1) # delta2
Tim Dettmers's avatar
Tim Dettmers committed
681

682
683
        # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
        # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
Tim Dettmers's avatar
Tim Dettmers committed
684
685
686
687
688

    torch.cuda.synchronize()
    t16 = time.time() - t0
    print(t16)

689
    # torch.cuda.empty_cache()
Tim Dettmers's avatar
Tim Dettmers committed
690

691
692
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
693

694
695
696
697
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
698

699
700
    # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    # C32A, SA = F.transform2(CA, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
701
    ## fc1
702
    # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
703
704
705
    ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)

    ## fc2
706
707
708
    # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    # C32out1, Sout1 = F.transform2(Cout1, 'col32')
    # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
Tim Dettmers's avatar
Tim Dettmers committed
709
710
711
    ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)

    ## delta1
712
713
    # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
714
715
716
717
    ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)

    ## delta2
718
719
    # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    # C32d1, Sd1 = F.transform2(Cd1, 'col32')
Tim Dettmers's avatar
Tim Dettmers committed
720
721
722
723
    ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)

    ## grad1
724
725
    # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
726
727
728
729
    ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)

    ## grad2
730
731
    # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
Tim Dettmers's avatar
Tim Dettmers committed
732
733
734
    ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)

735
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
736

737
738
    # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
Tim Dettmers's avatar
Tim Dettmers committed
739

740
741
742
743
744
745
746
    # CTw1, Sw1 = F.transform2(Cw1, formatB)
    # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    # CTw2, Sw2 = F.transform2(Cw2, formatB)
    # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(k):
Tim Dettmers's avatar
Tim Dettmers committed
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)
    #    #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
    #    #CTw1, Sw1 = F.transform2(Cw1, formatB)

    #    #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
    #    CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
    #    #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
    #    #CTw2, Sw2 = F.transform2(Cw2, formatB)
    #    #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)

    #    C32A, SA = F.transform2(CA, 'col32')

    #    # fc1
    #    out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
    #    #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

    #    #print(coo_tensor.nnz)
    #    #out1sp = F.spmm_coo(coo_tensor, w1.t())
    #    #print(w1.t().shape)
    #    #out1 = out1dn + out1sp

    #    # fc2
    #    Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
    #    C32out1, Sout1 = F.transform2(Cout1, 'col32')
    #    out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
    #    #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)

    #    # delta1
    #    Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
    #    C32grad, Sgrad = F.transform2(Cgrad, 'col32')
    #    d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
    #    #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)

    #    # delta2
    #    Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
    #    C32d1, Sd1 = F.transform2(Cd1, 'col32')
    #    d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
    #    #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)

    #    # grad1
    #    #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
    #    #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
    #    #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
    #    #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)

    #    ## grad2
    #    #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
    #    #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
    #    #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
    #    #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)

799
800
801
    # torch.cuda.synchronize()
    # t8 = time.time() - t0
    # print(t8)
Tim Dettmers's avatar
Tim Dettmers committed
802
803


Aarni Koskela's avatar
Aarni Koskela committed
804
805
806
807
808
@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
809
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
Tim Dettmers's avatar
Tim Dettmers committed
810
    inner = torch.randint(1, 128, size=(1,)).item()
811
    bias = None
Ruff's avatar
Ruff committed
812
813
    if has_bias:
        bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
814
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
815
    for i in range(1):
816
817
        A = torch.randn(dim1, inner, device="cuda")
        B = torch.randn(dim4, inner, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
818
        C1 = torch.matmul(A.half(), B.t().half())
Ruff's avatar
Ruff committed
819
820
        if has_bias:
            C1 += bias
Tim Dettmers's avatar
Tim Dettmers committed
821
822
823
824

        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

825
        A2, SA = F.nvidia_transform(A1, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
826
827
828
        B2, SB = F.nvidia_transform(B1, formatB)
        C2, SC = F.igemmlt(A2, B2, SA, SB)

829
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
830
        C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
Ruff's avatar
Ruff committed
831
832
        if has_bias:
            C4 += bias
Tim Dettmers's avatar
Tim Dettmers committed
833

834
        # TODO: is something wrong here? If so, the problem goes deeper
Aarni Koskela's avatar
Aarni Koskela committed
835
836
        # n = C1.numel()
        # p = 0.06
837
838
839
        std = C1.std(0).view(1, -1)
        C1 /= std
        C4 /= std
Aarni Koskela's avatar
Aarni Koskela committed
840
841
        # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
        # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
Tim Dettmers's avatar
Tim Dettmers committed
842

843
        C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
Aarni Koskela's avatar
Aarni Koskela committed
844
        # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
845
        n = C5.numel()
Aarni Koskela's avatar
Aarni Koskela committed
846
        assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
Tim Dettmers's avatar
Tim Dettmers committed
847

848

Aarni Koskela's avatar
Aarni Koskela committed
849
850
851
@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
Tim Dettmers's avatar
Tim Dettmers committed
852
853
854
def test_colrow_absmax(dim1, dim2, dims):
    for i in range(k):
        threshold = 3.0
855
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
856
857
858
859
860
861
862
863
864
865
        A_truncated = A.clone()
        A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
        if dims == 2:
            row_stats1, _ = torch.abs(A.float()).max(1)
            col_stats1, _ = torch.abs(A.float()).max(0)
            row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
            col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
        else:
            assert False

Ruff's avatar
Ruff committed
866
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
867
868
869
870
871
872
873
874
875
876
877
878
879

        A_blocked = einops.rearrange(
            torch.abs(A),
            "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
            row_tiles=16,
            block_size=64 * 4,
        )
        nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
        nnz_block_ptr1 = torch.zeros(
            nnz_rows1_counts.shape[0] + 1,
            dtype=nnz_rows1_counts.dtype,
            device=nnz_rows1_counts.device,
        )
Tim Dettmers's avatar
Tim Dettmers committed
880
881
        nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

882
883
884
        torch.testing.assert_close(col_stats1_trunc, col_stats2)
        torch.testing.assert_close(row_stats1_trunc, row_stats2)
        torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
Tim Dettmers's avatar
Tim Dettmers committed
885

Ruff's avatar
Ruff committed
886
        row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
Tim Dettmers's avatar
Tim Dettmers committed
887

888
889
        torch.testing.assert_close(col_stats1, col_stats2)
        torch.testing.assert_close(row_stats1, row_stats2)
Tim Dettmers's avatar
Tim Dettmers committed
890
891
892
        assert nnz_block_ptr2 is None


Aarni Koskela's avatar
Aarni Koskela committed
893
894
@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
895
896
def test_double_quant(dim1, dim2):
    for i in range(k):
897
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
898
899
900
901
902
903
        out_col1, Scol = F.vectorwise_quant(A, dim=0)
        out_row1, Srow = F.vectorwise_quant(A, dim=1)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)

        # max difference is 1 due to rounding differences
904
905
        torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
        torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
Tim Dettmers's avatar
Tim Dettmers committed
906
907

        n = CAt.numel()
Ruff's avatar
Ruff committed
908
909
        num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
        num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
910
911

        # allow for 1:500 error due to rounding differences
912
913
        min_error = 1 / 500
        if num_not_close_cols > (min_error * n):
Ruff's avatar
Ruff committed
914
            print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
Tim Dettmers's avatar
Tim Dettmers committed
915
            assert False
916
        if num_not_close_rows > (min_error * n):
Ruff's avatar
Ruff committed
917
            print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
Tim Dettmers's avatar
Tim Dettmers committed
918
919
            assert False

920
921
        torch.testing.assert_close(Srow.flatten().float(), statsA)
        torch.testing.assert_close(Scol.flatten().float(), statsAt)
Tim Dettmers's avatar
Tim Dettmers committed
922
923


Aarni Koskela's avatar
Aarni Koskela committed
924
925
926
927
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
Ruff's avatar
Ruff committed
928
        for (dim1, dim4, inner) in zip(
Aarni Koskela's avatar
Aarni Koskela committed
929
930
931
932
            get_test_dims(1, 4 * 1024, n=4),
            get_test_dims(1, 4 * 1024, n=4),
            get_test_dims(1, 4 * 1024, n=4),
        )
Ruff's avatar
Ruff committed
933
    ),
Aarni Koskela's avatar
Aarni Koskela committed
934
)
Tim Dettmers's avatar
Tim Dettmers committed
935
936
def test_integrated_igemmlt(dim1, dim4, inner):
    for i in range(k):
937
938
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
939
940
941
942
943
944
945
946

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        A1, maxA = F.vectorwise_quant(A, dim=1)
        B1, maxB = F.vectorwise_quant(B, dim=1)

947
948
949
950
        torch.testing.assert_close(maxA.flatten().float(), stats1a)
        torch.testing.assert_close(maxB.flatten().float(), stats2a)
        torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
        torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
Tim Dettmers's avatar
Tim Dettmers committed
951

952
953
        A2, SA = F.nvidia_transform(C1a, "col32")
        B2, SB = F.nvidia_transform(C2a, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
954
955
956
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

957
958
        A2, SA = F.nvidia_transform(A1, "col32")
        B2, SB = F.nvidia_transform(B1, "col_turing")
Tim Dettmers's avatar
Tim Dettmers committed
959
960
        C2, SC = F.igemmlt(A2, B2, SA, SB)

961
        C3, S = F.nvidia_transform(C2, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
962
963
        out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())

964
965
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out3).mean().item()
966
        assert err2 <= err1 * 1.025
Tim Dettmers's avatar
Tim Dettmers committed
967
968


Aarni Koskela's avatar
Aarni Koskela committed
969
970
971
972
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    (
        pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
Ruff's avatar
Ruff committed
973
        for (dim1, dim4, inner) in zip(
Aarni Koskela's avatar
Aarni Koskela committed
974
975
976
977
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
            get_test_dims(1, 4 * 1024, n=6),
        )
Ruff's avatar
Ruff committed
978
    ),
Aarni Koskela's avatar
Aarni Koskela committed
979
)
980
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
981
982
983
984
985
986
def test_igemmlt_row_scale(dim1, dim4, inner):
    formatB = F.get_special_format_str()
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
    for i in range(k):
987
988
        A = torch.randn(dim1, inner, device="cuda").half()
        B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
989
990
991
992
993
994
        torch.nn.init.xavier_uniform_(B)
        C1 = torch.matmul(A, B.t())

        out1 = torch.matmul(A.half(), B.t().half())

        C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
995
996
        CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
        A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
997
998
999
        B2, SB = F.nvidia_transform(CB, formatB)
        A1, maxA = F.vectorwise_quant(A, dim=1)

1000
1001
        c = 10.0 * inner * scale
        row_scale = torch.ones_like(maxA) / c
Ruff's avatar
Ruff committed
1002
        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
1003
        C3, S = F.nvidia_transform(outC32, "row", state=SC)
Tim Dettmers's avatar
Tim Dettmers committed
1004
1005
1006
1007
        maxval = torch.abs(C3).max()
        if maxval == 127:
            scale = 1.5
        else:
1008
1009
            scale = maxval / 120
        out3 = C3 * maxA * absmaxB * c / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
1010
1011
1012
1013
1014
1015
1016
1017

        C4 = torch.matmul(C1a.float(), CB.float().t())

        C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
        B2, SB = F.nvidia_transform(C2a, formatB)
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
        out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)

1018
1019
        CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
        CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
Tim Dettmers's avatar
Tim Dettmers committed
1020
1021

        C = torch.matmul(CA.float(), CB.t().float())
1022
1023
        out4 = C * SA * SB / (127 * 127)
        # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
Tim Dettmers's avatar
Tim Dettmers committed
1024

1025
1026
1027
1028
        # print('='*80)
        # print(out1)
        # print(out2)
        # print(out3)
Tim Dettmers's avatar
Tim Dettmers committed
1029

1030
1031
1032
1033
1034
1035
        # print(out1)
        # print(out2)
        # print(out3)
        err1.append(torch.abs(out1 - out2).mean().item())
        err2.append(torch.abs(out1 - out3).mean().item())
        err3.append(torch.abs(out1 - out4).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1036

1037
1038
1039
1040
1041
        # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
    print("")
    print(sum(err1) / len(err1))
    print(sum(err2) / len(err2))
    print(sum(err3) / len(err3))
Tim Dettmers's avatar
Tim Dettmers committed
1042
1043


Aarni Koskela's avatar
Aarni Koskela committed
1044
1045
1046
1047
1048
1049
1050
@pytest.mark.parametrize(
    ("dim1", "dim4", "inner"),
    [
        pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
        pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
    ],
)
1051
@pytest.mark.skip("Row scale has some bugs for ampere")
Aarni Koskela's avatar
Aarni Koskela committed
1052
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1053
def test_row_scale_bench(dim1, dim4, inner):
Aarni Koskela's avatar
Aarni Koskela committed
1054
    formatB = F.get_special_format_str()
Tim Dettmers's avatar
Tim Dettmers committed
1055
1056
1057
    err1, err2, err3 = [], [], []
    relerr1, relerr2 = [], []
    scale = 1
1058
1059
    A = torch.randn(dim1, inner, device="cuda").half()
    B = torch.randn(dim4, inner, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    torch.nn.init.xavier_uniform_(B)
    # warmpup
    for i in range(k):
        C1 = torch.matmul(A, B.t())

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        C1 = torch.matmul(A, B.t())
    torch.cuda.synchronize()
1070
    print("16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1071
1072

    C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
1073
1074
    CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
    A2, SA = F.nvidia_transform(C1a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1075
1076
1077
    B2, SB = F.nvidia_transform(CB, formatB)
    A1, maxA = F.vectorwise_quant(A, dim=1)

1078
1079
    c = 10.0 * inner * scale
    row_scale = maxA / c
Tim Dettmers's avatar
Tim Dettmers committed
1080
1081
1082
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
Ruff's avatar
Ruff committed
1083
        outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
Tim Dettmers's avatar
Tim Dettmers committed
1084
    torch.cuda.synchronize()
1085
    print("row-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1086
1087
1088
1089
1090
1091
1092
1093

    C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
    B2, SB = F.nvidia_transform(C2a, formatB)
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        outC32, SC = F.igemmlt(A2, B2, SA, SB)
    torch.cuda.synchronize()
1094
    print("vector-wise", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1095
1096


Aarni Koskela's avatar
Aarni Koskela committed
1097
1098
1099
1100
1101
1102
1103
1104
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
Tim Dettmers's avatar
Tim Dettmers committed
1105
1106
1107
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
    for i in range(k):
        if dims == 2:
Ruff's avatar
Ruff committed
1108
            A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1109
        elif dims == 3:
Ruff's avatar
Ruff committed
1110
            A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
Tim Dettmers's avatar
Tim Dettmers committed
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121

        A.view(-1)[-1] = -1
        if transpose:
            At = A.t().contiguous()
            out1, S1 = F.nvidia_transform(At, to_order=orderOut)
        else:
            out1, S1 = F.nvidia_transform(A, to_order=orderOut)
        out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)

        assert S1[0][0] == S2[0][0]
        assert S1[0][1] == S2[0][1]
1122
1123
        # print(out1)
        # print(out2)
Tim Dettmers's avatar
Tim Dettmers committed
1124

1125
        torch.testing.assert_close(out1, out2)
Tim Dettmers's avatar
Tim Dettmers committed
1126

1127

Tim Dettmers's avatar
Tim Dettmers committed
1128
1129
def test_overflow():
    formatB = F.get_special_format_str()
1130
    print(formatB)
Tim Dettmers's avatar
Tim Dettmers committed
1131
    for i in range(2):
1132
1133
        a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
        b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Tim Dettmers's avatar
Tim Dettmers committed
1134

1135
        Ca, Sa = F.nvidia_transform(a, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1136
1137
1138
1139
1140
1141
        Cb, Sb = F.nvidia_transform(b, formatB)

        c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
        c2 = torch.matmul(a.float(), b.float().t())


Aarni Koskela's avatar
Aarni Koskela committed
1142
1143
@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
1144
1145
1146
def test_coo_double_quant(dim1, dim2):
    threshold = 3.00
    for i in range(k):
1147
        A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1148

1149
        idx = torch.abs(A) >= threshold
Tim Dettmers's avatar
Tim Dettmers committed
1150
        CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
Ruff's avatar
Ruff committed
1151
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
Tim Dettmers's avatar
Tim Dettmers committed
1152
1153

        if coo_tensor is not None:
1154
            A1 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1155
            A2 = torch.zeros_like(A)
Ruff's avatar
Ruff committed
1156
            A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
1157
            torch.testing.assert_close(A1, A2)
Tim Dettmers's avatar
Tim Dettmers committed
1158

1159
1160
            A1 = A * (idx == 0)
            A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
Ruff's avatar
Ruff committed
1161
            torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
1162

Tim Dettmers's avatar
Tim Dettmers committed
1163

Aarni Koskela's avatar
Aarni Koskela committed
1164
1165
1166
@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
Tim Dettmers's avatar
Tim Dettmers committed
1167
1168
1169
def test_spmm_coo(dim1, dim2, transposed_B):
    threshold = 1.5
    dim3 = torch.randint(32, 128, size=(1,)).item()
1170
    # dim3 = 17
Tim Dettmers's avatar
Tim Dettmers committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        if transposed_B:
            B = torch.randn(dim3, dim2).cuda().half()
        else:
            B = torch.randn(dim2, dim3).cuda().half()

        idx = torch.abs(A) >= threshold
        nnz = (idx == 1).sum().item()
        rows, cols = torch.where(idx)
        values = A[idx]
Ruff's avatar
Ruff committed
1182
        cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
1183
        A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194

        if transposed_B:
            out2 = F.spmm_coo(cooA, B.t())
            out1 = torch.matmul(A2, B.t())
        else:
            out2 = F.spmm_coo(cooA, B)
            out1 = torch.matmul(A2, B)

        assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)


Aarni Koskela's avatar
Aarni Koskela committed
1195
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1196
1197
def test_spmm_bench():
    batch = 2
1198
1199
    model = 1024 * 1
    hidden = model * 4
Tim Dettmers's avatar
Tim Dettmers committed
1200
    seq = 1024
1201
    dim1 = batch * seq
Tim Dettmers's avatar
Tim Dettmers committed
1202
1203
1204
    dim2 = model
    dim3 = hidden
    threshold = 4
1205
1206
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.randn(dim2, dim3, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1207
    for i in range(10):
1208
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1209
1210
1211
1212

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
1213
        C1 = bnb.matmul(A, B.t())
Tim Dettmers's avatar
Tim Dettmers committed
1214
    torch.cuda.synchronize()
1215
    t8 = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1216
1217
1218

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
1219
    print(nnz / idx.numel())
Tim Dettmers's avatar
Tim Dettmers committed
1220
1221
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
1222
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
Tim Dettmers's avatar
Tim Dettmers committed
1223
1224

    for i in range(10):
Tim Dettmers's avatar
Tim Dettmers committed
1225
1226
1227
1228
1229
1230
1231
        out2 = F.spmm_coo(cooA, B)

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(k):
        out2 = F.spmm_coo(cooA, B)
    torch.cuda.synchronize()
1232
    tsp = time.time() - t0
Tim Dettmers's avatar
Tim Dettmers committed
1233
    print(tsp, t8)
1234
    print(tsp / t8)
Tim Dettmers's avatar
Tim Dettmers committed
1235
1236


Aarni Koskela's avatar
Aarni Koskela committed
1237
1238
@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2"))
Tim Dettmers's avatar
Tim Dettmers committed
1239
1240
def test_integrated_sparse_decomp(dim1, dim2):
    threshold = 3.0
1241
    formatB = "col_turing"
Tim Dettmers's avatar
Tim Dettmers committed
1242
1243
1244
1245
1246
1247
1248
1249
1250
    for i in range(k):
        A = torch.randn(dim1, dim2).cuda().half()
        w1 = torch.randn(dim1, dim2).cuda().half()
        out1 = torch.matmul(A, w1.t())

        Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
        CTw1, Sw1 = F.transform(Cw1, formatB)

        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
1251
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1252
1253
1254
1255

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

Ruff's avatar
Ruff committed
1256
        CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
1257
        C32A, SA = F.transform(CA, "col32")
Tim Dettmers's avatar
Tim Dettmers committed
1258
1259
1260
1261
1262
1263
1264
1265
1266

        out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
        out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)

        assert coo_tensor is not None

        out4 = F.spmm_coo(coo_tensor, w1.t())
        out5 = out3 + out4

1267
1268
        err1 = torch.abs(out1 - out2).mean().item()
        err2 = torch.abs(out1 - out5).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1269
1270
1271
1272
        assert err2 < err1


def test_matmuls():
1273
1274
1275
    a = torch.randn(256, 512).half().cuda()
    b = torch.randn(256, 512).half().cuda()
    c1 = torch.matmul(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1276
    c2 = bnb.matmul(a, b)
1277
    c3 = bnb.matmul_cublas(a, b.t())
Tim Dettmers's avatar
Tim Dettmers committed
1278

1279
1280
    err1 = torch.abs(c1 - c2).mean().item()
    err2 = torch.abs(c1 - c3).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1281
1282
    assert err1 < 0.2
    assert err2 < 0.2
1283
    print(err1, err2)
Tim Dettmers's avatar
Tim Dettmers committed
1284
1285


Aarni Koskela's avatar
Aarni Koskela committed
1286
1287
1288
1289
@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
Tim Dettmers's avatar
Tim Dettmers committed
1290
1291
1292
1293
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
    out_func = getattr(torch, out_func)

    threshold = 3.3
1294
1295
1296
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1297
    if dtype == torch.float16:
1298
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1299
1300
        torch.nn.init.xavier_uniform_(B)
    else:
1301
        B = torch.randn(dim2, dim2 * 4, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
1302
        torch.nn.init.xavier_uniform_(B)
1303
1304
        B, SB = F.vectorwise_quant(B, quant_type="linear")
        # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
1305

1306
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1307
1308
1309
1310
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
1311
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
1312
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1313
1314
1315
1316
    out1 = torch.matmul(A2.half(), B.half())
    out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
    out1 += out.clone()
    out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
1317
1318
1319
1320
    # print(B)
    # print(out1)
    # print(out2)
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1321
    n = out1.numel()
1322
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1323
1324
1325
    std = out1.std()
    out1 /= std
    out2 /= std
Ruff's avatar
Ruff committed
1326
    assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
1327
    # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
Tim Dettmers's avatar
Tim Dettmers committed
1328
1329
1330

    idx_col = torch.randint(0, A2.shape[-1], size=(15,))

1331
    # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1332

1333
1334
1335
1336
1337
    # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
    # torch.cuda.synchronize()
    # t0 = time.time()
    # print(A2.shape, B.shape)
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1338
1339
1340
1341
1342
    #   #out3 = F.spmm_coo(cooA, Bt.t())
    #   #out2 = F.spmm_coo(cooA, B)
    #   #out2 = F.spmm_coo_very_sparse(cooA, B)
    #   #out1 = torch.matmul(A, Bt.t())

1343
1344
1345
    # torch.cuda.synchronize()
    # print(time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1346
1347
1348
1349
1350
1351
1352
1353

def test_coo2csr():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
1354
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
1355
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1356
1357
1358
1359
    csrA = F.coo2csr(cooA)
    counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
    assert counts.numel() == A.shape[0]

1360
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
1361
    idx = A2 != 0
1362
    torch.testing.assert_close(A2[idx], csrA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1363
1364
1365
1366
1367
1368
1369
1370
1371


def test_coo2csc():
    threshold = 1
    A = torch.randn(128, 128).half().cuda()
    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
1372
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
1373
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1374
1375
1376
1377
    cscA = F.coo2csc(cooA)
    counts = cscA.colptr[1:] - cscA.colptr[:-1]
    assert counts.numel() == A.shape[1]

1378
    torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
Tim Dettmers's avatar
Tim Dettmers committed
1379
    # torch uses row-major -> use transpose to transfer to col-major
1380
    idx = A2.t() != 0
1381
    torch.testing.assert_close(A2.t()[idx], cscA.values)
Tim Dettmers's avatar
Tim Dettmers committed
1382
1383


Aarni Koskela's avatar
Aarni Koskela committed
1384
1385
1386
@pytest.mark.parametrize("dim1", [1 * 2048])
@pytest.mark.parametrize("dim2", [2048])
@pytest.mark.parametrize("dtype", [torch.int8])
Tim Dettmers's avatar
Tim Dettmers committed
1387
1388
def test_spmm_coo_dequant(dim1, dim2, dtype):
    threshold = 6.0
1389
1390
1391
1392
    # threshold = 2.8
    # threshold = 0.0
    A = torch.randn(dim1, dim2, device="cuda").half()
    B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Tim Dettmers's avatar
Tim Dettmers committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    torch.nn.init.xavier_uniform_(B)
    Bt = B.t().contiguous()

    CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)

    rowidx = torch.randint(0, A.shape[-1], size=(15,))

    A[:, rowidx] = 8.0

    idx = torch.abs(A) >= threshold
    nnz = (idx == 1).sum().item()
    rows, cols = torch.where(idx)
    values = A[idx]
Ruff's avatar
Ruff committed
1406
    cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
1407
    A2 = A * idx
Tim Dettmers's avatar
Tim Dettmers committed
1408
1409
1410
    out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
    out1 = torch.matmul(A2, B.half())
    out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
1411
    out3 = out3 * statsBt.half() / 127
Tim Dettmers's avatar
Tim Dettmers committed
1412
1413
1414
1415
1416
1417

    values, counts = torch.unique(cooA.rowidx, return_counts=True)
    offset = counts.cumsum(0).int()
    max_count, max_idx = torch.sort(counts, descending=True)
    print(torch.median(max_count.float()))

1418
    torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
Tim Dettmers's avatar
Tim Dettmers committed
1419

1420
    p = 200 / (2048 * 12288 * 4)
Tim Dettmers's avatar
Tim Dettmers committed
1421
    n = out1.numel()
1422
    count = math.ceil(p * n)
Tim Dettmers's avatar
Tim Dettmers committed
1423
1424
    assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)

1425
1426
1427
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(100):
Tim Dettmers's avatar
Tim Dettmers committed
1428
    #   out2 = F.spmm_coo_very_sparse(cooA, B)
1429
1430
    # torch.cuda.synchronize()
    # print('fp16', time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1431
1432
1433
1434

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1435
        out2 = F.spmm_coo(cooA, B)
Tim Dettmers's avatar
Tim Dettmers committed
1436
    torch.cuda.synchronize()
1437
    print("cusparse fp16", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1438
1439
1440
1441

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1442
        out2 = F.spmm_coo_very_sparse(cooA, CBt)
Tim Dettmers's avatar
Tim Dettmers committed
1443
    torch.cuda.synchronize()
1444
    print("int8", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1445
1446
1447
1448

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1449
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
Tim Dettmers's avatar
Tim Dettmers committed
1450
    torch.cuda.synchronize()
1451
    print("int8+dequant", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1452
1453
1454
1455

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1456
        out2 = torch.matmul(A, B)
Tim Dettmers's avatar
Tim Dettmers committed
1457
    torch.cuda.synchronize()
1458
    print("matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1459
1460
1461
1462
1463
1464

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
1465
        out = out1 + out2
Tim Dettmers's avatar
Tim Dettmers committed
1466
    torch.cuda.synchronize()
1467
    print("sparse+ matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1468
1469
1470
1471
1472
1473
1474

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
        torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
    torch.cuda.synchronize()
1475
    print("partial matmul", time.time() - t0)
Tim Dettmers's avatar
Tim Dettmers committed
1476
1477
1478
1479
1480
1481

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
        out1 = bnb.matmul(A, Bt)
    torch.cuda.synchronize()
1482
1483
    print("partial matmul", time.time() - t0)

Tim Dettmers's avatar
Tim Dettmers committed
1484

Aarni Koskela's avatar
Aarni Koskela committed
1485
1486
@pytest.mark.parametrize(
    ("batch", "seq", "model", "hidden"),
Ruff's avatar
Ruff committed
1487
    [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")],
Aarni Koskela's avatar
Aarni Koskela committed
1488
1489
)
@pytest.mark.benchmark
Tim Dettmers's avatar
Tim Dettmers committed
1490
def test_bench_matmul(batch, seq, model, hidden):
1491
    iters = 1000
Tim Dettmers's avatar
Tim Dettmers committed
1492
1493
    formatB = F.get_special_format_str()

1494
1495
    A = torch.randn(batch, seq, model, device="cuda").half()
    B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1496
1497
    torch.nn.init.xavier_uniform_(B)

1498
    B_fp4, state = F.quantize_fp4(B)
1499
    B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
1500

1501
    B_nf4, state_nf4 = F.quantize_nf4(B)
1502
    B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
1503

Tim Dettmers's avatar
Tim Dettmers committed
1504
    linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
1505
1506
1507
1508
1509
    linear8bit.eval()

    outliers = torch.randint(0, model, size=(5,)).cuda()
    A[:, :, outliers] = 8.0

Ruff's avatar
Ruff committed
1510
1511
    linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
    # linearMixedBit.eval()
Tim Dettmers's avatar
Tim Dettmers committed
1512

1513
1514
    linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
    linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
1515
    bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1516

Tim Dettmers's avatar
Tim Dettmers committed
1517
    # warmup
1518
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1519
1520
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
1521
    print("")
Tim Dettmers's avatar
Tim Dettmers committed
1522
1523
1524

    torch.cuda.synchronize()
    t0 = time.time()
1525
    for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1526
1527
        torch.matmul(A, B.t())
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1528
1529
1530
    print(
        f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
    )
1531

Ruff's avatar
Ruff committed
1532
1533
1534
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
1535
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
Ruff's avatar
Ruff committed
1536
1537
    # torch.cuda.synchronize()
    # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
Tim Dettmers's avatar
Tim Dettmers committed
1538

Ruff's avatar
Ruff committed
1539
1540
1541
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
1542
    #    bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
Ruff's avatar
Ruff committed
1543
1544
    # torch.cuda.synchronize()
    # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
1545

1546
1547
1548
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
1549
        bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
1550
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1551
    print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
1552

1553
1554
1555
1556
1557
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
        bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1558
    print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
1559

Ruff's avatar
Ruff committed
1560
1561
1562
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1563
    #    bnb.matmul(A, B)
Ruff's avatar
Ruff committed
1564
1565
    # torch.cuda.synchronize()
    # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1566

Ruff's avatar
Ruff committed
1567
1568
1569
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1570
    #    bnb.matmul(A, B, threshold=6.0)
Ruff's avatar
Ruff committed
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
    # torch.cuda.synchronize()
    # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
    # C32A, SA = F.transform(CA, "col32")
    # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
    # CxB, SB = F.transform(CB, to_order=formatB)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1581
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Ruff's avatar
Ruff committed
1582
1583
1584
1585
1586
1587
1588
1589
    # torch.cuda.synchronize()
    # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    # BA, statsB = F.vectorwise_quant(B, dim=1)
    # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1590
1591
1592
1593
1594
1595
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1)
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
Ruff's avatar
Ruff committed
1596
1597
1598
1599
1600
1601
1602
1603
    # torch.cuda.synchronize()
    # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

    # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
    # CxB, SB = F.nvidia_transform(CB, to_order=formatB)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1604
1605
1606
1607
1608
1609
    #    A2 = A.view(-1, A.shape[-1]).contiguous()
    #    CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
    #    C32A, SA = F.nvidia_transform(CA, "col32")
    #    out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
    #    Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
    #    out = Cout * statsB * statsA * (1.0 / (127 * 127))
Ruff's avatar
Ruff committed
1610
1611
    # torch.cuda.synchronize()
    # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1612

Ruff's avatar
Ruff committed
1613
1614
1615
1616
    # linear8bit(A)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
1617
    #    linear8bit(A)
Ruff's avatar
Ruff committed
1618
1619
    # torch.cuda.synchronize()
    # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1620

Ruff's avatar
Ruff committed
1621
1622
1623
1624
    # linearMixedBit(A)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
1625
    #    linearMixedBit(A)
Ruff's avatar
Ruff committed
1626
1627
    # torch.cuda.synchronize()
    # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1628

Ruff's avatar
Ruff committed
1629
1630
1631
1632
    # linear8bit_train(A)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1633
    #    linear8bit_train(A)
Ruff's avatar
Ruff committed
1634
1635
    # torch.cuda.synchronize()
    # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
Tim Dettmers's avatar
Tim Dettmers committed
1636

Ruff's avatar
Ruff committed
1637
1638
1639
1640
    # linear8bit_train_thresh(A)
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
Tim Dettmers's avatar
Tim Dettmers committed
1641
    #    linear8bit_train(A)
Ruff's avatar
Ruff committed
1642
1643
1644
    # torch.cuda.synchronize()
    # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")

Tim Dettmers's avatar
Tim Dettmers committed
1645
1646
1647
1648
1649
1650

def test_zeropoint():
    def quant_zp(x):
        dtype = x.dtype
        x = x.float()
        dyna = x.max() - x.min()
1651
1652
1653
        if dyna == 0:
            dyna = 1
        qx = 254.0 / dyna
Tim Dettmers's avatar
Tim Dettmers committed
1654
        minx = x.min()
1655
1656
1657
1658
        # zpx = torch.round(minx* qx)
        # zpx = 127 - torch.round(x.max()* qx)
        zpx = torch.round(x.min() * qx) - 127
        x = (qx * x) + zpx
Tim Dettmers's avatar
Tim Dettmers committed
1659
        return x, qx, zpx
1660

Tim Dettmers's avatar
Tim Dettmers committed
1661
1662
1663
    batch = 2
    seq = 512
    model = 1024
1664
1665
1666
    hidden = 4 * model
    A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
    B = torch.randn(model, hidden, device="cuda").half() * 0.1
Tim Dettmers's avatar
Tim Dettmers committed
1667
1668
1669

    C0 = torch.matmul(A, B)

1670
1671
    # A, SA = F.vectorwise_quant(A, quant_type='linear')
    # B, SB = F.vectorwise_quant(B, quant_type='linear')
Tim Dettmers's avatar
Tim Dettmers committed
1672
1673
1674
1675
1676
1677
1678
    A = A.float()
    B = B.float()

    C1 = torch.matmul(A, B)
    C3 = bnb.matmul(A.half(), B.t().contiguous().half())

    zp = 1
1679
1680
1681
1682
    # C2 = torch.matmul(A-zp, B)
    # C2 += B.sum(0).view(1, -1)*zp
    C2 = torch.matmul(A, B - zp)
    C2 -= A.sum(1).view(-1, 1) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1683
1684

    ca, cqa, cza = quant_zp(A)
Ruff's avatar
Ruff committed
1685
1686
    # print(ca.min(), ca.max())
    # print((ca - cza).min(), (ca - cza).max())
Tim Dettmers's avatar
Tim Dettmers committed
1687
1688
1689

    zp = 1
    scale = 2.0
1690
1691
    C5 = torch.matmul((A * scale) - zp, B)
    C5 += B.sum(0) * zp
Tim Dettmers's avatar
Tim Dettmers committed
1692
1693
1694
1695
    C5 /= scale

    CA, qa, zpa = quant_zp(A)
    C4 = torch.matmul(CA, B)
1696
    C4 -= B.sum(0) * zpa
Tim Dettmers's avatar
Tim Dettmers committed
1697
    C4 /= qa
Tim Dettmers's avatar
Tim Dettmers committed
1698

Tim Dettmers's avatar
Tim Dettmers committed
1699
1700
1701
1702
    zpb = 1
    zpa = 1
    qa = 2
    qb = 2
1703
1704
1705
1706
    C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
    C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C6 -= zpa * zpb * A.shape[1]
    C6 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1707

Tim Dettmers's avatar
Tim Dettmers committed
1708
1709
1710
    CA, qa, zpa = quant_zp(A)
    CB, qb, zpb = quant_zp(B)
    C7 = torch.matmul(CA, CB)
1711
1712
1713
    C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
    C7 -= zpa * zpb * A.shape[1]
    C7 /= qa * qb
Tim Dettmers's avatar
Tim Dettmers committed
1714

Ruff's avatar
Ruff committed
1715
    # print("")
1716
    # print(C0.flatten()[:10])
Ruff's avatar
Ruff committed
1717
1718
1719
1720
1721
1722
    # print(C1.flatten()[:10])
    # print(C2.flatten()[:10])
    # print(C3.flatten()[:10])
    # print(C5.flatten()[:10])
    # print(C6.flatten()[:10])
    # print(C7.flatten()[:10])
1723
1724
1725
1726
1727
1728
    err1 = torch.abs(C1 - C2).mean().item()
    err2 = torch.abs(C1 - C3).mean().item()
    err3 = torch.abs(C1 - C4).mean().item()
    err4 = torch.abs(C1 - C5).mean().item()
    err5 = torch.abs(C1 - C6).mean().item()
    err6 = torch.abs(C1 - C7).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
1729
    print(err1, err2, err3, err4, err5, err6)
Tim Dettmers's avatar
Tim Dettmers committed
1730
1731


1732
def test_extract_outliers():
1733
    for i in range(k):
1734
        shapeA = (4096, 4096 * 4)
1735
        idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
1736
1737
        # idx = torch.Tensor([0]).int().cuda()
        A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
1738
        outliers1 = A[:, idx.long()]
1739

1740
        CA, SA = F.transform(A, "col_turing")
1741

1742
        outliers2 = F.extract_outliers(CA, SA, idx)
1743

1744
1745
        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1746

1747
        torch.testing.assert_close(outliers1, outliers2)
1748

1749
        CA, SA = F.transform(A, "col_ampere")
1750
1751
1752
1753
1754

        outliers2 = F.extract_outliers(CA, SA, idx)

        assert outliers2.shape[0] == shapeA[0]
        assert outliers2.shape[1] == idx.numel()
1755

1756
        torch.testing.assert_close(outliers1, outliers2)
1757
1758
1759
1760
1761
1762
1763


def test_blockwise_cpu_large():
    diffs = []
    reldiffs = []
    batch = 128
    seq = 128
Ruff's avatar
Ruff committed
1764
    for hidden in [128]:  # , 14336]:
1765
1766
        for blocksize in [4096, 16384]:
            for i in range(2):
Ruff's avatar
Ruff committed
1767
                A1 = torch.randn(batch, seq, hidden, device="cpu")
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
                t0 = time.time()
                C, S = F.quantize_blockwise(A1, blocksize=blocksize)
                A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
                print(time.time() - t0)
                diff = torch.abs(A1 - A2)
                reldiff = diff / torch.abs(A1 + 1e-8)
                diffs.append(diff.mean().item())
                reldiffs.append(reldiff.mean().item())
                assert diffs[-1] < 0.011
            # print(sum(diffs)/len(diffs))
            # print(sum(reldiffs)/len(reldiffs))
Tim Dettmers's avatar
Tim Dettmers committed
1779
1780
1781
1782


def test_fp8_quant():
    for e_bits in range(1, 7):
Ruff's avatar
Ruff committed
1783
        p_bits = 7 - e_bits
Tim Dettmers's avatar
Tim Dettmers committed
1784
1785
1786
1787
1788
1789
1790
1791
1792
        code = F.create_fp8_map(True, e_bits, p_bits).cuda()

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1793
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1794
1795
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1796
1797
1798
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1799
1800
1801
1802
1803
1804
1805
1806

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.rand(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1, code=code)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1807
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1808
1809
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1810
1811
1812
            # assert diff < 0.0075
        # print(sum(abserr)/len(abserr))
        # print(sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1813
1814
1815
1816
1817
1818
1819
1820

        abserr = []
        relerr = []
        for i in range(100):
            A1 = torch.randn(1024, 1024, device="cuda")
            C, SC = F.quantize_blockwise(A1)
            A2 = F.dequantize_blockwise(C, SC)
            diff = torch.abs(A1 - A2)
Ruff's avatar
Ruff committed
1821
            reldiff = diff / torch.abs(A1 + 1e-8)
Tim Dettmers's avatar
Tim Dettmers committed
1822
1823
            abserr.append(diff.mean().item())
            relerr.append(reldiff.mean().item())
Ruff's avatar
Ruff committed
1824
1825
1826
            # assert diff < 0.0075
        # print(3, sum(abserr)/len(abserr))
        # print(3, sum(relerr)/len(relerr))
Tim Dettmers's avatar
Tim Dettmers committed
1827

1828
1829

def test_few_bit_quant():
Ruff's avatar
Ruff committed
1830
    # print('')
1831
    for bits in range(2, 9):
Ruff's avatar
Ruff committed
1832
1833
        # print('='*30, bits, '='*30)
        for method in ["linear", "fp8", "dynamic", "quantile"]:
Tim Dettmers's avatar
Tim Dettmers committed
1834
1835
            abserrs = []
            relerrs = []
Tim Dettmers's avatar
Tim Dettmers committed
1836
            code = None
Ruff's avatar
Ruff committed
1837
            if method == "linear":
1838
                code = F.create_linear_map(True, total_bits=bits).cuda()
Ruff's avatar
Ruff committed
1839
1840
1841
            elif method == "fp8":
                ebits = math.ceil(bits / 2)
                pbits = bits - ebits - 1
Tim Dettmers's avatar
Tim Dettmers committed
1842
                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
Ruff's avatar
Ruff committed
1843
1844
1845
1846
            elif method == "dynamic":
                code = F.create_dynamic_map(True, bits - 0, bits).cuda()
            elif method == "quantile":
                values = torch.randn(2048, 2048, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1847
1848
1849
1850
                code = F.create_quantile_map(values, bits).cuda()
            # for some data types we have no zero
            # for some data types we have one zero
            # for some data types we have two zeros
Ruff's avatar
Ruff committed
1851
1852
            assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
            # print(method, (code==0).sum())
Tim Dettmers's avatar
Tim Dettmers committed
1853
1854
            assert code.numel() == 256
            for i in range(10):
Ruff's avatar
Ruff committed
1855
                values = torch.randn(1, 32, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1856
                values /= values.abs().max()
Ruff's avatar
Ruff committed
1857
                # values[values.abs() < 1e-6] += 1e-5
Tim Dettmers's avatar
Tim Dettmers committed
1858
1859
1860
1861

                q1 = []
                v1 = []
                for v in values[0]:
Ruff's avatar
Ruff committed
1862
                    idx = torch.abs(v - code).argmin()
Tim Dettmers's avatar
Tim Dettmers committed
1863
1864
1865
1866
1867
1868
                    q1.append(idx.item())
                    v1.append(code[idx].item())

                q1 = torch.Tensor(q1).cuda()
                v1 = torch.Tensor(v1).cuda()

Tim Dettmers's avatar
Tim Dettmers committed
1869
1870
                q2, S2 = F.quantize_blockwise(values, code=code)
                v2 = F.dequantize_blockwise(q2, S2)
Tim Dettmers's avatar
Tim Dettmers committed
1871
1872

                idx = torch.isclose(q1.int(), q2.int())
Ruff's avatar
Ruff committed
1873
                err2 = torch.abs(v2 - values)
Tim Dettmers's avatar
Tim Dettmers committed
1874
                abserrs.append(err2.mean().item())
Ruff's avatar
Ruff committed
1875
                relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
1876
1877
                if idx.sum():
                    # some weird cases
Ruff's avatar
Ruff committed
1878
1879
                    err1 = torch.abs(v1 - values).mean()
                    # assert err2.mean() <= err1
Tim Dettmers's avatar
Tim Dettmers committed
1880
1881

                else:
1882
                    torch.testing.assert_close(q1, q2)
Ruff's avatar
Ruff committed
1883
1884
            # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
    # assert False
Tim Dettmers's avatar
Tim Dettmers committed
1885
1886
1887
1888


def test_kbit_quantile_estimation():
    for i in range(100):
Ruff's avatar
Ruff committed
1889
        data = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1890
        for bits in range(2, 9):
Ruff's avatar
Ruff committed
1891
            p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
Tim Dettmers's avatar
Tim Dettmers committed
1892
1893
            val1 = torch.Tensor(norm.ppf(p)).cuda()
            val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
Ruff's avatar
Ruff committed
1894
            err = torch.abs(val1 - val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1895
1896
1897
            assert err < 0.038

    for i in range(100):
Ruff's avatar
Ruff committed
1898
        data = torch.randn(1024, 1024, device="cuda")
Tim Dettmers's avatar
Tim Dettmers committed
1899
        for bits in range(2, 4):
Ruff's avatar
Ruff committed
1900
1901
1902
            total_values = 2**bits - 1
            p = np.linspace(0, 1, 2 * total_values + 1)
            idx = np.arange(1, 2 * total_values + 1, 2)
Tim Dettmers's avatar
Tim Dettmers committed
1903
            p = p[idx]
Ruff's avatar
Ruff committed
1904
1905
            offset = 1 / (2 * total_values)
            p = np.linspace(offset, 1 - offset, total_values)
Tim Dettmers's avatar
Tim Dettmers committed
1906
            val1 = torch.Tensor(norm.ppf(p)).cuda()
Ruff's avatar
Ruff committed
1907
1908
            val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
            err = torch.abs(val1 - val2).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1909
            assert err < 0.035
1910
1911


Aarni Koskela's avatar
Aarni Koskela committed
1912
@pytest.mark.benchmark
1913
def test_bench_dequantization():
Ruff's avatar
Ruff committed
1914
1915
    a = torch.rand(1024, 1024, device="cuda").half()
    code = F.create_fp8_map(True, 3, 0, 4).cuda()
1916
1917
    qa, SA = F.quantize_blockwise(a, code=code)
    print(qa.max())
1918

Ruff's avatar
Ruff committed
1919
1920
    max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
    # print(max_theoretical_mu)
1921
1922
1923
1924

    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(100):
1925
        qa, SA = F.quantize_blockwise(a)
1926
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
1927
    # print((time.time()-t0)/1e6)
1928
1929


Aarni Koskela's avatar
Aarni Koskela committed
1930
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
1931
def test_fp4_quant(dtype):
1932
1933
1934
1935
1936
1937
1938
    vals = list(product([0, 1], repeat=4))

    code = {}
    for bits in vals:
        result = 0
        bias = 3
        sign, e1, e2, p1 = bits
Ruff's avatar
Ruff committed
1939
        idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
1940
        sign = -1.0 if sign else 1.0
Ruff's avatar
Ruff committed
1941
        exp = e1 * 2 + e2 * 1
1942
1943
        if exp == 0:
            # sub-normal
Ruff's avatar
Ruff committed
1944
1945
1946
1947
            if p1 == 0:
                result = 0
            else:
                result = sign * 0.0625
1948
1949
        else:
            # normal
Ruff's avatar
Ruff committed
1950
            exp = 2 ** (-exp + bias + 1)
1951
            frac = 1.5 if p1 else 1.0
Ruff's avatar
Ruff committed
1952
            result = sign * exp * frac
1953
1954
        code[idx] = result

Ruff's avatar
Ruff committed
1955
    A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
1956
1957
1958
1959
    qa, SA = F.quantize_fp4(A1, blocksize=64)
    A2 = F.dequantize_fp4(qa, SA)

    err = (A1 - A2).abs().float()
Ruff's avatar
Ruff committed
1960
    relerr = (err / (A1.abs().float() + 1e-8)).mean()
Tim Dettmers's avatar
Tim Dettmers committed
1961
    idx = err > 1.0
1962
1963
    err = err.mean()

1964
    assert A2.dtype == dtype
Tim Dettmers's avatar
Tim Dettmers committed
1965
1966
    assert err.item() < 0.1
    assert relerr.item() < 0.28
1967
1968


Ruff's avatar
Ruff committed
1969
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
Tim Dettmers's avatar
Tim Dettmers committed
1970
def test_4bit_compressed_stats(quant_type):
1971
1972
1973
    for blocksize in [128, 64]:
        errs1 = []
        errs2 = []
Tim Dettmers's avatar
Tim Dettmers committed
1974
        for i in range(10):
Ruff's avatar
Ruff committed
1975
            A1 = torch.randn(1024, 1024, device="cuda").half()
1976
            q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
Ruff's avatar
Ruff committed
1977
            q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
1978
1979
            A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
            A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
1980
1981

            err = (A1 - A2).abs().float()
Ruff's avatar
Ruff committed
1982
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
1983
1984
            err = err.mean()

1985
1986
            errs1.append(err.item())

1987
1988
1989
1990
            assert err.item() < 0.11
            assert relerr.item() < 0.28

            err = (A1 - A3).abs().float()
Ruff's avatar
Ruff committed
1991
            relerr = (err / (A1.abs().float() + 1e-15)).mean()
1992
1993
            err = err.mean()

1994
            errs2.append(err.item())
1995
1996
1997
1998

            assert err.item() < 0.11
            assert relerr.item() < 0.28

Ruff's avatar
Ruff committed
1999
2000
        # print(sum(errs1)/len(errs1), blocksize, quant_type)
        # print(sum(errs2)/len(errs2), blocksize, quant_type)
2001
2002


Ruff's avatar
Ruff committed
2003
2004
# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ["nf4"])
Aarni Koskela's avatar
Aarni Koskela committed
2005
@pytest.mark.benchmark
2006
def test_bench_4bit_dequant(quant_type):
2007
    blocksize = 256
Ruff's avatar
Ruff committed
2008
    a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
2009
    qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
2010

Ruff's avatar
Ruff committed
2011
2012
2013
2014
2015
2016
2017
    input_size = a.numel() / 2
    output_size = a.numel() * 2
    num_bytes = input_size + output_size
    GB = num_bytes / 1e9
    max_theoretical_s = GB / 768
    # print(max_theoretical_s*1e6)
    b = torch.randn(128, 1024 * 12, device="cuda").half()
2018

Tim Dettmers's avatar
Tim Dettmers committed
2019
    iters = 100
2020
2021
2022
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(iters):
2023
        F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
Ruff's avatar
Ruff committed
2024
        # b.copy_(a)
2025
    torch.cuda.synchronize()
Ruff's avatar
Ruff committed
2026
    # print((time.time()-t0)/iters*1e6)
2027

Ruff's avatar
Ruff committed
2028
2029
2030
    # torch.cuda.synchronize()
    # t0 = time.time()
    # for i in range(iters):
2031
    #    torch.matmul(b, a.t())
Ruff's avatar
Ruff committed
2032
2033
    # torch.cuda.synchronize()
    # print((time.time()-t0)/iters*1e6)
2034
2035
2036
2037


def test_normal_map_tree():
    code = F.create_normal_map()
Ruff's avatar
Ruff committed
2038
    values = code[:8].tolist() + code[-8:].tolist()
2039
    num_pivots = 1
Ruff's avatar
Ruff committed
2040
2041
2042
2043
    # print(values)
    while num_pivots < 16:
        idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots))
        # print(idx)
2044
2045
2046
        num_pivots *= 2
        pivots = []
        for i in idx:
Ruff's avatar
Ruff committed
2047
2048
            pivots.append((values[i - 1] + values[i]) / 2)
        # print(pivots)
2049

Tim Dettmers's avatar
Tim Dettmers committed
2050

Aarni Koskela's avatar
Aarni Koskela committed
2051
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
Ruff's avatar
Ruff committed
2052
2053
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
Aarni Koskela's avatar
Aarni Koskela committed
2054
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
Ruff's avatar
Ruff committed
2055
2056
2057
2058
2059
@pytest.mark.parametrize(
    "quant_storage",
    [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
    ids=describe_dtype,
)
2060
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
2061
    for dim in [128, 256, 512, 1024]:
Ruff's avatar
Ruff committed
2062
2063
        # for dim in [4*1024]:
        # for dim in [1*16]:
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
        errs1 = []
        errs2 = []
        errs3 = []
        relerrs1 = []
        relerrs2 = []
        relerrs3 = []
        max_errs1 = []
        max_errs2 = []
        max_errs3 = []

2074
        for i in range(100):
Ruff's avatar
Ruff committed
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
            if kind == "fc1":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "fc2":
                A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
            elif kind == "attn_packed":
                A = torch.randn(1, dim, dtype=dtype, device="cuda")
                B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)

            qB, state = F.quantize_4bit(
                B,
                quant_type=storage_type,
                compress_statistics=double_quant,
                quant_storage=quant_storage,
            )
2094
            C3 = torch.matmul(A, B.t())
2095
            C2 = F.gemv_4bit(A, qB.t(), state=state)
2096
2097
            A.requires_grad = True
            C1 = bnb.matmul_4bit(A, qB.t(), state)
Tim Dettmers's avatar
Tim Dettmers committed
2098

Ruff's avatar
Ruff committed
2099
2100
2101
            err1 = (C1 - C2).abs().float()
            err2 = (C3 - C2).abs().float()
            err3 = (C3 - C1).abs().float()
2102

Ruff's avatar
Ruff committed
2103
2104
2105
            mag1 = torch.abs(C1).float() + 1e-5
            mag2 = torch.abs(C3).float() + 1e-5
            mag3 = torch.abs(C3).float() + 1e-5
2106

Ruff's avatar
Ruff committed
2107
2108
2109
            relerr1 = err1 / mag1
            relerr2 = err2 / mag2
            relerr3 = err3 / mag3
2110

2111
2112
2113
            max_err1 = err1.max()
            max_err2 = err2.max()
            max_err3 = err3.max()
Tim Dettmers's avatar
Tim Dettmers committed
2114

2115
2116
2117
            errs1.append(err1.mean().item())
            errs2.append(err2.mean().item())
            errs3.append(err3.mean().item())
Tim Dettmers's avatar
Tim Dettmers committed
2118

2119
2120
2121
            relerrs1.append(relerr1.mean().item())
            relerrs2.append(relerr2.mean().item())
            relerrs3.append(relerr3.mean().item())
2122

2123
2124
2125
            max_errs1.append(max_err1.item())
            max_errs2.append(max_err2.item())
            max_errs3.append(max_err3.item())
2126

Ruff's avatar
Ruff committed
2127
            c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
Tim Dettmers's avatar
Tim Dettmers committed
2128

2129
            c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
Ruff's avatar
Ruff committed
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
        err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
        err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
        err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
        relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
        relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
        relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
        maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
        maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
        maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
        absratio = err2 / err3
        relratio = relerr2 / relerr3
        maxratio = relerr2 / relerr3
2142
2143
2144

        # for debugging if the tests fails
        #
Ruff's avatar
Ruff committed
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
        # print('='*80)
        # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
        # print(C1.flatten()[-20:])
        # print(C2.flatten()[-20:])
        # print(f'inference vs training abs: {err1}')
        # print(f'inference vs training rel: {relerr1}')
        # print(f'inference vs training max: {maxerr1}')
        # print(f'inference vs training vs torch err ratio abs: {absratio}')
        # print(f'inference vs training vs torch err ratio rel: {relratio}')
        # print(f'inference vs training vs torch err ratio max: {maxratio}')
2155
        if dtype == torch.float16:
2156
2157
2158
2159
2160
2161
2162
2163
2164
            if dim <= 512:
                assert err1 < 7e-5
                assert relerr1 < 0.0008
            else:
                assert err1 < 6e-5
                assert relerr1 < 2e-4
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2165
        elif dtype == torch.float32:
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
            if dim <= 512:
                assert err1 < 5e-8
                assert relerr1 < 1e-6
                assert maxerr1 < 1e-7
            else:
                assert err1 < 5e-8
                assert relerr1 < 8e-6
                assert maxerr1 < 1e-7
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.005 and relratio > 0.995
            assert maxratio < 1.005 and maxratio > 0.995
2177
        elif dtype == torch.bfloat16:
2178
            if dim <= 512:
Tim Dettmers's avatar
Tim Dettmers committed
2179
                assert err1 < 6e-4
2180
2181
2182
2183
2184
2185
2186
2187
2188
                assert relerr1 < 0.007
                assert maxerr1 < 0.015
            else:
                assert err1 < 2e-4
                assert relerr1 < 0.002
                assert maxerr1 < 0.0012
            assert absratio < 1.005 and absratio > 0.995
            assert relratio < 1.04 and relratio > 0.96
            assert maxratio < 1.02 and maxratio > 0.98
Tim Dettmers's avatar
Tim Dettmers committed
2189

Ruff's avatar
Ruff committed
2190

2191
@pytest.mark.skip("Row scale has some bugs for ampere")
Tim Dettmers's avatar
Tim Dettmers committed
2192
def test_managed():
Ruff's avatar
Ruff committed
2193
    n = 32 * 10
Tim Dettmers's avatar
Tim Dettmers committed
2194
2195
2196
2197
2198
    A = F.get_paged(n, n, dtype=torch.float32)
    B = F.get_paged(n, n, dtype=torch.uint8)
    B2 = F.get_paged(n, n, dtype=torch.float32)
    assert A.is_paged
    assert B.is_paged
Ruff's avatar
Ruff committed
2199
2200
    assert A.page_deviceid == 0
    assert B.page_deviceid == 0
Tim Dettmers's avatar
Tim Dettmers committed
2201
2202
2203
    F.fill(A, 17.0)
    F.fill(B, 17)
    F.fill(B2, 2)
Ruff's avatar
Ruff committed
2204
2205
2206
2207
    assert (A == 17).sum().item() == n * n
    assert (B == 17).sum().item() == n * n
    C = A * B.float()
    assert (C == 289).sum().item() == n * n
Tim Dettmers's avatar
Tim Dettmers committed
2208
2209
2210
    F._mul(A, B2)
    F._mul(A, B2)
    F._mul(A, B2)
Ruff's avatar
Ruff committed
2211
2212
2213
2214
2215
    assert (A == 17 * (2**3)).sum().item() == n * n


# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
Tim Dettmers's avatar
Tim Dettmers committed
2216
2217


Ruff's avatar
Ruff committed
2218
2219
# F.fill(B2, 17.0)
# F._mul(A, B2)
Tim Dettmers's avatar
Tim Dettmers committed
2220

Ruff's avatar
Ruff committed
2221
2222
2223
2224
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
Tim Dettmers's avatar
Tim Dettmers committed
2225

Ruff's avatar
Ruff committed
2226
# assert (A==17).sum().item() == n*n
Tim Dettmers's avatar
Tim Dettmers committed
2227

Ruff's avatar
Ruff committed
2228
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
2229
2230


Ruff's avatar
Ruff committed
2231
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
Aarni Koskela's avatar
Aarni Koskela committed
2232
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
Ruff's avatar
Ruff committed
2233
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
2234
2235
2236
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
    dims = 10
    torch.random.manual_seed(np.random.randint(0, 412424242))
Aarni Koskela's avatar
Aarni Koskela committed
2237
    dims = get_test_dims(0, 8192, n=dims)
Ruff's avatar
Ruff committed
2238
2239
    dims = [dim + (64 - (dim % 64)) for dim in dims]
    # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
2240
    for dim in dims:
Ruff's avatar
Ruff committed
2241
2242
        A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
        B = torch.eye(dim, dtype=dtype, device="cuda")
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252

        qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
        C3 = torch.matmul(A, B.t())
        C2 = bnb.matmul_4bit(A, qB.t(), state)
        A.requires_grad = True
        C1 = bnb.matmul_4bit(A, qB.t(), state)

        torch.testing.assert_close(A, C3)
        torch.testing.assert_close(A, C1)
        torch.testing.assert_close(A, C2)
Ruff's avatar
Ruff committed
2253
2254
        # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
        # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)