blockwise_2d_tensor_op.hip.hpp 22 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
template <unsigned BlockSize, class Float, class DstDesc, class F>
5
__device__ void
Chao Liu's avatar
Chao Liu committed
6
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
7
{
Chao Liu's avatar
Chao Liu committed
8
9
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
10

11
12
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
13
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
14

15
16
17
#if 0
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
18
19
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
20
21
22
    }
#endif

Chao Liu's avatar
Chao Liu committed
23
24
    constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;

Chao Liu's avatar
faster  
Chao Liu committed
25
    for(unsigned iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
26
27
28
29
30
31
32
33
34
    {
        unsigned is = threadIdx.x + iloop * BlockSize;

        const unsigned did0 = is / desc.GetStride(I0);

        is -= did0 * desc.GetStride(I0);

        const unsigned did1 = is / desc.GetStride(I1);

Chao Liu's avatar
Chao Liu committed
35
        const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
36

Chao Liu's avatar
Chao Liu committed
37
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
        unsigned is = threadIdx.x + NLoop * BlockSize;

        if(is < desc.GetElementSize())
        {
            const unsigned did0 = is / desc.GetStride(I0);

            is -= did0 * desc.GetStride(I0);

            const unsigned did1 = is / desc.GetStride(I1);

Chao Liu's avatar
Chao Liu committed
54
            const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
57
58
59
        }
    }
}
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
62
63
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
64
template <unsigned BlockSize,
Chao Liu's avatar
Chao Liu committed
65
          class Float,
66
67
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
68
69
          class SrcOpLengths,
          class DstFromSrcReorder,
Chao Liu's avatar
Chao Liu committed
70
          class F>
Chao Liu's avatar
Chao Liu committed
71
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
72
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
73
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
79
{
Chao Liu's avatar
Chao Liu committed
80
81
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
82

Chao Liu's avatar
Chao Liu committed
83
84
    constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
85

86
87
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
88
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
89

90
    constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
91
92
93
94
95

    for(unsigned iloop = 0; iloop < NLoop; ++iloop)
    {
        unsigned is = threadIdx.x + iloop * BlockSize;

Chao Liu's avatar
Chao Liu committed
96
        unsigned did[2];
Chao Liu's avatar
Chao Liu committed
97

98
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
99

100
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
101

102
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
        const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
        const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
107
108

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
109
110
    }

111
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
112
113
114
115
116

    if(has_tail)
    {
        unsigned is = threadIdx.x + NLoop * BlockSize;

117
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
118
        {
Chao Liu's avatar
Chao Liu committed
119
            unsigned did[2];
120
121

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
122

123
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
124

125
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
            const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
128

Chao Liu's avatar
Chao Liu committed
129
            const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
130

131
            f(p_src[aindex], p_dst[bindex]);
132
133
134
135
        }
    }
}

Chao Liu's avatar
Chao Liu committed
136
template <unsigned BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
137
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
138
{
Chao Liu's avatar
Chao Liu committed
139
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
140

Chao Liu's avatar
Chao Liu committed
141
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
142
}
143

Chao Liu's avatar
Chao Liu committed
144
template <unsigned BlockSize,
Chao Liu's avatar
Chao Liu committed
145
          class Float,
146
147
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
148
149
150
          class SrcOpLengths,
          class DstFromSrcReorder>
__device__ void
Chao Liu's avatar
Chao Liu committed
151
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
152
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
153
154
155
156
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
                                                     DstFromSrcReorder)
157
{
Chao Liu's avatar
Chao Liu committed
158
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
159

Chao Liu's avatar
Chao Liu committed
160
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
161
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
162
163
}

Chao Liu's avatar
Chao Liu committed
164
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
165
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
166
{
Chao Liu's avatar
Chao Liu committed
167
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
168
169
    {
        constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
170

171
172
173
174
175
        blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
            SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
    }
};

176
177
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
178
179
180
181
182
183
184
template <unsigned BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          unsigned ThreadPerDim0,
          unsigned ThreadPerDim1>
