naive_attention.hpp 39.9 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <thread>
#include <string>

namespace ck_tile {

enum class naive_attention_layout_enum
{
16
17
18
19
20
    DEFAULT, // maybe this tensor is not used, set some irrelevant value
    BSHD,    // [batch, seqlen, nhead, hdim]
    BHSD,    // [batch, nhead, seqlen, hdim]
    BS3HD,   // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
    PHSD,    // [pages, nhead, page_size, hdim]
carlushuang's avatar
carlushuang committed
21
22
23
    // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
    PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
    PHDS,  // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
24
25
26
27

    // scale layout used for dynamic dequant
    SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
    SCALE_SH, // [tokens, nhead]
carlushuang's avatar
carlushuang committed
28
29
30
31
32
33
34
35
36
37
};

// will used to specialize kernel variation
enum class naive_attention_variation_enum
{
    FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
    FLASH_GROUPED,
    DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
};

38
39
40
41
42
43
44
enum class naive_attention_quant_algo
{
    NO              = 0,
    KV_8BIT_PERHEAD = 1,
    // FP8/INT8 quant for KVCache, per-token quant
    // [num_tokens, nhead, hdim] -> [nhead, num_tokens]
    KV_8BIT_PERTOKEN = 2,
valarLip's avatar
valarLip committed
45
46
    // same as 8bit per token quant but 4 bit
    KV_4BIT_PERTOKEN = 3,
47
48
};

carlushuang's avatar
carlushuang committed
49
50
51
52
53
54
55
56
57
58
// TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args
{
    void* q_ptr;
    void* k_ptr;
    void* v_ptr;
    void* o_ptr;
    void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
                           // number, not cumsum)
    void* page_table_ptr;  // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
59
60
    void* kscale_ptr;      // [nhead, max_kv_tokens] used for kvcache dequant
    void* vscale_ptr;      // [nhead, max_kv_tokens] used for kvcache dequant
carlushuang's avatar
carlushuang committed
61
62
63
64
65
66
67
68
69
70
71
72
73
    float scale_s;
    int hdim;
    int hdim_v; // could be cross-attn, where V and Q/K hdim are different
    int batch_q;
    int batch_kv;
    int batch_ratio_kv; // batch_q / batch_kv
    int seqlen_q;       // in decode case, this should be 1
    int seqlen_kv;      // if context_len_ptr is not nullptr, ignore this field
    int nhead_q;
    int nhead_kv;
    int nhead_ratio_kv; // nhead_q / nhead_kv
    int page_size;      // if paged, the seqlen-kv per each block
    int max_pages_per_seq;
74
    int max_kv_tokens; // used as stride to access kv scale ptr
carlushuang's avatar
carlushuang committed
75
76
77
78
79
80
81
82
83
84
85
86
87
};

// this is trait for host API
struct naive_attention_fwd_traits
{
    std::string q_type;
    std::string k_type;
    std::string v_type;
    std::string o_type;
    std::string q_layout;
    std::string k_layout;
    std::string v_layout;
    std::string o_layout;
88
89
    int variation;  // sync with naive_attention_variation_enum
    int quant_algo; // sync with naive_attention_quant_algo
carlushuang's avatar
carlushuang committed
90
91
92
};

// this is trait for kernel template
93
template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
carlushuang's avatar
carlushuang committed
94
95
96
struct naive_attention_fwd_kernel_traits
{
    static constexpr naive_attention_variation_enum variation = variation_;
97
    static constexpr naive_attention_quant_algo quant_algo    = quant_algo_;
carlushuang's avatar
carlushuang committed
98
99
100
101
102
103
104
};

// for simplicity, please do not use const-reference type for the template type
template <typename QType,
          typename KType,
          typename VType,
          typename OType,
valarLip's avatar
valarLip committed
105
106
          typename AccType_I, // i.e. input of mfma
          typename AccType_O, // i.e. results of mfma
107
          typename KVScaleType,
carlushuang's avatar
carlushuang committed
108
109
110
111
          naive_attention_layout_enum QLayout,
          naive_attention_layout_enum KLayout,
          naive_attention_layout_enum VLayout,
          naive_attention_layout_enum OLayout,
112
113
          naive_attention_layout_enum KScaleLayout,
          naive_attention_layout_enum VScaleLayout,
carlushuang's avatar
carlushuang committed
114
115
116
117
          typename Traits>
struct naive_attention_fwd_kernel
{
    // TODO: hardcode
118
119
    using SoftmaxType      = float; // always using float to do softmax compute
    using QuantComputeType = float; // used for quant/dequant scale compute
valarLip's avatar
valarLip committed
120
121
122
    using QCompute = AccType_I; // src A of gemm1, may different with K, like int4 we use i8GEMM now
    using PType    = AccType_I; // src A of gemm2, may different with V, like int4 we use i8GEMM now
    using TailType = float;     // always float, in case int8 FA
carlushuang's avatar
carlushuang committed
123

valarLip's avatar
valarLip committed
124
125
126
127
    static constexpr int gemm1_vec_size = min(16 / sizeof(QCompute), 16 / sizeof(KType));
    using q_vec_type                    = ext_vector_t<QCompute, gemm1_vec_size>;
    using p_vec_type                    = ext_vector_t<PType, 16 / sizeof(PType)>;
    static constexpr int p_vec_elem     = vector_traits<p_vec_type>::vector_size;
carlushuang's avatar
carlushuang committed
128

129
130
131
132
133
134
    // clang-format off
    template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
    template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
    template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
    // clang-format on

carlushuang's avatar
carlushuang committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    __host__ __device__ naive_attention_fwd_kernel() {}

