mmq.cuh 23.4 KB
Newer Older
1
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
2
template <typename scalar_t, int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
3
4
              allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
5
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
6
7
8
9
10
11
12
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {

    const block_q_t  * x = (const block_q_t  *) vx;
    const block_q8_1 * y = (const block_q8_1 *) vy;

    const int blocks_per_row_x = ncols_x / qk;
    const int blocks_per_col_y = nrows_y / QK8_1;
13
    const int blocks_per_warp = WARP_SIZE_GGUF / qi;
14
15
16

    const int & ncols_dst = ncols_y;

17
    const auto row_dst_0 = blockIdx.x*mmq_y;
18
19
    const int & row_x_0 = row_dst_0;

20
    const auto col_dst_0 = blockIdx.y*mmq_x;
21
22
23
24
25
26
27
28
29
    const int & col_y_0 = col_dst_0;

    int   * tile_x_ql = nullptr;
    half2 * tile_x_dm = nullptr;
    int   * tile_x_qh = nullptr;
    int   * tile_x_sc = nullptr;

    allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);

30
31
    __shared__ int    tile_y_qs[mmq_x * WARP_SIZE_GGUF];
    __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1];
32

33
    float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}};
34
35
36
37
38
39
40

    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {

        load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
                   threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);

#pragma unroll
41
        for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
42
            const auto kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
43
44
45
46
47
48
            const int kbxd = kqs / QI8_1;

#pragma unroll
            for (int i = 0; i < mmq_x; i += nwarps) {
                const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
49
                const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
50
51
52
53
54
                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
            }

#pragma unroll
            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
55
                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
56
                const auto kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
57
58
59
                const int col_y_eff = min(col_y_0 + ids, ncols_y-1);

                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
60
61
                const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds;
                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby];
62
63
64
65
66
67
68
69
70
71
72
                if (need_sum) {
                    *dsi_dst = *dsi_src;
                } else {
                    float * dfi_dst = (float *) dsi_dst;
                    *dfi_dst = __low2float(*dsi_src);
                }
            }

            __syncthreads();

// #pragma unroll // unrolling this loop causes too much register pressure
73
            for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) {
74
75
76
#pragma unroll
                for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
77
78
                    for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
                        sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot(
79
80
81
82
83
84
85
86
87
88
89
                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
                            threadIdx.x + i, threadIdx.y + j, k);
                    }
                }
            }
            __syncthreads();
        }
    }

#pragma unroll
    for (int j = 0; j < mmq_x; j += nwarps) {
90
        const auto col_dst = col_dst_0 + j + threadIdx.y;
91
92
93
94
95
        if (col_dst >= ncols_dst) {
            return;
        }

#pragma unroll
96
        for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
97
            const auto row_dst = row_dst_0 + threadIdx.x + i;
98
99
100
            if (row_dst >= nrows_dst) {
                continue;
            }
101
            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE_GGUF][j/nwarps];
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        }
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q4_0  64
#define  MMQ_Y_Q4_0  128
#define NWARPS_Q4_0  8
#else
#define  MMQ_X_Q4_0 4
#define  MMQ_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif

