naive_attention.hpp 36.9 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <thread>
#include <string>

namespace ck_tile {

enum class naive_attention_layout_enum
{
16
17
18
19
20
    DEFAULT, // maybe this tensor is not used, set some irrelevant value
    BSHD,    // [batch, seqlen, nhead, hdim]
    BHSD,    // [batch, nhead, seqlen, hdim]
    BS3HD,   // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
    PHSD,    // [pages, nhead, page_size, hdim]
carlushuang's avatar
carlushuang committed
21
22
23
    // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
    PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
    PHDS,  // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
24
25
26
27

    // scale layout used for dynamic dequant
    SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
    SCALE_SH, // [tokens, nhead]
carlushuang's avatar
carlushuang committed
28
29
30
31
32
33
34
35
36
37
};

// will used to specialize kernel variation
enum class naive_attention_variation_enum
{
    FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
    FLASH_GROUPED,
    DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
};

38
39
40
41
42
43
44
45
46
enum class naive_attention_quant_algo
{
    NO              = 0,
    KV_8BIT_PERHEAD = 1,
    // FP8/INT8 quant for KVCache, per-token quant
    // [num_tokens, nhead, hdim] -> [nhead, num_tokens]
    KV_8BIT_PERTOKEN = 2,
};

carlushuang's avatar
carlushuang committed
47
48
49
50
51
52
53
54
55
56
// TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args
{
    void* q_ptr;
    void* k_ptr;
    void* v_ptr;
    void* o_ptr;
    void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
                           // number, not cumsum)
    void* page_table_ptr;  // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
57
58
    void* kscale_ptr;      // [nhead, max_kv_tokens] used for kvcache dequant
    void* vscale_ptr;      // [nhead, max_kv_tokens] used for kvcache dequant
carlushuang's avatar
carlushuang committed
59
60
61
62
63
64
65
66
67
68
69
70
71
    float scale_s;
    int hdim;
    int hdim_v; // could be cross-attn, where V and Q/K hdim are different
    int batch_q;
    int batch_kv;
    int batch_ratio_kv; // batch_q / batch_kv
    int seqlen_q;       // in decode case, this should be 1
    int seqlen_kv;      // if context_len_ptr is not nullptr, ignore this field
    int nhead_q;
    int nhead_kv;
    int nhead_ratio_kv; // nhead_q / nhead_kv
    int page_size;      // if paged, the seqlen-kv per each block
    int max_pages_per_seq;
72
    int max_kv_tokens; // used as stride to access kv scale ptr
carlushuang's avatar
carlushuang committed
73
74
75
76
77
78
79
80
81
82
83
84
85
};

// this is trait for host API
struct naive_attention_fwd_traits
{
    std::string q_type;
    std::string k_type;
    std::string v_type;
    std::string o_type;
    std::string q_layout;
    std::string k_layout;
    std::string v_layout;
    std::string o_layout;
86
87
    int variation;  // sync with naive_attention_variation_enum
    int quant_algo; // sync with naive_attention_quant_algo
carlushuang's avatar
carlushuang committed
88
89
90
};

// this is trait for kernel template
91
template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
carlushuang's avatar
carlushuang committed
92
93
94
struct naive_attention_fwd_kernel_traits
{
    static constexpr naive_attention_variation_enum variation = variation_;
95
    static constexpr naive_attention_quant_algo quant_algo    = quant_algo_;
carlushuang's avatar
carlushuang committed
96
97
98
99
100
101
102
103
};

// for simplicity, please do not use const-reference type for the template type
template <typename QType,
          typename KType,
          typename VType,
          typename OType,
          typename AccType,
104
          typename KVScaleType,
carlushuang's avatar
carlushuang committed
105
106
107
108
          naive_attention_layout_enum QLayout,
          naive_attention_layout_enum KLayout,
          naive_attention_layout_enum VLayout,
          naive_attention_layout_enum OLayout,
109
110
          naive_attention_layout_enum KScaleLayout,
          naive_attention_layout_enum VScaleLayout,
carlushuang's avatar
carlushuang committed
111
112
113
114
          typename Traits>
struct naive_attention_fwd_kernel
{
    static constexpr bool is_kvcache_i8 =
115
116
117
        std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
    static constexpr bool is_kvcache_fp8 =
        std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
carlushuang's avatar
carlushuang committed
118

119
    static constexpr int v_per_token_quant_group_size = 64;
carlushuang's avatar
carlushuang committed
120
121

