attention.cuh 25.9 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
2
3
4
5
6
7
#pragma once

#include "gemm_base.cuh"

namespace nunchaku::kernels {

Muyang Li's avatar
Muyang Li committed
8
// M: Q tokens
9
10
11
12
13
14
15
// N: V HEAD_DIM
// K: K tokens
// D: QK HEAD_DIM
template<bool bf16out>
struct AttentionFP16Config {
    static constexpr int HEAD_DIM = 128;

Muyang Li's avatar
Muyang Li committed
16
    static constexpr int BLOCK_M   = 128;
17
18
19
20
21
    static constexpr int WARP_SIZE = 32;
    static constexpr int NUM_WARPS = 8;

    static constexpr int WARP_K = 32;

Muyang Li's avatar
Muyang Li committed
22
23
    static constexpr int INSN_M    = 16;
    static constexpr int INSN_N    = 16;
24
25
26
27
28
29
    static constexpr int INSN_K_QK = 16;
    static constexpr int INSN_K_PV = 16;

    using half_t  = half;
    using half2_t = half2;

fengzch-das's avatar
fengzch-das committed
30
31
    using epilogue_half_t  = typename std::conditional_t<bf16out, __hip_bfloat16, half>;
    using epilogue_half2_t = typename std::conditional_t<bf16out, __hip_bfloat162, half2>;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
};

using AttentionFP16Config_FP16 = AttentionFP16Config<false>;
using AttentionFP16Config_BF16 = AttentionFP16Config<true>;

template<typename AttentionConfig>
class Attention;

#ifndef __INTELLISENSE__
template<typename AttentionConfig>
class Attention : public AttentionConfig {
#else
template<>
class Attention<AttentionFP16Config_BF16> : public AttentionFP16Config_BF16 {
    using AttentionConfig = AttentionFP16Config_BF16;
#endif

public:
    using AttentionConfig::HEAD_DIM;
    using AttentionConfig::BLOCK_M;
    using AttentionConfig::WARP_SIZE;
    using AttentionConfig::NUM_WARPS;
    using AttentionConfig::WARP_K;
    using AttentionConfig::INSN_M;
    using AttentionConfig::INSN_N;
    using AttentionConfig::INSN_K_QK;
    using AttentionConfig::INSN_K_PV;
    using typename AttentionConfig::half_t;
    using typename AttentionConfig::half2_t;
    using typename AttentionConfig::epilogue_half_t;
    using typename AttentionConfig::epilogue_half2_t;

fengzch-das's avatar
fengzch-das committed
64
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800
65
66
67
68
69
    static constexpr bool IS_SM80 = true;
#else
    static constexpr bool IS_SM80 = false;
#endif

70
    struct GEMMConfig {
Muyang Li's avatar
Muyang Li committed
71
72
        static constexpr int BLOCK_M   = AttentionConfig::BLOCK_M;
        static constexpr int BLOCK_N   = AttentionConfig::HEAD_DIM;
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        static constexpr int WARP_SIZE = AttentionConfig::WARP_SIZE;
        static constexpr int NUM_WARPS = AttentionConfig::NUM_WARPS;

        static constexpr int INSN_M = AttentionConfig::INSN_M;
        static constexpr int INSN_N = AttentionConfig::INSN_N;
        static constexpr int INSN_K = AttentionConfig::INSN_K_PV;

        using half_t  = typename AttentionConfig::epilogue_half_t;
        using half2_t = typename AttentionConfig::epilogue_half2_t;
    };

    using GEMM = typename nunchaku::kernels::GEMMBase<GEMMConfig>;

    static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
    static constexpr int WARP_N = HEAD_DIM;
    static constexpr int WARP_D = HEAD_DIM;

Muyang Li's avatar
Muyang Li committed
90
91
92
93
    static constexpr int WARP_M_TILES = WARP_M / INSN_M;
    static constexpr int WARP_N_TILES = WARP_N / INSN_N;
    static constexpr int WARP_K_TILES_QK =
        WARP_K / INSN_N; // when multiplying Q*K, K is on dimension of N in MMA instruction
94
95
96
97
98
99
    static constexpr int WARP_K_TILES_PV = WARP_K / INSN_K_PV;
    static constexpr int WARP_D_TILES    = WARP_D / INSN_K_QK;

