test_triton.py 2.52 KB
Newer Older
Tim Dettmers's avatar
Tim Dettmers committed
1
2
3
import pytest
import torch

Mitchell Wortsman's avatar
Mitchell Wortsman committed
4
from bitsandbytes.nn import Linear8bitLt
Aarni Koskela's avatar
Aarni Koskela committed
5
6
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.triton.triton_utils import is_triton_available
Aarni Koskela's avatar
Aarni Koskela committed
7
from tests.helpers import TRUE_FALSE
Aarni Koskela's avatar
Aarni Koskela committed
8

Tim Dettmers's avatar
Tim Dettmers committed
9

Ruff's avatar
Ruff committed
10
11
12
13
@pytest.mark.skipif(
    not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
    reason="This test requires triton and a GPU with compute capability 8.0 or higher.",
)
Aarni Koskela's avatar
Aarni Koskela committed
14
@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE)
15
16
17
def test_switchback(vector_wise_quantization):
    for dim in [83]:
        for batch in [13]:
Tim Dettmers's avatar
Tim Dettmers committed
18
            standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
Ruff's avatar
Ruff committed
19
20
21
            switchback = (
                SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
            )
Mitchell Wortsman's avatar
Mitchell Wortsman committed
22
            baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
Tim Dettmers's avatar
Tim Dettmers committed
23
24
            switchback.weight.data.copy_(standard.weight)
            switchback.bias.data.copy_(standard.bias)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
25
26
27
28
29
30
            baseline.weight.data.copy_(standard.weight)
            baseline.bias.data.copy_(standard.bias)

            x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
            x2 = x1.clone().detach().requires_grad_(True)
            x3 = x1.clone().detach().requires_grad_(True)
Tim Dettmers's avatar
Tim Dettmers committed
31

Mitchell Wortsman's avatar
Mitchell Wortsman committed
32
33
            out_standard = standard(x1)
            (2**10 * out_standard.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
34

35
            print(x2.dtype)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
36
37
            out_sb = switchback(x2)
            (2**10 * out_sb.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
38

Mitchell Wortsman's avatar
Mitchell Wortsman committed
39
40
            out_baseline = baseline(x3)
            (2**10 * out_baseline.abs().mean()).backward()
Tim Dettmers's avatar
Tim Dettmers committed
41

Mitchell Wortsman's avatar
Mitchell Wortsman committed
42
43
            err_sb = (out_standard - out_sb).abs().mean()
            err_baseline = (out_standard - out_baseline).abs().mean()
Ruff's avatar
Ruff committed
44
            print("OUT", err_sb, err_baseline)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
45
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
46

Mitchell Wortsman's avatar
Mitchell Wortsman committed
47
48
            err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
            err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
49

Ruff's avatar
Ruff committed
50
            print("GW2", err_sb, err_baseline)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
51
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
52

Mitchell Wortsman's avatar
Mitchell Wortsman committed
53
54
            err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
            err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
55

Ruff's avatar
Ruff committed
56
            print("GW1", err_sb, err_baseline)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
57
            assert err_sb < 2 * err_baseline
Tim Dettmers's avatar
Tim Dettmers committed
58

Mitchell Wortsman's avatar
Mitchell Wortsman committed
59
60
            err_sb = (x1.grad - x2.grad).abs().mean()
            err_baseline = (x1.grad - x3.grad).abs().mean()
Tim Dettmers's avatar
Tim Dettmers committed
61

Ruff's avatar
Ruff committed
62
            print("GX1", err_sb, err_baseline)
Mitchell Wortsman's avatar
Mitchell Wortsman committed
63
            assert err_sb < 2 * err_baseline