attention.cuh 26.4 KB
Newer Older
1
2
3
4
5
6
#pragma once

#include "gemm_base.cuh"

namespace nunchaku::kernels {

Muyang Li's avatar
Muyang Li committed
7
// M: Q tokens
8
9
10
11
12
13
14
// N: V HEAD_DIM
// K: K tokens
// D: QK HEAD_DIM
template<bool bf16out>
struct AttentionFP16Config {
    static constexpr int HEAD_DIM = 128;

Muyang Li's avatar
Muyang Li committed
15
    static constexpr int BLOCK_M   = 128;
16
17
18
19
20
    static constexpr int WARP_SIZE = 32;
    static constexpr int NUM_WARPS = 8;

    static constexpr int WARP_K = 32;

Muyang Li's avatar
Muyang Li committed
21
22
    static constexpr int INSN_M    = 16;
    static constexpr int INSN_N    = 16;
23
24
25
26
27
28
    static constexpr int INSN_K_QK = 16;
    static constexpr int INSN_K_PV = 16;

    using half_t  = half;
    using half2_t = half2;

fengzch-das's avatar
fengzch-das committed
29
30
    using epilogue_half_t  = typename std::conditional_t<bf16out, __nv_bfloat16, half>;
    using epilogue_half2_t = typename std::conditional_t<bf16out, __nv_bfloat162, half2>;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
};

using AttentionFP16Config_FP16 = AttentionFP16Config<false>;
using AttentionFP16Config_BF16 = AttentionFP16Config<true>;

template<typename AttentionConfig>
class Attention;

#ifndef __INTELLISENSE__
template<typename AttentionConfig>
class Attention : public AttentionConfig {
#else
template<>
class Attention<AttentionFP16Config_BF16> : public AttentionFP16Config_BF16 {
    using AttentionConfig = AttentionFP16Config_BF16;
#endif

public:
    using AttentionConfig::HEAD_DIM;
    using AttentionConfig::BLOCK_M;
    using AttentionConfig::WARP_SIZE;
    using AttentionConfig::NUM_WARPS;
    using AttentionConfig::WARP_K;
    using AttentionConfig::INSN_M;
    using AttentionConfig::INSN_N;
    using AttentionConfig::INSN_K_QK;
    using AttentionConfig::INSN_K_PV;
    using typename AttentionConfig::half_t;
    using typename AttentionConfig::half2_t;
    using typename AttentionConfig::epilogue_half_t;
    using typename AttentionConfig::epilogue_half2_t;

fengzch-das's avatar
fengzch-das committed
63
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
64
65
66
67
68
    static constexpr bool IS_SM80 = true;
#else
    static constexpr bool IS_SM80 = false;
#endif

69
    struct GEMMConfig {
Muyang Li's avatar
Muyang Li committed
70
71
        static constexpr int BLOCK_M   = AttentionConfig::BLOCK_M;
        static constexpr int BLOCK_N   = AttentionConfig::HEAD_DIM;
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        static constexpr int WARP_SIZE = AttentionConfig::WARP_SIZE;
        static constexpr int NUM_WARPS = AttentionConfig::NUM_WARPS;

        static constexpr int INSN_M = AttentionConfig::INSN_M;
        static constexpr int INSN_N = AttentionConfig::INSN_N;
        static constexpr int INSN_K = AttentionConfig::INSN_K_PV;

        using half_t  = typename AttentionConfig::epilogue_half_t;
        using half2_t = typename AttentionConfig::epilogue_half2_t;
    };

    using GEMM = typename nunchaku::kernels::GEMMBase<GEMMConfig>;

    static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
    static constexpr int WARP_N = HEAD_DIM;
    static constexpr int WARP_D = HEAD_DIM;

Muyang Li's avatar
Muyang Li committed
89
90
91
92
    static constexpr int WARP_M_TILES = WARP_M / INSN_M;
    static constexpr int WARP_N_TILES = WARP_N / INSN_N;
    static constexpr int WARP_K_TILES_QK =
        WARP_K / INSN_N; // when multiplying Q*K, K is on dimension of N in MMA instruction
93
94
95
96
97
98
    static constexpr int WARP_K_TILES_PV = WARP_K / INSN_K_PV;
    static constexpr int WARP_D_TILES    = WARP_D / INSN_K_QK;

