test_ops.py 10 KB
Newer Older
1
2
3
4
5
6
from math import prod

import pytest
import torch

import bitsandbytes
7
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
8

9
10
11
12
13
14
15
# torch.library.opcheck is only available in torch 2.4 and later.
# When testing with older versions, we will skip it as a no-op.
if torch.__version__ >= (2, 4):
    opcheck = torch.library.opcheck
else:
    opcheck = lambda *args, **kwargs: None

16
17

class TestLLMInt8Ops:
18
    @pytest.mark.parametrize("device", get_available_devices())
19
20
21
22
23
24
25
26
27
    def test_int8_linear_matmul(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

28
        opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
29

30
    @pytest.mark.parametrize("device", get_available_devices())
31
32
33
34
35
36
37
38
39
40
41
    def test_int8_linear_matmul_out(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)

        out = torch.empty((10, 30), dtype=torch.int32, device=device)
        torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

42
        opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
43
44

    @pytest.mark.parametrize("threshold", [0.0, 6.0])
45
    @pytest.mark.parametrize("device", get_available_devices())
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    def test_int8_vectorwise_quant(self, threshold, device):
        A = torch.randn(10, 20, dtype=torch.float16, device=device)
        A[1][0] = 1000.0

        out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)

        assert out_row.shape == (10, 20)
        assert out_row.dtype == torch.int8
        assert out_row.device == A.device
        assert row_stats.shape == (10,)
        assert row_stats.dtype == torch.float32
        assert row_stats.device == A.device

        if threshold > 0.0:
            assert outlier_cols is not None
            assert outlier_cols.dim() == 1
            assert outlier_cols.shape[0] <= A.shape[1]
            assert outlier_cols.device == A.device
        else:
            assert outlier_cols is None

67
68
        opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
        opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
69

70
    @pytest.mark.parametrize("device", get_available_devices())
71
72
73
74
75
76
77
78
79
80
    def test_int8_mm_dequant(self, device):
        A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
        row_stats = torch.randn(256, dtype=torch.float32, device=device)
        col_stats = torch.randn(256, dtype=torch.float32, device=device)
        out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)

        assert out.shape == A.shape
        assert out.dtype == torch.float16
        assert out.device == A.device

81
        opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
82

83
    @pytest.mark.parametrize("device", get_available_devices())
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("has_bias", TRUE_FALSE)
    def test_int8_scaled_mm(self, device, dtype, has_bias):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        row_stats = torch.randn(10, dtype=torch.float32, device=device)
        col_stats = torch.randn(30, dtype=torch.float32, device=device)
        bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
        out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)

        assert out.shape == (10, 30)
        assert out.dtype == dtype
        assert out.device == A.device

98
        opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
99
100
101


class TestInt8BlockwiseQuantOps:
102
    @pytest.mark.parametrize("device", get_available_devices())
103
104
105
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_quantize_blockwise(self, device, dtype, blocksize):
106
107
108
109
110
111
        if device == "cpu":
            if dtype != torch.float32:
                pytest.skip("CPU implementation is only available for float32")

            if blocksize != 256:
                pytest.skip("CPU implementation is slow; only test blocksize=256")
112
113
114
115
116
117
118
119
120
121
122
123

        code = bitsandbytes.functional.create_dynamic_map().to(device)
        A = torch.randn(1024, 1024, dtype=dtype, device=device)
        out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)

        assert out.shape == A.shape
        assert out.dtype == torch.uint8
        assert out.device == A.device

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

124
        opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
125

126
    @pytest.mark.parametrize("device", get_available_devices())
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_dequantize_blockwise(self, device, dtype, blocksize):
        if device == "cpu" and dtype != torch.float32:
            pytest.skip("CPU implementation is only available for float32")

        A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
        code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)

        n = A.numel()
        blocks = -(n // -blocksize)
        absmax = torch.randn((blocks,), device=device, dtype=torch.float32)

        out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)

        assert out.shape == A.shape
        assert out.dtype == dtype
        assert out.device == A.device

146
147
148
149
        # TODO: Enable it
        if device == "xpu":
            pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")

150
        opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
151
152
153


class Test4bitBlockwiseQuantOps:
154
    @pytest.mark.parametrize("device", get_available_devices())
155
156
157
158
159
160
161
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        A = torch.randn(1024, 1024, dtype=dtype, device=device)

162
        out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype)
163
164
165
166
167
168
169

        assert out.device == A.device
        assert out.dtype == storage_dtype

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

170
171
172
173
        # TODO: Enable it
        if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16:
            pytest.skip("CPU bf16 storage_dtype will fail on torch op check")

174
        opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
175

176
    @pytest.mark.parametrize("device", get_available_devices())
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        shape = (128, 128)

        n = prod(shape)
        blocks = -(n // -blocksize)
        quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)

        A = (
            torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
            .view(storage_dtype)
            .reshape(quantized_shape)
            .contiguous()
        )

        absmax = torch.randn((blocks,), dtype=torch.float32, device=device)

        out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)

        assert out.device == A.device
        assert out.shape == shape

202
203
204
        opcheck(
            torch.ops.bitsandbytes.dequantize_4bit.default,
            (A, absmax, blocksize, quant_type, shape, dtype),
205
206
        )

207
    @pytest.mark.parametrize("device", get_available_devices())
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        out_features = 1024
        in_features = 256

        A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
        B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
        B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
        code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)

        out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)

        assert out.device == A.device
        assert out.dtype == dtype
        assert out.shape == (1, 1, out_features)
        assert out.isreal().all()

228
        opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))