blockwise_tensor_slice_op.hip.hpp 11.8 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "threadwise_tensor_slice_op.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8
9
10
11
12
13
14

template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcLengths,
          class SrcSubLengths,
          class SrcClusterLengths,
          class MapDst2Src,
          class MapThreadCluster2SrcCluster,
          index_t SrcDataPerRead,
          index_t DstDataPerWrite>
Chao Liu's avatar
Chao Liu committed
15
struct BlockwiseTensorSliceReorderCopy_v3
Chao Liu's avatar
Chao Liu committed
16
17
18
19
20
21
{
    static constexpr index_t nDim = SrcLengths::GetSize();

    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;

Chao Liu's avatar
Chao Liu committed
22
23
24
    __device__
    BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
                                       Array<index_t, nDim> dst_block_data_multi_id_begin)
Chao Liu's avatar
Chao Liu committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    {
        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr auto src_lengths = SrcLengths{};

        constexpr auto map_dst2src = MapDst2Src{};

        constexpr auto src_sub_lengths = SrcSubLengths{};
        constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);

        constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};

        constexpr auto src_cluster_lengths = SrcClusterLengths{};
        constexpr auto thread_cluster_lengths =
            src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);

        constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(thread_cluster_lengths);

        // sanity check: data type
        static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");

        // sanity check: nDim
Chao Liu's avatar
Chao Liu committed
48
49
50
        static_assert(SrcDesc::GetNumOfDimension() == nDim &&
                          DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
                          SrcSubLengths::GetSize() == nDim &&
Chao Liu's avatar
Chao Liu committed
51
52
53
54
55
56
57
58
59
60
61
                          SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
                          MapThreadCluster2SrcCluster::GetSize() == nDim,
                      "wrong! nDim is not consistent\n");

        // sanity check: BlockSize
        constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();

        static_assert(BlockSize >= num_active_thread,
                      "wrong! BlockSize is not big enough for ThreadPerDims!");

        // sanity check: work division
Chao Liu's avatar
Chao Liu committed
62
        static_for<0, nDim, 1>{}([&](auto IDim) {
Chao Liu's avatar
Chao Liu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            constexpr auto I                  = decltype(IDim){};
            constexpr index_t src_len         = src_lengths.Get(I);
            constexpr index_t src_sub_len     = src_sub_lengths.Get(I);
            constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
            static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
                          "wrong! cannot evenly divide Src tensor lengths");
        });

        // sanity check: src read
        static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
                      "wrong! only support SrcDataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
                      "wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");

        static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
                      "wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");

        static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
                      "wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
                      "keep alignment");

        // sanity check: dst write
        static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
                      "wrong! only support DstDataPerWrite == 1, 2 or 4!\n");

        static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
                      "wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");

        static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
                      "wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");

        static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
                      "wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
                      "keep alignment");

        // start dividing work
        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());

        // compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
        // regsiters, or only one copy???
        auto src_data_multi_id =
            reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr auto I    = decltype(IDim){};
            constexpr index_t i = I.Get();
Chao Liu's avatar
Chao Liu committed
118
            // compiler: will it really compute index here, or be merged with Get1dIndex and
Chao Liu's avatar
Chao Liu committed
119
120
121
122
            // optimized away???
            src_data_multi_id[i] *= src_sub_lengths.Get(I);
        });

Chao Liu's avatar
Chao Liu committed
123
        // compiler: will it really compute index here, or be merged with Get1dIndex and
Chao Liu's avatar
Chao Liu committed
124
125
126
        // optimized away???
        const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);

Chao Liu's avatar
Chao Liu committed
127
128
        mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id + src_block_data_multi_id_begin);
        mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
