attention.cuh 25.9 KB
Newer Older
1
2
3
4
5
6
#pragma once

#include "gemm_base.cuh"

namespace nunchaku::kernels {

Muyang Li's avatar
Muyang Li committed
7
// M: Q tokens
8
9
10
11
12
13
14
// N: V HEAD_DIM
// K: K tokens
// D: QK HEAD_DIM
template<bool bf16out>
struct AttentionFP16Config {
    static constexpr int HEAD_DIM = 128;

Muyang Li's avatar
Muyang Li committed
15
    static constexpr int BLOCK_M   = 128;
16
17
18
19
20
    static constexpr int WARP_SIZE = 32;
    static constexpr int NUM_WARPS = 8;

    static constexpr int WARP_K = 32;

Muyang Li's avatar
Muyang Li committed
21
22
    static constexpr int INSN_M    = 16;
    static constexpr int INSN_N    = 16;
23
24
25
26
27
28
    static constexpr int INSN_K_QK = 16;
    static constexpr int INSN_K_PV = 16;

    using half_t  = half;
    using half2_t = half2;

fengzch-das's avatar
fengzch-das committed
29
30
    using epilogue_half_t  = typename std::conditional_t<bf16out, __nv_bfloat16, half>;
    using epilogue_half2_t = typename std::conditional_t<bf16out, __nv_bfloat162, half2>;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
};

using AttentionFP16Config_FP16 = AttentionFP16Config<false>;
using AttentionFP16Config_BF16 = AttentionFP16Config<true>;

template<typename AttentionConfig>
class Attention;

#ifndef __INTELLISENSE__
template<typename AttentionConfig>
class Attention : public AttentionConfig {
#else
template<>
class Attention<AttentionFP16Config_BF16> : public AttentionFP16Config_BF16 {
    using AttentionConfig = AttentionFP16Config_BF16;
#endif

public:
    using AttentionConfig::HEAD_DIM;
    using AttentionConfig::BLOCK_M;
    using AttentionConfig::WARP_SIZE;
    using AttentionConfig::NUM_WARPS;
    using AttentionConfig::WARP_K;
    using AttentionConfig::INSN_M;
    using AttentionConfig::INSN_N;
    using AttentionConfig::INSN_K_QK;
    using AttentionConfig::INSN_K_PV;
    using typename AttentionConfig::half_t;
    using typename AttentionConfig::half2_t;
    using typename AttentionConfig::epilogue_half_t;
    using typename AttentionConfig::epilogue_half2_t;

fengzch-das's avatar
fengzch-das committed
63
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
64
65
66
67
68
    static constexpr bool IS_SM80 = true;
#else
    static constexpr bool IS_SM80 = false;
#endif

69
    struct GEMMConfig {
Muyang Li's avatar
Muyang Li committed
70
71
        static constexpr int BLOCK_M   = AttentionConfig::BLOCK_M;
        static constexpr int BLOCK_N   = AttentionConfig::HEAD_DIM;
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        static constexpr int WARP_SIZE = AttentionConfig::WARP_SIZE;
        static constexpr int NUM_WARPS = AttentionConfig::NUM_WARPS;

        static constexpr int INSN_M = AttentionConfig::INSN_M;
        static constexpr int INSN_N = AttentionConfig::INSN_N;
        static constexpr int INSN_K = AttentionConfig::INSN_K_PV;

        using half_t  = typename AttentionConfig::epilogue_half_t;
        using half2_t = typename AttentionConfig::epilogue_half2_t;
    };

    using GEMM = typename nunchaku::kernels::GEMMBase<GEMMConfig>;

    static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
    static constexpr int WARP_N = HEAD_DIM;
    static constexpr int WARP_D = HEAD_DIM;

Muyang Li's avatar
Muyang Li committed
89
90
91
92
    static constexpr int WARP_M_TILES = WARP_M / INSN_M;
    static constexpr int WARP_N_TILES = WARP_N / INSN_N;
    static constexpr int WARP_K_TILES_QK =
        WARP_K / INSN_N; // when multiplying Q*K, K is on dimension of N in MMA instruction
93
94
95
96
97
98
    static constexpr int WARP_K_TILES_PV = WARP_K / INSN_K_PV;
    static constexpr int WARP_D_TILES    = WARP_D / INSN_K_QK;

