blockwise_2d_tensor_op.hip.hpp 19.8 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
template <index_t BlockSize, class Float, class DstDesc, class F>
5
__device__ void
Chao Liu's avatar
Chao Liu committed
6
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
7
{
Chao Liu's avatar
Chao Liu committed
8
9
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
10

11
12
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
13
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
14

15
16
17
#if 0
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
18
19
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
20
21
22
    }
#endif

Chao Liu's avatar
Chao Liu committed
23
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
24

Chao Liu's avatar
Chao Liu committed
25
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
26
    {
Chao Liu's avatar
Chao Liu committed
27
        index_t is = threadIdx.x + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
28

Chao Liu's avatar
Chao Liu committed
29
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
30
31
32

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
33
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
34

Chao Liu's avatar
Chao Liu committed
35
        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
36

Chao Liu's avatar
Chao Liu committed
37
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
43
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
Chao Liu's avatar
Chao Liu committed
44
        index_t is = threadIdx.x + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
45
46
47

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
48
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
49
50
51

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
52
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
53

Chao Liu's avatar
Chao Liu committed
54
            const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
57
58
59
        }
    }
}
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
62
63
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
64
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
65
          class Float,
66
67
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
68
69
          class SrcOpLengths,
          class DstFromSrcReorder,
Chao Liu's avatar
Chao Liu committed
70
          class F>
Chao Liu's avatar
Chao Liu committed
71
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
72
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
73
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
79
{
Chao Liu's avatar
Chao Liu committed
80
81
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
82

Chao Liu's avatar
Chao Liu committed
83
84
    constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
85

86
87
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
88
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
89

Chao Liu's avatar
Chao Liu committed
90
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
91

Chao Liu's avatar
Chao Liu committed
92
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
93
    {
Chao Liu's avatar
Chao Liu committed
94
        index_t is = threadIdx.x + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
95

Chao Liu's avatar
Chao Liu committed
96
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
97

98
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
99

100
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
101

102
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
        const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
        const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
107
108

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
109
110
    }

111
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
112
113
114

    if(has_tail)
    {
Chao Liu's avatar
Chao Liu committed
115
        index_t is = threadIdx.x + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
116

117
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
118
        {
Chao Liu's avatar
Chao Liu committed
119
            index_t did[2];
120
121

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
122

123
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
124

125
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
            const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
128

Chao Liu's avatar
Chao Liu committed
129
            const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
130

131
            f(p_src[aindex], p_dst[bindex]);
132
133
134
135
        }
    }
}

Chao Liu's avatar
Chao Liu committed
136
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
137
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
138
{
Chao Liu's avatar
Chao Liu committed
139
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
140

Chao Liu's avatar
Chao Liu committed
141
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
142
}
143

Chao Liu's avatar
Chao Liu committed
144
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
145
          class Float,
146
147
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
148
149
150
          class SrcOpLengths,
          class DstFromSrcReorder>
__device__ void
Chao Liu's avatar
Chao Liu committed
151
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
152
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
153
154
155
156
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
                                                     DstFromSrcReorder)
157
{
Chao Liu's avatar
Chao Liu committed
158
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
159

Chao Liu's avatar
Chao Liu committed
160
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
161
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
162
163
}

Chao Liu's avatar
Chao Liu committed
164
template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
165
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
166
{
Chao Liu's avatar
Chao Liu committed
167
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
168
169
    {
        constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
170

171
172
173
174
175
        blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
            SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
    }
};

