blockwise_tensor_slice_copy.hpp 12.9 KB
Newer Older
1
2
3
4
5
6
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP

#include "threadwise_tensor_slice_copy.hpp"

namespace ck {
Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
12
13
14
15
16
17
18

template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcLengths,
          class SrcSubLengths,
          class SrcClusterLengths,
          class MapDst2Src,
          class MapThreadCluster2SrcCluster,
          index_t SrcDataPerRead,
          index_t DstDataPerWrite>
Chao Liu's avatar
Chao Liu committed
19
struct BlockwiseTensorSliceReorderCopy_v3
Chao Liu's avatar
Chao Liu committed
20
21
22
{
    static constexpr index_t nDim = SrcLengths::GetSize();

Chao Liu's avatar
Chao Liu committed
23
24
    index_t mThreadSrcOffset;
    index_t mThreadDstOffset;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
27
28
    __device__
    BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
                                       Array<index_t, nDim> dst_block_data_multi_id_begin)
Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    {
        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr auto src_lengths = SrcLengths{};

        constexpr auto map_dst2src = MapDst2Src{};

        constexpr auto src_sub_lengths = SrcSubLengths{};
        constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);

        constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};

        constexpr auto src_cluster_lengths = SrcClusterLengths{};
        constexpr auto thread_cluster_lengths =
            src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);

46
        constexpr auto thread_cluster_desc =
Chao Liu's avatar
Chao Liu committed
47
            make_ConstantTensorDescriptor_packed(thread_cluster_lengths);
Chao Liu's avatar
Chao Liu committed
48
49
50
51
52

        // sanity check: data type
        static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");

        // sanity check: nDim
Chao Liu's avatar
Chao Liu committed
53
54
55
        static_assert(SrcDesc::GetNumOfDimension() == nDim &&
                          DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
                          SrcSubLengths::GetSize() == nDim &&
Chao Liu's avatar
Chao Liu committed
56
57
58
59
60
61
62
63
64
65
66
                          SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
                          MapThreadCluster2SrcCluster::GetSize() == nDim,
                      "wrong! nDim is not consistent\n");

        // sanity check: BlockSize
        constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();

        static_assert(BlockSize >= num_active_thread,
                      "wrong! BlockSize is not big enough for ThreadPerDims!");

        // sanity check: work division
Chao Liu's avatar
Chao Liu committed
67
        static_for<0, nDim, 1>{}([&](auto IDim) {
Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            constexpr auto I                  = decltype(IDim){};
            constexpr index_t src_len         = src_lengths.Get(I);
            constexpr index_t src_sub_len     = src_sub_lengths.Get(I);
            constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
            static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
                          "wrong! cannot evenly divide Src tensor lengths");
        });

        // sanity check: src read
        static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
                      "wrong! only support SrcDataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
                      "wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");

        static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
                      "wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");

        static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
                      "wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
                      "keep alignment");

        // sanity check: dst write
        static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
                      "wrong! only support DstDataPerWrite == 1, 2 or 4!\n");

        static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
                      "wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");

        static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
                      "wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");

        static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
                      "wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
                      "keep alignment");

        // start dividing work
        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

113
114
        const auto thread_multi_id =
            thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
Chao Liu's avatar
Chao Liu committed
115
116
117
118
119
120
121
122
123

        // compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
        // regsiters, or only one copy???
        auto src_data_multi_id =
            reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr auto I    = decltype(IDim){};
            constexpr index_t i = I.Get();
124
125
            // compiler: will it really compute index here, or be merged with
            // GetOffsetFromMultiIndex and
Chao Liu's avatar
Chao Liu committed
126
            // optimized away???
127
            src_data_multi_id(i) *= src_sub_lengths.Get(I);
Chao Liu's avatar
Chao Liu committed
128
129
        });

130
131
        // compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
        // and
Chao Liu's avatar
Chao Liu committed
132
133
134
        // optimized away???
        const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);

Chao Liu's avatar
Chao Liu committed
135
        mThreadSrcOffset =
136
            src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
Chao Liu's avatar
Chao Liu committed
137

Chao Liu's avatar
Chao Liu committed
138
        mThreadDstOffset =
139
            dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
Chao Liu's avatar
Chao Liu committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#if 0
        if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
        {
            print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
        }

        if(get_block_1d_id() == 0)
        {
            printf("id %5u %5u: "
                   "thread_multi_id: %u %u, "
                   "src_block_data_multi_id_begin: %u %u, "
                   "src_data_multi_id: %u %u, "
                   "mThreadSrcOffset %u, mThreadDstOffset %u \n",
                   get_block_1d_id(),
                   get_thread_local_1d_id(),
                   thread_multi_id[0],
                   thread_multi_id[1],
                   src_block_data_multi_id_begin[0],
                   src_block_data_multi_id_begin[1],
                   src_data_multi_id[0],
                   src_data_multi_id[1],
                   mThreadSrcOffset,
                   mThreadDstOffset);
        }
#endif
Chao Liu's avatar
Chao Liu committed
165
166
167
168
169
170
    }

    __device__ static constexpr index_t GetRegisterClipboardSize()
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

Chao Liu's avatar
Chao Liu committed
171
172
        constexpr auto src_data_per_cluster_per_dims =
            thread_sub_tensor_lengths * SrcClusterLengths{};
Chao Liu's avatar
Chao Liu committed
173

