You need to sign in or sign up before continuing.
test_ops.py 10.7 KB
Newer Older
1
2
3
4
5
6
from math import prod

import pytest
import torch

import bitsandbytes
7
from bitsandbytes.cextension import HIP_ENVIRONMENT
8
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
9

10
11
12
13
14
15
16
# torch.library.opcheck is only available in torch 2.4 and later.
# When testing with older versions, we will skip it as a no-op.
if torch.__version__ >= (2, 4):
    opcheck = torch.library.opcheck
else:
    opcheck = lambda *args, **kwargs: None

17
18

class TestLLMInt8Ops:
19
    @pytest.mark.parametrize("device", get_available_devices())
20
21
22
23
24
25
26
27
28
    def test_int8_linear_matmul(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

29
        opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
30

31
    @pytest.mark.parametrize("device", get_available_devices())
32
33
34
35
36
37
38
39
40
41
42
    def test_int8_linear_matmul_out(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)

        out = torch.empty((10, 30), dtype=torch.int32, device=device)
        torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

43
        opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
44
45

    @pytest.mark.parametrize("threshold", [0.0, 6.0])
46
    @pytest.mark.parametrize("device", get_available_devices())
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    def test_int8_vectorwise_quant(self, threshold, device):
        A = torch.randn(10, 20, dtype=torch.float16, device=device)
        A[1][0] = 1000.0

        out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)

        assert out_row.shape == (10, 20)
        assert out_row.dtype == torch.int8
        assert out_row.device == A.device
        assert row_stats.shape == (10,)
        assert row_stats.dtype == torch.float32
        assert row_stats.device == A.device

        if threshold > 0.0:
            assert outlier_cols is not None
            assert outlier_cols.dim() == 1
            assert outlier_cols.shape[0] <= A.shape[1]
            assert outlier_cols.device == A.device
        else:
            assert outlier_cols is None

68
69
        opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
        opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
70

71
    @pytest.mark.parametrize("device", get_available_devices())
72
73
74
75
76
77
78
79
80
81
    def test_int8_mm_dequant(self, device):
        A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
        row_stats = torch.randn(256, dtype=torch.float32, device=device)
        col_stats = torch.randn(256, dtype=torch.float32, device=device)
        out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)

        assert out.shape == A.shape
        assert out.dtype == torch.float16
        assert out.device == A.device

82
        opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
83

84
    @pytest.mark.parametrize("device", get_available_devices())
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("has_bias", TRUE_FALSE)
    def test_int8_scaled_mm(self, device, dtype, has_bias):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        row_stats = torch.randn(10, dtype=torch.float32, device=device)
        col_stats = torch.randn(30, dtype=torch.float32, device=device)
        bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
        out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)

        assert out.shape == (10, 30)
        assert out.dtype == dtype
        assert out.device == A.device

99
        opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
100
101
102


class TestInt8BlockwiseQuantOps:
103
    @pytest.mark.parametrize("device", get_available_devices())
104
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
105
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
106
    def test_quantize_blockwise(self, device, dtype, blocksize):
107
108
109
110
111
112
        if device == "cpu":
            if dtype != torch.float32:
                pytest.skip("CPU implementation is only available for float32")

            if blocksize != 256:
                pytest.skip("CPU implementation is slow; only test blocksize=256")
113
114
115
116
117
118
119
120
121
122
123
124

        code = bitsandbytes.functional.create_dynamic_map().to(device)
        A = torch.randn(1024, 1024, dtype=dtype, device=device)
        out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)

        assert out.shape == A.shape
        assert out.dtype == torch.uint8
        assert out.device == A.device

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

125
        opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
126

127
    @pytest.mark.parametrize("device", get_available_devices())
128
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
129
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    def test_dequantize_blockwise(self, device, dtype, blocksize):
        if device == "cpu" and dtype != torch.float32:
            pytest.skip("CPU implementation is only available for float32")

        A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
        code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)

        n = A.numel()
        blocks = -(n // -blocksize)
        absmax = torch.randn((blocks,), device=device, dtype=torch.float32)

        out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)

        assert out.shape == A.shape
        assert out.dtype == dtype
        assert out.device == A.device

147
        opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
148
149
150


class Test4bitBlockwiseQuantOps:
151
    @pytest.mark.parametrize("device", get_available_devices())
152
153
154
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
155
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
156
    def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
157
158
159
        if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
            pytest.skip("This configuration is not supported on HPU.")

160
161
        A = torch.randn(1024, 1024, dtype=dtype, device=device)

162
        out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype)
163
164
165
166
167
168
169

        assert out.device == A.device
        assert out.dtype == storage_dtype

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

Matthew Douglas's avatar
Matthew Douglas committed
170
171
        if storage_dtype != torch.uint8:
            pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
172

173
        opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
174

175
    @pytest.mark.parametrize("device", get_available_devices())
176
177
178
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
179
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
180
    def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
181
182
183
        if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
            pytest.skip("This configuration is not supported on HPU.")

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        shape = (128, 128)

        n = prod(shape)
        blocks = -(n // -blocksize)
        quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)

        A = (
            torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
            .view(storage_dtype)
            .reshape(quantized_shape)
            .contiguous()
        )

        absmax = torch.randn((blocks,), dtype=torch.float32, device=device)

        out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)

        assert out.device == A.device
        assert out.shape == shape

204
205
206
        opcheck(
            torch.ops.bitsandbytes.dequantize_4bit.default,
            (A, absmax, blocksize, quant_type, shape, dtype),
207
208
        )

209
    @pytest.mark.parametrize("device", get_available_devices())
210
211
212
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
213
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
214
    @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
215
    def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
216
217
218
        if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
            pytest.skip("This configuration is not supported on HPU.")

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        out_features = 1024
        in_features = 256

        A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
        B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
        B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
        code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)

        out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)

        assert out.device == A.device
        assert out.dtype == dtype
        assert out.shape == (1, 1, out_features)
        assert out.isreal().all()

234
        opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))