    // TODO: hardcode
122
123
124
125
126
    using SoftmaxType      = float; // always using float to do softmax compute
    using QuantComputeType = float; // used for quant/dequant scale compute
    using QCompute         = KType; // src A of gemm1, same type as K
    using PType            = VType; // src A of gemm2, same type as V
    using OAccType         = float; // always float, in case int8 FA
carlushuang's avatar
carlushuang committed
127
128
129
130

    using p_vec_type                = ext_vector_t<PType, 16 / sizeof(PType)>;
    static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;

131
132
133
134
135
136
    // clang-format off
    template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
    template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
    template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
    // clang-format on

carlushuang's avatar
carlushuang committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    __host__ __device__ naive_attention_fwd_kernel() {}

    template <typename T, naive_attention_layout_enum Layout>
    struct addresser
    {
        int b, s, h, d; // batch, seqlen, nhead, hdim
        T* base_ptr;
        __device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
            : b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
        {
        }

        // TODO: all the batch/nhead offset will accumulate to the base pointer
        __device__ T* get_base(int i_b, int i_h)
        {
            if constexpr(Layout == naive_attention_layout_enum::BSHD)
                return base_ptr + i_b * s * h * d + i_h * d;
            else if constexpr(Layout == naive_attention_layout_enum::BHSD)
                return base_ptr + i_b * s * h * d + i_h * s * d;
        }

        __device__ int get_offset(int i_s, int i_d)
        {
            if constexpr(Layout == naive_attention_layout_enum::BSHD)
                return i_s * h * d + i_d;
            else if constexpr(Layout == naive_attention_layout_enum::BHSD)
                return i_s * d + i_d;
        }

        // below set of API will directly use pointer inside this struct
        __device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
        __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
        __device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
    };

    template <typename T, naive_attention_layout_enum Layout>
    struct page_addresser
    {
        int s, h, d;                             // page_size, nhead, hdim
        static constexpr int x = 16 / sizeof(T); // pack 4 dword
        T* base_ptr;
        int* page_table_ptr; // TODO: page table always int
        int i_h;             // store current head

        __device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
            : s(s_),
              h(h_),
              d(d_),
              base_ptr(reinterpret_cast<T*>(base_ptr_)),
              page_table_ptr(reinterpret_cast<int*>(pptr_))
        {
        }

        __device__ int64_t get_phy_page_idx(int i_s)
        {
            // dynamic compute page idx is simple but slow
            int page_idx = i_s / s;
            int phy      = page_table_ptr[page_idx];
            return static_cast<int64_t>(phy);
        }

        __device__ int get_phy_page_offset(int i_s)
        {
            // dynamic compute page idx is simple but slow
            return i_s % s;
        }

        __device__ int64_t get_offset(int i_s, int i_d)
        {
            int page_offset  = get_phy_page_offset(i_s);
            int64_t page_idx = get_phy_page_idx(i_s);
            int64_t base_    = page_idx * h * s * d;
            if constexpr(Layout == naive_attention_layout_enum::PHSD)
                return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
            else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
            {
                int d_r = i_d / x;
                int d_x = i_d % x;
                return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
                       base_;
            }
            else if constexpr(Layout == naive_attention_layout_enum::PHDS)
            {
                return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
            }
        }

        // below set of API will directly use pointer inside this struct
        __device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
        __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
        __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
    };

230
    template <typename T, naive_attention_layout_enum Layout>
carlushuang's avatar
carlushuang committed
231
232
    struct kvscale_addresser
    {
233
        int s, h, d; // seqlen(tokens), nhead, hdim
carlushuang's avatar
carlushuang committed
234
        T* base_ptr;
235
236
        __device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
            : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
carlushuang's avatar
carlushuang committed
237
238
        {
        }
239
        __device__ int get_offset(int i_s, int i_h, int i_d)
carlushuang's avatar
carlushuang committed
240
        {
241
242
243
244
245
246
247
248
249
250
            if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
            {
                // [nhead, tokens]
                (void)i_d;
                return i_h * s + i_s;
            }
            else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
            {
                return 0;
            }
carlushuang's avatar
carlushuang committed
251
            // [h, 2, d]
252
            // return i_h * 2 * d + i_kv * d + i_d;
carlushuang's avatar
carlushuang committed
253
        }
254
        __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
carlushuang's avatar
carlushuang committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    };