174
175
        constexpr auto repeat_lengths = transform_sequences(
            math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
Chao Liu's avatar
Chao Liu committed
176

Chao Liu's avatar
Chao Liu committed
177
        constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
Chao Liu's avatar
Chao Liu committed
178

179
        constexpr auto thread_tensor_desc =
Chao Liu's avatar
Chao Liu committed
180
            make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
Chao Liu's avatar
Chao Liu committed
181
182
183
184
185
186
187
188
189

        return thread_tensor_desc.GetElementSpace();
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
                                             Float* __restrict__ p_clipboard) const
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

Chao Liu's avatar
Chao Liu committed
190
191
        constexpr auto src_data_per_cluster_per_dims =
            thread_sub_tensor_lengths * SrcClusterLengths{};
Chao Liu's avatar
Chao Liu committed
192

193
194
        constexpr auto repeat_lengths = transform_sequences(
            math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
Chao Liu's avatar
Chao Liu committed
195

Chao Liu's avatar
Chao Liu committed
196
        constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
Chao Liu's avatar
Chao Liu committed
197

198
        constexpr auto thread_tensor_desc =
Chao Liu's avatar
Chao Liu committed
199
            make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
Chao Liu's avatar
Chao Liu committed
200
201
202
203

        static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
            constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};

Chao Liu's avatar
Chao Liu committed
204
            constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
Chao Liu's avatar
Chao Liu committed
205

Chao Liu's avatar
Chao Liu committed
206
            constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
Chao Liu's avatar
Chao Liu committed
207

208
            constexpr index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(src_data_multi_id);
Chao Liu's avatar
Chao Liu committed
209
            constexpr index_t clipboard_offset =
210
                thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
Chao Liu's avatar
Chao Liu committed
211

Chao Liu's avatar
Chao Liu committed
212
            threadwise_tensor_slice_copy(SrcDesc{},
Chao Liu's avatar
Chao Liu committed
213
                                         p_src + src_offset + mThreadSrcOffset,
Chao Liu's avatar
Chao Liu committed
214
215
216
217
                                         thread_tensor_desc,
                                         p_clipboard + clipboard_offset,
                                         thread_sub_tensor_lengths,
                                         Number<SrcDataPerRead>{});
Chao Liu's avatar
Chao Liu committed
218
219
220
221
222
223
224
225
        });
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};

Chao Liu's avatar
Chao Liu committed
226
227
        constexpr auto src_data_per_cluster_per_dims =
            thread_sub_tensor_lengths * SrcClusterLengths{};
Chao Liu's avatar
Chao Liu committed
228

229
230
        constexpr auto repeat_lengths = transform_sequences(
            math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
Chao Liu's avatar
Chao Liu committed
231

Chao Liu's avatar
Chao Liu committed
232
        constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
Chao Liu's avatar
Chao Liu committed
233

234
        constexpr auto thread_tensor_desc =
Chao Liu's avatar
Chao Liu committed
235
            make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
Chao Liu's avatar
Chao Liu committed
236
237
238
239

        static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
            constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};

Chao Liu's avatar
Chao Liu committed
240
            constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
Chao Liu's avatar
Chao Liu committed
241

Chao Liu's avatar
Chao Liu committed
242
            constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
Chao Liu's avatar
Chao Liu committed
243
244
245
246
247

            // reorder src_data_multi_id to get dst_data_multi_id
            constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});

            constexpr index_t clipboard_offset =
248
                thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
Chao Liu's avatar
Chao Liu committed
249

250
            constexpr index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id);
Chao Liu's avatar
Chao Liu committed
251

Chao Liu's avatar
Chao Liu committed
252
// write in the order of dst
Chao Liu's avatar
Chao Liu committed
253
#if 1
Chao Liu's avatar
Chao Liu committed
254
255
256
257
            threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
                                                                  p_clipboard + clipboard_offset,
                                                                  DstDesc{},
                                                                  p_dst + dst_offset +
Chao Liu's avatar
Chao Liu committed
258
                                                                      mThreadDstOffset,
Chao Liu's avatar
Chao Liu committed
259
260
                                                                  thread_sub_tensor_lengths,
                                                                  MapDst2Src{});
Chao Liu's avatar
Chao Liu committed
261
#else
Chao Liu's avatar
Chao Liu committed
262
263
264
265
            threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
                                                                  p_clipboard + clipboard_offset,
                                                                  DstDesc{},
                                                                  p_dst + dst_offset +
Chao Liu's avatar
Chao Liu committed
266
                                                                      mThreadDstOffset,
Chao Liu's avatar
Chao Liu committed
267
268
269
                                                                  thread_sub_tensor_lengths,
                                                                  MapDst2Src{},
                                                                  Number<DstDataPerWrite>{});
Chao Liu's avatar
Chao Liu committed
270
271
272
273
274
275
276
277
278
279
280
#endif
        });
    }

    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
    {
        Float p_clipboard[GetRegisterClipboardSize()];

        RunLoadRegisterClipboard(p_src, p_clipboard);
        RunStoreRegisterClipboard(p_clipboard, p_dst);
    }
Chao Liu's avatar
Chao Liu committed
281
282
283
284
285
286
287
288
289
290
291
292
293

    // this function doesn't do santiy check on whether the slicing window is out of the boundary
    // of the tensor being sliced
    template <index_t IDim_, index_t StepSize, bool PositiveDirection>
    __device__ void MoveSlicingWindowOnSourceTensor(
        Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
    {
        constexpr auto IDim = Number<IDim_>{};

        static_if<PositiveDirection>{}([&](auto fwd) {
            mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
        }).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
    }
Chao Liu's avatar
Chao Liu committed
294
};
295
296
297

} // namespace ck
#endif