test_triton.py 2.16 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import pytest
import torch

Mitchell Wortsman's avatar
Mitchell Wortsman committed
4
5
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
Tim Dettmers's avatar
Tim Dettmers committed
6
7


Mitchell Wortsman's avatar
Mitchell Wortsman committed
8
9
@pytest.mark.parametrize("vectorrize", [False, True])
def test_switchback(vectorrize):
Tim Dettmers's avatar
Tim Dettmers committed
10
11
12
13
    for dim in [83, 17, 128]:
        for batch in [13, 128, 256]:

            standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
Mitchell Wortsman's avatar
Mitchell Wortsman committed
14
15
16
            print('vectorrize', vectorrize)
            switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half()
            baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
17
18
            switchback.weight.data.copy_(standard.weight)
            switchback.bias.data.copy_(standard.bias)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
19
20
21
22
23
24
            baseline.weight.data.copy_(standard.weight)
            baseline.bias.data.copy_(standard.bias)

            x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
            x2 = x1.clone().detach().requires_grad_(True)
            x3 = x1.clone().detach().requires_grad_(True)
Tim Dettmers's avatar
Tim Dettmers committed
25

Mitchell Wortsman's avatar
Mitchell Wortsman committed
26
27
            out_standard = standard(x1)
            (2**10 * out_standard.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
28

Mitchell Wortsman's avatar
Mitchell Wortsman committed
29
30
            out_sb = switchback(x2)
            (2**10 * out_sb.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
31

Mitchell Wortsman's avatar
Mitchell Wortsman committed
32
33
            out_baseline = baseline(x3)
            (2**10 * out_baseline.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
34

Mitchell Wortsman's avatar
Mitchell Wortsman committed
35
36
37
38
            err_sb = (out_standard - out_sb).abs().mean()
            err_baseline = (out_standard - out_baseline).abs().mean()
            print('OUT', err_sb, err_baseline)
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
39

Mitchell Wortsman's avatar
Mitchell Wortsman committed
40
41
            err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
            err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
42

Mitchell Wortsman's avatar
Mitchell Wortsman committed
43
44
            print('GW2', err_sb,  err_baseline)
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
45

Mitchell Wortsman's avatar
Mitchell Wortsman committed
46
47
            err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
            err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
48

Mitchell Wortsman's avatar
Mitchell Wortsman committed
49
50
            print('GW1', err_sb,  err_baseline)
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
51

Mitchell Wortsman's avatar
Mitchell Wortsman committed
52
53
            err_sb = (x1.grad - x2.grad).abs().mean()
            err_baseline = (x1.grad - x3.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
54

Mitchell Wortsman's avatar
Mitchell Wortsman committed
55
56
            print('GX1', err_sb, err_baseline)
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
57