    template <typename T, naive_attention_layout_enum Layout>
    struct addresser
    {
        int b, s, h, d; // batch, seqlen, nhead, hdim
        T* base_ptr;
        __device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
            : b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
        {
        }

        // TODO: all the batch/nhead offset will accumulate to the base pointer
        __device__ T* get_base(int i_b, int i_h)
        {
            if constexpr(Layout == naive_attention_layout_enum::BSHD)
                return base_ptr + i_b * s * h * d + i_h * d;
            else if constexpr(Layout == naive_attention_layout_enum::BHSD)
                return base_ptr + i_b * s * h * d + i_h * s * d;
        }

        __device__ int get_offset(int i_s, int i_d)
        {
            if constexpr(Layout == naive_attention_layout_enum::BSHD)
                return i_s * h * d + i_d;
            else if constexpr(Layout == naive_attention_layout_enum::BHSD)
                return i_s * d + i_d;
        }

        // below set of API will directly use pointer inside this struct
        __device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
        __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
        __device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
valarLip's avatar
valarLip committed
168
169
170
171
172
173
        template <int vec_size>
        __device__ ext_vector_t<T, vec_size> load_vector(int i_s, int i_d)
        {
            return reinterpret_cast<ext_vector_t<T, vec_size>*>(base_ptr +
                                                                get_offset(i_s, i_d * vec_size))[0];
        }
carlushuang's avatar
carlushuang committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    };

    template <typename T, naive_attention_layout_enum Layout>
    struct page_addresser
    {
        int s, h, d;                             // page_size, nhead, hdim
        static constexpr int x = 16 / sizeof(T); // pack 4 dword
        T* base_ptr;
        int* page_table_ptr; // TODO: page table always int
        int i_h;             // store current head

        __device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
            : s(s_),
              h(h_),
              d(d_),
              base_ptr(reinterpret_cast<T*>(base_ptr_)),
              page_table_ptr(reinterpret_cast<int*>(pptr_))
        {
        }

        __device__ int64_t get_phy_page_idx(int i_s)
        {
            // dynamic compute page idx is simple but slow
            int page_idx = i_s / s;
            int phy      = page_table_ptr[page_idx];
            return static_cast<int64_t>(phy);
        }

