q4_matmul.cu 8.77 KB
Newer Older
ilyas@huggingface.co's avatar
ilyas@huggingface.co committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32;       // Block size and thread count along columns in w and out
const int THREADS_Y = 1;        // Block size and thread count along rows in x and out

typedef void (*fp_q4_matmul_kernel)
(
    const half*,
    const uint32_t*,
    half*,
    const half*,
    const uint32_t*,
    const int,
    const int,
    const int,
    const int,
    const int,
    const uint32_t*,
    bool
);

template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
    const half* __restrict__ x,
    const uint32_t* __restrict__ w,
    half* __restrict__ out,
    const half* __restrict__ w_scales,
    const uint32_t* __restrict__ w_zeros,
    const int height,
    const int dim,
    const int width,
    const int groupsize,
    const int block_size_z,
    const uint32_t* __restrict__ x_map,
    bool no_zero
)
{
    // Start of block

    int x_column = block_size_z * blockIdx.z;
    int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));

    int w_column = THREADS_X * blockIdx.x + threadIdx.x;
    int x_row = THREADS_Y * blockIdx.y + threadIdx.y;

    int iterations = (x_column_end - x_column) / 8;

    // Views

    MatrixView_half x_(x, height, dim);
    MatrixView_half w_scales_(w_scales, dim / groupsize, width);
    MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
    MatrixView_q4_column w_(w, dim, width);
    MatrixView_half_rw out_(out, height, width);

    // Zero output

    if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
    {
        *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
        __syncthreads();
    }

    // Loop over part of x row (and w column)

    half2 acc = {};
    half acc_h = {};

    if constexpr (use_groupsize)
    {
        // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
        // could be slightly faster

        for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
        {
            if constexpr (use_half2)
            {
                half2 w_scale = w_scales_.item_half2half2(group, w_column);
                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;

                if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
                else                     acc = dot_product_8      (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
            }
            else
            {
                half w_scale = w_scales_.item(group, w_column);
                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;

                if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
                else                     acc_h = dot_product_8_h      (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
            }
        }
    }
    else
    {
        // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache

        for (int k = x_column; k < x_column + iterations * 8; k += 8)
        {
            if constexpr (use_half2)
            {
                int group = k / groupsize;
                half2 w_scale = w_scales_.item_half2half2(group, w_column);
                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;

                if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
                else                     acc = dot_product_8      (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
            }
            else
            {
                int group = k / groupsize;
                half w_scale = w_scales_.item(group, w_column);
                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;

                if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
                else                     acc_h = dot_product_8_h      (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
            }
        }
    }

    // Add to block result

    if constexpr (use_half2)
    {
        half result = __hadd(__low2half(acc), __high2half(acc));
        atomicAdd(out_.item_ptr(x_row, w_column), result);
    }
    else
    {
        atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
    }
}

fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
    // <bool use_half2, bool use_groupsize, bool use_x_map>
    if (tuningParams->matmul_no_half2) {
        if (block_size_z % groupsize == 0) {
            if (x_map) return q4_matmul_kernel<false, true,  true >;
            else       return q4_matmul_kernel<false, true,  false>;
        } else {
            if (x_map) return q4_matmul_kernel<false, false, true >;
            else       return q4_matmul_kernel<false, false, false>;
        }
    } else {
        if (block_size_z % groupsize == 0)
        {
            if (x_map) return q4_matmul_kernel<true,  true,  true >;
            else       return q4_matmul_kernel<true,  true,  false>;
        } else {
            if (x_map) return q4_matmul_kernel<true,  false, true >;
            else       return q4_matmul_kernel<true,  false, false>;
        }
    }
};

// Compute y = x @ w

void q4_matmul_cuda
(
    ExLlamaTuning* tuningParams,
    const half* x,
    const int x_height,
    const Q4Matrix* w,
    half* out,
    bool no_zero,
    cudaStream_t alt_stream
)
{
    int height = x_height;
    int dim = w->height;
    int width = w->width;

    cudaSetDevice(w->device);

    uint32_t* x_map = w->cuda_x_map;
    const half* x_mapped = x;
    if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
    {
        CudaBuffers* buffers = get_buffers(w->device);
        column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
        x_mapped = buffers->temp_state;
        x_map = NULL;
    }

    int block_size_z;
    if (w->width == 4096) block_size_z = 384;           // 7B
    else if (w->width == 11008) block_size_z = 256;
    else if (w->width == 5120) block_size_z = 384;      // 13B
    else if (w->width == 13824) block_size_z = 256;
    else if (w->width == 6656) block_size_z = 256;      // 33B
    else if (w->width == 17920) block_size_z = 128;
    else block_size_z = 256;

    //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));

    dim3 threads(THREADS_X, THREADS_Y, 1);

    dim3 blocks
    (
        (width + threads.x - 1) / threads.x,
        (height + threads.y - 1) / threads.y,
        (dim + block_size_z - 1) / block_size_z
    );

    fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);

    kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}

void q4_matmul_recons_cuda
(
    ExLlamaTuning* tuningParams,
    const half* x,
    const int x_height,
    Q4Matrix* w,
    half* out,
    const cublasHandle_t handle,
    bool no_zero
)
{
    int height = x_height;
    int dim = w->height;
    int width = w->width;

    cudaSetDevice(w->device);
    CudaBuffers* buffers = get_buffers(w->device);

    const half* x_mapped = x;
    if (w->cuda_x_map)
    {
        TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend for GPTQ with act-order. Please call the exllama_set_max_input_length function to increase the buffer size for a sequence length >=", x_height, ":\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, max_input_length=", x_height, ")");
        column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
        x_mapped = buffers->temp_state;
    }

    w->reconstruct(buffers->temp_dq);

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
    const float alpha = 1.0f;
    const float beta = no_zero ? 1.0f : 0.0f;
    cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
                  x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
    const half alpha = __float2half(1.0f);
    const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
    cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}