    using packed_q_t = uint4;
    using packed_k_t = uint4;
    using packed_v_t = uint4;
Muyang Li's avatar
Muyang Li committed
99
100
101
    using q_warp     = std::array<packed_q_t, WARP_M_TILES * WARP_D_TILES>;
    using k_warp     = std::array<packed_k_t, WARP_K_TILES_QK * WARP_D_TILES>;
    using v_warp     = std::array<packed_v_t, WARP_K_TILES_PV * WARP_N_TILES>;
102
103

    using packed_p_t = uint4;
Muyang Li's avatar
Muyang Li committed
104
    using p_warp     = std::array<packed_v_t, WARP_M_TILES * WARP_K_TILES_PV>;
105

Muyang Li's avatar
Muyang Li committed
106
    using packed_fpsum_t   = uint4;
107
108
109
110
111
112
113
114
115
    using packed_f32psum_t = typename GEMM::packed_f32psum_t;

    using qk_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_K_TILES_QK>;
    // using o_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
    using o_warp = typename GEMM::f32psum_warp;

    using rowval_warp = std::array<float2, WARP_M_TILES>;

    struct BlockInfo {
Muyang Li's avatar
Muyang Li committed
116
117
118
119
        int bm;    // M: Q tokens, bm: block id of M
        int head;  // H: head
        int batch; // B: batch
        int numBlocksM;
120
121
122
123
        int numHeads;
        int numBatch;
    };

Muyang Li's avatar
Muyang Li committed
124
    __device__ __forceinline__ static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
125
126
127
128
        std::array<half2_t, 4> results;
        for (int i = 0; i < 4; i++) {
            results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
        }
sxtyzhangzk's avatar
sxtyzhangzk committed
129
        return kernels::bit_cast<packed_fpsum_t>(results);
130
131
    }

Muyang Li's avatar
Muyang Li committed
132
    __device__ __forceinline__ static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
133
        auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
134
135
        packed_f32psum_t results;
        for (int i = 0; i < 4; i++) {
Muyang Li's avatar
Muyang Li committed
136
137
            float2 tmp              = half22float2(arr[i]);
            results.data[i * 2]     = tmp.x;
138
139
140
141
142
143
            results.data[i * 2 + 1] = tmp.y;
        }
        return results;
    }