        __device__ int get_phy_page_offset(int i_s)
        {
            // dynamic compute page idx is simple but slow
            return i_s % s;
        }

        __device__ int64_t get_offset(int i_s, int i_d)
        {
            int page_offset  = get_phy_page_offset(i_s);
            int64_t page_idx = get_phy_page_idx(i_s);
            int64_t base_    = page_idx * h * s * d;
            if constexpr(Layout == naive_attention_layout_enum::PHSD)
                return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
            else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
            {
                int d_r = i_d / x;
                int d_x = i_d % x;
                return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
                       base_;
            }
            else if constexpr(Layout == naive_attention_layout_enum::PHDS)
            {
                return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
            }
        }

        // below set of API will directly use pointer inside this struct
        __device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
        __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
        __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
valarLip's avatar
valarLip committed
232
233
234
235
236
237
        template <int vec_size>
        __device__ ext_vector_t<T, vec_size> load_vector(int i_s, int i_d)
        {
            return reinterpret_cast<ext_vector_t<T, vec_size>*>(base_ptr +
                                                                get_offset(i_s, i_d * vec_size))[0];
        }
carlushuang's avatar
carlushuang committed
238
239
    };

240
    template <typename T, naive_attention_layout_enum Layout>
carlushuang's avatar
carlushuang committed
241
242
    struct kvscale_addresser
    {
243
        int s, h, d; // seqlen(tokens), nhead, hdim
carlushuang's avatar
carlushuang committed
244
        T* base_ptr;
245
246
        __device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
            : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
carlushuang's avatar
carlushuang committed
247
248
        {
        }
249
        __device__ int get_offset(int i_s, int i_h, int i_d)
carlushuang's avatar
carlushuang committed
250
        {
251
252
253
254
255
256
257
258
259
260
            if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
            {
                // [nhead, tokens]
                (void)i_d;
                return i_h * s + i_s;
            }
            else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
            {
                return 0;
            }
carlushuang's avatar
carlushuang committed
261
            // [h, 2, d]
262
            // return i_h * 2 * d + i_kv * d + i_d;
carlushuang's avatar
carlushuang committed
263
        }
264
        __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
carlushuang's avatar
carlushuang committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    };

    __device__ __host__ static constexpr int get_block_size() { return 256; }

    // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
    // compute all hdim from q, compute WG_SIZE hdim from v
    // 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
    // 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
    // 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
    // TODO: could support split-kv to validate intermediate logsum
    __host__ static dim3 get_grid_size(naive_attention_fwd_args args)
    {
        constexpr int wg_size = get_block_size();
        auto g =
            dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
        return g;
    }

    // reduce single pixel within a wave
    template <typename T, typename F>
    __device__ constexpr T wave_reduce(T local, F reduce_f)
    {
        // constexpr int wave_size = 64;
        constexpr int reduce_stage = 6; // 1<<6=64
        T v_local                  = local;
#pragma unroll
        for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
        {
            int src_lane = __lane_id() ^ (1 << i_stage);
            int32_t v_remote_tmp =
                __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
            T v_remote = bit_cast<T>(v_remote_tmp);
            v_local    = reduce_f(v_local, v_remote);
        }
        return v_local;
    }

    // Note: this function must be called after wave_reduce
    // Note: better not use this under if...else... with thread divergence (syncthreads)
    template <typename T, typename F>
    __device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
    {
        constexpr int waves     = 4;
        constexpr int wave_size = 64;
        int lane_id             = threadIdx.x % wave_size;

        __syncthreads();
        smem[threadIdx.x] = local;
        __syncthreads();

        // the data within single wave is the same
        // but for simplicity, we still use data from each lane.
        T v_local = smem[lane_id];
#pragma unroll
        for(int i_stage = 1; i_stage < waves; i_stage++)
        {
            T v_remote = smem[i_stage * wave_size + lane_id];
            v_local    = reduce_f(v_local, v_remote);
        }
        return v_local;
    }