    using packed_q_t = uint4;
    using packed_k_t = uint4;
    using packed_v_t = uint4;
Muyang Li's avatar
Muyang Li committed
100
101
102
    using q_warp     = std::array<packed_q_t, WARP_M_TILES * WARP_D_TILES>;
    using k_warp     = std::array<packed_k_t, WARP_K_TILES_QK * WARP_D_TILES>;
    using v_warp     = std::array<packed_v_t, WARP_K_TILES_PV * WARP_N_TILES>;
103
104

    using packed_p_t = uint4;
Muyang Li's avatar
Muyang Li committed
105
    using p_warp     = std::array<packed_v_t, WARP_M_TILES * WARP_K_TILES_PV>;
106

Muyang Li's avatar
Muyang Li committed
107
    using packed_fpsum_t   = uint4;
108
109
110
111
112
113
114
115
116
    using packed_f32psum_t = typename GEMM::packed_f32psum_t;

    using qk_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_K_TILES_QK>;
    // using o_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
    using o_warp = typename GEMM::f32psum_warp;

    using rowval_warp = std::array<float2, WARP_M_TILES>;

    struct BlockInfo {
Muyang Li's avatar
Muyang Li committed
117
118
119
120
        int bm;    // M: Q tokens, bm: block id of M
        int head;  // H: head
        int batch; // B: batch
        int numBlocksM;
121
122
123
124
        int numHeads;
        int numBatch;
    };

Muyang Li's avatar
Muyang Li committed
125
    __device__ __forceinline__ static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
126
127
128
129
        std::array<half2_t, 4> results;
        for (int i = 0; i < 4; i++) {
            results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
        }
sxtyzhangzk's avatar
sxtyzhangzk committed
130
        return kernels::bit_cast<packed_fpsum_t>(results);
131
132
    }

Muyang Li's avatar
Muyang Li committed
133
    __device__ __forceinline__ static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
134
        auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
135
136
        packed_f32psum_t results;
        for (int i = 0; i < 4; i++) {
Muyang Li's avatar
Muyang Li committed
137
138
            float2 tmp              = half22float2(arr[i]);
            results.data[i * 2]     = tmp.x;
139
140
141
142
143
144
            results.data[i * 2 + 1] = tmp.y;
        }
        return results;
    }

