test_modules.py 17.9 KB
Newer Older
1
2
from itertools import product

3
4
import pytest
import torch
Tim Dettmers's avatar
Tim Dettmers committed
5
6
from torch import nn

7
8
import bitsandbytes as bnb

9

Tim Dettmers's avatar
Tim Dettmers committed
10
11
12
13
14
class MockArgs(object):
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

15

Tim Dettmers's avatar
Tim Dettmers committed
16
17
18
class MLP8bit(torch.nn.Module):
    def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
        super(MLP8bit, self).__init__()
19
20
21
22
23
24
        self.fc1 = bnb.nn.Linear8bitLt(
            dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
        )
        self.fc2 = bnb.nn.Linear8bitLt(
            dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
        )
Tim Dettmers's avatar
Tim Dettmers committed
25
26
27
28
29
30
31
32
33

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
34
35
    args.quant_type = "vector"
    args.use_8bit_training = "full"
Tim Dettmers's avatar
Tim Dettmers committed
36
37
38
    args.clip_freq = 9999
    return args

39

Tim Dettmers's avatar
Tim Dettmers committed
40
41
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
42
    sumval = (idx == 0).sum().item()
Tim Dettmers's avatar
Tim Dettmers committed
43
    if sumval > count:
44
        print(f"Too many values not close: assert {sumval} < {count}")
Tim Dettmers's avatar
Tim Dettmers committed
45
46
47
        torch.testing.assert_allclose(a, b, rtol, atol)


48
class LinearFunction(torch.autograd.Function):
Tim Dettmers's avatar
Tim Dettmers committed
49
50
51
    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
52
53
        norm = math.sqrt(math.pi) / math.sqrt(2.0)
        # std = torch.abs(x).mean()*norm
Tim Dettmers's avatar
Tim Dettmers committed
54
        std = torch.std(x)
55
56
        max1 = std * trim_value
        x = x / max1 * 127
Tim Dettmers's avatar
Tim Dettmers committed
57
58
59
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
60
        x = x / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
61
62
63
64

        return x

    def quant(x, quant_type, dim=1):
65
        if quant_type == "linear":
Tim Dettmers's avatar
Tim Dettmers committed
66
            max1 = torch.abs(x).max().float()
67
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
68
            return xq, max1
69
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
70
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
71
            xq = torch.round(x / max1 * 127).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
72
            return xq, max1
73
        elif quant_type == "min-max":
Tim Dettmers's avatar
Tim Dettmers committed
74
75
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
76
77
            scale = (maxA - minA) / 2.0
            xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
Tim Dettmers's avatar
Tim Dettmers committed
78
            return xq, (minA.float(), scale.float())
79
80
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
81
82

    def dequant(xq, S1, S2, dtype, quant_type):
83
84
        if quant_type == "linear":
            norm = S1 * S2 / (127 * 127)
Tim Dettmers's avatar
Tim Dettmers committed
85
            # double cast needed to prevent overflows
86
87
            return (xq.float() * norm).to(dtype)
        elif quant_type == "vector":
Tim Dettmers's avatar
Tim Dettmers committed
88
            x = xq.float()
89
90
91
92
93
            if len(xq.shape) == 2 and len(S1.shape) == 3:
                S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3:
                S2 = S2.squeeze(0)
            # print(x.shape, S1.shape, S2.shape)
Tim Dettmers's avatar
Tim Dettmers committed
94
            if len(S1.shape) == 2:
95
                x *= S1.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
96
            else:
97
98
                x *= S1 / 127
            x *= S2 / 127
Tim Dettmers's avatar
Tim Dettmers committed
99
            return x.to(dtype)
100
101
        else:
            return None
Tim Dettmers's avatar
Tim Dettmers committed
102
103

    def dequant_min_max(xq, A, B, SA, SB, dtype):
104
        offset = B.float().t().sum(0) * (SA[0] + SA[1])
Tim Dettmers's avatar
Tim Dettmers committed
105
        x = xq.float()
106
107
108
109
        if len(xq.shape) == 2 and len(SB.shape) == 3:
            SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3:
            SA = SA.squeeze(0)
Tim Dettmers's avatar
Tim Dettmers committed
110
        if len(SB.shape) == 2:
111
            x *= SB.t() / 127
Tim Dettmers's avatar
Tim Dettmers committed
112
        else:
113
114
115
            x *= SB / 127
        x *= SA[1] / 127
        x += offset