    // kernel entry point
    __device__ void operator()(naive_attention_fwd_args args)
    {
        constexpr int wg_size = get_block_size();
331
332
333
334
335
336
337
        __shared__ char smem[wg_size * 4 * sizeof(float)];       //  should enough
        char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
        int i_dv           = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
        int i_sq           = blockIdx.y;                         // index of seqlen_q
        int i_batch        = blockIdx.z;                         // index of batch_q * nhead_q
        int i_bq           = i_batch / args.nhead_q;             // index of batch_q
        int i_hq           = i_batch % args.nhead_q;             // index of nhead_q
carlushuang's avatar
carlushuang committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

        int i_bk = i_bq / args.batch_ratio_kv;
        int i_hk = i_hq / args.nhead_ratio_kv;

        void* page_table_ptr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
            }
            else
            {
                return nullptr;
            }
        }();

        auto q_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<QType, QLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return addresser<QType, QLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
            }
        }();
        auto k_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<KType, KLayout>{
                    args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return page_addresser<KType, KLayout>{
                    args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
            }
        }();
        auto v_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<VType, VLayout>{
                    args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return page_addresser<VType, VLayout>{
                    args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
            }
        }();
        auto o_addr = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return addresser<OType, OLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return addresser<OType, OLayout>{
                    args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
            }
        }();

        q_addr.init(i_bq, i_hq);
        k_addr.init(i_bk, i_hk);
        v_addr.init(i_bk, i_hk);
        o_addr.init(i_bq, i_hq);

        auto f_max        = [](auto x_, auto y_) { return max(x_, y_); };
        auto f_sum        = [](auto x_, auto y_) { return x_ + y_; };
        auto f_absmax_f32 = [](float v_0_, float v_1_) {
410
411
412
413
            // float rtn;
            // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
            // return rtn;
            return max(abs(v_0_), abs(v_1_));
carlushuang's avatar
carlushuang committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        };

        int seqlen_kv = [&]() {
            if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
            {
                return args.seqlen_kv;
            }
            else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
            {
                return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
            }
        }();

        SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
        SoftmaxType l{0};
valarLip's avatar
valarLip committed
429
430
        // AccType_O o_acc = {0};
        TailType o_acc = {0};
carlushuang's avatar
carlushuang committed
431

432
433
434
435
436
437
        int sk_loops                     = (seqlen_kv + wg_size - 1) / wg_size;
        QuantComputeType q_dequant_scale = .0f;
        kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
            args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
        kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
            args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
carlushuang's avatar
carlushuang committed
438

439
        if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
carlushuang's avatar
carlushuang committed
440
        {
valarLip's avatar
valarLip committed
441
442
443
            // AccType_O is i32 now, seqlen_q = 1, hdim up to 256
            AccType_O q   = 0;
            AccType_O k_s = 0;
carlushuang's avatar
carlushuang committed
444
445
            if(static_cast<int>(threadIdx.x) < args.hdim)
            {
valarLip's avatar
valarLip committed
446
447
                q   = type_convert<AccType_O>(q_addr.load(0, threadIdx.x));
                k_s = type_convert<AccType_O>(kscale_addr.load(i_hk, threadIdx.x, 0));
carlushuang's avatar
carlushuang committed
448
449
            }
            // 1) we apply the k scale to q
valarLip's avatar
valarLip committed
450
            AccType_O q_forwarded = q * k_s;
carlushuang's avatar
carlushuang committed
451
452
453

            // 2) apply smooth-quant
            // find absmax
valarLip's avatar
valarLip committed
454
455
            AccType_O qf_max = wave_reduce(q_forwarded, f_absmax_f32);
            qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType_O*>(smem));
carlushuang's avatar
carlushuang committed
456
457

            // per-token scale
458
            q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