    // q: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_D_TILES, WARP_SIZE] of packed_q_t
Muyang Li's avatar
Muyang Li committed
145
    __device__ __forceinline__ static void load_q(const packed_q_t *ptr, q_warp &out, bool pred) {
146
147
148
149
150
151
152
153
154
155
156
157
158
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_q_t *base = &ptr[((warpId * WARP_M_TILES + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_M_TILES>([&]<int m>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[m * WARP_D_TILES + d] = load_pred(&base[(m * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // k: [batch, head, ktile, WARP_K_TILES_QK, WARP_D_TILES, WARP_SIZE] of packed_k_t
Muyang Li's avatar
Muyang Li committed
159
    __device__ __forceinline__ static void load_k(const packed_k_t *ptr, int ktile, k_warp &out, bool pred) {
160
161
162
163
164
165
166
167
168
169
170
171
172
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_k_t *base = &ptr[((ktile * WARP_K_TILES_QK + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_QK>([&]<int k>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[k * WARP_D_TILES + d] = load_pred(&base[(k * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // v: [batch, head, ktile, WARP_K_TILES_PV, WARP_N_TILES, WARP_SIZE] of packed_v_t
Muyang Li's avatar
Muyang Li committed
173
    __device__ __forceinline__ static void load_v(const packed_v_t *ptr, int ktile, v_warp &out, bool pred) {
174
175
176
177
178
179
180
181
182
183
184
185
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_v_t *base = &ptr[((ktile * WARP_K_TILES_PV + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_PV>([&]<int k>() {
            unrolled_loop<WARP_N_TILES>([&]<int n>() {
                out[n * WARP_K_TILES_PV + k] = load_pred(&base[(k * WARP_N_TILES + n) * WARP_SIZE], pred);
            });
        });
    }

Muyang Li's avatar
Muyang Li committed
186
187
    __device__ __forceinline__ static packed_fpsum_t
    mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) {
188
189
190
        uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y));
        uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w));
        return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y};
191
192
193
    }

    // set nan values to -inf
Muyang Li's avatar
Muyang Li committed
194
    __device__ __forceinline__ static half2_t fix_nan(half2_t input) {
195
196
        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        /**
Muyang Li's avatar
Muyang Li committed
197
198
199
         * In accordance to the IEEE-754R standard,
         * if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
         * but not the other,
200
201
202
203
204
         * the result is the non-NaN parameter.
         */
        return __hmax2(input, half2_t(neginf, neginf));
    }

Muyang Li's avatar
Muyang Li committed
205
    __device__ __forceinline__ static float fix_nan(float input) {
206
207
208
209
        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        return fmaxf(input, neginf);
    }

Muyang Li's avatar
Muyang Li committed
210
    __device__ __forceinline__ static packed_fpsum_t fix_nan(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
211
212
213
214
        input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
        input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
        input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
        input.w = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.w)));
215
216
217
        return input;
    }

Muyang Li's avatar
Muyang Li committed
218
219
    __device__ __forceinline__ static packed_f32psum_t fix_nan(packed_f32psum_t input) {
#pragma unroll
220
221
222
223
224
225
        for (int i = 0; i < 8; i++) {
            input.data[i] = fix_nan(input.data[i]);
        }
        return input;
    }

Muyang Li's avatar
Muyang Li committed
226
    __device__ __forceinline__ static qk_warp compute_qk(q_warp Q, k_warp K) {
227
        qk_warp QK;
Muyang Li's avatar
Muyang Li committed
228
#pragma unroll
229
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
230
#pragma unroll
231
232
233
            for (int k = 0; k < WARP_K_TILES_QK; k++) {

#if 0
Muyang Li's avatar
Muyang Li committed
234
#pragma unroll
235
236
237
238
                for (int d = 0; d < WARP_D_TILES; d++) {
                    packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                    auto f32psum = packed_fp16_to_fp32(psum);
Muyang Li's avatar
Muyang Li committed
239
#pragma unroll
240
241
242
243
244
245
246
                    for (int i = 0; i < 8; i++) {
                        QK[m * WARP_K_TILES_QK + k].data[i] += f32psum.data[i];
                    }
                }

#else
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
247
#pragma unroll
248
249
250
                for (int d = 0; d < WARP_D_TILES; d++) {
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                }
Muyang Li's avatar
Muyang Li committed
251

252
                if constexpr (IS_SM80) {
Muyang Li's avatar
Muyang Li committed
253
                    psum                        = fix_nan(psum);
254
255
256
257
                    QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
                } else {
                    QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum));
                }
258
259
260
261
262
263
#endif
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
264
265
    __device__ __forceinline__ static rowval_warp compute_rowmax(qk_warp QK, rowval_warp rowmax, float scale) {
#pragma unroll
266
267
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 maxv;
Muyang Li's avatar
Muyang Li committed
268
#pragma unroll
269
270
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
271
272
                float x               = fmaxf(fmaxf(val.data[0], val.data[1]), fmaxf(val.data[4], val.data[5]));
                float y               = fmaxf(fmaxf(val.data[2], val.data[3]), fmaxf(val.data[6], val.data[7]));
273
274
275
276
277
278
279
                if (k == 0) {
                    maxv = make_float2(x, y);
                } else {
                    maxv.x = fmaxf(maxv.x, x);
                    maxv.y = fmaxf(maxv.y, y);
                }
            }
Muyang Li's avatar
Muyang Li committed
280
#pragma unroll
281
282
283
284
285
286
287
288
289
290
            for (int mask = 1; mask <= 2; mask *= 2) {
                maxv.x = fmaxf(maxv.x, __shfl_xor_sync(~0, maxv.x, mask));
                maxv.y = fmaxf(maxv.y, __shfl_xor_sync(~0, maxv.y, mask));
            }
            rowmax[m].x = fmaxf(rowmax[m].x, maxv.x * scale);
            rowmax[m].y = fmaxf(rowmax[m].y, maxv.y * scale);
        }
        return rowmax;
    }

Muyang Li's avatar
Muyang Li committed
291
292
    __device__ __forceinline__ static qk_warp softmax(qk_warp QK, rowval_warp rowmax_scaled, float scale) {
#pragma unroll
293
294
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 shift = rowmax_scaled[m];
Muyang Li's avatar
Muyang Li committed
295
#pragma unroll
296
297
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
298
299
300
301
302
303
304
305
                val.data[0]           = cuda_exp2(fmaf(val.data[0], scale, -shift.x));
                val.data[1]           = cuda_exp2(fmaf(val.data[1], scale, -shift.x));
                val.data[4]           = cuda_exp2(fmaf(val.data[4], scale, -shift.x));
                val.data[5]           = cuda_exp2(fmaf(val.data[5], scale, -shift.x));
                val.data[2]           = cuda_exp2(fmaf(val.data[2], scale, -shift.y));
                val.data[3]           = cuda_exp2(fmaf(val.data[3], scale, -shift.y));
                val.data[6]           = cuda_exp2(fmaf(val.data[6], scale, -shift.y));
                val.data[7]           = cuda_exp2(fmaf(val.data[7], scale, -shift.y));
306
307
308
309
310
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
311
    __device__ __forceinline__ static rowval_warp compute_rowsum(qk_warp QK) {
312
        rowval_warp rowsum;
Muyang Li's avatar
Muyang Li committed
313
#pragma unroll
314
315
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 sumv = make_float2(0.0f, 0.0f);
Muyang Li's avatar
Muyang Li committed
316
#pragma unroll
317
318
319
320
321
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
                sumv.x += val.data[0] + val.data[1] + val.data[4] + val.data[5];
                sumv.y += val.data[2] + val.data[3] + val.data[6] + val.data[7];
            }
Muyang Li's avatar
Muyang Li committed
322
#pragma unroll
323
324
325
326
327
328
329
330
331
            for (int mask = 1; mask <= 2; mask *= 2) {
                sumv.x += __shfl_xor_sync(~0, sumv.x, mask);
                sumv.y += __shfl_xor_sync(~0, sumv.y, mask);
            }
            rowsum[m] = sumv;
        }
        return rowsum;
    }

Muyang Li's avatar
Muyang Li committed
332
    __device__ __forceinline__ static rowval_warp compute_rescale(rowval_warp rowmax0, rowval_warp rowmax1) {
333
        rowval_warp rescale;
Muyang Li's avatar
Muyang Li committed
334
#pragma unroll
335
336
337
338
339
340
341
        for (int m = 0; m < WARP_M_TILES; m++) {
            rescale[m].x = cuda_exp2(rowmax0[m].x - rowmax1[m].x);
            rescale[m].y = cuda_exp2(rowmax0[m].y - rowmax1[m].y);
        }
        return rescale;
    }

Muyang Li's avatar
Muyang Li committed
342
343
    __device__ __forceinline__ static o_warp compute_pv(p_warp P, v_warp V, o_warp O, rowval_warp rescale) {
#pragma unroll
344
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
345
#pragma unroll
346
347
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
348
#pragma unroll
349
350
351
                for (int k = 0; k < WARP_K_TILES_PV; k++) {
                    psum = mma_f16xf16_f16(P[m * WARP_K_TILES_PV + k], V[n * WARP_K_TILES_PV + k], psum);
                }
Muyang Li's avatar
Muyang Li committed
352
353

                packed_f32psum_t pv    = packed_fp16_to_fp32(psum);
354
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
355
356
357
358
359
360
361
362
                oval.data[0]           = oval.data[0] * rescale[m].x + pv.data[0];
                oval.data[1]           = oval.data[1] * rescale[m].x + pv.data[1];
                oval.data[4]           = oval.data[4] * rescale[m].x + pv.data[4];
                oval.data[5]           = oval.data[5] * rescale[m].x + pv.data[5];
                oval.data[2]           = oval.data[2] * rescale[m].y + pv.data[2];
                oval.data[3]           = oval.data[3] * rescale[m].y + pv.data[3];
                oval.data[6]           = oval.data[6] * rescale[m].y + pv.data[6];
                oval.data[7]           = oval.data[7] * rescale[m].y + pv.data[7];
363
364
365
366
367
            }
        }
        return O;
    }

Muyang Li's avatar
Muyang Li committed
368
369
    __device__ __forceinline__ static rowval_warp compute_l(rowval_warp L, rowval_warp rescale, rowval_warp rowsum) {
#pragma unroll
370
371
372
373
374
375
376
        for (int m = 0; m < WARP_M_TILES; m++) {
            L[m].x = fmaf(L[m].x, rescale[m].x, rowsum[m].x);
            L[m].y = fmaf(L[m].y, rescale[m].y, rowsum[m].y);
        }
        return L;
    }

Muyang Li's avatar
Muyang Li committed
377
    __device__ __forceinline__ static p_warp qk_to_p(qk_warp QK) {
378
379
        static_assert(WARP_K_TILES_QK == WARP_K_TILES_PV);
        p_warp P;
Muyang Li's avatar
Muyang Li committed
380
#pragma unroll
381
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
382
#pragma unroll
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            for (int k = 0; k < WARP_K_TILES_PV; k++) {
                P[m * WARP_K_TILES_PV + k] = packed_fp32_to_fp16(QK[m * WARP_K_TILES_QK + k]);
            }
        }
        return P;
    }

    // __device__ __forceinline__
    // static void compute(q_warp Q, k_warp K, v_warp V, o_warp &O, rowval_warp &M, rowval_warp &L, float scale) {
    //     qk_warp qk = compute_qk(Q, K);
    //     rowval_warp M1 = compute_rowmax(qk, M, scale);
    //     qk = softmax(qk, M1, scale);
    //     rowval_warp rowsum = compute_rowsum(qk);
    //     p_warp P = qk_to_p(qk);
    //     rowval_warp rescale = compute_rescale(M, M1);
    //     M = M1;
    //     L = compute_l(L, rescale, rowsum);
    //     O = compute_pv(P, V, O, rescale);
    // }

Muyang Li's avatar
Muyang Li committed
403
404
405
406
407
408
409
    __device__ __forceinline__ static std::tuple<p_warp, rowval_warp>
    compute(q_warp Q, k_warp K, rowval_warp &M, rowval_warp &L, float scale) {
        qk_warp qk          = compute_qk(Q, K);
        rowval_warp M1      = compute_rowmax(qk, M, scale);
        qk                  = softmax(qk, M1, scale);
        rowval_warp rowsum  = compute_rowsum(qk);
        p_warp P            = qk_to_p(qk);
410
        rowval_warp rescale = compute_rescale(M, M1);
Muyang Li's avatar
Muyang Li committed
411
412
        M                   = M1;
        L                   = compute_l(L, rescale, rowsum);
413
414
415
        return {P, rescale};
    }

Muyang Li's avatar
Muyang Li committed
416
417
    __device__ __forceinline__ static o_warp compute_o(o_warp O, rowval_warp L) {
#pragma unroll
418
419
420
421
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 inv;
            inv.x = cuda_frcp(L[m].x);
            inv.y = cuda_frcp(L[m].y);
Muyang Li's avatar
Muyang Li committed
422
#pragma unroll
423
424
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
425
426
427
428
429
430
431
432
                oval.data[0]           = oval.data[0] * inv.x;
                oval.data[1]           = oval.data[1] * inv.x;
                oval.data[4]           = oval.data[4] * inv.x;
                oval.data[5]           = oval.data[5] * inv.x;
                oval.data[2]           = oval.data[2] * inv.y;
                oval.data[3]           = oval.data[3] * inv.y;
                oval.data[6]           = oval.data[6] * inv.y;
                oval.data[7]           = oval.data[7] * inv.y;
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            }
        }
        return O;
    }

#if 0
    template<typename Epilogue>
    __device__ __forceinline__
    static void attention_fp16_block(
        const BlockInfo binfo,
        const packed_q_t *ptr_q,
        const packed_k_t *ptr_k,
        const packed_v_t *ptr_v,
        float scale,
Muyang Li's avatar
Muyang Li committed
447
        int ntokens_q,
448
449
        int ntokens_kv,
        Epilogue::Arguments epilogueArgs,
Muyang Li's avatar
Muyang Li committed
450
        bool alwaysfalse)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    {
        constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        q_warp Q;   // 32
        k_warp K[NUM_STAGES];  // 32 * 2
        v_warp V[NUM_STAGES];  // 32 * 2
        o_warp O;   // 64
        rowval_warp L;  // 2
        rowval_warp M;  // 2

        load_q(ptr_q, Q, true);
        for (int k = 0; k < NUM_STAGES - 1; k++) {
            load_k(ptr_k, k, K[k], true);
            load_v(ptr_v, k, V[k], true);
        }

Muyang Li's avatar
Muyang Li committed
470
#pragma unroll
471
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
472
#pragma unroll
473
474
475
476
477
478
479
480
481
482
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

        __shared__ q_warp Q_shmem[NUM_WARPS];
Muyang Li's avatar
Muyang Li committed
483
#pragma unroll
484
485
486
487
488
489
490
491
        for (int i = 0; i < Q.size(); i++) {
            store<true>(&Q_shmem[warpId][i], Q[i]);
        }

        int dummy = 0;

        // TODO: mask tokens in last block
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 += NUM_STAGES) {
Muyang Li's avatar
Muyang Li committed
492
#pragma unroll
493
            for (int k2 = 0; k2 < NUM_STAGES; k2++) {
Muyang Li's avatar
Muyang Li committed
494
#pragma unroll
495
496
497
498
499
500
501
502
503
                for (int i = 0; i < Q.size(); i++) {
                    Q[i] = load<true>(&Q_shmem[warpId][i]);
                }

                int nextk = k1 + k2 + NUM_STAGES - 1;
                int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
                bool pred = nextk < ntokens_kv / WARP_K;
                load_k(ptr_k, nextk, K[idx], pred);
                load_v(ptr_v, nextk, V[idx], pred);
Muyang Li's avatar
Muyang Li committed
504

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                // __syncthreads();
                // if (alwaysfalse) {
                //     dummy = clock();
                // }
                auto [P, rescale] = compute(Q, K[k2], V[k2], M, L, scale);
                O = compute_pv(P, V[idx], O, rescale);

                if (alwaysfalse) {
                    dummy = clock();
                }
                // asm volatile ("membar.cta;");
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

        Epilogue()(typename GEMM::BlockInfo{
            .bm = binfo.batch * binfo.numBlocksM + binfo.bm,
            .bn = binfo.head,
            .numBlocksM = binfo.numBatch * binfo.numBlocksM,
            .numBlocksN = binfo.numHeads,
        }, f16psum, binfo.numBatch * binfo.numBlocksM * BLOCK_M, binfo.numHeads * HEAD_DIM, 0, epilogueArgs);
    }
#else
    template<typename Epilogue>
Muyang Li's avatar
Muyang Li committed
534
535
536
537
538
539
540
541
542
    __device__ __forceinline__ static void attention_fp16_block(const BlockInfo binfo,
                                                                const packed_q_t *ptr_q,
                                                                const packed_k_t *ptr_k,
                                                                const packed_v_t *ptr_v,
                                                                float scale,
                                                                int ntokens_q,
                                                                int ntokens_kv,
                                                                Epilogue::Arguments epilogueArgs,
                                                                bool alwaysfalse) {
543
544
545
546
547
        // constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

Muyang Li's avatar
Muyang Li committed
548
549
550
551
552
553
        q_warp Q;      // 32
        k_warp K;      // 64
        v_warp V;      // 64
        o_warp O;      // 64
        rowval_warp L; // 2
        rowval_warp M; // 2
554
555
556
557

        load_q(ptr_q, Q, true);
        load_k(ptr_k, 0, K, true);

Muyang Li's avatar
Muyang Li committed
558
#pragma unroll
559
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
560
#pragma unroll
561
562
563
564
565
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

Muyang Li's avatar
Muyang Li committed
566
567
        static constexpr float neginf =
            -std::numeric_limits<float>::max(); // not real inf, to prevent nan during computation
568
569
570
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

571
        static constexpr int SHMEM_TILES = IS_SM80 ? 4 : 7;
572
573
574
575
        static_assert(SHMEM_TILES <= Q.size());
        using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE];
        __shared__ q_shmem_t Q_shmem;

Muyang Li's avatar
Muyang Li committed
576
#pragma unroll
577
578
579
580
581
582
583
584
        for (int i = 0; i < SHMEM_TILES; i++) {
            store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]);
        }
        __syncwarp();

        int dummy = 0;

        // TODO: mask tokens in last block
Muyang Li's avatar
Muyang Li committed
585
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1++) {
586
587
588
589
            if (alwaysfalse) {
                ptr_v += K[0].x;
            }

Muyang Li's avatar
Muyang Li committed
590
#pragma unroll
591
592
593
594
            for (int i = 0; i < SHMEM_TILES; i++) {
                Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]);
            }