    // q: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_D_TILES, WARP_SIZE] of packed_q_t
Muyang Li's avatar
Muyang Li committed
144
    __device__ __forceinline__ static void load_q(const packed_q_t *ptr, q_warp &out, bool pred) {
145
146
147
148
149
150
151
152
153
154
155
156
157
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_q_t *base = &ptr[((warpId * WARP_M_TILES + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_M_TILES>([&]<int m>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[m * WARP_D_TILES + d] = load_pred(&base[(m * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // k: [batch, head, ktile, WARP_K_TILES_QK, WARP_D_TILES, WARP_SIZE] of packed_k_t
Muyang Li's avatar
Muyang Li committed
158
    __device__ __forceinline__ static void load_k(const packed_k_t *ptr, int ktile, k_warp &out, bool pred) {
159
160
161
162
163
164
165
166
167
168
169
170
171
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_k_t *base = &ptr[((ktile * WARP_K_TILES_QK + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_QK>([&]<int k>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[k * WARP_D_TILES + d] = load_pred(&base[(k * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // v: [batch, head, ktile, WARP_K_TILES_PV, WARP_N_TILES, WARP_SIZE] of packed_v_t
Muyang Li's avatar
Muyang Li committed
172
    __device__ __forceinline__ static void load_v(const packed_v_t *ptr, int ktile, v_warp &out, bool pred) {
173
174
175
176
177
178
179
180
181
182
183
184
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_v_t *base = &ptr[((ktile * WARP_K_TILES_PV + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_PV>([&]<int k>() {
            unrolled_loop<WARP_N_TILES>([&]<int n>() {
                out[n * WARP_K_TILES_PV + k] = load_pred(&base[(k * WARP_N_TILES + n) * WARP_SIZE], pred);
            });
        });
    }

Muyang Li's avatar
Muyang Li committed
185
186
    __device__ __forceinline__ static packed_fpsum_t
    mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) {
187
188
189
        uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y));
        uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w));
        return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y};
190
191
192
    }

    // set nan values to -inf
Muyang Li's avatar
Muyang Li committed
193
    __device__ __forceinline__ static half2_t fix_nan(half2_t input) {
194
195
        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        /**
Muyang Li's avatar
Muyang Li committed
196
197
198
         * In accordance to the IEEE-754R standard,
         * if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
         * but not the other,
199
200
201
202
203
         * the result is the non-NaN parameter.
         */
        return __hmax2(input, half2_t(neginf, neginf));
    }

Muyang Li's avatar
Muyang Li committed
204
    __device__ __forceinline__ static float fix_nan(float input) {
205
206
207
208
        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        return fmaxf(input, neginf);
    }

Muyang Li's avatar
Muyang Li committed
209
    __device__ __forceinline__ static packed_fpsum_t fix_nan(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
210
211
212
213
        input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
        input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
        input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
        input.w = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.w)));
214
215
216
        return input;
    }

Muyang Li's avatar
Muyang Li committed
217
218
    __device__ __forceinline__ static packed_f32psum_t fix_nan(packed_f32psum_t input) {
#pragma unroll
219
220
221
222
223
224
        for (int i = 0; i < 8; i++) {
            input.data[i] = fix_nan(input.data[i]);
        }
        return input;
    }

Muyang Li's avatar
Muyang Li committed
225
    __device__ __forceinline__ static qk_warp compute_qk(q_warp Q, k_warp K) {
226
        qk_warp QK;
Muyang Li's avatar
Muyang Li committed
227
#pragma unroll
228
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
229
#pragma unroll
230
231
232
            for (int k = 0; k < WARP_K_TILES_QK; k++) {

#if 0
Muyang Li's avatar
Muyang Li committed
233
#pragma unroll
234
235
236
237
                for (int d = 0; d < WARP_D_TILES; d++) {
                    packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                    auto f32psum = packed_fp16_to_fp32(psum);
Muyang Li's avatar
Muyang Li committed
238
#pragma unroll
239
240
241
242
243
244
245
                    for (int i = 0; i < 8; i++) {
                        QK[m * WARP_K_TILES_QK + k].data[i] += f32psum.data[i];
                    }
                }

#else
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
246
#pragma unroll
247
248
249
                for (int d = 0; d < WARP_D_TILES; d++) {
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                }
Muyang Li's avatar
Muyang Li committed
250

251
                if constexpr (IS_SM80) {
Muyang Li's avatar
Muyang Li committed
252
                    psum                        = fix_nan(psum);
253
254
255
256
                    QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
                } else {
                    QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum));
                }
257
258
259
260
261
262
#endif
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
263
264
    __device__ __forceinline__ static rowval_warp compute_rowmax(qk_warp QK, rowval_warp rowmax, float scale) {
#pragma unroll
265
266
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 maxv;
Muyang Li's avatar
Muyang Li committed
267
#pragma unroll
268
269
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
270
271
                float x               = fmaxf(fmaxf(val.data[0], val.data[1]), fmaxf(val.data[4], val.data[5]));
                float y               = fmaxf(fmaxf(val.data[2], val.data[3]), fmaxf(val.data[6], val.data[7]));
272
273
274
275
276
277
278
                if (k == 0) {
                    maxv = make_float2(x, y);
                } else {
                    maxv.x = fmaxf(maxv.x, x);
                    maxv.y = fmaxf(maxv.y, y);
                }
            }
Muyang Li's avatar
Muyang Li committed
279
#pragma unroll
280
281
282
283
284
285
286
287
288
289
            for (int mask = 1; mask <= 2; mask *= 2) {
                maxv.x = fmaxf(maxv.x, __shfl_xor_sync(~0, maxv.x, mask));
                maxv.y = fmaxf(maxv.y, __shfl_xor_sync(~0, maxv.y, mask));
            }
            rowmax[m].x = fmaxf(rowmax[m].x, maxv.x * scale);
            rowmax[m].y = fmaxf(rowmax[m].y, maxv.y * scale);
        }
        return rowmax;
    }

Muyang Li's avatar
Muyang Li committed
290
291
    __device__ __forceinline__ static qk_warp softmax(qk_warp QK, rowval_warp rowmax_scaled, float scale) {
#pragma unroll
292
293
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 shift = rowmax_scaled[m];
Muyang Li's avatar
Muyang Li committed
294
#pragma unroll
295
296
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
297
298
299
300
301
302
303
304
                val.data[0]           = cuda_exp2(fmaf(val.data[0], scale, -shift.x));
                val.data[1]           = cuda_exp2(fmaf(val.data[1], scale, -shift.x));
                val.data[4]           = cuda_exp2(fmaf(val.data[4], scale, -shift.x));
                val.data[5]           = cuda_exp2(fmaf(val.data[5], scale, -shift.x));
                val.data[2]           = cuda_exp2(fmaf(val.data[2], scale, -shift.y));
                val.data[3]           = cuda_exp2(fmaf(val.data[3], scale, -shift.y));
                val.data[6]           = cuda_exp2(fmaf(val.data[6], scale, -shift.y));
                val.data[7]           = cuda_exp2(fmaf(val.data[7], scale, -shift.y));
305
306
307
308
309
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
310
    __device__ __forceinline__ static rowval_warp compute_rowsum(qk_warp QK) {
311
        rowval_warp rowsum;
Muyang Li's avatar
Muyang Li committed
312
#pragma unroll
313
314
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 sumv = make_float2(0.0f, 0.0f);
Muyang Li's avatar
Muyang Li committed
315
#pragma unroll
316
317
318
319
320
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
                sumv.x += val.data[0] + val.data[1] + val.data[4] + val.data[5];
                sumv.y += val.data[2] + val.data[3] + val.data[6] + val.data[7];
            }
Muyang Li's avatar
Muyang Li committed
321
#pragma unroll
322
323
324
325
326
327
328
329
330
            for (int mask = 1; mask <= 2; mask *= 2) {
                sumv.x += __shfl_xor_sync(~0, sumv.x, mask);
                sumv.y += __shfl_xor_sync(~0, sumv.y, mask);
            }
            rowsum[m] = sumv;
        }
        return rowsum;
    }

Muyang Li's avatar
Muyang Li committed
331
    __device__ __forceinline__ static rowval_warp compute_rescale(rowval_warp rowmax0, rowval_warp rowmax1) {
332
        rowval_warp rescale;
Muyang Li's avatar
Muyang Li committed
333
#pragma unroll
334
335
336
337
338
339
340
        for (int m = 0; m < WARP_M_TILES; m++) {
            rescale[m].x = cuda_exp2(rowmax0[m].x - rowmax1[m].x);
            rescale[m].y = cuda_exp2(rowmax0[m].y - rowmax1[m].y);
        }
        return rescale;
    }

Muyang Li's avatar
Muyang Li committed
341
342
    __device__ __forceinline__ static o_warp compute_pv(p_warp P, v_warp V, o_warp O, rowval_warp rescale) {
#pragma unroll
343
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
344
#pragma unroll
345
346
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
347
#pragma unroll
348
349
350
                for (int k = 0; k < WARP_K_TILES_PV; k++) {
                    psum = mma_f16xf16_f16(P[m * WARP_K_TILES_PV + k], V[n * WARP_K_TILES_PV + k], psum);
                }
Muyang Li's avatar
Muyang Li committed
351
352

                packed_f32psum_t pv    = packed_fp16_to_fp32(psum);
353
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
354
355
356
357
358
359
360
361
                oval.data[0]           = oval.data[0] * rescale[m].x + pv.data[0];
                oval.data[1]           = oval.data[1] * rescale[m].x + pv.data[1];
                oval.data[4]           = oval.data[4] * rescale[m].x + pv.data[4];
                oval.data[5]           = oval.data[5] * rescale[m].x + pv.data[5];
                oval.data[2]           = oval.data[2] * rescale[m].y + pv.data[2];
                oval.data[3]           = oval.data[3] * rescale[m].y + pv.data[3];
                oval.data[6]           = oval.data[6] * rescale[m].y + pv.data[6];
                oval.data[7]           = oval.data[7] * rescale[m].y + pv.data[7];
362
363
364
365
366
            }
        }
        return O;
    }

Muyang Li's avatar
Muyang Li committed
367
368
    __device__ __forceinline__ static rowval_warp compute_l(rowval_warp L, rowval_warp rescale, rowval_warp rowsum) {
#pragma unroll
369
370
371
372
373
374
375
        for (int m = 0; m < WARP_M_TILES; m++) {
            L[m].x = fmaf(L[m].x, rescale[m].x, rowsum[m].x);
            L[m].y = fmaf(L[m].y, rescale[m].y, rowsum[m].y);
        }
        return L;
    }

Muyang Li's avatar
Muyang Li committed
376
    __device__ __forceinline__ static p_warp qk_to_p(qk_warp QK) {
377
378
        static_assert(WARP_K_TILES_QK == WARP_K_TILES_PV);
        p_warp P;
Muyang Li's avatar
Muyang Li committed
379
#pragma unroll
380
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
381
#pragma unroll
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            for (int k = 0; k < WARP_K_TILES_PV; k++) {
                P[m * WARP_K_TILES_PV + k] = packed_fp32_to_fp16(QK[m * WARP_K_TILES_QK + k]);
            }
        }
        return P;
    }

    // __device__ __forceinline__
    // static void compute(q_warp Q, k_warp K, v_warp V, o_warp &O, rowval_warp &M, rowval_warp &L, float scale) {
    //     qk_warp qk = compute_qk(Q, K);
    //     rowval_warp M1 = compute_rowmax(qk, M, scale);
    //     qk = softmax(qk, M1, scale);
    //     rowval_warp rowsum = compute_rowsum(qk);
    //     p_warp P = qk_to_p(qk);
    //     rowval_warp rescale = compute_rescale(M, M1);
    //     M = M1;
    //     L = compute_l(L, rescale, rowsum);
    //     O = compute_pv(P, V, O, rescale);
    // }

Muyang Li's avatar
Muyang Li committed
402
403
404
405
406
407
408
    __device__ __forceinline__ static std::tuple<p_warp, rowval_warp>
    compute(q_warp Q, k_warp K, rowval_warp &M, rowval_warp &L, float scale) {
        qk_warp qk          = compute_qk(Q, K);
        rowval_warp M1      = compute_rowmax(qk, M, scale);
        qk                  = softmax(qk, M1, scale);
        rowval_warp rowsum  = compute_rowsum(qk);
        p_warp P            = qk_to_p(qk);
409
        rowval_warp rescale = compute_rescale(M, M1);
Muyang Li's avatar
Muyang Li committed
410
411
        M                   = M1;
        L                   = compute_l(L, rescale, rowsum);
412
413
414
        return {P, rescale};
    }

Muyang Li's avatar
Muyang Li committed
415
416
    __device__ __forceinline__ static o_warp compute_o(o_warp O, rowval_warp L) {
#pragma unroll
417
418
419
420
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 inv;
            inv.x = cuda_frcp(L[m].x);
            inv.y = cuda_frcp(L[m].y);
Muyang Li's avatar
Muyang Li committed
421
#pragma unroll
422
423
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
424
425
426
427
428
429
430
431
                oval.data[0]           = oval.data[0] * inv.x;
                oval.data[1]           = oval.data[1] * inv.x;
                oval.data[4]           = oval.data[4] * inv.x;
                oval.data[5]           = oval.data[5] * inv.x;
                oval.data[2]           = oval.data[2] * inv.y;
                oval.data[3]           = oval.data[3] * inv.y;
                oval.data[6]           = oval.data[6] * inv.y;
                oval.data[7]           = oval.data[7] * inv.y;
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            }
        }
        return O;
    }

#if 0
    template<typename Epilogue>
    __device__ __forceinline__
    static void attention_fp16_block(
        const BlockInfo binfo,
        const packed_q_t *ptr_q,
        const packed_k_t *ptr_k,
        const packed_v_t *ptr_v,
        float scale,
Muyang Li's avatar
Muyang Li committed
446
        int ntokens_q,
447
448
        int ntokens_kv,
        Epilogue::Arguments epilogueArgs,
Muyang Li's avatar
Muyang Li committed
449
        bool alwaysfalse)
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    {
        constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        q_warp Q;   // 32
        k_warp K[NUM_STAGES];  // 32 * 2
        v_warp V[NUM_STAGES];  // 32 * 2
        o_warp O;   // 64
        rowval_warp L;  // 2
        rowval_warp M;  // 2

        load_q(ptr_q, Q, true);
        for (int k = 0; k < NUM_STAGES - 1; k++) {
            load_k(ptr_k, k, K[k], true);
            load_v(ptr_v, k, V[k], true);
        }

Muyang Li's avatar
Muyang Li committed
469
#pragma unroll
470
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
471
#pragma unroll
472
473
474
475
476
477
478
479
480
481
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

        __shared__ q_warp Q_shmem[NUM_WARPS];
Muyang Li's avatar
Muyang Li committed
482
#pragma unroll
483
484
485
486
487
488
489
490
        for (int i = 0; i < Q.size(); i++) {
            store<true>(&Q_shmem[warpId][i], Q[i]);
        }

        int dummy = 0;

        // TODO: mask tokens in last block
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 += NUM_STAGES) {
Muyang Li's avatar
Muyang Li committed
491
#pragma unroll
492
            for (int k2 = 0; k2 < NUM_STAGES; k2++) {
Muyang Li's avatar
Muyang Li committed
493
#pragma unroll
494
495
496
497
498
499
500
501
502
                for (int i = 0; i < Q.size(); i++) {
                    Q[i] = load<true>(&Q_shmem[warpId][i]);
                }

                int nextk = k1 + k2 + NUM_STAGES - 1;
                int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
                bool pred = nextk < ntokens_kv / WARP_K;
                load_k(ptr_k, nextk, K[idx], pred);
                load_v(ptr_v, nextk, V[idx], pred);
Muyang Li's avatar
Muyang Li committed
503

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
                // __syncthreads();
                // if (alwaysfalse) {
                //     dummy = clock();
                // }
                auto [P, rescale] = compute(Q, K[k2], V[k2], M, L, scale);
                O = compute_pv(P, V[idx], O, rescale);

                if (alwaysfalse) {
                    dummy = clock();
                }
                // asm volatile ("membar.cta;");
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

        Epilogue()(typename GEMM::BlockInfo{
            .bm = binfo.batch * binfo.numBlocksM + binfo.bm,
            .bn = binfo.head,
            .numBlocksM = binfo.numBatch * binfo.numBlocksM,
            .numBlocksN = binfo.numHeads,
        }, f16psum, binfo.numBatch * binfo.numBlocksM * BLOCK_M, binfo.numHeads * HEAD_DIM, 0, epilogueArgs);
    }
#else
    template<typename Epilogue>
Muyang Li's avatar
Muyang Li committed
533
534
535
536
537
538
539
540
541
    __device__ __forceinline__ static void attention_fp16_block(const BlockInfo binfo,
                                                                const packed_q_t *ptr_q,
                                                                const packed_k_t *ptr_k,
                                                                const packed_v_t *ptr_v,
                                                                float scale,
                                                                int ntokens_q,
                                                                int ntokens_kv,
                                                                Epilogue::Arguments epilogueArgs,
                                                                bool alwaysfalse) {
542
543
544
545
546
        // constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

Muyang Li's avatar
Muyang Li committed
547
548
549
550
551
552
        q_warp Q;      // 32
        k_warp K;      // 64
        v_warp V;      // 64
        o_warp O;      // 64
        rowval_warp L; // 2
        rowval_warp M; // 2
553
554
555
556

        load_q(ptr_q, Q, true);
        load_k(ptr_k, 0, K, true);

Muyang Li's avatar
Muyang Li committed
557
#pragma unroll
558
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
559
#pragma unroll
560
561
562
563
564
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

Muyang Li's avatar
Muyang Li committed
565
566
        static constexpr float neginf =
            -std::numeric_limits<float>::max(); // not real inf, to prevent nan during computation
567
568
569
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

570
        static constexpr int SHMEM_TILES = IS_SM80 ? 4 : 7;
571
572
573
574
        static_assert(SHMEM_TILES <= Q.size());
        using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE];
        __shared__ q_shmem_t Q_shmem;

Muyang Li's avatar
Muyang Li committed
575
#pragma unroll
576
577
578
579
580
581
582
583
        for (int i = 0; i < SHMEM_TILES; i++) {
            store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]);
        }
        __syncwarp();

        int dummy = 0;

        // TODO: mask tokens in last block
Muyang Li's avatar
Muyang Li committed
584
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1++) {
585
586
587
588
            if (alwaysfalse) {
                ptr_v += K[0].x;
            }

Muyang Li's avatar
Muyang Li committed
589
#pragma unroll
590
591
592
593
            for (int i = 0; i < SHMEM_TILES; i++) {
                Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]);
            }