carlushuang's avatar
carlushuang committed
459
460

            // devide by scale
461
            q = q / q_dequant_scale;
carlushuang's avatar
carlushuang committed
462
463

            // fp32->i8
464
            QCompute quantized_q = static_cast<QCompute>(q);
carlushuang's avatar
carlushuang committed
465
            __syncthreads();
466
            reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
carlushuang's avatar
carlushuang committed
467
468
469
470
            __syncthreads();

            // after above process, we have 2 data
            // 1) int8 q data stored in smem(no need to reload)
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
        }
        else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
        {
            if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
            {
                // dyanmic quant q here
                float q = 0;
                if(static_cast<int>(threadIdx.x) < args.hdim)
                {
                    q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
                }

                // apply smooth-quant
                // find absmax
                float q_max = wave_reduce(q, f_absmax_f32);
                q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));

                // per-token scale
                q_dequant_scale =
                    type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;

                // devide by scale
                q = q / q_dequant_scale;

                QCompute quantized_q = type_convert<QCompute>(q);
                __syncthreads();
                reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
                __syncthreads();

                // after above process, we have 2 data
                // 1) fp8 q data stored in smem(no need to reload from global)
                // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
            }
carlushuang's avatar
carlushuang committed
505
        }
valarLip's avatar
valarLip committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_4BIT_PERTOKEN)
        {
            // current same with KV_8BIT_PERTOKEN, as we use 8bit mfma
            if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
            {
                // dyanmic quant q here
                float q = 0;
                if(static_cast<int>(threadIdx.x) < args.hdim)
                {
                    q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
                }

                // apply smooth-quant
                // find absmax
                float q_max = wave_reduce(q, f_absmax_f32);
                q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));

                // per-token scale
                q_dequant_scale =
                    type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;

                // devide by scale
                q = q / q_dequant_scale;

                QCompute quantized_q = type_convert<QCompute>(q);
                __syncthreads();
                reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
                __syncthreads();

                // after above process, we have 2 data
                // 1) fp8 q data stored in smem(no need to reload from global)
                // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
            }
        }
carlushuang's avatar
carlushuang committed
540
541
542
543
544
545
546
547

        for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
        {
            int i_sk = i_loop1 * wg_size + threadIdx.x;
            // gemm-1
            SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
            if(i_sk < seqlen_kv)
            {
valarLip's avatar
valarLip committed
548
549
550
                AccType_O s_acc{0}; // clear for every loop
                int gemm_1_loop = args.hdim / gemm1_vec_size;
                for(auto i_loop = 0; i_loop < gemm_1_loop; i_loop++)
carlushuang's avatar
carlushuang committed
551
                {
552
553
554
555
                    auto q = [&]() {
                        if constexpr(Traits::quant_algo ==
                                         naive_attention_quant_algo::KV_8BIT_PERHEAD ||
                                     Traits::quant_algo ==
valarLip's avatar
valarLip committed
556
557
558
                                         naive_attention_quant_algo::KV_8BIT_PERTOKEN ||
                                     Traits::quant_algo ==
                                         naive_attention_quant_algo::KV_4BIT_PERTOKEN)
559
                        {
valarLip's avatar
valarLip committed
560
                            return reinterpret_cast<q_vec_type*>(smem_quant_q)[i_loop];
561
562
                        }
                        else
valarLip's avatar
valarLip committed
563
564
565
566
567
568
                        {
                            return q_addr.template load_vector<gemm1_vec_size>(i_sq, i_loop);
                        }
                    }();
                    auto k = [&]() {
                        return k_addr.template load_vector<gemm1_vec_size>(i_sk, i_loop);
569
                    }();
carlushuang's avatar
carlushuang committed
570

valarLip's avatar
valarLip committed
571
572
573
574
                    for(int i = 0; i < gemm1_vec_size; i++)
                    {
                        s_acc += type_convert<AccType_O>(q[i]) * type_convert<AccType_O>(k[i]);
                    }
carlushuang's avatar
carlushuang committed
575
576
577
578
579
                }
                // scale
                s_softmax = type_convert<SoftmaxType>(s_acc);
                s_softmax *=
                    type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
580
581
582
583
584
                if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
                {
                    s_softmax *= q_dequant_scale; // post scale the per-token factor
                }
                else if constexpr(Traits::quant_algo ==
valarLip's avatar
valarLip committed
585
586
587
                                      naive_attention_quant_algo::KV_8BIT_PERTOKEN ||
                                  Traits::quant_algo ==
                                      naive_attention_quant_algo::KV_4BIT_PERTOKEN)
carlushuang's avatar
carlushuang committed
588
                {
589
590
591
592
                    SoftmaxType k_per_token_scale =
                        type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
                    s_softmax *= q_dequant_scale;
                    s_softmax *= k_per_token_scale;
carlushuang's avatar
carlushuang committed
593
594
595
596
                }
            }

            // s->p
597
            QuantComputeType p_dequant_scale = 1.;
carlushuang's avatar
carlushuang committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            {
                // softmax, find max
                SoftmaxType old_max = row_max;
                SoftmaxType cur_max = wave_reduce(s_softmax, f_max);

                cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
                row_max = max(old_max, cur_max); // update row_max
                // softmax, exp(i_elem - max)
                SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);

                // compute exp_sum
                SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
                row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));

                // l, pre-scall o_acc
                SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
                l               = tmp * l + row_sum;