    using packed_q_t = uint4;
    using packed_k_t = uint4;
    using packed_v_t = uint4;
Muyang Li's avatar
Muyang Li committed
99
100
101
    using q_warp     = std::array<packed_q_t, WARP_M_TILES * WARP_D_TILES>;
    using k_warp     = std::array<packed_k_t, WARP_K_TILES_QK * WARP_D_TILES>;
    using v_warp     = std::array<packed_v_t, WARP_K_TILES_PV * WARP_N_TILES>;
102
103

    using packed_p_t = uint4;
Muyang Li's avatar
Muyang Li committed
104
    using p_warp     = std::array<packed_v_t, WARP_M_TILES * WARP_K_TILES_PV>;
105

Muyang Li's avatar
Muyang Li committed
106
    using packed_fpsum_t   = uint4;
107
108
109
110
111
112
113
114
115
    using packed_f32psum_t = typename GEMM::packed_f32psum_t;

    using qk_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_K_TILES_QK>;
    // using o_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
    using o_warp = typename GEMM::f32psum_warp;

    using rowval_warp = std::array<float2, WARP_M_TILES>;

    struct BlockInfo {
Muyang Li's avatar
Muyang Li committed
116
117
118
119
        int bm;    // M: Q tokens, bm: block id of M
        int head;  // H: head
        int batch; // B: batch
        int numBlocksM;
120
121
122
123
        int numHeads;
        int numBatch;
    };

Muyang Li's avatar
Muyang Li committed
124
    __device__ __forceinline__ static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
125
126
127
128
        std::array<half2_t, 4> results;
        for (int i = 0; i < 4; i++) {
            results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
        }
sxtyzhangzk's avatar
sxtyzhangzk committed
129
        return kernels::bit_cast<packed_fpsum_t>(results);
130
131
    }

Muyang Li's avatar
Muyang Li committed
132
    __device__ __forceinline__ static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
133
        auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
134
135
        packed_f32psum_t results;
        for (int i = 0; i < 4; i++) {
Muyang Li's avatar
Muyang Li committed
136
137
            float2 tmp              = half22float2(arr[i]);
            results.data[i * 2]     = tmp.x;
138
139
140
141
142
143
            results.data[i * 2 + 1] = tmp.y;
        }
        return results;
    }