594
595
596
597
598
599
            if constexpr (!IS_SM80) {
                if (k1 % 2 == 1) {
                    __syncthreads();
                }
            }

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            if (alwaysfalse) {
                dummy = clock();
            }

            load_v(ptr_v, k1, V, true);

            if (alwaysfalse) {
                dummy = clock();
            }

            auto [P, rescale] = compute(Q, K, M, L, scale);

            if (alwaysfalse) {
                dummy = clock();
            }

            if (alwaysfalse) {
                ptr_k += V[0].x;
            }

            // if (alwaysfalse) {
            //     dummy = clock();
            // }

Muyang Li's avatar
Muyang Li committed
624
            load_k(ptr_k, k1 + 1, K, k1 + 1 < ntokens_kv / WARP_K);
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
            // if (alwaysfalse) {
            //     dummy = clock();
            // }

            O = compute_pv(P, V, O, rescale);

            if (alwaysfalse) {
                dummy = clock();
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

Muyang Li's avatar
Muyang Li committed
643
644
645
646
647
648
649
650
651
652
653
654
        Epilogue()(
            typename GEMM::BlockInfo{
                .bm         = binfo.batch * binfo.numBlocksM + binfo.bm,
                .bn         = binfo.head,
                .numBlocksM = binfo.numBatch * binfo.numBlocksM,
                .numBlocksN = binfo.numHeads,
            },
            f16psum,
            binfo.numBatch * binfo.numBlocksM * BLOCK_M,
            binfo.numHeads * HEAD_DIM,
            0,
            epilogueArgs);
655
656
657
658
659
    }