    __device__ __host__ static constexpr int get_block_size() { return 256; }

    // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
    // compute all hdim from q, compute WG_SIZE hdim from v
    // 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
    // 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
    // 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
    // TODO: could support split-kv to validate intermediate logsum
    __host__ static dim3 get_grid_size(naive_attention_fwd_args args)
    {
        constexpr int wg_size = get_block_size();
        auto g =
            dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
        return g;
    }

    // reduce single pixel within a wave
    template <typename T, typename F>
    __device__ constexpr T wave_reduce(T local, F reduce_f)
    {
        // constexpr int wave_size = 64;
        constexpr int reduce_stage = 6; // 1<<6=64
        T v_local                  = local;
#pragma unroll
        for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
        {
            int src_lane = __lane_id() ^ (1 << i_stage);
            int32_t v_remote_tmp =
                __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
            T v_remote = bit_cast<T>(v_remote_tmp);
            v_local    = reduce_f(v_local, v_remote);
        }
        return v_local;
    }

    // Note: this function must be called after wave_reduce
    // Note: better not use this under if...else... with thread divergence (syncthreads)
    template <typename T, typename F>
    __device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
    {
        constexpr int waves     = 4;
        constexpr int wave_size = 64;
        int lane_id             = threadIdx.x % wave_size;

        __syncthreads();
        smem[threadIdx.x] = local;
        __syncthreads();

        // the data within single wave is the same
        // but for simplicity, we still use data from each lane.
        T v_local = smem[lane_id];
#pragma unroll
        for(int i_stage = 1; i_stage < waves; i_stage++)
        {
            T v_remote = smem[i_stage * wave_size + lane_id];
            v_local    = reduce_f(v_local, v_remote);
        }
        return v_local;
    }

    // kernel entry point
    __device__ void operator()(naive_attention_fwd_args args)
    {
        constexpr int wg_size = get_block_size();
321
322
323
324
325
326
327
        __shared__ char smem[wg_size * 4 * sizeof(float)];       //  should enough
        char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
        int i_dv           = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
        int i_sq           = blockIdx.y;                         // index of seqlen_q
        int i_batch        = blockIdx.z;                         // index of batch_q * nhead_q
        int i_bq           = i_batch / args.nhead_q;             // index of batch_q
        int i_hq           = i_batch % args.nhead_q;             // index of nhead_q
carlushuang's avatar
carlushuang committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

        int i_bk = i_bq / args.batch_ratio_kv;
        int i_hk = i_hq / args.nhead_ratio_kv;

        void* page_table_ptr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
            }
            else
            {
                return nullptr;
            }
        }();

        auto q_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<QType, QLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return addresser<QType, QLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
            }
        }();
        auto k_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<KType, KLayout>{
                    args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return page_addresser<KType, KLayout>{
                    args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
            }
        }();
        auto v_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<VType, VLayout>{
                    args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return page_addresser<VType, VLayout>{
                    args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
            }
        }();
        auto o_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<OType, OLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return addresser<OType, OLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
            }
        }();

        q_addr.init(i_bq, i_hq);
        k_addr.init(i_bk, i_hk);
        v_addr.init(i_bk, i_hk);
        o_addr.init(i_bq, i_hq);

        auto f_max        = [](auto x_, auto y_) { return max(x_, y_); };
        auto f_sum        = [](auto x_, auto y_) { return x_ + y_; };
        auto f_absmax_f32 = [](float v_0_, float v_1_) {
400
401
402
403
            // float rtn;
            // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
            // return rtn;
            return max(abs(v_0_), abs(v_1_));
carlushuang's avatar
carlushuang committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        };

        int seqlen_kv = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return args.seqlen_kv;
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
            }
        }();

        SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
        SoftmaxType l{0};
419
420
        // AccType o_acc = {0};
        OAccType o_acc = {0};
carlushuang's avatar
carlushuang committed
421