    // q: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_D_TILES, WARP_SIZE] of packed_q_t
Muyang Li's avatar
Muyang Li committed
144
    __device__ __forceinline__ static void load_q(const packed_q_t *ptr, q_warp &out, bool pred) {
145
146
147
148
149
150
151
152
153
154
155
156
157
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_q_t *base = &ptr[((warpId * WARP_M_TILES + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_M_TILES>([&]<int m>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[m * WARP_D_TILES + d] = load_pred(&base[(m * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // k: [batch, head, ktile, WARP_K_TILES_QK, WARP_D_TILES, WARP_SIZE] of packed_k_t
Muyang Li's avatar
Muyang Li committed
158
    __device__ __forceinline__ static void load_k(const packed_k_t *ptr, int ktile, k_warp &out, bool pred) {
159
160
161
162
163
164
165
166
167
168
169
170
171
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_k_t *base = &ptr[((ktile * WARP_K_TILES_QK + 0) * WARP_D_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_QK>([&]<int k>() {
            unrolled_loop<WARP_D_TILES>([&]<int d>() {
                out[k * WARP_D_TILES + d] = load_pred(&base[(k * WARP_D_TILES + d) * WARP_SIZE], pred);
            });
        });
    }

    // v: [batch, head, ktile, WARP_K_TILES_PV, WARP_N_TILES, WARP_SIZE] of packed_v_t
Muyang Li's avatar
Muyang Li committed
172
    __device__ __forceinline__ static void load_v(const packed_v_t *ptr, int ktile, v_warp &out, bool pred) {
173
174
175
176
177
178
179
180
181
182
183
184
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        const packed_v_t *base = &ptr[((ktile * WARP_K_TILES_PV + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];

        unrolled_loop<WARP_K_TILES_PV>([&]<int k>() {
            unrolled_loop<WARP_N_TILES>([&]<int n>() {
                out[n * WARP_K_TILES_PV + k] = load_pred(&base[(k * WARP_N_TILES + n) * WARP_SIZE], pred);
            });
        });
    }

Muyang Li's avatar
Muyang Li committed
185
186
    __device__ __forceinline__ static packed_fpsum_t
    mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) {
187
188
189
        uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y));
        uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w));
        return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y};
190
191
192
    }

    // set nan values to -inf
Muyang Li's avatar
Muyang Li committed
193
    __device__ __forceinline__ static half2_t fix_nan(half2_t input) {
fengzch's avatar
fengzch committed
194
        // static constexpr float neginf = -std::numeric_limits<float>::infinity();
195
        /**
Muyang Li's avatar
Muyang Li committed
196
197
198
         * In accordance to the IEEE-754R standard,
         * if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
         * but not the other,
199
200
         * the result is the non-NaN parameter.
         */
fengzch's avatar
fengzch committed
201
202
203
204
205
206
207
208
209
210
211
212
213
        // return __hmax2(input, half2_t(neginf, neginf));
	half_t lo = __low2half(input);
        half_t hi = __high2half(input);

        // Step 2: Convert to float to use isnan (HIP supports __hisnan)
        // Option A: Use __hisnan if available (preferred)
        half_t neg_inf = __float2half(-std::numeric_limits<float>::infinity());

        half_t out_lo = __hisnan(lo) ? neg_inf : lo;
        half_t out_hi = __hisnan(hi) ? neg_inf : hi;

        // Step 3: Pack back into half2_t
        return __halves2half2(out_lo, out_hi);
214
215
    }

Muyang Li's avatar
Muyang Li committed
216
    __device__ __forceinline__ static float fix_nan(float input) {
217
218
219
220
        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        return fmaxf(input, neginf);
    }

Muyang Li's avatar
Muyang Li committed
221
    __device__ __forceinline__ static packed_fpsum_t fix_nan(packed_fpsum_t input) {
sxtyzhangzk's avatar
sxtyzhangzk committed
222
223
224
225
        input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
        input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
        input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
        input.w = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.w)));
226
227
228
        return input;
    }

Muyang Li's avatar
Muyang Li committed
229
230
    __device__ __forceinline__ static packed_f32psum_t fix_nan(packed_f32psum_t input) {
#pragma unroll
231
232
233
234
235
236
        for (int i = 0; i < 8; i++) {
            input.data[i] = fix_nan(input.data[i]);
        }
        return input;
    }

Muyang Li's avatar
Muyang Li committed
237
    __device__ __forceinline__ static qk_warp compute_qk(q_warp Q, k_warp K) {
238
        qk_warp QK;
Muyang Li's avatar
Muyang Li committed
239
#pragma unroll
240
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
241
#pragma unroll
242
243
244
            for (int k = 0; k < WARP_K_TILES_QK; k++) {

#if 0
Muyang Li's avatar
Muyang Li committed
245
#pragma unroll
246
247
248
249
                for (int d = 0; d < WARP_D_TILES; d++) {
                    packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                    auto f32psum = packed_fp16_to_fp32(psum);
Muyang Li's avatar
Muyang Li committed
250
#pragma unroll
251
252
253
254
255
256
257
                    for (int i = 0; i < 8; i++) {
                        QK[m * WARP_K_TILES_QK + k].data[i] += f32psum.data[i];
                    }
                }

#else
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
258
#pragma unroll
259
260
261
                for (int d = 0; d < WARP_D_TILES; d++) {
                    psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
                }
Muyang Li's avatar
Muyang Li committed
262

263
                if constexpr (IS_SM80) {
Muyang Li's avatar
Muyang Li committed
264
                    psum                        = fix_nan(psum);
265
266
267
268
                    QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
                } else {
                    QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum));
                }
269
270
271
272
273
274
#endif
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
275
276
    __device__ __forceinline__ static rowval_warp compute_rowmax(qk_warp QK, rowval_warp rowmax, float scale) {
#pragma unroll
277
278
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 maxv;
Muyang Li's avatar
Muyang Li committed
279
#pragma unroll
280
281
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
282
283
                float x               = fmaxf(fmaxf(val.data[0], val.data[1]), fmaxf(val.data[4], val.data[5]));
                float y               = fmaxf(fmaxf(val.data[2], val.data[3]), fmaxf(val.data[6], val.data[7]));
284
285
286
287
288
289
290
                if (k == 0) {
                    maxv = make_float2(x, y);
                } else {
                    maxv.x = fmaxf(maxv.x, x);
                    maxv.y = fmaxf(maxv.y, y);
                }
            }
Muyang Li's avatar
Muyang Li committed
291
#pragma unroll
292
            for (int mask = 1; mask <= 2; mask *= 2) {
limm's avatar
limm committed
293
294
                maxv.x = fmaxf(maxv.x, __shfl_xor(maxv.x, mask));
                maxv.y = fmaxf(maxv.y, __shfl_xor(maxv.y, mask));
295
296
297
298
299
300
301
            }
            rowmax[m].x = fmaxf(rowmax[m].x, maxv.x * scale);
            rowmax[m].y = fmaxf(rowmax[m].y, maxv.y * scale);
        }
        return rowmax;
    }

Muyang Li's avatar
Muyang Li committed
302
303
    __device__ __forceinline__ static qk_warp softmax(qk_warp QK, rowval_warp rowmax_scaled, float scale) {
#pragma unroll
304
305
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 shift = rowmax_scaled[m];
Muyang Li's avatar
Muyang Li committed
306
#pragma unroll
307
308
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
Muyang Li's avatar
Muyang Li committed
309
310
311
312
313
314
315
316
                val.data[0]           = cuda_exp2(fmaf(val.data[0], scale, -shift.x));
                val.data[1]           = cuda_exp2(fmaf(val.data[1], scale, -shift.x));
                val.data[4]           = cuda_exp2(fmaf(val.data[4], scale, -shift.x));
                val.data[5]           = cuda_exp2(fmaf(val.data[5], scale, -shift.x));
                val.data[2]           = cuda_exp2(fmaf(val.data[2], scale, -shift.y));
                val.data[3]           = cuda_exp2(fmaf(val.data[3], scale, -shift.y));
                val.data[6]           = cuda_exp2(fmaf(val.data[6], scale, -shift.y));
                val.data[7]           = cuda_exp2(fmaf(val.data[7], scale, -shift.y));
317
318
319
320
321
            }
        }
        return QK;
    }