Tim Dettmers's avatar
Tim Dettmers committed
116
117
118
119
120
        return x.to(dtype)

    def get_8bit_linear(x, stochastic=False):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
        max1 = torch.abs(x).max()
121
122
123
        x = x / max1 * 127
        x = round_func(x) / 127 * max1
        # x = torch.round(x)/128*max1
Tim Dettmers's avatar
Tim Dettmers committed
124
125
126
127
128
129
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
130
131
132
        max1[max1 == 0] = 1.0
        x = (x * 127) / max1
        x = round_func(x) / 127 * max1
Tim Dettmers's avatar
Tim Dettmers committed
133
134
135
136
137
138
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
139
        decimal = absx - torch.floor(absx)
Tim Dettmers's avatar
Tim Dettmers committed
140
        rdm = torch.rand_like(decimal)
141
        return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
Tim Dettmers's avatar
Tim Dettmers committed
142
143
144
145
146
147
148
149
150
151
152
153
154

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
155
156
157
158
        # C = bnb.functional.quantize_no_absmax(code, w)
        # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        # print(out)
        # out = out.half()
Tim Dettmers's avatar
Tim Dettmers committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
177
        blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
Tim Dettmers's avatar
Tim Dettmers committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w

    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
208
        if args.use_8bit_training != "off":
Tim Dettmers's avatar
Tim Dettmers committed
209
210
211
212
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
213
214
215
216
217
            # if torch.rand(1) < 0.01:
            # output32 = torch.matmul(x, weight.t())
            # err = torch.abs(output-output32).float()
            # relerr = err/(torch.abs(output32).float()+1e-8)
            # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
Tim Dettmers's avatar
Tim Dettmers committed
218
        else:
219
220
            # output = torch.matmul(x, weight.t())
            output = torch.einsum("bsi,oi->bso", x, weight)
Tim Dettmers's avatar
Tim Dettmers committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
235
236
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
Tim Dettmers's avatar
Tim Dettmers committed
237
238
239

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
240
241
242
243
        if args.use_8bit_training == "forward+wgrad":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
244
245
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
246
247
248
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
249

250
            # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
Tim Dettmers's avatar
Tim Dettmers committed
251
252

            grad_input = grad_output.matmul(weight)
253
254
255
256
        elif args.use_8bit_training == "full":
            grad_output8, S1 = LinearFunction.quant(
                grad_output, args.quant_type, dim=[0, 1]
            )
