quant_utils.py 4.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from typing import List


Q_BITS = 4
STORAGE_BITS = 32
PACK_NUM = STORAGE_BITS // Q_BITS

ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]


def pack(imatrix: torch.Tensor, direction: str = "column"):
    """
    Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
    Args:
        imatrix (torch.Tensor): matrix of integers
        direction (str): direction of packing, either "column" or "row"

    Returns:
        qmatrix (torch.Tensor): packed matrix of integers
    """
    shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)

    imatrix = imatrix.to(torch.int8)
    imatrix = torch.bitwise_and(imatrix, 0x0F)  # eventually correct overflow

    if direction == "column":
        imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)
        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)

    elif direction == "row":
        imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)
        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)

    qmatrix = qmatrix.to(torch.int32)

    return qmatrix


def unpack(qmatrix: torch.Tensor, direction: str = "column"):
    """
    Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.

    Args:
        qmatrix (torch.Tensor): matrix of packed integers
        direction (str): direction of unpacking, either "column" or "row"

    Returns:
        imatrix (torch.Tensor): matrix of integers
    """
    shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)

    if direction == "column":
        imatrix = torch.bitwise_right_shift(
            qmatrix[:, :, None], shifts[None, None, :]
        ).view(qmatrix.shape[0], -1)

    elif direction == "row":
        imatrix = torch.bitwise_right_shift(
            qmatrix[:, None, :], shifts[None, :, None]
        ).view(-1, qmatrix.shape[-1])

    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow

    return imatrix


def quantize(fmatrix, scales, zeros, group_size):
    """
    Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.

    Args:
        fmatrix (torch.Tensor): matrix of 16-bit floats
        scales (torch.Tensor): matrix of 16-bit floats
        zeros (torch.Tensor): matrix of 4-bit integers
        group_size (int): group size

    Returns:
        imatrix (torch.Tensor): matrix of 4-bit integers
    """
    zeros = zeros.to(torch.int8) & 0x0F

    imatrix = torch.round(
        (
            fmatrix / scales.repeat_interleave(group_size, dim=0)
            + zeros.repeat_interleave(group_size, dim=0)
        )
    )

    imatrix = imatrix.to(torch.int8) & 0x0F

    return imatrix


def dequantize(imatrix, scales, zeros, group_size):
    """
    Dequantizes a 4-bit integer matrix into a float matrix.

    Args:
        imatrix (torch.Tensor): matrix of 4-bit integers
        scales (torch.Tensor): matrix of 16-bit floats
        zeros (torch.Tensor): matrix of 4-bit integers
        group_size (int): group size

    Returns:
        fmatrix (torch.Tensor): matrix of 16-bit floats
    """
    zeros = zeros.to(torch.int8) & 0x0F
    imatrix = imatrix.to(torch.int8) & 0x0F

    fmatrix = (
        imatrix - zeros.repeat_interleave(group_size, dim=0)
    ) * scales.repeat_interleave(group_size, dim=0)

    fmatrix = fmatrix.to(torch.float16)
    
    return fmatrix


def apply_order(
    imatrix: torch.Tensor,
    direction: str = "column",
    order: List[int] = ORDINAL_PACK_ORDER,
):
    """
    Applies the order to a 4-bit integer matrix.

    Args:
        imatrix (torch.Tensor): matrix of integers
        direction (str): direction of applying order, either "column" or "row"
        order (List[int]): order to apply, default is ordinal packing order

    Returns:
        imatrix (torch.Tensor): matrix of integers
    """
    if direction == "column":
        imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape)
    elif direction == "row":
        imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape)

    return imatrix


def awq_to_exllama(qweight, qzeros):
    # awq uses column packing for both weights and zeros
    izeros = unpack(qzeros, direction="column")
    iweights = unpack(qweight, direction="column")

    # Reverse the order of the iweight and izeros tensors
    izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
    iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
    # Subtract 1 from the izeros tensor (exllama adds 1 during inference)
    izeros = izeros - 1
    # exllama uses row packing for weights and column packing for zeros
    qzeros = pack(izeros, direction="column")
    qweight = pack(iweights, direction="row")

    return qweight, qzeros