116
template<typename scalar_t, bool need_check> static __global__ void
117
#if defined(USE_ROCM)
118
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
119
120
#endif
mul_mat_q4_0(
121
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
122
123
124
125
126
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q4_0;
    const int mmq_y  =  MMQ_Y_Q4_0;
    const int nwarps = NWARPS_Q4_0;

127
    mul_mat_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
128
129
130
131
        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

132
template<typename scalar_t>
133
static void ggml_mul_mat_q4_0_q8_1_cuda(
134
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
135
136
137
138
139
140
141
142
143
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {

    int mmq_x  =  MMQ_X_Q4_0;
    int mmq_y  =  MMQ_Y_Q4_0;
    int nwarps = NWARPS_Q4_0;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
144
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
145
146
147

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
148
        mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
149
150
151
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
152
        mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q4_1 64
#define  MMQ_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define  MMQ_X_Q4_1 4
#define  MMQ_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif

167
template<typename scalar_t, bool need_check> static __global__ void
168
#if defined(USE_ROCM)
169
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
170
171
#endif
mul_mat_q4_1(
172
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
173
174
175
176
177
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q4_1;
    const int mmq_y  =  MMQ_Y_Q4_1;
    const int nwarps = NWARPS_Q4_1;

178
    mul_mat_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
179
180
181
182
        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

183
template<typename scalar_t>
184
static void ggml_mul_mat_q4_1_q8_1_cuda(
185
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
186
187
188
189
190
191
192
193
194
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {

    int mmq_x  =  MMQ_X_Q4_1;
    int mmq_y  =  MMQ_Y_Q4_1;
    int nwarps = NWARPS_Q4_1;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
195
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
196
197
198

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
199
        mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
200
201
202
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
203
        mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q5_0 64
#define  MMQ_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define  MMQ_X_Q5_0 4
#define  MMQ_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif

218
template<typename scalar_t, bool need_check> static __global__ void
219
#if defined(USE_ROCM)
220
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
221
222
#endif
mul_mat_q5_0(
223
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
224
225
226
227
228
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q5_0;
    const int mmq_y  =  MMQ_Y_Q5_0;
    const int nwarps = NWARPS_Q5_0;

229
    mul_mat_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
230
231
232
233
        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

234
template<typename scalar_t>
235
static void ggml_mul_mat_q5_0_q8_1_cuda(
236
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
237
238
239
240
241
242
243
244
245
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {

    const int mmq_x  =  MMQ_X_Q5_0;
    const int mmq_y  =  MMQ_Y_Q5_0;
    const int nwarps = NWARPS_Q5_0;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
246
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
247
248
249

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
250
        mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
251
252
253
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
254
        mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q5_1 64
#define  MMQ_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define  MMQ_X_Q5_1 4
#define  MMQ_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif

269
template<typename scalar_t, bool need_check> static __global__ void
270
#if defined(USE_ROCM)
271
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
272
273
#endif
mul_mat_q5_1(
274
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
275
276
277
278
279
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q5_1;
    const int mmq_y  =  MMQ_Y_Q5_1;
    const int nwarps = NWARPS_Q5_1;

280
    mul_mat_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
281
282
283
284
        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

285
template<typename scalar_t>
286
static void ggml_mul_mat_q5_1_q8_1_cuda(
287
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
288
289
290
291
292
293
294
295
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
    const int mmq_x  =  MMQ_X_Q5_1;
    const int mmq_y  =  MMQ_Y_Q5_1;
    const int nwarps = NWARPS_Q5_1;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
296
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
297
298
299

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
300
        mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
301
302
303
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
304
        mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q8_0 64
#define  MMQ_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define  MMQ_X_Q8_0 4
#define  MMQ_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif

319
template<typename scalar_t, bool need_check> static __global__ void
320
#if defined(USE_ROCM)
321
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
322
323
#endif
mul_mat_q8_0(
324
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
325
326
327
328
329
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q8_0;
    const int mmq_y  =  MMQ_Y_Q8_0;
    const int nwarps = NWARPS_Q8_0;

330
    mul_mat_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
331
332
333
334
        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

335
template<typename scalar_t>
336
static void ggml_mul_mat_q8_0_q8_1_cuda(
337
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
338
339
340
341
342
343
344
345
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
    const int mmq_x  =  MMQ_X_Q8_0;
    const int mmq_y  =  MMQ_Y_Q8_0;
    const int nwarps = NWARPS_Q8_0;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
346
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
347
348
349

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
350
        mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
351
352
353
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
354
        mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q2_K 64
#define  MMQ_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define  MMQ_X_Q2_K 4
#define  MMQ_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif

369
template<typename scalar_t, bool need_check> static __global__ void
370
#if defined(USE_ROCM)
371
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
372
373
#endif
mul_mat_q2_K(
374
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
375
376
377
378
379
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q2_K;
    const int mmq_y  =  MMQ_Y_Q2_K;
    const int nwarps = NWARPS_Q2_K;

380
    mul_mat_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
381
382
383
384
        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

385
template<typename scalar_t>
386
static void ggml_mul_mat_q2_K_q8_1_cuda(
387
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
388
389
390
391
392
393
394
395
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
    const int mmq_x  =  MMQ_X_Q2_K;
    const int mmq_y  =  MMQ_Y_Q2_K;
    const int nwarps = NWARPS_Q2_K;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
396
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
397
398
399

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
400
        mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
401
402
403
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
404
        mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
405
406
407
408
409
410
411
412
413
414
415
416
417
418
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q3_K 64
#define  MMQ_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define  MMQ_X_Q3_K 4
#define  MMQ_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif

419
template<typename scalar_t, bool need_check> static __global__ void
420
#if defined(USE_ROCM)
421
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
422
423
#endif
mul_mat_q3_K(
424
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
425
426
427
428
429
430
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {

    const int mmq_x  =  MMQ_X_Q3_K;
    const int mmq_y  =  MMQ_Y_Q3_K;
    const int nwarps = NWARPS_Q3_K;

431
    mul_mat_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
432
433
434
435
        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

436
template<typename scalar_t>
437
static void ggml_mul_mat_q3_K_q8_1_cuda(
438
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
439
440
441
442
443
444
445
446
447
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {

    const int mmq_x  =  MMQ_X_Q3_K;
    const int mmq_y  =  MMQ_Y_Q3_K;
    const int nwarps = NWARPS_Q3_K;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
448
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
449
450
451

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
452
        mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
453
454
455
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
456
        mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
457
458
459
460
461
462
463
464
465
466
467
468
469
470
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q4_K 64
#define  MMQ_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define  MMQ_X_Q4_K 4
#define  MMQ_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif

471
template<typename scalar_t, bool need_check> static __global__ void
472
#if defined(USE_ROCM)
473
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
474
475
#endif
mul_mat_q4_K(
476
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
477
478
479
480
481
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q4_K;
    const int mmq_y  =  MMQ_Y_Q4_K;
    const int nwarps = NWARPS_Q4_K;

482
    mul_mat_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
483
484
485
486
        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

487
template<typename scalar_t>
488
static void ggml_mul_mat_q4_K_q8_1_cuda(
489
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
490
491
492
493
494
495
496
497
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
    const int mmq_x  =  MMQ_X_Q4_K;
    const int mmq_y  =  MMQ_Y_Q4_K;
    const int nwarps = NWARPS_Q4_K;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
498
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
499
500
501

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
502
        mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
503
504
505
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
506
        mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
507
508
509
510
511
512
513
514
515
516
517
518
519
520
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q5_K 64
#define  MMQ_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define  MMQ_X_Q5_K 4
#define  MMQ_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif

521
template<typename scalar_t, bool need_check> static __global__ void
522
#if defined(USE_ROCM)
523
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
524
525
#endif
mul_mat_q5_K(
526
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
527
528
529
530
531
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q5_K;
    const int mmq_y  =  MMQ_Y_Q5_K;
    const int nwarps = NWARPS_Q5_K;

532
    mul_mat_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
533
534
535
536
        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

537
template<typename scalar_t>
538
static void ggml_mul_mat_q5_K_q8_1_cuda(
539
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
540
541
542
543
544
545
546
547
548
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {

    const int mmq_x  =  MMQ_X_Q5_K;
    const int mmq_y  =  MMQ_Y_Q5_K;
    const int nwarps = NWARPS_Q5_K;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
549
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
550
551
552

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
553
        mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
554
555
556
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
557
        mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
558
559
560
561
562
563
564
565
566
567
568
569
570
571
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}

#if defined(USE_ROCM)
#define  MMQ_X_Q6_K 64
#define  MMQ_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define  MMQ_X_Q6_K 4
#define  MMQ_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif

572
template<typename scalar_t, bool need_check> static __global__ void
573
#if defined(USE_ROCM)
574
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
575
576
#endif
mul_mat_q6_K(
577
    const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
578
579
580
581
582
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
    const int mmq_x  =  MMQ_X_Q6_K;
    const int mmq_y  =  MMQ_Y_Q6_K;
    const int nwarps = NWARPS_Q6_K;

583
    mul_mat_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
584
585
586
587
        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}

588
template<typename scalar_t>
589
static void ggml_mul_mat_q6_K_q8_1_cuda(
590
    const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
591
592
593
594
595
596
597
598
    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
    const int mmq_x  =  MMQ_X_Q6_K;
    const int mmq_y  =  MMQ_Y_Q6_K;
    const int nwarps = NWARPS_Q6_K;

    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
    const dim3 block_nums(block_num_x, block_num_y, 1);
599
    const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
600
601
602

    if (nrows_x % mmq_y == 0) {
        const bool need_check = false;
603
        mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
604
605
606
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    } else {
        const bool need_check = true;
607
        mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
608
609
610
            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
    }
}