595
596
597
598
599
600
            if constexpr (!IS_SM80) {
                if (k1 % 2 == 1) {
                    __syncthreads();
                }
            }

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            if (alwaysfalse) {
                dummy = clock();
            }

            load_v(ptr_v, k1, V, true);

            if (alwaysfalse) {
                dummy = clock();
            }

            auto [P, rescale] = compute(Q, K, M, L, scale);

            if (alwaysfalse) {
                dummy = clock();
            }

            if (alwaysfalse) {
                ptr_k += V[0].x;
            }

            // if (alwaysfalse) {
            //     dummy = clock();
            // }

Muyang Li's avatar
Muyang Li committed
625
            load_k(ptr_k, k1 + 1, K, k1 + 1 < ntokens_kv / WARP_K);
626

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
            // if (alwaysfalse) {
            //     dummy = clock();
            // }

            O = compute_pv(P, V, O, rescale);

            if (alwaysfalse) {
                dummy = clock();
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

Muyang Li's avatar
Muyang Li committed
644
645
646
647
648
649
650
651
652
653
654
655
        Epilogue()(
            typename GEMM::BlockInfo{
                .bm         = binfo.batch * binfo.numBlocksM + binfo.bm,
                .bn         = binfo.head,
                .numBlocksM = binfo.numBatch * binfo.numBlocksM,
                .numBlocksN = binfo.numHeads,
            },
            f16psum,
            binfo.numBatch * binfo.numBlocksM * BLOCK_M,
            binfo.numHeads * HEAD_DIM,
            0,
            epilogueArgs);
656
657
658
659
660
    }