valarLip's avatar
valarLip committed
615
                o_acc           = type_convert<TailType>(type_convert<SoftmaxType>(o_acc) * tmp);
carlushuang's avatar
carlushuang committed
616

valarLip's avatar
valarLip committed
617
618
                // prepare the p_compute into smem, to let every thread read same p_compute and
                // do 2nd gemm
619
                if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
carlushuang's avatar
carlushuang committed
620
                {
621
                    QuantComputeType v_s = 0;
carlushuang's avatar
carlushuang committed
622
623
                    if(static_cast<int>(threadIdx.x) < args.hdim_v)
                    {
624
625
                        v_s =
                            type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
carlushuang's avatar
carlushuang committed
626
627
628
                    }

                    // 1) we apply the v scale to p
629
                    QuantComputeType p_forwarded = p_compute * v_s;
carlushuang's avatar
carlushuang committed
630
631
632

                    // 2) apply smooth-quant
                    // find absmax
633
634
635
                    QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
                    pf_max                  = cross_wave_reduce(
                        pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
carlushuang's avatar
carlushuang committed
636
637

                    // per-token scale
638
                    p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
carlushuang's avatar
carlushuang committed
639
640

                    // devide by scale
641
                    p_compute = p_compute / p_dequant_scale;
carlushuang's avatar
carlushuang committed
642
643

                    // fp32->i8
644
                    PType quantized_p = static_cast<PType>(p_compute);
carlushuang's avatar
carlushuang committed
645
                    __syncthreads();
646
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
carlushuang's avatar
carlushuang committed
647
648
649
                    __syncthreads();
                    // after above process, we have 2 data
                    // 1) int8 p data stored in smem(no need to reload)
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
                    // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
                }
                else if constexpr(Traits::quant_algo ==
                                  naive_attention_quant_algo::KV_8BIT_PERTOKEN)
                {
                    // forward apply the v scale to p_compute, this is compute friendly
                    auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
                    p_compute *= v_scale;
                    // smooth-quant
                    // find absmax
                    QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
                    p_max                  = cross_wave_reduce(
                        p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));

                    // per-token scale
                    p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;

                    // devide by scale
                    p_compute = p_compute / p_dequant_scale;

                    // fp32->i8
                    PType quantized_p = type_convert<PType>(p_compute);
                    __syncthreads();
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
                    __syncthreads();
                    // after above process, we have 2 data
                    // 1) fp8_t p data stored in smem(no need to reload)
                    // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
carlushuang's avatar
carlushuang committed
678
679
680
681
682
683
684
685
686
687
688
689
                }
                else
                {
                    __syncthreads();
                    reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
                    __syncthreads();
                }
            }

            // gemm-2, simple loop over vector by vector
            constexpr int gemm_2_loop = wg_size / p_vec_elem;
            {
valarLip's avatar
valarLip committed
690
                AccType_O o_acc_local = {0};
carlushuang's avatar
carlushuang committed
691
692
693
694
695
696
697
698
699
700
                int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
                for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
                {
                    p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
#pragma unroll
                    for(int i_j = 0; i_j < p_vec_elem; i_j++)
                    {
                        int sv_offset = i_loop2 * p_vec_elem + i_j;
                        int i_sv      = sk_start + sv_offset;

701
                        VType v = 0;
carlushuang's avatar
carlushuang committed
702
703
704
705
706
                        if(i_dv < args.hdim_v && i_sv < seqlen_kv)
                        {
                            v = v_addr.load(i_sv, i_dv);
                        }

valarLip's avatar
valarLip committed
707
                        AccType_O v_compute = [&]() { return type_convert<AccType_O>(v); }();
708

valarLip's avatar
valarLip committed
709
                        o_acc_local += type_convert<AccType_O>(p_vec[i_j]) * v_compute;
carlushuang's avatar
carlushuang committed
710
711
                    }
                }
712

valarLip's avatar
valarLip committed
713
                TailType post_scale_o_acc_local = [&]() {
714
715
716
                    if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
                    {
                        // apply pr scale to local acc
valarLip's avatar
valarLip committed
717
                        return type_convert<TailType>(type_convert<QuantComputeType>(o_acc_local) *
718
719
720
721
722
723
                                                      p_dequant_scale);
                    }
                    else if constexpr(Traits::quant_algo ==
                                      naive_attention_quant_algo::KV_8BIT_PERTOKEN)
                    {
                        // apply pr scale to local acc
valarLip's avatar
valarLip committed
724
                        return type_convert<TailType>(type_convert<QuantComputeType>(o_acc_local) *
725
726
727
728
                                                      p_dequant_scale);
                    }
                    else
                    {
valarLip's avatar
valarLip committed
729
                        return type_convert<TailType>(o_acc_local);
730
731
732
                    }
                }();
                o_acc += post_scale_o_acc_local;
carlushuang's avatar
carlushuang committed
733
734
735
736
737
738
            }
        }

        // post scale o_acc
        {
            SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
valarLip's avatar
valarLip committed
739
            o_acc           = type_convert<TailType>(type_convert<SoftmaxType>(o_acc) * tmp);
carlushuang's avatar
carlushuang committed
740
741
742
743
744
745
746
747
748
749
        }

        // store O
        if(i_dv < args.hdim_v)
            o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
    }
};

