exllama.py 4.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")

def ext_make_q4(qweight, qzeros, scales, g_idx, device):
    """Construct Q4Matrix, return handle"""
    return make_q4(qweight,
                   qzeros,
                   scales,
                   g_idx if g_idx is not None else none_tensor,
                   device)

def ext_q4_matmul(x, q4, q4_width):
    """Matrix multiplication, returns x @ q4"""
    outshape = x.shape[:-1] + (q4_width,)
    x = x.view(-1, x.shape[-1])
    output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)

    q4_matmul(x, q4, output)

    return output.view(outshape)

MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None

TEMP_STATE = None
TEMP_DQ = None

def create_exllama_buffers():
    global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
                            
    if ACT_ORDER:
        # TODO: this should be set to rust side `max_total_tokens`, but TGI
        # does not offer an API to expose this variable to python, as this variable
        # is handled by the client but it appears the model is initialized by the server.
        # An alternative could be to initialize the buffers during warmup.
        # Dummy
        max_total_tokens = 2048
    else:
        max_total_tokens = 1

    # This temp_state buffer is required to reorder X in the act-order case.
    temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
    temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)

    # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
    prepare_buffers(DEVICE, temp_state, temp_dq)

    matmul_recons_thd = 8
    matmul_fused_remap = False
    matmul_no_half2 = False
    set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

    TEMP_STATE, TEMP_DQ =  temp_state, temp_dq

class Ex4bitLinear:
    """Linear layer implementation with per-group 4-bit quantization of the weights"""
    def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
        global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
        assert bits == 4

        self.device = qweight.device
        self.qweight = qweight
        self.qzeros = qzeros
        self.scales = scales
        self.g_idx = g_idx.cpu() if g_idx is not None else None
        self.bias = bias if bias is not None else None
        
        if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
            self.empty_g_idx = True
            self.g_idx = None
        
        assert self.device.type == "cuda"
        assert self.device.index is not None

        self.q4 = ext_make_q4(
            self.qweight,
            self.qzeros,
            self.scales,
            self.g_idx,
            self.device.index
        )

        self.height = qweight.shape[0] * 8
        self.width = qweight.shape[1]

        # Infer groupsize from height of qzeros
        self.groupsize = None
        if self.qzeros.shape[0] > 1:
            self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])

        if self.groupsize is not None:
            assert groupsize == self.groupsize

        # Handle act-order matrix
        if self.g_idx is not None:
            if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
            self.act_order = True
        else:
            self.act_order = False

        DEVICE = self.qweight.device

        MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)

        if self.act_order:
            MAX_INNER = max(MAX_INNER, self.height, self.width)

            ACT_ORDER = True
    
    def forward(self, x):
        out = ext_q4_matmul(x, self.q4, self.width)

        if self.bias is not None:
            out.add_(self.bias)
        return out