Muyang Li's avatar
Muyang Li committed
322
    __device__ __forceinline__ static rowval_warp compute_rowsum(qk_warp QK) {
323
        rowval_warp rowsum;
Muyang Li's avatar
Muyang Li committed
324
#pragma unroll
325
326
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 sumv = make_float2(0.0f, 0.0f);
Muyang Li's avatar
Muyang Li committed
327
#pragma unroll
328
329
330
331
332
            for (int k = 0; k < WARP_K_TILES_QK; k++) {
                packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
                sumv.x += val.data[0] + val.data[1] + val.data[4] + val.data[5];
                sumv.y += val.data[2] + val.data[3] + val.data[6] + val.data[7];
            }
Muyang Li's avatar
Muyang Li committed
333
#pragma unroll
334
            for (int mask = 1; mask <= 2; mask *= 2) {
limm's avatar
limm committed
335
336
                sumv.x += __shfl_xor(sumv.x, mask);
                sumv.y += __shfl_xor(sumv.y, mask);
337
338
339
340
341
342
            }
            rowsum[m] = sumv;
        }
        return rowsum;
    }

Muyang Li's avatar
Muyang Li committed
343
    __device__ __forceinline__ static rowval_warp compute_rescale(rowval_warp rowmax0, rowval_warp rowmax1) {
344
        rowval_warp rescale;
Muyang Li's avatar
Muyang Li committed
345
#pragma unroll
346
347
348
349
350
351
352
        for (int m = 0; m < WARP_M_TILES; m++) {
            rescale[m].x = cuda_exp2(rowmax0[m].x - rowmax1[m].x);
            rescale[m].y = cuda_exp2(rowmax0[m].y - rowmax1[m].y);
        }
        return rescale;
    }

