exllama.py 3.98 KB
Newer Older
Daniël de Kok's avatar
Daniël de Kok committed
1
from text_generation_server.layers.gptq import GPTQWeight
2
3
4
5
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
6
7
none_tensor = torch.empty((1, 1), device="meta")

8
9
10

def ext_make_q4(qweight, qzeros, scales, g_idx, device):
    """Construct Q4Matrix, return handle"""
11
12
13
14
    return make_q4(
        qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
    )

15
16
17
18
19

def ext_q4_matmul(x, q4, q4_width):
    """Matrix multiplication, returns x @ q4"""
    outshape = x.shape[:-1] + (q4_width,)
    x = x.view(-1, x.shape[-1])
20
    output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
21
22
23
24
25

    q4_matmul(x, q4, output)

    return output.view(outshape)

26

27
28
29
30
31
32
33
34
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None

TEMP_STATE = None
TEMP_DQ = None

35

36
37
38
39
40
def set_device(device):
    global DEVICE
    DEVICE = device


41
def create_exllama_buffers(max_total_tokens: int):
42
    global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
43

44
45
    assert DEVICE is not None, "call set_device first"

46
    if not ACT_ORDER:
47
48
49
        max_total_tokens = 1

    # This temp_state buffer is required to reorder X in the act-order case.
50
51
52
    temp_state = torch.zeros(
        (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
    )
53
54
55
56
57
58
59
60
61
62
    temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)

    # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
    prepare_buffers(DEVICE, temp_state, temp_dq)

    matmul_recons_thd = 8
    matmul_fused_remap = False
    matmul_no_half2 = False
    set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

63
64
    TEMP_STATE, TEMP_DQ = temp_state, temp_dq

65

66
class Ex4bitLinear(torch.nn.Module):
67
    """Linear layer implementation with per-group 4-bit quantization of the weights"""
68

69
    def __init__(self, weight: GPTQWeight, bias):
70
        super().__init__()
71
        global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
72
        assert weight.bits == 4
73

74
75
76
77
78
        self.device = weight.qweight.device
        self.qweight = weight.qweight
        self.qzeros = weight.qzeros
        self.scales = weight.scales
        self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
79
        self.bias = bias if bias is not None else None
80
81
82
83

        if self.g_idx is not None and (
            (self.g_idx == 0).all()
            or torch.equal(
84
                weight.g_idx.cpu(),
85
                torch.tensor(
86
87
                    [i // weight.groupsize for i in range(weight.g_idx.shape[0])],
                    dtype=torch.int32,
88
89
90
                ),
            )
        ):
91
92
            self.empty_g_idx = True
            self.g_idx = None
93

94
95
96
97
        assert self.device.type == "cuda"
        assert self.device.index is not None

        self.q4 = ext_make_q4(
98
            self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
99
100
        )

101
102
        self.height = weight.qweight.shape[0] * 8
        self.width = weight.qweight.shape[1]
103
104
105
106
107
108
109

        # Infer groupsize from height of qzeros
        self.groupsize = None
        if self.qzeros.shape[0] > 1:
            self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])

        if self.groupsize is not None:
110
            assert weight.groupsize == self.groupsize
111
112
113

        # Handle act-order matrix
        if self.g_idx is not None:
114
115
            if self.groupsize is None:
                raise ValueError("Found group index but no groupsize. What do?")
116
117
118
119
120
121
122
123
124
125
126
127
            self.act_order = True
        else:
            self.act_order = False

        DEVICE = self.qweight.device

        MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)

        if self.act_order:
            MAX_INNER = max(MAX_INNER, self.height, self.width)

            ACT_ORDER = True
128

129
130
131
132
133
134
    def forward(self, x):
        out = ext_q4_matmul(x, self.q4, self.width)

        if self.bias is not None:
            out.add_(self.bias)
        return out