Chao Liu's avatar
Chao Liu committed
129
130
131
132
133
134
135
    }

    __device__ static constexpr index_t GetRegisterClipboardSize()
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

        constexpr auto src_data_per_cluster_per_dims = transform_sequences(
Chao Liu's avatar
Chao Liu committed
136
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
Chao Liu's avatar
Chao Liu committed
137
138
139
140
141
142
143

        constexpr auto repeat_lengths =
            transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
                                SrcLengths{},
                                src_data_per_cluster_per_dims);

        constexpr auto thread_tensor_lengths = transform_sequences(
Chao Liu's avatar
Chao Liu committed
144
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
Chao Liu's avatar
Chao Liu committed
145
146
147
148
149
150
151
152
153
154
155
156

        constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);

        return thread_tensor_desc.GetElementSpace();
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
                                             Float* __restrict__ p_clipboard) const
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

        constexpr auto src_data_per_cluster_per_dims = transform_sequences(
Chao Liu's avatar
Chao Liu committed
157
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
Chao Liu's avatar
Chao Liu committed
158
159
160
161
162
163
164

        constexpr auto repeat_lengths =
            transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
                                SrcLengths{},
                                src_data_per_cluster_per_dims);

        constexpr auto thread_tensor_lengths = transform_sequences(
Chao Liu's avatar
Chao Liu committed
165
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
Chao Liu's avatar
Chao Liu committed
166
167
168
169
170
171
172

        constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);

        static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
            constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};

            constexpr auto src_data_multi_id = transform_sequences(
Chao Liu's avatar
Chao Liu committed
173
                std::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
Chao Liu's avatar
Chao Liu committed
174
175

            constexpr auto clipboard_data_multi_id = transform_sequences(
Chao Liu's avatar
Chao Liu committed
176
                std::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
Chao Liu's avatar
Chao Liu committed
177
178
179
180
181

            constexpr index_t src_offset = SrcDesc{}.Get1dIndex(src_data_multi_id);
            constexpr index_t clipboard_offset =
                thread_tensor_desc.Get1dIndex(clipboard_data_multi_id);

Chao Liu's avatar
Chao Liu committed
182
183
184
185
186
187
            threadwise_tensor_slice_copy(SrcDesc{},
                                         p_src + src_offset + mSrcMyThreadOffset,
                                         thread_tensor_desc,
                                         p_clipboard + clipboard_offset,
                                         thread_sub_tensor_lengths,
                                         Number<SrcDataPerRead>{});
Chao Liu's avatar
Chao Liu committed
188
189
190
191
192
193
194
195
196
        });
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

        constexpr auto src_data_per_cluster_per_dims = transform_sequences(
Chao Liu's avatar
Chao Liu committed
197
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
Chao Liu's avatar
Chao Liu committed
198
199
200
201
202
203
204

        constexpr auto repeat_lengths =
            transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
                                SrcLengths{},
                                src_data_per_cluster_per_dims);

        constexpr auto thread_tensor_lengths = transform_sequences(
Chao Liu's avatar
Chao Liu committed
205
            std::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
Chao Liu's avatar
Chao Liu committed
206
207
208
209
210
211
212

        constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);

        static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
            constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};

            constexpr auto clipboard_data_multi_id = transform_sequences(
Chao Liu's avatar
Chao Liu committed
213
                std::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
Chao Liu's avatar
Chao Liu committed
214
215

            constexpr auto src_data_multi_id = transform_sequences(
Chao Liu's avatar
Chao Liu committed
216
                std::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
Chao Liu's avatar
Chao Liu committed
217
218
219
220
221
222
223
224
225

            // reorder src_data_multi_id to get dst_data_multi_id
            constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});

            constexpr index_t clipboard_offset =
                thread_tensor_desc.Get1dIndex(clipboard_data_multi_id);

            constexpr index_t dst_offset = DstDesc{}.Get1dIndex(dst_data_multi_id);

Chao Liu's avatar
Chao Liu committed
226
// write in the order of dst
Chao Liu's avatar
Chao Liu committed
227
#if 1
Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
233
234
            threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
                                                                  p_clipboard + clipboard_offset,
                                                                  DstDesc{},
                                                                  p_dst + dst_offset +
                                                                      mDstMyThreadOffset,
                                                                  thread_sub_tensor_lengths,
                                                                  MapDst2Src{});
Chao Liu's avatar
Chao Liu committed
235
#else
Chao Liu's avatar
Chao Liu committed
236
237
238
239
240
241
242
243
            threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
                                                                  p_clipboard + clipboard_offset,
                                                                  DstDesc{},
                                                                  p_dst + dst_offset +
                                                                      mDstMyThreadOffset,
                                                                  thread_sub_tensor_lengths,
                                                                  MapDst2Src{},
                                                                  Number<DstDataPerWrite>{});
Chao Liu's avatar
Chao Liu committed
244
245
246
247
248
249
250
251
252
253
254
255
#endif
        });
    }

    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
    {
        Float p_clipboard[GetRegisterClipboardSize()];

        RunLoadRegisterClipboard(p_src, p_clipboard);
        RunStoreRegisterClipboard(p_clipboard, p_dst);
    }
};