Muyang Li's avatar
Muyang Li committed
353
354
    __device__ __forceinline__ static o_warp compute_pv(p_warp P, v_warp V, o_warp O, rowval_warp rescale) {
#pragma unroll
355
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
356
#pragma unroll
357
358
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
Muyang Li's avatar
Muyang Li committed
359
#pragma unroll
360
361
362
                for (int k = 0; k < WARP_K_TILES_PV; k++) {
                    psum = mma_f16xf16_f16(P[m * WARP_K_TILES_PV + k], V[n * WARP_K_TILES_PV + k], psum);
                }
Muyang Li's avatar
Muyang Li committed
363
364

                packed_f32psum_t pv    = packed_fp16_to_fp32(psum);
365
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
366
367
368
369
370
371
372
373
                oval.data[0]           = oval.data[0] * rescale[m].x + pv.data[0];
                oval.data[1]           = oval.data[1] * rescale[m].x + pv.data[1];
                oval.data[4]           = oval.data[4] * rescale[m].x + pv.data[4];
                oval.data[5]           = oval.data[5] * rescale[m].x + pv.data[5];
                oval.data[2]           = oval.data[2] * rescale[m].y + pv.data[2];
                oval.data[3]           = oval.data[3] * rescale[m].y + pv.data[3];
                oval.data[6]           = oval.data[6] * rescale[m].y + pv.data[6];
                oval.data[7]           = oval.data[7] * rescale[m].y + pv.data[7];
374
375
376
377
378
            }
        }
        return O;
    }

Muyang Li's avatar
Muyang Li committed
379
380
    __device__ __forceinline__ static rowval_warp compute_l(rowval_warp L, rowval_warp rescale, rowval_warp rowsum) {
#pragma unroll
381
382
383
384
385
386
387
        for (int m = 0; m < WARP_M_TILES; m++) {
            L[m].x = fmaf(L[m].x, rescale[m].x, rowsum[m].x);
            L[m].y = fmaf(L[m].y, rescale[m].y, rowsum[m].y);
        }
        return L;
    }

Muyang Li's avatar
Muyang Li committed
388
    __device__ __forceinline__ static p_warp qk_to_p(qk_warp QK) {
389
390
        static_assert(WARP_K_TILES_QK == WARP_K_TILES_PV);
        p_warp P;
Muyang Li's avatar
Muyang Li committed
391
#pragma unroll
392
        for (int m = 0; m < WARP_M_TILES; m++) {
Muyang Li's avatar
Muyang Li committed
393
#pragma unroll
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
            for (int k = 0; k < WARP_K_TILES_PV; k++) {
                P[m * WARP_K_TILES_PV + k] = packed_fp32_to_fp16(QK[m * WARP_K_TILES_QK + k]);
            }
        }
        return P;
    }

    // __device__ __forceinline__
    // static void compute(q_warp Q, k_warp K, v_warp V, o_warp &O, rowval_warp &M, rowval_warp &L, float scale) {
    //     qk_warp qk = compute_qk(Q, K);
    //     rowval_warp M1 = compute_rowmax(qk, M, scale);
    //     qk = softmax(qk, M1, scale);
    //     rowval_warp rowsum = compute_rowsum(qk);
    //     p_warp P = qk_to_p(qk);
    //     rowval_warp rescale = compute_rescale(M, M1);
    //     M = M1;
    //     L = compute_l(L, rescale, rowsum);
    //     O = compute_pv(P, V, O, rescale);
    // }