185
struct Blockwise2dTensorCopy2
186
187
188
189
{
    unsigned mThreadId0;
    unsigned mThreadId1;

190
    __device__ Blockwise2dTensorCopy2()
191
    {
192
193
194
195
196
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
197

198
199
200
201
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
202
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
203
    {
204
205
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
206
207
208
        using Float4 = float4;
        using Float2 = float2;

209
210
211
212
213
214
215
216
217
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
218
219
220
221
222
223
224
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

225
226
227
228
229
230
        constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
        constexpr unsigned L1 = SrcOpLengths{}.Get(I1);

        constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail  = (L0 > ThreadPerDim0 * Dim0Loop);

Chao Liu's avatar
Chao Liu committed
231
232
        constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;

233
        constexpr unsigned Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
234
235
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

236
237
238
        constexpr unsigned Dim1V1Loop =
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
239

240
241
242
243
244
245
246
247
248
249
250
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

        for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop)
        {
            unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0;

            // v4
            for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
            {
                unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
251

Chao Liu's avatar
Chao Liu committed
252
253
                const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
254

Chao Liu's avatar
Chao Liu committed
255
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
256
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
257
258
259
260
261
262
263
264
            }

            // v2
            for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
            {
                unsigned did1 =
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
265
266
267
                const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

Chao Liu's avatar
Chao Liu committed
268
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
269
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            }

            // v1
            for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
            {
                unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                d1v1loop * ThreadPerDim1 + mThreadId1;

                const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
                unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                Dim1V1Loop * ThreadPerDim1 + mThreadId1;

                if(did1 < L1)
                {
                    const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                    const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
            unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;

            if(did0 < L0)
            {

                // v4
                for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
                {
                    unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
313
314
315
                    const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                    const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

Chao Liu's avatar
Chao Liu committed
316
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
317
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
318
319
320
321
322
323
324
325
                }

                // v2
                for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
                {
                    unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                    2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
326
327
328
                    const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                    const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

Chao Liu's avatar
Chao Liu committed
329
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
330
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                }

                // v1
                for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
                {
                    unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
                                    Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 +
                                    mThreadId1;

                    const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                    const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
                    unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
                                    Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 +
                                    mThreadId1;

                    if(did1 < L1)
                    {
                        const unsigned sindex = src_desc.Get1dIndex(did0, did1);
                        const unsigned dindex = dst_desc.Get1dIndex(did0, did1);

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
365

366
367
368
369
370
371
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template <unsigned BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
372
          class CopyLengths,
373
374
          unsigned DataPerRead>
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
375
{
376
377
    unsigned mSrcMyThreadOffset;
    unsigned mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
378

379
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
380
    {
381
382
383
384
385
386
387
388
389
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! only support stride1 == 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
390
391
392
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
393

Chao Liu's avatar
Chao Liu committed
394
395
        constexpr unsigned L0 = CopyLengths{}.Get(I0);
        constexpr unsigned L1 = CopyLengths{}.Get(I1);
396

Chao Liu's avatar
Chao Liu committed
397
        constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
398
        constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
399

Chao Liu's avatar
Chao Liu committed
400
401
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
402
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
403
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
404
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
405

Chao Liu's avatar
Chao Liu committed
406
407
408
409
410
411
412
413
414
415
416
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

        constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
417

418
419
420
421
422
        const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;

        mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
423
424
    }

Chao Liu's avatar
Chao Liu committed
425
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
426
    {
427
428
429
430
431
432
433
434
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        using Float2 = float2;
        using Float4 = float4;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
435
436
        constexpr unsigned L0 = CopyLengths{}.Get(I0);
        constexpr unsigned L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
437

Chao Liu's avatar
Chao Liu committed
438
        constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
439
        constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
440

441
442
443
        constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
444
        {
Chao Liu's avatar
Chao Liu committed
445
            if(get_thread_local_1d_id() >= num_active_thread)
446
447
448
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
449
450
        }

451
452
453
454
        constexpr unsigned nloop_d0 = L0 / thread_per_d0;

        constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
455

Chao Liu's avatar
Chao Liu committed
456
        auto f_copy = [&](unsigned iloop) {
457
458
459
460
461
462
463
464
            if(DataPerRead == 1)
            {
                p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] =
                    p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
            }
            else if(DataPerRead == 2)
            {
                *(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
Chao Liu's avatar
Chao Liu committed
465
466
                    *(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
                                                      iloop * src_loop_stride));
467
468
469
470
            }
            else if(DataPerRead == 4)
            {
                *(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
Chao Liu's avatar
Chao Liu committed
471
472
                    *(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
                                                      iloop * src_loop_stride));
473
474
475
476
477
            }
            else
            {
                assert(false);
            }
Chao Liu's avatar
Chao Liu committed
478
479
480
481
482
        };

        for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
483
        }
Chao Liu's avatar
Chao Liu committed
484

Chao Liu's avatar
Chao Liu committed
485
486
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
487
488
489
490
491
492
        if(has_tail_d0)
        {
            constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
493
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
494
495
            }
        }
Chao Liu's avatar
Chao Liu committed
496
    }
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545

    __device__ constexpr unsigned GetRegisterClipboardSize() const
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr unsigned L0 = CopyLengths{}.Get(I0);
        constexpr unsigned L1 = CopyLengths{}.Get(I1);

        constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
                                             Float* p_clipboard) const
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        using Float2 = float2;
        using Float4 = float4;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr unsigned L0 = CopyLengths{}.Get(I0);
        constexpr unsigned L1 = CopyLengths{}.Get(I1);

        constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;

        constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr unsigned nloop_d0 = L0 / thread_per_d0;

        constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

Chao Liu's avatar
Chao Liu committed
546
        auto f_copy = [&](unsigned iloop) {
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            if(DataPerRead == 1)
            {
                p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
            }
            else if(DataPerRead == 2)
            {
                *(reinterpret_cast<Float2*>(p_clipboard + iloop * 2)) =
                    *(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
                                                      iloop * src_loop_stride));
            }
            else if(DataPerRead == 4)
            {
                *(reinterpret_cast<Float4*>(p_clipboard + iloop * 4)) =
                    *(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
                                                      iloop * src_loop_stride));
            }
            else
            {
                assert(false);
            }
Chao Liu's avatar
Chao Liu committed
567
568
569
570
571
        };

        for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
572
573
574
575
576
577
578
579
580
581
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
582
                f_copy(nloop_d0);
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        using Float2 = float2;
        using Float4 = float4;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr unsigned L0 = CopyLengths{}.Get(I0);
        constexpr unsigned L1 = CopyLengths{}.Get(I1);

        constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;

        constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr unsigned nloop_d0 = L0 / thread_per_d0;

        constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

Chao Liu's avatar
Chao Liu committed
619
        auto f_copy = [&](unsigned iloop) {
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            if(DataPerRead == 1)
            {
                p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop];
            }
            else if(DataPerRead == 2)
            {
                *(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                    *(reinterpret_cast<const Float2*>(p_clipboard + iloop * 2));
            }
            else if(DataPerRead == 4)
            {
                *(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                    *(reinterpret_cast<const Float4*>(p_clipboard + iloop * 4));
            }
            else
            {
                assert(false);
            }
Chao Liu's avatar
Chao Liu committed
638
639
640
641
642
        };

        for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
643
644
645
646
647
648
649
650
651
652
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
653
                f_copy(nloop_d0);
654
655
656
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
657
};