aqlm.py 13.2 KB
Newer Older
James Fleming's avatar
James Fleming committed
1
2
3
4
5
6
7
8
9
10
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf

import math
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

11
from vllm import _custom_ops as ops
12
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
James Fleming's avatar
James Fleming committed
13
14
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
15
from vllm.model_executor.utils import set_weight_attrs
James Fleming's avatar
James Fleming committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def get_int_dtype(nbits: int) -> torch.dtype:
    if nbits <= 8:
        return torch.int8
    if nbits <= 16:
        return torch.int16
    if nbits <= 32:
        return torch.int32
    if nbits <= 64:
        return torch.int64
    raise ValueError(f"No dtype available for {nbits}-bit codebooks")


@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
    return data.to(torch.int64) % (2**nbits)


def dequantize_weight(codes: torch.Tensor,
                      codebooks: torch.Tensor,
                      scales: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Decode float weights from quantization codes. Differentiable.
    :param codes: tensor of integer quantization codes, shape 
        [*dims, num_out_groups, num_in_groups, num_codebooks]
    :param codebooks: tensor of vectors for each quantization code, 
        [num_codebooks, codebook_size, out_group_size, in_group_size]
    :param scales: weight will be multiplied by this factor, must be 
        broadcastble with 
        [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
    :return: reconstructed weight tensor of shape 
        [*dims, num_in_groups*group_size]
    """
    num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
    num_codebooks, codebook_size, out_group_size, in_group_size = \
        codebooks.shape
    out_features = num_out_groups * out_group_size
    in_features = num_in_groups * in_group_size
    codebook_offsets = torch.arange(
        0, num_codebooks * codebook_size, codebook_size,
        device=codes.device)  # shape: [num_codebooks]
    reconstructed_weight_flat = F.embedding_bag(
        codes.flatten(0, -2) + codebook_offsets,
        codebooks.flatten(0, 1).flatten(-2, -1),
        mode="sum"
    )  # [prod(dims) * num_out_groups * num_in_groups, out_group_size
    # * in_group_size]

    reconstructed_weight_groupwise = reconstructed_weight_flat.view(
        list(codes.shape[:-3]) +
        [num_out_groups, num_in_groups, out_group_size, in_group_size])
    if scales is not None:
        reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
            scales)
    return reconstructed_weight_groupwise.swapaxes(
        -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])


def dequantize_gemm(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
    bias: Optional[torch.Tensor],
) -> torch.Tensor:
    dequantized_weight = dequantize_weight(
        unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
        codebooks,
        scales,
    )
    return F.linear(input, dequantized_weight, bias)


# Generic dequantization, slow but flexible.
def generic_dequantize_gemm(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
98
    output_partition_sizes: List[int],
James Fleming's avatar
James Fleming committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    bias: Optional[torch.Tensor],
) -> torch.Tensor:
    output_shape = input.shape[:-1] + (scales.shape[0], )
    output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
    num_outputs = len(output_partition_sizes)

    # break the inputs and codebooks apart then combine the outputs.
    # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
    # multiply at the end.
    num_codebooks = codebooks.shape[0] // num_outputs
    assert (scales.shape[0] == codes.shape[0])
    assert (sum(output_partition_sizes) == scales.shape[0])
    output_offset = 0
    codebooks_offset = 0
    for output_size in output_partition_sizes:
        shard_output = dequantize_gemm(
            input, codes.narrow(0, output_offset, output_size),
            codebooks.narrow(0, codebooks_offset, num_codebooks),
            scales.narrow(0, output_offset, output_size), None
            if bias is None else bias.narrow(0, output_offset, output_size))

        output_slice = output.narrow(-1, output_offset, output_size)
        assert (output_slice.shape == shard_output.shape)
        output_slice.copy_(shard_output)
        output_offset += output_size
        codebooks_offset += num_codebooks
    return output


# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def optimized_dequantize_gemm(
    input: torch.Tensor,  #  [..., in_features]
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
    codebooks: torch.
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
136
    output_partition_sizes: List[int],
James Fleming's avatar
James Fleming committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    bias: Optional[torch.Tensor],
) -> torch.Tensor:
    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

    if bias is None:
        # scaling the output is fastest, so we do that when possible.
        output = F.linear(input, weights, bias)
        orig_shape = output.shape
        flattened_output = output.view(-1, output.size(-1))
        f_scales = scales.view(-1, scales.shape[0])
        b_scales = f_scales.expand(flattened_output.shape[0], -1)
        flattened_output *= b_scales
        return output.view(orig_shape)
    else:
        b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
            -1, weights.shape[1])
        weights *= b_scales
        return F.linear(input, weights, bias)


class AQLMConfig(QuantizationConfig):
    """Config class for AQLM.

    Reference: https://github.com/Vahe1994/AQLM
    """

    def __init__(
        self,
        in_group_size: int,
        nbits_per_codebook: int,
        num_codebooks: int,
        out_group_size: int,
    ) -> None:
        self.in_group_size = in_group_size
        self.nbits_per_codebook = nbits_per_codebook
        self.num_codebooks = num_codebooks
        self.out_group_size = out_group_size

        # out_group_size > 1 is untested, and probably won't work as-is.
        assert (self.out_group_size == 1)
        self.pack_factor = (self.in_group_size * self.out_group_size)

    def __repr__(self) -> str:
        return (f"AQLMConfig(in_group_size={self.in_group_size}, "
                f"nbits_per_codebook={self.nbits_per_codebook}, "
                f"num_codebooks={self.num_codebooks}, "
                f"out_group_size={self.out_group_size})")

    @classmethod
    def get_name(cls) -> str:
        return "aqlm"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
195
        return 60
James Fleming's avatar
James Fleming committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []  # no extra configs.

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
        in_group_size = cls.get_from_keys(config, ["in_group_size"])
        nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
        num_code_books = cls.get_from_keys(config, ["num_codebooks"])
        out_group_size = cls.get_from_keys(config, ["out_group_size"])
        return cls(in_group_size, nbits_per_codebook, num_code_books,
                   out_group_size)

210
211
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["AQLMLinearMethod"]:
212
213
214
        if isinstance(layer, LinearBase):
            return AQLMLinearMethod(self)
        return None
James Fleming's avatar
James Fleming committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287


class AQLMLinearMethod(LinearMethodBase):
    """Linear method for AQLM.

    Args:
        quant_config: The AQLM quantization config.
    """

    def __init__(self, quant_config: AQLMConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        del output_size  # Unused.
        del input_size  # Unused.

        if params_dtype != torch.half:
            raise ValueError("Only half is currently supported by aqlm")
        if input_size_per_partition % self.quant_config.in_group_size != 0:
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

        output_size_per_partition = sum(output_partition_sizes)
        if output_size_per_partition % self.quant_config.out_group_size != 0:
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")

        codes = Parameter(
            torch.empty(
                # There could actually be two pack factors, one along input and
                # one along output, but we don't currently support
                # out_group_size, and only the one along output needs to be
                # marked with "packed_dim" in order for QKVLinear to work.
                output_size_per_partition,
                input_size_per_partition // self.quant_config.pack_factor,
                self.quant_config.num_codebooks,
                dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
            ),
            requires_grad=False,
        )

        set_weight_attrs(
            codes,
            {
                "input_dim": 1,
                "output_dim": 0,
                "packed_dim": 1,
                "pack_factor": self.quant_config.pack_factor,
            },
        )

        codebooks = Parameter(
            torch.empty(
                self.quant_config.num_codebooks * len(output_partition_sizes),
                2**self.quant_config.nbits_per_codebook,
                self.quant_config.out_group_size,
                self.quant_config.in_group_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            codebooks,
            {
                # metadata indicates fixed size concatenated along dim 0
288
289
                "is_metadata": True,
                "output_partition_sizes": output_partition_sizes
James Fleming's avatar
James Fleming committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            },
        )

        scales = Parameter(
            torch.empty(
                (
                    output_size_per_partition //
                    self.quant_config.out_group_size,
                    1,
                    1,
                    1,
                ),
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            scales,
            {
                "output_dim": 0,
                "packed_dim": 0,
                "pack_factor": self.quant_config.out_group_size
            },
        )

        layer.register_parameter("codes", codes)
        set_weight_attrs(codes, extra_weight_attrs)
        layer.register_parameter("codebooks", codebooks)
        set_weight_attrs(codebooks, extra_weight_attrs)
        layer.register_parameter("scales", scales)
        set_weight_attrs(scales, extra_weight_attrs)

322
    def apply(
James Fleming's avatar
James Fleming committed
323
324
325
326
327
328
329
330
331
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        codebooks = layer.codebooks
        codes = layer.codes
        scales = layer.scales
        output_partition_sizes = getattr(codebooks, "output_partition_sizes",
332
                                         [])
James Fleming's avatar
James Fleming committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

        nbooks = codes.shape[2]
        ingroups = codebooks.shape[3]
        outgroups = codebooks.shape[2]
        bits = codebooks.shape[1]

        # We support these formats with dedicated gemm and decompression
        # kernels.
        if ingroups == 8 and outgroups == 1 and (
            (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):

            # thresholds determined by timings on an A6000, one GPU
            use_gemv = math.prod(x.shape[:-1]) <= 6

            return ops.aqlm_gemm(
                x,
                codes,
                codebooks,
                scales,
                output_partition_sizes,
                bias,
            ) if use_gemv else optimized_dequantize_gemm(
                x,
                codes,
                codebooks,
                scales,
                output_partition_sizes,
                bias,
            )

        # fall back all unoptimized formats
        return generic_dequantize_gemm(
            x,
            codes,
            codebooks,
            scales,
            output_partition_sizes,
            bias,
        )