#endif

    template<typename Epilogue>
    struct attention_fp16_kernel {
fengzch-das's avatar
fengzch-das committed
661
        static constexpr int MIN_ARCH   = std::is_same_v<half_t, __hip_bfloat16> ? 800 : 750;
662
663
        static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t);

Muyang Li's avatar
Muyang Li committed
664
665
666
667
668
669
670
671
        __device__ void operator()(const packed_q_t *ptr_q,
                                   const packed_k_t *ptr_k,
                                   const packed_v_t *ptr_v,
                                   float scale,
                                   int ntokens_q,
                                   int ntokens_kv,
                                   Epilogue::Arguments epilogueArgs,
                                   bool alwaysfalse) {
672
            BlockInfo binfo = {
Muyang Li's avatar
Muyang Li committed
673
674
675
                .bm         = (int)blockIdx.x,
                .head       = (int)blockIdx.y,
                .batch      = (int)blockIdx.z,
676
                .numBlocksM = (int)gridDim.x,
Muyang Li's avatar
Muyang Li committed
677
678
                .numHeads   = (int)gridDim.y,
                .numBatch   = (int)gridDim.z,
679
680
681
682
683
684
685
686
687
            };

            // extern __shared__ uint8_t shmem[];
            // q_shmem_t *Q_shmem = reinterpret_cast<q_shmem_t *>(shmem);

            const int ktiles = ceilDiv(ntokens_kv, WARP_K);

            attention_fp16_block<Epilogue>(
                binfo,
Muyang Li's avatar
Muyang Li committed
688
689
690
691
692
693
                ptr_q + ((binfo.batch * binfo.numHeads + binfo.head) * binfo.numBlocksM + binfo.bm) * NUM_WARPS *
                            WARP_M_TILES * WARP_D_TILES * WARP_SIZE,
                ptr_k +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_QK * WARP_D_TILES * WARP_SIZE,
                ptr_v +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_PV * WARP_N_TILES * WARP_SIZE,
694
695
696
697
698
                scale,
                ntokens_q,
                ntokens_kv,
                // *Q_shmem,
                epilogueArgs,
Muyang Li's avatar
Muyang Li committed
699
                alwaysfalse);
700
701
702
703
        }
    };
};

Muyang Li's avatar
Muyang Li committed
704
}; // namespace nunchaku::kernels