"mmdet3d/datasets/vscode:/vscode.git/clone" did not exist on "f2b01720792e67bdc91f41b33e783dcab79324e5"
softmax.hpp 8.89 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"

template <typename ADataType,
          typename AccDataType,
          typename BDataType,
          ck::index_t kBlockSize,
          ck::index_t kMPerBlock,
          ck::index_t kNPerBlock>
struct Softmax
{
Chao Liu's avatar
Chao Liu committed
26
    __device__ static constexpr auto MakeABlockTileDistribution()
Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    {
        using namespace ck;
        using namespace ck::tile_program;

        // 4x1 wave
        return make_static_tile_distribution(
            StaticTileDistributionEncoding<Sequence<>,
                                           Tuple<Sequence<1, 4, 4, 2, 4>, Sequence<4, 1, 32>>,
                                           Tuple<Sequence<1, 2>, Sequence<1, 2>>,
                                           Tuple<Sequence<1, 1>, Sequence<3, 2>>,
                                           Sequence<1, 2, 1, 1>,
                                           Sequence<0, 0, 2, 4>>{});
    }

Chao Liu's avatar
Chao Liu committed
41
    __device__ void
Chao Liu's avatar
Chao Liu committed
42
    MultiPassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
49
50
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
51
52
        // A DRAM tensor view
        const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
Chao Liu's avatar
Chao Liu committed
53
54
            p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

Chao Liu's avatar
Chao Liu committed
55
        const auto iM = get_block_id() * kMPerBlock;
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
58
59
        // A DRAM window
        auto a_dram_window =
            make_tile_window(a_dram,
Chao Liu's avatar
Chao Liu committed
60
61
62
63
                             make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
                             {iM, 0},
                             MakeABlockTileDistribution());

Chao Liu's avatar
Chao Liu committed
64
65
        // m = rowmax(A)
        const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
Chao Liu's avatar
Chao Liu committed
66

Chao Liu's avatar
Chao Liu committed
67
68
        auto m = decltype(block_tile_reduce<AccDataType>(
            load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};
Chao Liu's avatar
Chao Liu committed
69
70

        tile_elementwise_inout(
Chao Liu's avatar
Chao Liu committed
71
            [&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);
Chao Liu's avatar
Chao Liu committed
72
73
74
75
76

        index_t iN = 0;

        do
        {
Chao Liu's avatar
Chao Liu committed
77
78
            // load A tile from DRAM
            const auto a = load_tile(a_dram_window);
Chao Liu's avatar
Chao Liu committed
79

Chao Liu's avatar
Chao Liu committed
80
81
            // m = rowmax(A)
            block_tile_reduce(m, a, Sequence<1>{}, f_max);
Chao Liu's avatar
Chao Liu committed
82

Chao Liu's avatar
Chao Liu committed
83
            move_tile_window(a_dram_window, {0, kNPerBlock});
Chao Liu's avatar
Chao Liu committed
84
85
86
87
88
89

            iN += kNPerBlock;

        } while(iN < N);

        // cross lane reduce: max
Chao Liu's avatar
Chao Liu committed
90
        block_tile_reduce_sync(m, f_max);
Chao Liu's avatar
Chao Liu committed
91
92
93

        // reset window location
        iN = 0;
Chao Liu's avatar
Chao Liu committed
94
95
96
97
98
99
        move_tile_window(a_dram_window, {0, -N});

        // l = rowsum(exp(A - m))
        auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());

        tile_elementwise_inout([&](auto& e) { e = 0; }, l);
Chao Liu's avatar
Chao Liu committed
100
101
102

        do
        {
Chao Liu's avatar
Chao Liu committed
103
104
            // load A tile from DRAM
            const auto a = load_tile(a_dram_window);
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
            constexpr auto a_spans = decltype(a)::GetDistributedSpans();
Chao Liu's avatar
Chao Liu committed
107
108

            sweep_tile_span(a_spans[I0], [&](auto idx0) {
Chao Liu's avatar
Chao Liu committed
109
                constexpr auto i_idx = make_tuple(idx0);
Chao Liu's avatar
Chao Liu committed
110
111

                sweep_tile_span(a_spans[I1], [&](auto idx1) {
Chao Liu's avatar
Chao Liu committed
112
                    constexpr auto i_j_idx = make_tuple(idx0, idx1);
Chao Liu's avatar
Chao Liu committed
113

Chao Liu's avatar
Chao Liu committed
114
115
                    // l = rowsum(exp(A - m))
                    l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
Chao Liu's avatar
Chao Liu committed
116
117
118
                });
            });

Chao Liu's avatar
Chao Liu committed
119
            move_tile_window(a_dram_window, {0, kNPerBlock});
Chao Liu's avatar
Chao Liu committed
120
121
122
123
124
125

            iN += kNPerBlock;

        } while(iN < N);

        // cross lane reduce: sum
Chao Liu's avatar
Chao Liu committed
126
        block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });
