fmha_fwd_kernel.hpp 8.23 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tile_program/tile/tile_window.hpp"

// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q]
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]

14
#define C_LOG2E 1.44269504088896340736 // log2(e)
15
16

template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
carlushuang's avatar
carlushuang committed
17
18
struct FmhaFwdKernel
{
19
    using TilePartitioner                   = ck::remove_cvref_t<TilePartitioner_>;
carlushuang's avatar
carlushuang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    using FmhaPipeline                      = ck::remove_cvref_t<FmhaPipeline_>;
    using EpiloguePipeline                  = ck::remove_cvref_t<EpiloguePipeline_>;
    static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize;

    using QDataType = ck::remove_cvref_t<typename FmhaPipeline::QDataType>;
    using KDataType = ck::remove_cvref_t<typename FmhaPipeline::KDataType>;
    using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>;
    using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>;

    struct Kargs
    {
        const void* q_ptr;
        const void* k_ptr;
        const void* v_ptr;
        void* o_ptr;
        ck::index_t seqlen_q;
        ck::index_t seqlen_k;
        ck::index_t hdim_q;
        ck::index_t hdim_v;
39
40
41

        float scale;

carlushuang's avatar
carlushuang committed
42
43
44
45
        ck::index_t stride_q;
        ck::index_t stride_k;
        ck::index_t stride_v;
        ck::index_t stride_o;
46
47
48
49
50
51

        ck::index_t nhead_stride_q;
        ck::index_t nhead_stride_k;
        ck::index_t nhead_stride_v;
        ck::index_t nhead_stride_o;

carlushuang's avatar
carlushuang committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        ck::index_t batch_stride_q;
        ck::index_t batch_stride_k;
        ck::index_t batch_stride_v;
        ck::index_t batch_stride_o;
    };

    __host__ static constexpr Kargs MakeKargs(const void* q_ptr,
                                              const void* k_ptr,
                                              const void* v_ptr,
                                              void* o_ptr,
                                              ck::index_t seqlen_q,
                                              ck::index_t seqlen_k,
                                              ck::index_t hdim_q,
                                              ck::index_t hdim_v,
66
                                              float scale,
carlushuang's avatar
carlushuang committed
67
68
69
70
                                              ck::index_t stride_q,
                                              ck::index_t stride_k,
                                              ck::index_t stride_v,
                                              ck::index_t stride_o,
71
72
73
74
                                              ck::index_t nhead_stride_q,
                                              ck::index_t nhead_stride_k,
                                              ck::index_t nhead_stride_v,
                                              ck::index_t nhead_stride_o,
carlushuang's avatar
carlushuang committed
75
76
77
78
79
                                              ck::index_t batch_stride_q,
                                              ck::index_t batch_stride_k,
                                              ck::index_t batch_stride_v,
                                              ck::index_t batch_stride_o)
    {
80
81
82
83
        return Kargs{q_ptr,          k_ptr,          v_ptr,          o_ptr,          seqlen_q,
                     seqlen_k,       hdim_q,         hdim_v,         scale,          stride_q,
                     stride_k,       stride_v,       stride_o,       nhead_stride_q, nhead_stride_k,
                     nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v,
carlushuang's avatar
carlushuang committed
84
85
86
                     batch_stride_o};
    }

87
88
89
90
    __host__ static constexpr auto GridSize(ck::index_t batch_size_,
                                            ck::index_t nhead_,
                                            ck::index_t seqlen_q_,
                                            ck::index_t hdim_v_)
carlushuang's avatar
carlushuang committed
91
    {
92
        return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
carlushuang's avatar
carlushuang committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    }

    __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }

    __host__ __device__ static constexpr ck::index_t GetSmemSize()
    {
        return ck::math::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
    }

    __device__ void operator()(Kargs kargs) const
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        // allocate LDS
        __shared__ char smem_ptr[GetSmemSize()];

        // divide problem
112
113
        const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
            TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
carlushuang's avatar
carlushuang committed
114

115
116
        const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
        const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
carlushuang's avatar
carlushuang committed
117
118

        // for simplicity, batch stride we just modify the pointer
119
120
121
122
123
124
125
126
        const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
                                 i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q;
        const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
                                 i_nhead * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k;
        const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
                                 i_nhead * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v;
        ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
                           i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o;
carlushuang's avatar
carlushuang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

        // Q/K/V DRAM and DRAM window
        // FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
        const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            q_ptr,
            make_tuple(kargs.seqlen_q, kargs.hdim_q),
            make_tuple(kargs.stride_q, 1),
            Number<32>{},
            Number<1>{});

        const auto k_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            k_ptr,
            make_tuple(kargs.seqlen_k, kargs.hdim_q),
            make_tuple(kargs.stride_k, 1),
            Number<32>{},
            Number<1>{});

        const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            v_ptr,
            make_tuple(kargs.hdim_v, kargs.seqlen_k),
            make_tuple(kargs.stride_v, 1),
            Number<32>{},
            Number<1>{});

151
152
153
154
155
156
157
158
159
160
        auto q_dram_window = make_tile_window(
            q_dram,
            [&]() {
                if constexpr(FmhaPipeline::kQLoadOnce)
                    return make_tuple(Number<FmhaPipeline::kM0>{},
                                      Number<FmhaPipeline::kK0BlockLength>{});
                else
                    return make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{});
            }(),
            {i_m0, 0});
carlushuang's avatar
carlushuang committed
161
162
163
164
165
166
167
168
169
170
171
172

        auto k_dram_window = make_tile_window(
            k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0});

        auto v_dram_window =
            make_tile_window(v_dram,
                             make_tuple(Number<FmhaPipeline::kN1>{}, Number<FmhaPipeline::kK1>{}),
                             {i_n1, 0});

        auto o_acc_tile = FmhaPipeline{}(q_dram_window,
                                         k_dram_window,
                                         v_dram_window,
173
                                         kargs.scale,
carlushuang's avatar
carlushuang committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                                         kargs.seqlen_k / FmhaPipeline::kN0,
                                         kargs.hdim_q / FmhaPipeline::kK0,
                                         smem_ptr);

        // O DRAM and O DRAM window
        auto o_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            o_ptr,
            make_tuple(kargs.seqlen_q, kargs.hdim_v),
            make_tuple(kargs.stride_o, 1),
            Number<32>{},
            Number<1>{});

        auto o_dram_window =
            make_tile_window(o_dram,
                             make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kN1>{}),
                             {i_m0, i_n1});

        EpiloguePipeline{}(o_dram_window, o_acc_tile);
    }
};