422
423
424
425
426
427
        int sk_loops                     = (seqlen_kv + wg_size - 1) / wg_size;
        QuantComputeType q_dequant_scale = .0f;
        kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
            args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
        kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
            args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
carlushuang's avatar
carlushuang committed
428

429
        if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
carlushuang's avatar
carlushuang committed
430
431
        {
            // AccType is i32 now, seqlen_q = 1, hdim up to 256
432
433
            AccType q   = 0;
            AccType k_s = 0;
carlushuang's avatar
carlushuang committed
434
435
            if(static_cast<int>(threadIdx.x) < args.hdim)
            {
436
437
                q   = type_convert<AccType>(q_addr.load(0, threadIdx.x));
                k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
carlushuang's avatar
carlushuang committed
438
439
            }
            // 1) we apply the k scale to q
440
            AccType q_forwarded = q * k_s;
carlushuang's avatar
carlushuang committed
441
442
443

            // 2) apply smooth-quant
            // find absmax
444
445
            AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
            qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
carlushuang's avatar
carlushuang committed
446
447

            // per-token scale
448
            q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
carlushuang's avatar
carlushuang committed
449
450

            // devide by scale
451
            q = q / q_dequant_scale;
carlushuang's avatar
carlushuang committed
452
453

            // fp32->i8
454
            QCompute quantized_q = static_cast<QCompute>(q);
carlushuang's avatar
carlushuang committed
455
            __syncthreads();
456
            reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
carlushuang's avatar
carlushuang committed
457
458
459
460
            __syncthreads();

            // after above process, we have 2 data
            // 1) int8 q data stored in smem(no need to reload)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
        }
        else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
        {
            if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
            {
                // dyanmic quant q here
                float q = 0;
                if(static_cast<int>(threadIdx.x) < args.hdim)
                {
                    q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
                }

                // apply smooth-quant
                // find absmax
                float q_max = wave_reduce(q, f_absmax_f32);
                q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));

                // per-token scale
                q_dequant_scale =
                    type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;

                // devide by scale
                q = q / q_dequant_scale;

                QCompute quantized_q = type_convert<QCompute>(q);
                __syncthreads();
                reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
                __syncthreads();

                // after above process, we have 2 data
                // 1) fp8 q data stored in smem(no need to reload from global)
                // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
            }
carlushuang's avatar
carlushuang committed
495
496
497
498
499
500
501
502
503
504
505
506
        }

        for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
        {
            int i_sk = i_loop1 * wg_size + threadIdx.x;
            // gemm-1
            SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
            if(i_sk < seqlen_kv)
            {
                AccType s_acc{0}; // clear for every loop
                for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
                {
507
508
509
510
511
512
513
514
515
516
517
518
                    auto q = [&]() {
                        if constexpr(Traits::quant_algo ==
                                         naive_attention_quant_algo::KV_8BIT_PERHEAD ||
                                     Traits::quant_algo ==
                                         naive_attention_quant_algo::KV_8BIT_PERTOKEN)
                        {
                            return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
                        }
                        else
                            return q_addr.load(i_sq, i_dq); // q will have duplicate load
                    }();
                    auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
carlushuang's avatar
carlushuang committed
519

520
                    s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
carlushuang's avatar
carlushuang committed
521
522
523
524
525
                }
                // scale
                s_softmax = type_convert<SoftmaxType>(s_acc);
                s_softmax *=
                    type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
526
527
528
529
530
531
                if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
                {
                    s_softmax *= q_dequant_scale; // post scale the per-token factor
                }
                else if constexpr(Traits::quant_algo ==
                                  naive_attention_quant_algo::KV_8BIT_PERTOKEN)
carlushuang's avatar
carlushuang committed
532
                {
533
534
535
536
                    SoftmaxType k_per_token_scale =
                        type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
                    s_softmax *= q_dequant_scale;
                    s_softmax *= k_per_token_scale;
carlushuang's avatar
carlushuang committed
537
538
539
540
                }
            }

            // s->p
541
            QuantComputeType p_dequant_scale = 1.;
carlushuang's avatar
carlushuang committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            {
                // softmax, find max
                SoftmaxType old_max = row_max;
                SoftmaxType cur_max = wave_reduce(s_softmax, f_max);

                cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
                row_max = max(old_max, cur_max); // update row_max
                // softmax, exp(i_elem - max)
                SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);

                // compute exp_sum
                SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
                row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));

                // l, pre-scall o_acc
                SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
                l               = tmp * l + row_sum;