Muyang Li's avatar
Muyang Li committed
414
415
416
417
418
419
420
    __device__ __forceinline__ static std::tuple<p_warp, rowval_warp>
    compute(q_warp Q, k_warp K, rowval_warp &M, rowval_warp &L, float scale) {
        qk_warp qk          = compute_qk(Q, K);
        rowval_warp M1      = compute_rowmax(qk, M, scale);
        qk                  = softmax(qk, M1, scale);
        rowval_warp rowsum  = compute_rowsum(qk);
        p_warp P            = qk_to_p(qk);
421
        rowval_warp rescale = compute_rescale(M, M1);
Muyang Li's avatar
Muyang Li committed
422
423
        M                   = M1;
        L                   = compute_l(L, rescale, rowsum);
424
425
426
        return {P, rescale};
    }

Muyang Li's avatar
Muyang Li committed
427
428
    __device__ __forceinline__ static o_warp compute_o(o_warp O, rowval_warp L) {
#pragma unroll
429
430
431
432
        for (int m = 0; m < WARP_M_TILES; m++) {
            float2 inv;
            inv.x = cuda_frcp(L[m].x);
            inv.y = cuda_frcp(L[m].y);
Muyang Li's avatar
Muyang Li committed
433
#pragma unroll
434
435
            for (int n = 0; n < WARP_N_TILES; n++) {
                packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
Muyang Li's avatar
Muyang Li committed
436
437
438
439
440
441
442
443
                oval.data[0]           = oval.data[0] * inv.x;
                oval.data[1]           = oval.data[1] * inv.x;
                oval.data[4]           = oval.data[4] * inv.x;
                oval.data[5]           = oval.data[5] * inv.x;
                oval.data[2]           = oval.data[2] * inv.y;
                oval.data[3]           = oval.data[3] * inv.y;
                oval.data[6]           = oval.data[6] * inv.y;
                oval.data[7]           = oval.data[7] * inv.y;
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            }
        }
        return O;
    }