#endif

    template<typename Epilogue>
    struct attention_fp16_kernel {
fengzch-das's avatar
fengzch-das committed
660
        static constexpr int MIN_ARCH   = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
661
662
        static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t);

Muyang Li's avatar
Muyang Li committed
663
664
665
666
667
668
669
670
        __device__ void operator()(const packed_q_t *ptr_q,
                                   const packed_k_t *ptr_k,
                                   const packed_v_t *ptr_v,
                                   float scale,
                                   int ntokens_q,
                                   int ntokens_kv,
                                   Epilogue::Arguments epilogueArgs,
                                   bool alwaysfalse) {
671
            BlockInfo binfo = {
Muyang Li's avatar
Muyang Li committed
672
673
674
                .bm         = (int)blockIdx.x,
                .head       = (int)blockIdx.y,
                .batch      = (int)blockIdx.z,
675
                .numBlocksM = (int)gridDim.x,
Muyang Li's avatar
Muyang Li committed
676
677
                .numHeads   = (int)gridDim.y,
                .numBatch   = (int)gridDim.z,
678
679
680
681
682
683
684
685
686
            };

            // extern __shared__ uint8_t shmem[];
            // q_shmem_t *Q_shmem = reinterpret_cast<q_shmem_t *>(shmem);

            const int ktiles = ceilDiv(ntokens_kv, WARP_K);

            attention_fp16_block<Epilogue>(
                binfo,
Muyang Li's avatar
Muyang Li committed
687
688
689
690
691
692
                ptr_q + ((binfo.batch * binfo.numHeads + binfo.head) * binfo.numBlocksM + binfo.bm) * NUM_WARPS *
                            WARP_M_TILES * WARP_D_TILES * WARP_SIZE,
                ptr_k +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_QK * WARP_D_TILES * WARP_SIZE,
                ptr_v +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_PV * WARP_N_TILES * WARP_SIZE,
693
694
695
696
697
                scale,
                ntokens_q,
                ntokens_kv,
                // *Q_shmem,
                epilogueArgs,
Muyang Li's avatar
Muyang Li committed
698
                alwaysfalse);
699
700
701
702
        }
    };
};

Muyang Li's avatar
Muyang Li committed
703
}; // namespace nunchaku::kernels