559
                o_acc           = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
carlushuang's avatar
carlushuang committed
560
561
562

                // prepare the p_compute into smem, to let every thread read same p_compute and do
                // 2nd gemm
563
                if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
carlushuang's avatar
carlushuang committed
564
                {
565
                    QuantComputeType v_s = 0;
carlushuang's avatar
carlushuang committed
566
567
                    if(static_cast<int>(threadIdx.x) < args.hdim_v)
                    {
568
569
                        v_s =
                            type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
carlushuang's avatar
carlushuang committed
570
571
572
                    }

                    // 1) we apply the v scale to p
573
                    QuantComputeType p_forwarded = p_compute * v_s;
carlushuang's avatar
carlushuang committed
574
575
576

                    // 2) apply smooth-quant
                    // find absmax
577
578
579
                    QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
                    pf_max                  = cross_wave_reduce(
                        pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
carlushuang's avatar
carlushuang committed
580
581

                    // per-token scale
582
                    p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
carlushuang's avatar
carlushuang committed
583
584

                    // devide by scale
585
                    p_compute = p_compute / p_dequant_scale;
carlushuang's avatar
carlushuang committed
586
587

                    // fp32->i8
588
                    PType quantized_p = static_cast<PType>(p_compute);
carlushuang's avatar
carlushuang committed
589
                    __syncthreads();
590
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
carlushuang's avatar
carlushuang committed
591
592
593
                    __syncthreads();
                    // after above process, we have 2 data
                    // 1) int8 p data stored in smem(no need to reload)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
                    // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
                }
                else if constexpr(Traits::quant_algo ==
                                  naive_attention_quant_algo::KV_8BIT_PERTOKEN)
                {
                    // forward apply the v scale to p_compute, this is compute friendly
                    auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
                    p_compute *= v_scale;
                    // smooth-quant
                    // find absmax
                    QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
                    p_max                  = cross_wave_reduce(
                        p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));

                    // per-token scale
                    p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;

                    // devide by scale
                    p_compute = p_compute / p_dequant_scale;

                    // fp32->i8
                    PType quantized_p = type_convert<PType>(p_compute);
                    __syncthreads();
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
                    __syncthreads();
                    // after above process, we have 2 data
                    // 1) fp8_t p data stored in smem(no need to reload)
                    // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
carlushuang's avatar
carlushuang committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                }
                else
                {
                    __syncthreads();
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
                    __syncthreads();
                }
            }

            // gemm-2, simple loop over vector by vector
            constexpr int gemm_2_loop = wg_size / p_vec_elem;
            {
                AccType o_acc_local = {0};
                int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
                for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
                {
                    p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
#pragma unroll
                    for(int i_j = 0; i_j < p_vec_elem; i_j++)
                    {
                        int sv_offset = i_loop2 * p_vec_elem + i_j;
                        int i_sv      = sk_start + sv_offset;

645
                        VType v = 0;
carlushuang's avatar
carlushuang committed
646
647
648
649
650
                        if(i_dv < args.hdim_v && i_sv < seqlen_kv)
                        {
                            v = v_addr.load(i_sv, i_dv);
                        }

651
652
653
                        AccType v_compute = [&]() { return type_convert<AccType>(v); }();

                        o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
carlushuang's avatar
carlushuang committed
654
655
                    }
                }
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

                OAccType post_scale_o_acc_local = [&]() {
                    if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
                    {
                        // apply pr scale to local acc
                        return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
                                                      p_dequant_scale);
                    }
                    else if constexpr(Traits::quant_algo ==
                                      naive_attention_quant_algo::KV_8BIT_PERTOKEN)
                    {
                        // apply pr scale to local acc
                        return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
                                                      p_dequant_scale);
                    }
                    else
                    {
                        return type_convert<OAccType>(o_acc_local);
                    }
                }();
                o_acc += post_scale_o_acc_local;