#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_()                                                        \
    {                                                                                                       \
750
751
752
        using ktraits_ = naive_attention_fwd_kernel_traits<                                                 \
            static_cast<naive_attention_variation_enum>(variation_),                                        \
            static_cast<naive_attention_quant_algo>(quant_algo_)>;                                          \
carlushuang's avatar
carlushuang committed
753
        using k_   = naive_attention_fwd_kernel<q_type_,                                                    \
valarLip's avatar
valarLip committed
754
755
756
757
758
759
760
761
762
763
764
765
766
                                                k_type_,                                                    \
                                                v_type_,                                                    \
                                                o_type_,                                                    \
                                                acc_type_i_,                                                \
                                                acc_type_o_,                                                \
                                                kvscale_type_,                                              \
                                                q_layout_,                                                  \
                                                k_layout_,                                                  \
                                                v_layout_,                                                  \
                                                o_layout_,                                                  \
                                                k_scale_layout_,                                            \
                                                v_scale_layout_,                                            \
                                                ktraits_>;                                                  \
carlushuang's avatar
carlushuang committed
767
768
769
770
771
772
773
774
775
        dim3 grids = k_::get_grid_size(a);                                                                  \
        r          = ck_tile::launch_kernel(s,                                                              \
                                   ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
    }

#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()                                                 \
    if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
       t.o_layout == "bshd")                                                                       \
    {                                                                                              \
776
777
778
779
780
781
782
        constexpr auto q_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto v_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BSHD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr int variation_       = 0;                                                        \
carlushuang's avatar
carlushuang committed
783
784
785
786
787
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }                                                                                              \
    else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" &&                    \
            t.v_layout == "bhsd" && t.o_layout == "bhsd")                                          \
    {                                                                                              \
788
789
790
791
792
793
794
        constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto v_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
        constexpr int variation_       = 0;                                                        \