Chao Liu's avatar
Chao Liu committed
127
128
129

        // reset window location
        iN = 0;
Chao Liu's avatar
Chao Liu committed
130
        move_tile_window(a_dram_window, {0, -N});
Chao Liu's avatar
Chao Liu committed
131

Chao Liu's avatar
Chao Liu committed
132
133
134
135
136
137
138
139
140
        // B DRAM tensor view
        const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

        // B DRAM window
        auto b_dram_window = make_tile_window(
            b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});

        // B = exp(A - m) / l
Chao Liu's avatar
Chao Liu committed
141
142
        do
        {
Chao Liu's avatar
Chao Liu committed
143
144
            // load A tile from DRAM
            const auto a = load_tile(a_dram_window);
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
            constexpr auto a_spans = decltype(a)::GetDistributedSpans();
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
            auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());
Chao Liu's avatar
Chao Liu committed
149
150

            sweep_tile_span(a_spans[I0], [&](auto idx0) {
Chao Liu's avatar
Chao Liu committed
151
                constexpr auto i_idx = make_tuple(idx0);
Chao Liu's avatar
Chao Liu committed
152
153

                sweep_tile_span(a_spans[I1], [&](auto idx1) {
Chao Liu's avatar
Chao Liu committed
154
                    constexpr auto i_j_idx = make_tuple(idx0, idx1);
Chao Liu's avatar
Chao Liu committed
155

Chao Liu's avatar
Chao Liu committed
156
157
158
                    // B = exp(A - m) / l
                    b(i_j_idx) =
                        type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
Chao Liu's avatar
Chao Liu committed
159
160
161
162
                });
            });

            // store B tile
Chao Liu's avatar
Chao Liu committed
163
            store_tile(b_dram_window, b);
Chao Liu's avatar
Chao Liu committed
164

Chao Liu's avatar
Chao Liu committed
165
166
            move_tile_window(a_dram_window, {0, kNPerBlock});
            move_tile_window(b_dram_window, {0, kNPerBlock});
Chao Liu's avatar
Chao Liu committed
167
168
169
170
171

            iN += kNPerBlock;

        } while(iN < N);
    }
Chao Liu's avatar
Chao Liu committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

    __device__ void
    SinglePassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        // A DRAM tensor view
        const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

        const auto iM = get_block_id() * kMPerBlock;

        // A DRAM window
        auto a_dram_window =
            make_tile_window(a_dram,
                             make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
                             {iM, 0},
                             MakeABlockTileDistribution());

        // f_max
        const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };

        // m = rowmax(A)
        auto m = decltype(block_tile_reduce<AccDataType>(
            load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};

        tile_elementwise_inout(
            [&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);

        // l = rowsum(exp(A - m))
        auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());

        tile_elementwise_inout([&](auto& e) { e = 0; }, l);

        // load A tile from DRAM
        const auto a = load_tile(a_dram_window);

        constexpr auto a_spans = decltype(a)::GetDistributedSpans();

        // m = rowmax(A)
        block_tile_reduce(m, a, Sequence<1>{}, f_max);

        // cross lane reduce: max
        block_tile_reduce_sync(m, f_max);

        // l = rowsum(exp(A - m))
        sweep_tile_span(a_spans[I0], [&](auto idx0) {
            constexpr auto i_idx = make_tuple(idx0);

            sweep_tile_span(a_spans[I1], [&](auto idx1) {
                constexpr auto i_j_idx = make_tuple(idx0, idx1);

                l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
            });
        });

        // cross lane reduce: sum
        block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });

        auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());

        // B = exp(A - m) / l
        sweep_tile_span(a_spans[I0], [&](auto idx0) {
            constexpr auto i_idx = make_tuple(idx0);

            sweep_tile_span(a_spans[I1], [&](auto idx1) {
                constexpr auto i_j_idx = make_tuple(idx0, idx1);

                b(i_j_idx) = type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
            });
        });

        // B DRAM tensor view
        const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
            p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

        // B DRAM window
        auto b_dram_window = make_tile_window(
            b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});

        // store B tile
        store_tile(b_dram_window, b);
    }

    __device__ void
    operator()(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
    {
        if(N > kNPerBlock)
        {
            MultiPassSoftmax(p_a, p_b, M, N);
        }
        else
        {
            SinglePassSoftmax(p_a, p_b, M, N);
        }
    }
Chao Liu's avatar
Chao Liu committed
273
};