blockwise_2d_tensor_op.hpp 27.1 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
3
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <index_t BlockSize, class Float, class DstDesc, class F>
6
__device__ void
Chao Liu's avatar
Chao Liu committed
7
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
8
{
Chao Liu's avatar
Chao Liu committed
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
11

12
13
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
14
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
15

16
#if 0
17
    if(get_thread_local_1d_id() == 0)
18
    {
Chao Liu's avatar
Chao Liu committed
19
20
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
21
22
23
    }
#endif

Chao Liu's avatar
Chao Liu committed
24
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
27
    {
28
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
31
32
33

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
34
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
35

36
        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43
44
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
45
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
46
47
48

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
49
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
50
51
52

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
53
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
54

55
            const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
58
59
60
        }
    }
}
Chao Liu's avatar
Chao Liu committed
61

62
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
63
64
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
65
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
66
          class Float,
67
68
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
69
          class SrcOpLengths,
70
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
71
          class F>
Chao Liu's avatar
Chao Liu committed
72
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
73
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
74
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
75
76
77
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
78
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
79
    F f)
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
83

84
85
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
86

87
88
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
89
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
94
    {
95
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
98

99
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
100

101
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
102

103
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
104

105
        const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
106

107
        const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
108
109

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
110
111
    }

112
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
113
114
115

    if(has_tail)
    {
116
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
117

118
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
119
        {
Chao Liu's avatar
Chao Liu committed
120
            index_t did[2];
121
122

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
123

124
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
125

126
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
127

128
            const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
129

130
            const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
131

132
            f(p_src[aindex], p_dst[bindex]);
133
134
135
136
        }
    }
}

Chao Liu's avatar
Chao Liu committed
137
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
138
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
139
{
Chao Liu's avatar
Chao Liu committed
140
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
143
}
144

Chao Liu's avatar
Chao Liu committed
145
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
146
          class Float,
147
148
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
149
          class SrcOpLengths,
150
          class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
151
__device__ void
Chao Liu's avatar
Chao Liu committed
152
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
153
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
154
155
156
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
157
                                                     MapDst2Src)
158
{
Chao Liu's avatar
Chao Liu committed
159
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
160

Chao Liu's avatar
Chao Liu committed
161
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
162
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
163
164
}

165
166
167
168
169
170
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
171
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
172
{
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
195
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
196
197
198
199
200

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
201
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
202
    {
203
204
205
206
207
208
209
210
211
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

212
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
213

214
        constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
215
216
217
218
219
220
221
222
223
224
225
226

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

227
228
229
230
            const index_t src_index =
                src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index =
                dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
254
255
256
    }
};

