test_ops.py 9.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from math import prod

import pytest
import torch

import bitsandbytes
from tests.helpers import TRUE_FALSE, id_formatter


class TestLLMInt8Ops:
    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    def test_int8_linear_matmul(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

        torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    def test_int8_linear_matmul_out(self, device):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)

        out = torch.empty((10, 30), dtype=torch.int32, device=device)
        torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)

        assert out.shape == (10, 30)
        assert out.dtype == torch.int32
        assert out.device == A.device

        torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))

    @pytest.mark.parametrize("threshold", [0.0, 6.0])
    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    def test_int8_vectorwise_quant(self, threshold, device):
        if device == "cpu":
            pytest.skip("CPU implementation is not available")

        A = torch.randn(10, 20, dtype=torch.float16, device=device)
        A[1][0] = 1000.0

        out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)

        assert out_row.shape == (10, 20)
        assert out_row.dtype == torch.int8
        assert out_row.device == A.device
        assert row_stats.shape == (10,)
        assert row_stats.dtype == torch.float32
        assert row_stats.device == A.device

        if threshold > 0.0:
            assert outlier_cols is not None
            assert outlier_cols.dim() == 1
            assert outlier_cols.shape[0] <= A.shape[1]
            assert outlier_cols.device == A.device
        else:
            assert outlier_cols is None

        torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))

        torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    def test_int8_mm_dequant(self, device):
        A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
        row_stats = torch.randn(256, dtype=torch.float32, device=device)
        col_stats = torch.randn(256, dtype=torch.float32, device=device)
        out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)

        assert out.shape == A.shape
        assert out.dtype == torch.float16
        assert out.device == A.device

        torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("has_bias", TRUE_FALSE)
    def test_int8_scaled_mm(self, device, dtype, has_bias):
        A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
        B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
        row_stats = torch.randn(10, dtype=torch.float32, device=device)
        col_stats = torch.randn(30, dtype=torch.float32, device=device)
        bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
        out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)

        assert out.shape == (10, 30)
        assert out.dtype == dtype
        assert out.device == A.device

        torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))


class TestInt8BlockwiseQuantOps:
    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_quantize_blockwise(self, device, dtype, blocksize):
        if device == "cpu" and dtype != torch.float32:
            pytest.skip("CPU implementation is only available for float32")

        code = bitsandbytes.functional.create_dynamic_map().to(device)
        A = torch.randn(1024, 1024, dtype=dtype, device=device)
        out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)

        assert out.shape == A.shape
        assert out.dtype == torch.uint8
        assert out.device == A.device

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

        torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_dequantize_blockwise(self, device, dtype, blocksize):
        if device == "cpu" and dtype != torch.float32:
            pytest.skip("CPU implementation is only available for float32")

        A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
        code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)

        n = A.numel()
        blocks = -(n // -blocksize)
        absmax = torch.randn((blocks,), device=device, dtype=torch.float32)

        out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)

        assert out.shape == A.shape
        assert out.dtype == dtype
        assert out.device == A.device

        torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))


class Test4bitBlockwiseQuantOps:
    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        if device == "cpu" and quant_type != "nf4":
            pytest.skip("CPU implementation is only available for nf4")

        A = torch.randn(1024, 1024, dtype=dtype, device=device)

        out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)

        assert out.device == A.device
        assert out.dtype == storage_dtype

        assert absmax.device == A.device
        assert absmax.dtype == torch.float32

        torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        if device == "cpu":
            pytest.skip("CPU implementation is not available")

        shape = (128, 128)

        n = prod(shape)
        blocks = -(n // -blocksize)
        quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)

        A = (
            torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
            .view(storage_dtype)
            .reshape(quantized_shape)
            .contiguous()
        )

        absmax = torch.randn((blocks,), dtype=torch.float32, device=device)

        out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)

        assert out.device == A.device
        assert out.shape == shape

        torch.library.opcheck(
            torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
        )

    @pytest.mark.parametrize("device", ["cpu", "cuda"])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
    @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
    @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
    @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
    def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
        if device == "cpu":
            pytest.skip("CPU implementation is not available")

        out_features = 1024
        in_features = 256

        A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
        B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
        B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
        code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)

        out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)

        assert out.device == A.device
        assert out.dtype == dtype
        assert out.shape == (1, 1, out_features)
        assert out.isreal().all()

        torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))