fmha_fwd_kernel.hpp 8 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tile_program/tile/tile_window.hpp"

// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q]
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]

14
15
16
#define C_LOG2E    1.44269504088896340736   // log2(e)

template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
carlushuang's avatar
carlushuang committed
17
18
struct FmhaFwdKernel
{
19
    using TilePartitioner                   = ck::remove_cvref_t<TilePartitioner_>;
carlushuang's avatar
carlushuang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    using FmhaPipeline                      = ck::remove_cvref_t<FmhaPipeline_>;
    using EpiloguePipeline                  = ck::remove_cvref_t<EpiloguePipeline_>;
    static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize;

    using QDataType = ck::remove_cvref_t<typename FmhaPipeline::QDataType>;
    using KDataType = ck::remove_cvref_t<typename FmhaPipeline::KDataType>;
    using VDataType = ck::remove_cvref_t<typename FmhaPipeline::VDataType>;
    using ODataType = ck::remove_cvref_t<typename FmhaPipeline::ODataType>;

    struct Kargs
    {
        const void* q_ptr;
        const void* k_ptr;
        const void* v_ptr;
        void* o_ptr;
        ck::index_t seqlen_q;
        ck::index_t seqlen_k;
        ck::index_t hdim_q;
        ck::index_t hdim_v;
39
40
41

        float scale;

carlushuang's avatar
carlushuang committed
42
43
44
45
        ck::index_t stride_q;
        ck::index_t stride_k;
        ck::index_t stride_v;
        ck::index_t stride_o;
46
47
48
49
50
51

        ck::index_t nhead_stride_q;
        ck::index_t nhead_stride_k;
        ck::index_t nhead_stride_v;
        ck::index_t nhead_stride_o;

carlushuang's avatar
carlushuang committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        ck::index_t batch_stride_q;
        ck::index_t batch_stride_k;
        ck::index_t batch_stride_v;
        ck::index_t batch_stride_o;
    };

    __host__ static constexpr Kargs MakeKargs(const void* q_ptr,
                                              const void* k_ptr,
                                              const void* v_ptr,
                                              void* o_ptr,
                                              ck::index_t seqlen_q,
                                              ck::index_t seqlen_k,
                                              ck::index_t hdim_q,
                                              ck::index_t hdim_v,
66
                                              float scale,
carlushuang's avatar
carlushuang committed
67
68
69
70
                                              ck::index_t stride_q,
                                              ck::index_t stride_k,
                                              ck::index_t stride_v,
                                              ck::index_t stride_o,
71
72
73
74
                                              ck::index_t nhead_stride_q,
                                              ck::index_t nhead_stride_k,
                                              ck::index_t nhead_stride_v,
                                              ck::index_t nhead_stride_o,
carlushuang's avatar
carlushuang committed
75
76
77
78
79
                                              ck::index_t batch_stride_q,
                                              ck::index_t batch_stride_k,
                                              ck::index_t batch_stride_v,
                                              ck::index_t batch_stride_o)
    {
80
81
82
83
        return Kargs{q_ptr,          k_ptr,          v_ptr,          o_ptr,          seqlen_q,
                     seqlen_k,       hdim_q,         hdim_v,         scale,          stride_q,
                     stride_k,       stride_v,       stride_o,       nhead_stride_q, nhead_stride_k,
                     nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v,
carlushuang's avatar
carlushuang committed
84
85
86
                     batch_stride_o};
    }

87
88
89
90
    __host__ static constexpr auto GridSize(ck::index_t batch_size_,
                                            ck::index_t nhead_,
                                            ck::index_t seqlen_q_,
                                            ck::index_t hdim_v_)
carlushuang's avatar
carlushuang committed
91
    {
92
        return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
carlushuang's avatar
carlushuang committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    }

    __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }

    __host__ __device__ static constexpr ck::index_t GetSmemSize()
    {
        return ck::math::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
    }

    __device__ void operator()(Kargs kargs) const
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        // allocate LDS
        __shared__ char smem_ptr[GetSmemSize()];

        // divide problem
112
113
        const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
            TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
carlushuang's avatar
carlushuang committed
114

115
116
        const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
        const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
carlushuang's avatar
carlushuang committed
117
118

        // for simplicity, batch stride we just modify the pointer
119
120
121
122
123
124
125
126
        const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
                                 i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q;
        const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
                                 i_nhead * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k;
        const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
                                 i_nhead * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v;
        ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
                           i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o;
carlushuang's avatar
carlushuang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        // Q/K/V DRAM and DRAM window
        // FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
        const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            q_ptr,
            make_tuple(kargs.seqlen_q, kargs.hdim_q),
            make_tuple(kargs.stride_q, 1),
            Number<32>{},
            Number<1>{});

        const auto k_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            k_ptr,
            make_tuple(kargs.seqlen_k, kargs.hdim_q),
            make_tuple(kargs.stride_k, 1),
            Number<32>{},
            Number<1>{});

        const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            v_ptr,
            make_tuple(kargs.hdim_v, kargs.seqlen_k),
            make_tuple(kargs.stride_v, 1),
            Number<32>{},
            Number<1>{});

        auto q_dram_window =
            make_tile_window(q_dram,
                             make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{}),
                             {i_m0, 0});

        auto k_dram_window = make_tile_window(
            k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0});

        auto v_dram_window =
            make_tile_window(v_dram,
                             make_tuple(Number<FmhaPipeline::kN1>{}, Number<FmhaPipeline::kK1>{}),
                             {i_n1, 0});

        auto o_acc_tile = FmhaPipeline{}(q_dram_window,
                                         k_dram_window,
                                         v_dram_window,
167
                                         kargs.scale,
carlushuang's avatar
carlushuang committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                                         kargs.seqlen_k / FmhaPipeline::kN0,
                                         kargs.hdim_q / FmhaPipeline::kK0,
                                         smem_ptr);

        // O DRAM and O DRAM window
        auto o_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            o_ptr,
            make_tuple(kargs.seqlen_q, kargs.hdim_v),
            make_tuple(kargs.stride_o, 1),
            Number<32>{},
            Number<1>{});

        auto o_dram_window =
            make_tile_window(o_dram,
                             make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kN1>{}),
                             {i_m0, i_n1});

        EpiloguePipeline{}(o_dram_window, o_acc_tile);
    }
};