carlushuang's avatar
carlushuang committed
795
796
797
798
799
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }                                                                                              \
    else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" &&                   \
            t.v_layout == "phds" && t.o_layout == "bhsd")                                          \
    {                                                                                              \
800
801
802
803
804
805
806
        constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_layout_       = naive_attention_layout_enum::PHDSX;                       \
        constexpr auto v_layout_       = naive_attention_layout_enum::PHDS;                        \
        constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
        constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
        constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
        constexpr int variation_       = 2;                                                        \
carlushuang's avatar
carlushuang committed
807
808
809
810
811
812
813
814
815
816
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
    }

//
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
                                       naive_attention_fwd_args a,
                                       ck_tile::stream_config s)
{
    float r = -1;
    // TODO: do not explicitly create too much instance!
817
818
819
820
821
822
823
    if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
       t.quant_algo == 0)
    {
        using q_type_             = fp16_t;
        using k_type_             = fp16_t;
        using v_type_             = fp16_t;
        using o_type_             = fp16_t;
valarLip's avatar
valarLip committed
824
825
        using acc_type_i_         = fp16_t;
        using acc_type_o_         = float;
826
827
828
829
830
831
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 0;
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
    else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
            t.quant_algo == 0)
carlushuang's avatar
carlushuang committed
832
    {
833
834
835
836
        using q_type_             = bf16_t;
        using k_type_             = bf16_t;
        using v_type_             = bf16_t;
        using o_type_             = bf16_t;
valarLip's avatar
valarLip committed
837
838
        using acc_type_i_         = fp16_t;
        using acc_type_o_         = float;
839
840
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 0;
carlushuang's avatar
carlushuang committed
841
842
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
843
844
    else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
845
    {
846
847
848
849
        using q_type_             = bf16_t;
        using k_type_             = fp8_t;
        using v_type_             = fp8_t;
        using o_type_             = bf16_t;
valarLip's avatar
valarLip committed
850
851
        using acc_type_i_         = fp8_t;
        using acc_type_o_         = float;
852
853
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
854
855
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
856
857
    else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
858
    {
859
860
861
862
        using q_type_             = fp16_t;
        using k_type_             = fp8_t;
        using v_type_             = fp8_t;
        using o_type_             = fp16_t;
valarLip's avatar
valarLip committed
863
864
        using acc_type_i_         = fp8_t;
        using acc_type_o_         = float;
865
866
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
867
868
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
869
870
    else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
            t.quant_algo == 2)
carlushuang's avatar
carlushuang committed
871
    {
872
873
874
875
        using q_type_             = bf16_t;
        using k_type_             = int8_t;
        using v_type_             = int8_t;
        using o_type_             = bf16_t;
valarLip's avatar
valarLip committed
876
877
        using acc_type_i_         = int8_t;
        using acc_type_o_         = int32_t;
878
879
        using kvscale_type_       = float;
        constexpr int quant_algo_ = 2;
carlushuang's avatar
carlushuang committed
880
881
882
883
884
885
886
887
888
        CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
    }
    return r;
}

#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_

} // namespace ck_tile