exllama.py 3.98 KB
Newer Older
xuxzh1's avatar
last  
xuxzh1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from text_generation_server.layers.gptq import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")


def ext_make_q4(qweight, qzeros, scales, g_idx, device):
    """Construct Q4Matrix, return handle"""
    return make_q4(
        qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
    )


def ext_q4_matmul(x, q4, q4_width):
    """Matrix multiplication, returns x @ q4"""
    outshape = x.shape[:-1] + (q4_width,)
    x = x.view(-1, x.shape[-1])
    output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)

    q4_matmul(x, q4, output)

    return output.view(outshape)


MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None

TEMP_STATE = None
TEMP_DQ = None


def set_device(device):
    global DEVICE
    DEVICE = device


def create_exllama_buffers(max_total_tokens: int):
    global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ

    assert DEVICE is not None, "call set_device first"

    if not ACT_ORDER:
        max_total_tokens = 1

    # This temp_state buffer is required to reorder X in the act-order case.
    temp_state = torch.zeros(
        (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
    )
    temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)

    # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
    prepare_buffers(DEVICE, temp_state, temp_dq)

    matmul_recons_thd = 8
    matmul_fused_remap = False
    matmul_no_half2 = False
    set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

    TEMP_STATE, TEMP_DQ = temp_state, temp_dq


class Ex4bitLinear(torch.nn.Module):
    """Linear layer implementation with per-group 4-bit quantization of the weights"""

    def __init__(self, weight: GPTQWeight, bias):
        super().__init__()
        global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
        assert weight.bits == 4

        self.device = weight.qweight.device
        self.qweight = weight.qweight
        self.qzeros = weight.qzeros
        self.scales = weight.scales
        self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
        self.bias = bias if bias is not None else None

        if self.g_idx is not None and (
            (self.g_idx == 0).all()
            or torch.equal(
                weight.g_idx.cpu(),
                torch.tensor(
                    [i // weight.groupsize for i in range(weight.g_idx.shape[0])],
                    dtype=torch.int32,
                ),
            )
        ):
            self.empty_g_idx = True
            self.g_idx = None

        assert self.device.type == "cuda"
        assert self.device.index is not None

        self.q4 = ext_make_q4(
            self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
        )

        self.height = weight.qweight.shape[0] * 8
        self.width = weight.qweight.shape[1]

        # Infer groupsize from height of qzeros
        self.groupsize = None
        if self.qzeros.shape[0] > 1:
            self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])

        if self.groupsize is not None:
            assert weight.groupsize == self.groupsize

        # Handle act-order matrix
        if self.g_idx is not None:
            if self.groupsize is None:
                raise ValueError("Found group index but no groupsize. What do?")
            self.act_order = True
        else:
            self.act_order = False

        DEVICE = self.qweight.device

        MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)

        if self.act_order:
            MAX_INNER = max(MAX_INNER, self.height, self.width)

            ACT_ORDER = True

    def forward(self, x):
        out = ext_q4_matmul(x, self.q4, self.width)

        if self.bias is not None:
            out.add_(self.bias)
        return out