softmax.hpp 7.6 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"

template <typename ADataType,
          typename AccDataType,
          typename BDataType,
          ck::index_t kBlockSize,
          ck::index_t kMPerBlock,
          ck::index_t kNPerBlock>
struct Softmax
{
#if 0
Chao Liu's avatar
Chao Liu committed
27
     __device__ static constexpr auto MakeABlockTileDistribution()
Chao Liu's avatar
Chao Liu committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    {
        using namespace ck;
        using namespace ck::tile_program;

        // 2x2 wave
        return make_static_tile_distribution(
            StaticTileDistributionEncoding<Sequence<>,
                                           Tuple<Sequence<2, 2, 4, 2, 4>, Sequence<2, 2, 32>>,
                                           Tuple<Sequence<1, 2>, Sequence<1, 2>>,
                                           Tuple<Sequence<1, 1>, Sequence<3, 2>>,
                                           Sequence<1, 2, 1, 1>,
                                           Sequence<0, 0, 2, 4>>{});
    }
#elif 0
Chao Liu's avatar
Chao Liu committed
42
    __device__ static constexpr auto MakeABlockTileDistribution()
Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    {
        using namespace ck;
        using namespace ck::tile_program;

        // 2x2 wave
        return make_static_tile_distribution(
            StaticTileDistributionEncoding<Sequence<>,
                                           Tuple<Sequence<2, 2, 32>, Sequence<2, 2, 4, 2, 4>>,
                                           Tuple<Sequence<2, 1>, Sequence<2, 1>>,
                                           Tuple<Sequence<1, 1>, Sequence<3, 2>>,
                                           Sequence<2, 1, 2, 2>,
                                           Sequence<0, 0, 2, 4>>{});
    }
#elif 1
Chao Liu's avatar
Chao Liu committed
57
    __device__ static constexpr auto MakeABlockTileDistribution()
Chao Liu's avatar
Chao Liu committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    {
        using namespace ck;
        using namespace ck::tile_program;

        // 4x1 wave
        return make_static_tile_distribution(
            StaticTileDistributionEncoding<Sequence<>,
                                           Tuple<Sequence<1, 4, 4, 2, 4>, Sequence<4, 1, 32>>,
                                           Tuple<Sequence<1, 2>, Sequence<1, 2>>,
                                           Tuple<Sequence<1, 1>, Sequence<3, 2>>,
                                           Sequence<1, 2, 1, 1>,
                                           Sequence<0, 0, 2, 4>>{});
    }
#endif

Chao Liu's avatar
Chao Liu committed
73
74
    __device__ void
    operator()(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79
80
81
82
83
84
85
    {
        using namespace ck;
        using namespace ck::tile_program;
        using namespace ck::tile_program::block;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        const auto a_m_n = make_naive_tensor_view<AddressSpaceEnum::Global>(
            p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

Chao Liu's avatar
Chao Liu committed
86
        const auto iM = get_block_id() * kMPerBlock;
Chao Liu's avatar
Chao Liu committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        // A window
        auto a_block_window =
            make_tile_window(a_m_n,
                             make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
                             {iM, 0},
                             MakeABlockTileDistribution());

        constexpr auto reduce_dims = Sequence<1>{};

        const auto f_max = [](auto v0, auto v1) { return max(v0, v1); };

        const ADataType max_reduce_init_value = NumericLimits<ADataType>::Lowest();

        // max = max(a)
        auto max_block_tensor = decltype(block_tile_reduce<AccDataType>(
            load_tile(a_block_window), reduce_dims, f_max, max_reduce_init_value)){};

        tile_elementwise_inout(
            [&](auto& max) { max = type_convert<AccDataType>(max_reduce_init_value); },
            max_block_tensor);

        index_t iN = 0;

        do
        {
            const auto a_block_tensor = load_tile(a_block_window);

            block_tile_reduce(max_block_tensor, a_block_tensor, reduce_dims, f_max);

            move_tile_window(a_block_window, {0, kNPerBlock});

            iN += kNPerBlock;

        } while(iN < N);

        // cross lane reduce: max
        block_tile_reduce_sync(max_block_tensor, f_max);

        // exp_sum = sum(exp(a - a_max))
        auto exp_sum_block_tensor =
            make_static_distributed_tensor<AccDataType>(max_block_tensor.GetTileDistribution());

        tile_elementwise_inout([&](auto& exp_sum) { exp_sum = 0; }, exp_sum_block_tensor);

        // reset window location
        iN = 0;
        move_tile_window(a_block_window, {0, -N});

        do
        {
            const auto a_block_tensor = load_tile(a_block_window);

            constexpr auto a_spans = decltype(a_block_tensor)::GetDistributedSpans();

            //
            sweep_tile_span(a_spans[I0], [&](auto idx0) {
                constexpr auto m_idx = make_tuple(idx0);

Chao Liu's avatar
Chao Liu committed
146
                const auto v_max = max_block_tensor[m_idx];
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
                AccDataType v_exp_sum = exp_sum_block_tensor[m_idx];
Chao Liu's avatar
Chao Liu committed
149
150
151
152

                sweep_tile_span(a_spans[I1], [&](auto idx1) {
                    constexpr auto m_n_idx = make_tuple(idx0, idx1);

Chao Liu's avatar
Chao Liu committed
153
                    const auto v_a = a_block_tensor[m_n_idx];
Chao Liu's avatar
Chao Liu committed
154
155
156
157
158

                    // exp and sum
                    v_exp_sum += math::exp(v_a - v_max);
                });

Chao Liu's avatar
Chao Liu committed
159
                exp_sum_block_tensor(m_idx) = v_exp_sum;
Chao Liu's avatar
Chao Liu committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            });

            move_tile_window(a_block_window, {0, kNPerBlock});

            iN += kNPerBlock;

        } while(iN < N);

        // cross lane reduce: sum
        block_tile_reduce_sync(exp_sum_block_tensor, [](auto v0, auto v1) { return v0 + v1; });

        // B
        const auto b_m_n = make_naive_tensor_view<AddressSpaceEnum::Global>(
            p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});

        // B window
        auto b_block_window = make_tile_window(
            b_m_n, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});

        // reset window location
        iN = 0;
        move_tile_window(a_block_window, {0, -N});

        do
        {
            const auto a_block_tensor = load_tile(a_block_window);

            constexpr auto a_spans = decltype(a_block_tensor)::GetDistributedSpans();

            auto b_block_tensor =
                make_static_distributed_tensor<BDataType>(a_block_tensor.GetTileDistribution());

            //
            sweep_tile_span(a_spans[I0], [&](auto idx0) {
                constexpr auto m_idx = make_tuple(idx0);

Chao Liu's avatar
Chao Liu committed
196
                const auto v_max = max_block_tensor[m_idx];
Chao Liu's avatar
Chao Liu committed
197

Chao Liu's avatar
Chao Liu committed
198
                const auto v_exp_sum = exp_sum_block_tensor[m_idx];
Chao Liu's avatar
Chao Liu committed
199
200
201
202

                sweep_tile_span(a_spans[I1], [&](auto idx1) {
                    constexpr auto m_n_idx = make_tuple(idx0, idx1);

Chao Liu's avatar
Chao Liu committed
203
                    const auto v_a = a_block_tensor[m_n_idx];
Chao Liu's avatar
Chao Liu committed
204
205
206
207
208

                    // exp
                    const BDataType v_b =
                        type_convert<BDataType>(math::exp(v_a - v_max) / v_exp_sum);

Chao Liu's avatar
Chao Liu committed
209
                    b_block_tensor(m_n_idx) = v_b;
Chao Liu's avatar
Chao Liu committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                });
            });

            // store B tile
            store_tile(b_block_window, b_block_tensor);

            move_tile_window(a_block_window, {0, kNPerBlock});
            move_tile_window(b_block_window, {0, kNPerBlock});

            iN += kNPerBlock;

        } while(iN < N);
    }
};