176
177
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
178
template <index_t BlockSize,
179
180
181
182
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
183
184
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
185
struct Blockwise2dTensorCopy2
186
{
Chao Liu's avatar
Chao Liu committed
187
188
    index_t mThreadId0;
    index_t mThreadId1;
189

190
    __device__ Blockwise2dTensorCopy2()
191
    {
192
193
194
195
196
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
197

198
199
200
201
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
202
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
203
    {
204
205
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
206
207
208
        using Float4 = float4;
        using Float2 = float2;

209
210
211
212
213
214
215
216
217
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
218
219
220
221
222
223
224
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
225
226
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
227

Chao Liu's avatar
Chao Liu committed
228
229
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
230

Chao Liu's avatar
Chao Liu committed
231
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
232

Chao Liu's avatar
Chao Liu committed
233
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
234
235
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
236
        constexpr index_t Dim1V1Loop =
237
238
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
239

240
241
242
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
243
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
244
        {
Chao Liu's avatar
Chao Liu committed
245
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
246
247

            // v4
Chao Liu's avatar
Chao Liu committed
248
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
249
            {
Chao Liu's avatar
Chao Liu committed
250
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
251

Chao Liu's avatar
Chao Liu committed
252
253
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
254

Chao Liu's avatar
Chao Liu committed
255
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
256
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
257
258
259
            }

            // v2
Chao Liu's avatar
Chao Liu committed
260
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
261
            {
Chao Liu's avatar
Chao Liu committed
262
                index_t did1 =
263
264
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
265
266
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
267

Chao Liu's avatar
Chao Liu committed
268
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
269
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
270
271
272
            }

            // v1
Chao Liu's avatar
Chao Liu committed
273
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
274
            {
Chao Liu's avatar
Chao Liu committed
275
276
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
277

Chao Liu's avatar
Chao Liu committed
278
279
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
280
281
282
283
284
285
286

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
287
288
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
289
290
291

                if(did1 < L1)
                {
Chao Liu's avatar
Chao Liu committed
292
293
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
294
295
296
297
298
299
300
301
302

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
303
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
304
305
306
307
308

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
309
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
310
                {
Chao Liu's avatar
Chao Liu committed
311
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
312

Chao Liu's avatar
Chao Liu committed
313
314
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
315

Chao Liu's avatar
Chao Liu committed
316
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
317
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
318
319
320
                }

                // v2
Chao Liu's avatar
Chao Liu committed
321
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
322
                {
Chao Liu's avatar
Chao Liu committed
323
324
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
325

Chao Liu's avatar
Chao Liu committed
326
327
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
328

Chao Liu's avatar
Chao Liu committed
329
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
330
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
331
332
333
                }

                // v1
Chao Liu's avatar
Chao Liu committed
334
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
335
                {
Chao Liu's avatar
Chao Liu committed
336
337
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
338

Chao Liu's avatar
Chao Liu committed
339
340
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
341
342
343
344
345
346
347

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
348
349
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
350
351
352

                    if(did1 < L1)
                    {
Chao Liu's avatar
Chao Liu committed
353
354
                        const index_t sindex = src_desc.Get1dIndex(did0, did1);
                        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
355
356
357
358
359
360
361
362

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
363

364
365
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
366
template <index_t BlockSize,
367
368
369
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
370
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
371
          index_t DataPerRead>
372
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
373
{
374
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
375

Chao Liu's avatar
Chao Liu committed
376
377
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
378

379
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
380
    {
381
382
383
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
384
385
386
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
387
388
389
390

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
391
392
393
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
394

Chao Liu's avatar
Chao Liu committed
395
396
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
397

Chao Liu's avatar
Chao Liu committed
398
399
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
400

Chao Liu's avatar
Chao Liu committed
401
402
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
403
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
404
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
405
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
406

Chao Liu's avatar
Chao Liu committed
407
408
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
409
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
410
411
412
413
414
415
416
417

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
418

Chao Liu's avatar
Chao Liu committed
419
420
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
421
422
423

        mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
424
425
    }

Chao Liu's avatar
Chao Liu committed
426
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
427
    {
428
429
430
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
431
432
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
433

Chao Liu's avatar
Chao Liu committed
434
435
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
436

Chao Liu's avatar
Chao Liu committed
437
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
438
439

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
440
        {
Chao Liu's avatar
Chao Liu committed
441
            if(get_thread_local_1d_id() >= num_active_thread)
442
443
444
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
445
446
        }

Chao Liu's avatar
Chao Liu committed
447
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
448

Chao Liu's avatar
Chao Liu committed
449
450
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
451

Chao Liu's avatar
Chao Liu committed
452
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
453
454
455
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
456
457
        };

Chao Liu's avatar
Chao Liu committed
458
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
459
460
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
461
        }
Chao Liu's avatar
Chao Liu committed
462

Chao Liu's avatar
Chao Liu committed
463
464
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
465
466
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
467
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
468
469
470

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
471
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
472
473
            }
        }
Chao Liu's avatar
Chao Liu committed
474
    }
475

Chao Liu's avatar
Chao Liu committed
476
    __device__ constexpr index_t GetRegisterClipboardSize() const
477
478
479
480
481
482
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
483
484
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
485

Chao Liu's avatar
Chao Liu committed
486
487
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
488
489
490
491
492
493
494
495
496
497

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
                                             Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
498
499
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
500

Chao Liu's avatar
Chao Liu committed
501
502
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
503

Chao Liu's avatar
Chao Liu committed
504
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
505
506
507
508
509
510
511
512
513

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
514
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
515

Chao Liu's avatar
Chao Liu committed
516
517
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
518

Chao Liu's avatar
Chao Liu committed
519
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
520
521
522
            *(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
523
524
        };

Chao Liu's avatar
Chao Liu committed
525
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
526
527
        {
            f_copy(iloop);
528
529
530
531
532
533
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
534
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
535
536
537

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
538
                f_copy(nloop_d0);
539
540
541
542
543
544
545
546
547
548
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
549
550
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
551

Chao Liu's avatar
Chao Liu committed
552
553
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
554

Chao Liu's avatar
Chao Liu committed
555
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
556
557
558
559
560
561
562
563
564

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
565
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
566

Chao Liu's avatar
Chao Liu committed
567
568
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
569

Chao Liu's avatar
Chao Liu committed
570
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
571
572
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4));
Chao Liu's avatar
Chao Liu committed
573
574
        };

Chao Liu's avatar
Chao Liu committed
575
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
576
577
        {
            f_copy(iloop);
578
579
580
581
582
583
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
584
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
585
586
587

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
588
                f_copy(nloop_d0);
589
590
591
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
592
};