#if 0
    template<typename Epilogue>
    __device__ __forceinline__
    static void attention_fp16_block(
        const BlockInfo binfo,
        const packed_q_t *ptr_q,
        const packed_k_t *ptr_k,
        const packed_v_t *ptr_v,
        float scale,
Muyang Li's avatar
Muyang Li committed
458
        int ntokens_q,
459
460
        int ntokens_kv,
        Epilogue::Arguments epilogueArgs,
Muyang Li's avatar
Muyang Li committed
461
        bool alwaysfalse)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    {
        constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        q_warp Q;   // 32
        k_warp K[NUM_STAGES];  // 32 * 2
        v_warp V[NUM_STAGES];  // 32 * 2
        o_warp O;   // 64
        rowval_warp L;  // 2
        rowval_warp M;  // 2

        load_q(ptr_q, Q, true);
        for (int k = 0; k < NUM_STAGES - 1; k++) {
            load_k(ptr_k, k, K[k], true);
            load_v(ptr_v, k, V[k], true);
        }

Muyang Li's avatar
Muyang Li committed
481
#pragma unroll
482
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
483
#pragma unroll
484
485
486
487
488
489
490
491
492
493
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

        static constexpr float neginf = -std::numeric_limits<float>::infinity();
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

        __shared__ q_warp Q_shmem[NUM_WARPS];
Muyang Li's avatar
Muyang Li committed
494
#pragma unroll
495
496
497
498
499
500
501
502
        for (int i = 0; i < Q.size(); i++) {
            store<true>(&Q_shmem[warpId][i], Q[i]);
        }

        int dummy = 0;

        // TODO: mask tokens in last block
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 += NUM_STAGES) {
Muyang Li's avatar
Muyang Li committed
503
#pragma unroll
504
            for (int k2 = 0; k2 < NUM_STAGES; k2++) {
Muyang Li's avatar
Muyang Li committed
505
#pragma unroll
506
507
508
509
510
511
512
513
514
                for (int i = 0; i < Q.size(); i++) {
                    Q[i] = load<true>(&Q_shmem[warpId][i]);
                }

                int nextk = k1 + k2 + NUM_STAGES - 1;
                int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
                bool pred = nextk < ntokens_kv / WARP_K;
                load_k(ptr_k, nextk, K[idx], pred);
                load_v(ptr_v, nextk, V[idx], pred);
Muyang Li's avatar
Muyang Li committed
515

516
517
518
519
520
521
522
523
524
525
                // __syncthreads();
                // if (alwaysfalse) {
                //     dummy = clock();
                // }
                auto [P, rescale] = compute(Q, K[k2], V[k2], M, L, scale);
                O = compute_pv(P, V[idx], O, rescale);

                if (alwaysfalse) {
                    dummy = clock();
                }
fengzch's avatar
fengzch committed
526
527
                asm volatile ("membar.cta;");
                
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

        Epilogue()(typename GEMM::BlockInfo{
            .bm = binfo.batch * binfo.numBlocksM + binfo.bm,
            .bn = binfo.head,
            .numBlocksM = binfo.numBatch * binfo.numBlocksM,
            .numBlocksN = binfo.numHeads,
        }, f16psum, binfo.numBatch * binfo.numBlocksM * BLOCK_M, binfo.numHeads * HEAD_DIM, 0, epilogueArgs);
    }
#else
    template<typename Epilogue>
Muyang Li's avatar
Muyang Li committed
546
547
548
549
550
551
552
553
554
    __device__ __forceinline__ static void attention_fp16_block(const BlockInfo binfo,
                                                                const packed_q_t *ptr_q,
                                                                const packed_k_t *ptr_k,
                                                                const packed_v_t *ptr_v,
                                                                float scale,
                                                                int ntokens_q,
                                                                int ntokens_kv,
                                                                Epilogue::Arguments epilogueArgs,
                                                                bool alwaysfalse) {
555
556
557
558
559
        // constexpr int NUM_STAGES = 2;

        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

Muyang Li's avatar
Muyang Li committed
560
561
562
563
564
565
        q_warp Q;      // 32
        k_warp K;      // 64
        v_warp V;      // 64
        o_warp O;      // 64
        rowval_warp L; // 2
        rowval_warp M; // 2
566
567
568
569

        load_q(ptr_q, Q, true);
        load_k(ptr_k, 0, K, true);

Muyang Li's avatar
Muyang Li committed
570
#pragma unroll
571
        for (auto &pack : O) {
Muyang Li's avatar
Muyang Li committed
572
#pragma unroll
573
574
575
576
577
            for (int i = 0; i < 8; i++) {
                pack.data[i] = 0;
            }
        }

Muyang Li's avatar
Muyang Li committed
578
579
        static constexpr float neginf =
            -std::numeric_limits<float>::max(); // not real inf, to prevent nan during computation
580
581
582
        L.fill(make_float2(0.0f, 0.0f));
        M.fill(make_float2(neginf, neginf));

583
        static constexpr int SHMEM_TILES = IS_SM80 ? 4 : 7;
584
585
586
587
        static_assert(SHMEM_TILES <= Q.size());
        using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE];
        __shared__ q_shmem_t Q_shmem;

Muyang Li's avatar
Muyang Li committed
588
#pragma unroll
589
590
591
        for (int i = 0; i < SHMEM_TILES; i++) {
            store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]);
        }
limm's avatar
limm committed
592
593
        // __syncwarp();
	__builtin_amdgcn_wave_barrier();
594
595
596
597

        int dummy = 0;

        // TODO: mask tokens in last block
Muyang Li's avatar
Muyang Li committed
598
        for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1++) {
599
600
601
602
            if (alwaysfalse) {
                ptr_v += K[0].x;
            }

Muyang Li's avatar
Muyang Li committed
603
#pragma unroll
604
605
606
607
            for (int i = 0; i < SHMEM_TILES; i++) {
                Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]);
            }