Tim Dettmers's avatar
Tim Dettmers committed
257
258
259
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
260
261
262
            grad_weight = LinearFunction.dequant(
                grad_weight8, S1, S2, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
263
264
265
266

            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
267
268
269
            grad_input = LinearFunction.dequant(
                grad_input8, S1, S3, grad_output.dtype, args.quant_type
            )
Tim Dettmers's avatar
Tim Dettmers committed
270
271
272

        else:
            grad_input = grad_output.matmul(weight)
273
            grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
274

Tim Dettmers's avatar
Tim Dettmers committed
275
        return grad_input, grad_weight, grad_bias, None
276

277

Tim Dettmers's avatar
Tim Dettmers committed
278
279
280
281
282
283
class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
        super(Linear8bit, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.args = args
284

Tim Dettmers's avatar
Tim Dettmers committed
285
286
287
288
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
289
            self.register_parameter("bias", None)
290

Tim Dettmers's avatar
Tim Dettmers committed
291
292
293
294
295
296
297
298
299
300
301
302
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)


def test_linear8bit():
    l0 = torch.nn.Linear(32, 64).cuda().half()
303
    l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
304
    l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
305
    l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
306
307
308
309
310
311
312
313
314
315
316

    l0.weight.data = l2.weight.data.clone()
    l0.bias.data = l2.bias.data.clone()

    l1.weight.data = l2.weight.data.clone()
    l1.bias.data = l2.bias.data.clone()

    l3.weight.data = l2.weight.data.clone()
    l3.bias.data = l2.bias.data.clone()

    for i in range(100):
317
318
        b1 = torch.randn(16, 8, 32, device="cuda").half()
        t = torch.randn(16, 8, 64, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        b2 = b1.clone()
        b3 = b1.clone()
        b0 = b1.clone()

        o0 = l0(b0)
        o1 = l1(b1)
        o2 = l2(b2)
        o3 = l3(b3)

        assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
        assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)

        loss0 = torch.nn.functional.mse_loss(o0, t)
        loss1 = torch.nn.functional.mse_loss(o1, t)
        loss2 = torch.nn.functional.mse_loss(o2, t)
        loss3 = torch.nn.functional.mse_loss(o3, t)

        loss0.backward()
        loss1.backward()
        loss2.backward()
        loss3.backward()

        assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
        assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
343
344
345
346
347
348
        assert_all_approx_close(
            l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
        )
        assert_all_approx_close(
            l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
        )
Tim Dettmers's avatar
Tim Dettmers committed
349

350
351
352
        err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
        err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
        err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
Tim Dettmers's avatar
Tim Dettmers committed
353

354
355
356
        assert err1 * 0.8 < err2
        assert err2 * 0.8 < err3
        assert err3 * 0.8 < err1
Tim Dettmers's avatar
Tim Dettmers committed
357
358
359
360
361
362
363
364
365
366
367
368
369

        l0.weight.grad = None
        l1.weight.grad = None
        l2.weight.grad = None
        l3.weight.grad = None
        l0.bias.grad = None
        l1.bias.grad = None
        l2.bias.grad = None
        l3.bias.grad = None


threshold = [0.0, 3.0]
values = threshold
370
371
372
names = ["threshold_{0}".format(vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
373
374
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
375
376
    l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
377
378
379
    assert l1.weight.dtype == torch.float16

    l1.eval()
380
    for i in range(100):
381
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
382
383
384
385
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

386

Tim Dettmers's avatar
Tim Dettmers committed
387
def test_linear8bitlt_accumulated_gradient():
388
389
390
391
    l1 = torch.nn.Sequential(
        *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
    )
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
Tim Dettmers's avatar
Tim Dettmers committed
392
393
394
395
396
397
398
399
400
401
    l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
    l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
    l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
    l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
    opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)

    acc_steps = 10

    for i in range(10):
402
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
403
404
405
406
407
408
409
410
411
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None
412

Tim Dettmers's avatar
Tim Dettmers committed
413
414
415
416
417
        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
418
419
420
421
422
423
            assert_all_approx_close(
                l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
            )
            assert_all_approx_close(
                l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
            )
Tim Dettmers's avatar
Tim Dettmers committed
424
425
426
427
428
429
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
        else:
            torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
            torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
430
431


Tim Dettmers's avatar
Tim Dettmers committed
432
433
threshold = [0.0, 2.0]
values = threshold
434
435
436
names = ["threshold_{0}".format(vals) for vals in values]


Tim Dettmers's avatar
Tim Dettmers committed
437
438
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
439
440
441
442
443
    l1 = (
        bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
        .cuda()
        .half()
    )
Tim Dettmers's avatar
Tim Dettmers committed
444
    assert l1.weight.dtype == torch.int8
445

Tim Dettmers's avatar
Tim Dettmers committed
446
447
    l1.eval()
    for i in range(100):
448
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
449
450
451
452
453
454
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
455

Tim Dettmers's avatar
Tim Dettmers committed
456
    for i in range(100):
457
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
458
459
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
460
461
462
463
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
464

Tim Dettmers's avatar
Tim Dettmers committed
465
466
467
468
469
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
470
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
471
472
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
473
474
475
476
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
477

Tim Dettmers's avatar
Tim Dettmers committed
478
479
480
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()

    for i in range(100):
481
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
482
483
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
484
485
486
487
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
488
489
490
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

491
    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
Tim Dettmers's avatar
Tim Dettmers committed
492
493

    for i in range(100):
494
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
495
496
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
497
498
499
500
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
501
502
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
503
504
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"
Tim Dettmers's avatar
Tim Dettmers committed
505

506
507
508
509
510
    mlp = (
        MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
        .to(torch.float16)
        .to("cuda")
    )
Tim Dettmers's avatar
Tim Dettmers committed
511
512

    for i in range(100):
513
        b1 = torch.randn(16, 8, 32, device="cuda").half()
Tim Dettmers's avatar
Tim Dettmers committed
514
515
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
516
517
518
519
        if threshold > 0:
            assert mlp.fc1.state.idx is not None
        if threshold > 0:
            assert mlp.fc2.state.idx is not None
Tim Dettmers's avatar
Tim Dettmers committed
520
521
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
522
523
    assert mlp.fc1.weight.device.type == "cuda"
    assert mlp.fc2.weight.device.type == "cuda"