257
258
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
259
template <index_t BlockSize,
260
261
262
263
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
264
265
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
266
struct Blockwise2dTensorCopy2
267
{
Chao Liu's avatar
Chao Liu committed
268
269
    index_t mThreadId0;
    index_t mThreadId1;
270

271
    __device__ Blockwise2dTensorCopy2()
272
    {
273
274
275
276
277
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
278

279
280
281
282
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
283
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
284
    {
285
286
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
287
288
289
        using Float4 = float4;
        using Float2 = float2;

290
291
292
293
294
295
296
297
298
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
299
300
301
302
303
304
305
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
306
307
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
308

Chao Liu's avatar
Chao Liu committed
309
310
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
311

Chao Liu's avatar
Chao Liu committed
312
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
313

Chao Liu's avatar
Chao Liu committed
314
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
315
316
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
317
        constexpr index_t Dim1V1Loop =
318
319
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
320

321
322
323
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
324
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
325
        {
Chao Liu's avatar
Chao Liu committed
326
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
327
328

            // v4
Chao Liu's avatar
Chao Liu committed
329
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
330
            {
Chao Liu's avatar
Chao Liu committed
331
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
332

333
334
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
335

Chao Liu's avatar
Chao Liu committed
336
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
337
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
338
339
340
            }

            // v2
Chao Liu's avatar
Chao Liu committed
341
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
342
            {
Chao Liu's avatar
Chao Liu committed
343
                index_t did1 =
344
345
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

346
347
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
348

Chao Liu's avatar
Chao Liu committed
349
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
350
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
351
352
353
            }

            // v1
Chao Liu's avatar
Chao Liu committed
354
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
355
            {
Chao Liu's avatar
Chao Liu committed
356
357
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
358

359
360
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
361
362
363
364
365
366
367

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
368
369
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
370
371
372

                if(did1 < L1)
                {
373
374
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
375
376
377
378
379
380
381
382
383

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
384
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
385
386
387
388
389

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
390
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
391
                {
Chao Liu's avatar
Chao Liu committed
392
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
393

394
395
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
396

Chao Liu's avatar
Chao Liu committed
397
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
398
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
399
400
401
                }

                // v2
Chao Liu's avatar
Chao Liu committed
402
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
403
                {
Chao Liu's avatar
Chao Liu committed
404
405
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
406

407
408
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
409

Chao Liu's avatar
Chao Liu committed
410
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
411
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
412
413
414
                }

                // v1
Chao Liu's avatar
Chao Liu committed
415
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
416
                {
Chao Liu's avatar
Chao Liu committed
417
418
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
419

420
421
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
422
423
424
425
426
427
428

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
429
430
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
431
432
433

                    if(did1 < L1)
                    {
434
435
                        const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
436
437
438
439
440
441
442
443

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
444

445
446
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
447
template <index_t BlockSize,
448
449
450
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
451
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
452
          index_t DataPerRead>
453
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
454
{
455
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
456

Chao Liu's avatar
Chao Liu committed
457
458
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
459

Chao Liu's avatar
Chao Liu committed
460
461
    __device__ Blockwise2dTensorCopy3(Array<index_t, 2> src_block_data_multi_id_begin,
                                      Array<index_t, 2> dst_block_data_multi_id_begin)
Chao Liu's avatar
Chao Liu committed
462
    {
463
464
465
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
466
467
468
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
469
470
471
472

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
473
474
475
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
476

Chao Liu's avatar
Chao Liu committed
477
        constexpr index_t L1 = CopyLengths{}.Get(I1);
478

Chao Liu's avatar
Chao Liu committed
479
480
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
481

Chao Liu's avatar
Chao Liu committed
482
483
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
484
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
485
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
486
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
487

Chao Liu's avatar
Chao Liu committed
488
489
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
490
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
491
492
493
494
495
496
497
498

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
499

Chao Liu's avatar
Chao Liu committed
500
501
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
502

Chao Liu's avatar
Chao Liu committed
503
504
505
506
507
508
509
        mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
            src_block_data_multi_id_begin +
            Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});

        mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
            dst_block_data_multi_id_begin +
            Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
Chao Liu's avatar
Chao Liu committed
510
511
    }

Chao Liu's avatar
Chao Liu committed
512
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
513
    {
514
515
516
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
517
518
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
519

Chao Liu's avatar
Chao Liu committed
520
521
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
522

Chao Liu's avatar
Chao Liu committed
523
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
524
525

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
526
        {
Chao Liu's avatar
Chao Liu committed
527
            if(get_thread_local_1d_id() >= num_active_thread)
528
529
530
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
531
532
        }

Chao Liu's avatar
Chao Liu committed
533
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
534

Chao Liu's avatar
Chao Liu committed
535
536
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
537

Chao Liu's avatar
Chao Liu committed
538
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
539
540
541
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
542
543
        };

Chao Liu's avatar
Chao Liu committed
544
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
545
546
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
547
        }
Chao Liu's avatar
Chao Liu committed
548

Chao Liu's avatar
Chao Liu committed
549
550
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
551
552
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
553
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
554
555
556

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
557
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
558
559
            }
        }
Chao Liu's avatar
Chao Liu committed
560
    }
561

Chao Liu's avatar
Chao Liu committed
562
    __device__ constexpr index_t GetRegisterClipboardSize() const
563
564
565
566
567
568
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
569
570
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
571

Chao Liu's avatar
Chao Liu committed
572
573
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
574
575
576
577
578

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
579
                                             Float* __restrict__ p_clipboard) const
580
581
582
583
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
584
585
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
586

Chao Liu's avatar
Chao Liu committed
587
588
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
589

Chao Liu's avatar
Chao Liu committed
590
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
591
592
593
594
595
596
597
598
599

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
600
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
601

Chao Liu's avatar
Chao Liu committed
602
603
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
604

Chao Liu's avatar
Chao Liu committed
605
        auto f_copy = [&](index_t iloop) {
606
607
608
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
609
610
        };

Chao Liu's avatar
Chao Liu committed
611
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
612
613
        {
            f_copy(iloop);
614
615
616
617
618
619
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
620
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
621
622
623

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
624
                f_copy(nloop_d0);
625
626
627
628
629
630
631
632
633
634
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
635
636
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
637

Chao Liu's avatar
Chao Liu committed
638
639
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
640

Chao Liu's avatar
Chao Liu committed
641
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
642
643
644
645
646
647
648
649
650

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
651
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
652

Chao Liu's avatar
Chao Liu committed
653
654
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
655

Chao Liu's avatar
Chao Liu committed
656
        auto f_copy = [&](index_t iloop) {
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

Chao Liu's avatar
Chao Liu committed
679
#if USE_AMD_INLINE_ASM
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
778
779
        };

Chao Liu's avatar
Chao Liu committed
780
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
781
782
        {
            f_copy(iloop);
783
784
785
786
787
788
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
789
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
790
791
792

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
793
                f_copy(nloop_d0);
794
795
796
            }
        }
    }
797
#endif
Chao Liu's avatar
Chao Liu committed
798
};