608
609
610
611
612
613
            if constexpr (!IS_SM80) {
                if (k1 % 2 == 1) {
                    __syncthreads();
                }
            }

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            if (alwaysfalse) {
                dummy = clock();
            }

            load_v(ptr_v, k1, V, true);

            if (alwaysfalse) {
                dummy = clock();
            }

            auto [P, rescale] = compute(Q, K, M, L, scale);

            if (alwaysfalse) {
                dummy = clock();
            }

            if (alwaysfalse) {
                ptr_k += V[0].x;
            }

            // if (alwaysfalse) {
            //     dummy = clock();
            // }

Muyang Li's avatar
Muyang Li committed
638
            load_k(ptr_k, k1 + 1, K, k1 + 1 < ntokens_kv / WARP_K);
639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
            // if (alwaysfalse) {
            //     dummy = clock();
            // }

            O = compute_pv(P, V, O, rescale);

            if (alwaysfalse) {
                dummy = clock();
            }
        }

        unused_var(dummy, alwaysfalse);

        O = compute_o(O, L);

        auto f16psum = GEMM::packed_fp32_to_fp16(O);

Muyang Li's avatar
Muyang Li committed
657
658
659
660
661
662
663
664
665
666
667
668
        Epilogue()(
            typename GEMM::BlockInfo{
                .bm         = binfo.batch * binfo.numBlocksM + binfo.bm,
                .bn         = binfo.head,
                .numBlocksM = binfo.numBatch * binfo.numBlocksM,
                .numBlocksN = binfo.numHeads,
            },
            f16psum,
            binfo.numBatch * binfo.numBlocksM * BLOCK_M,
            binfo.numHeads * HEAD_DIM,
            0,
            epilogueArgs);
669
670
671
672
673
    }
#endif

    template<typename Epilogue>
    struct attention_fp16_kernel {
fengzch-das's avatar
fengzch-das committed
674
        static constexpr int MIN_ARCH   = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
675
676
        static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t);

Muyang Li's avatar
Muyang Li committed
677
678
679
680
681
682
683
684
        __device__ void operator()(const packed_q_t *ptr_q,
                                   const packed_k_t *ptr_k,
                                   const packed_v_t *ptr_v,
                                   float scale,
                                   int ntokens_q,
                                   int ntokens_kv,
                                   Epilogue::Arguments epilogueArgs,
                                   bool alwaysfalse) {
685
            BlockInfo binfo = {
Muyang Li's avatar
Muyang Li committed
686
687
688
                .bm         = (int)blockIdx.x,
                .head       = (int)blockIdx.y,
                .batch      = (int)blockIdx.z,
689
                .numBlocksM = (int)gridDim.x,
Muyang Li's avatar
Muyang Li committed
690
691
                .numHeads   = (int)gridDim.y,
                .numBatch   = (int)gridDim.z,
692
693
694
695
696
697
698
699
700
            };

            // extern __shared__ uint8_t shmem[];
            // q_shmem_t *Q_shmem = reinterpret_cast<q_shmem_t *>(shmem);

            const int ktiles = ceilDiv(ntokens_kv, WARP_K);

            attention_fp16_block<Epilogue>(
                binfo,
Muyang Li's avatar
Muyang Li committed
701
702
703
704
705
706
                ptr_q + ((binfo.batch * binfo.numHeads + binfo.head) * binfo.numBlocksM + binfo.bm) * NUM_WARPS *
                            WARP_M_TILES * WARP_D_TILES * WARP_SIZE,
                ptr_k +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_QK * WARP_D_TILES * WARP_SIZE,
                ptr_v +
                    (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_PV * WARP_N_TILES * WARP_SIZE,
707
708
709
710
711
                scale,
                ntokens_q,
                ntokens_kv,
                // *Q_shmem,
                epilogueArgs,
Muyang Li's avatar
Muyang Li committed
712
                alwaysfalse);
713
714
715
716
        }
    };
};

Muyang Li's avatar
Muyang Li committed
717
}; // namespace nunchaku::kernels