carlushuang's avatar
carlushuang committed
677
678
679
680
681
682
            }
        }

        // post scale o_acc
        {
            SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
683
            o_acc           = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
carlushuang's avatar
carlushuang committed
684
685
686
687
688
689
690
691
692
693
        }

        // store O
        if(i_dv < args.hdim_v)
            o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
    }
};

#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_()                                                        \
    {                                                                                                       \
694
695
696
        using ktraits_ = naive_attention_fwd_kernel_traits<                                                 \
            static_cast<naive_attention_variation_enum>(variation_),                                        \
            static_cast<naive_attention_quant_algo>(quant_algo_)>;                                          \
carlushuang's avatar
carlushuang committed
697
698
699
700
701
        using k_   = naive_attention_fwd_kernel<q_type_,                                                    \
                                              k_type_,                                                    \
                                              v_type_,                                                    \
                                              o_type_,                                                    \
                                              acc_type_,                                                  \
702
                                              kvscale_type_,                                              \
carlushuang's avatar
carlushuang committed
703
704
705
706
                                              q_layout_,                                                  \
                                              k_layout_,                                                  \
                                              v_layout_,                                                  \
                                              o_layout_,                                                  \
707
708
                                              k_scale_layout_,                                            \
                                              v_scale_layout_,                                            \
carlushuang's avatar
carlushuang committed
709
710
711
712
713
714
715
716
717
718
                                              ktraits_>;                                                  \
        dim3 grids = k_::get_grid_size(a);                                                                  \
        r          = ck_tile::launch_kernel(s,                                                              \
                                   ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
    }

#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()                                                 \
    if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
       t.o_layout == "bshd")                                                                       \
    {                                                                                              \
719
720
721
722
723
724
725
        constexpr auto q_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto v_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr int variation_       = 0;                                                        \
carlushuang's avatar
carlushuang committed
726
727
728
729
730
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }                                                                                              \
    else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" &&                    \
            t.v_layout == "bhsd" && t.o_layout == "bhsd")                                          \
    {                                                                                              \
731
732
733
734
735
736
737
        constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto v_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr int variation_       = 0;                                                        \
carlushuang's avatar
carlushuang committed
738
739
740
741
742
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }                                                                                              \
    else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" &&                   \
            t.v_layout == "phds" && t.o_layout == "bhsd")                                          \
    {                                                                                              \
743
744
745
746
747
748
749
        constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::PHDSX;                       \
        constexpr auto v_layout_       = naive_attention_layout_enum::PHDS;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
        constexpr int variation_       = 2;                                                        \
carlushuang's avatar
carlushuang committed
750
751
752
753
754
755
756
757
758
759
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }

//
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
                                       naive_attention_fwd_args a,
                                       ck_tile::stream_config s)
{
    float r = -1;
    // TODO: do not explicitly create too much instance!
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
       t.quant_algo == 0)
    {
        using q_type_             = fp16_t;
        using k_type_             = fp16_t;
        using v_type_             = fp16_t;
        using o_type_             = fp16_t;
        using acc_type_           = float;
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 0;
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
    else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
            t.quant_algo == 0)
carlushuang's avatar
carlushuang committed
774
    {
775
776
777
778
779
780
781
        using q_type_             = bf16_t;
        using k_type_             = bf16_t;
        using v_type_             = bf16_t;
        using o_type_             = bf16_t;
        using acc_type_           = float;
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 0;
carlushuang's avatar
carlushuang committed
782
783
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
784
785
    else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
786
    {
787
788
789
790
791
792
793
        using q_type_             = bf16_t;
        using k_type_             = fp8_t;
        using v_type_             = fp8_t;
        using o_type_             = bf16_t;
        using acc_type_           = float; // NOTE!
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
794
795
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
796
797
    else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
798
    {
799
800
801
802
803
804
805
        using q_type_             = fp16_t;
        using k_type_             = fp8_t;
        using v_type_             = fp8_t;
        using o_type_             = fp16_t;
        using acc_type_           = float; // NOTE!
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
806
807
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
808
809
    else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
810
    {
811
812
813
814
815
816
817
        using q_type_             = bf16_t;
        using k_type_             = int8_t;
        using v_type_             = int8_t;
        using o_type_             = bf16_t;
        using acc_type_           = int32_t; // NOTE!
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
818
819
820
821
822
823
824
825
826
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
    return r;
}

#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_

} // namespace ck_tile