blockwise_2d_tensor_op.hip.hpp 26.8 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "ConstantTensorDescriptor.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <index_t BlockSize, class Float, class DstDesc, class F>
6
__device__ void
Chao Liu's avatar
Chao Liu committed
7
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
8
{
Chao Liu's avatar
Chao Liu committed
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
11

12
13
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
14
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
15

16
#if 0
17
    if(get_thread_local_1d_id() == 0)
18
    {
Chao Liu's avatar
Chao Liu committed
19
20
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
21
22
23
    }
#endif

Chao Liu's avatar
Chao Liu committed
24
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
27
    {
28
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
31
32
33

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
34
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
35

36
        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43
44
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
45
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
46
47
48

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
49
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
50
51
52

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
53
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
54

55
            const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
58
59
60
        }
    }
}
Chao Liu's avatar
Chao Liu committed
61

62
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
63
64
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
65
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
66
          class Float,
67
68
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
69
          class SrcOpLengths,
70
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
71
          class F>
Chao Liu's avatar
Chao Liu committed
72
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
73
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
74
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
75
76
77
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
78
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
79
    F f)
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
83

84
85
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
86

87
88
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
89
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
94
    {
95
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
98

99
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
100

101
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
102

103
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
104

105
        const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
106

107
        const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
108
109

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
110
111
    }

112
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
113
114
115

    if(has_tail)
    {
116
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
117

118
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
119
        {
Chao Liu's avatar
Chao Liu committed
120
            index_t did[2];
121
122

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
123

124
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
125

126
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
127

128
            const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
129

130
            const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
131

132
            f(p_src[aindex], p_dst[bindex]);
133
134
135
136
        }
    }
}

Chao Liu's avatar
Chao Liu committed
137
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
138
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
139
{
Chao Liu's avatar
Chao Liu committed
140
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
143
}
144

Chao Liu's avatar
Chao Liu committed
145
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
146
          class Float,
147
148
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
149
          class SrcOpLengths,
150
          class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
151
__device__ void
Chao Liu's avatar
Chao Liu committed
152
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
153
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
154
155
156
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
157
                                                     MapDst2Src)
158
{
Chao Liu's avatar
Chao Liu committed
159
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
160

Chao Liu's avatar
Chao Liu committed
161
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
162
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
163
164
}

165
166
167
168
169
170
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
171
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
172
{
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
195
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
196
197
198
199
200

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
201
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
202
    {
203
204
205
206
207
208
209
210
211
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

212
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
213

214
        constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
215
216
217
218
219
220
221
222
223
224
225
226

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

227
228
229
230
            const index_t src_index =
                src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index =
                dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
254
255
256
    }
};

257
258
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
259
template <index_t BlockSize,
260
261
262
263
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
264
265
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
266
struct Blockwise2dTensorCopy2
267
{
Chao Liu's avatar
Chao Liu committed
268
269
    index_t mThreadId0;
    index_t mThreadId1;
270

271
    __device__ Blockwise2dTensorCopy2()
272
    {
273
274
275
276
277
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
278

279
280
281
282
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
283
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
284
    {
285
286
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
287
288
289
        using Float4 = float4;
        using Float2 = float2;

290
291
292
293
294
295
296
297
298
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
299
300
301
302
303
304
305
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
306
307
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
308

Chao Liu's avatar
Chao Liu committed
309
310
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
311

Chao Liu's avatar
Chao Liu committed
312
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
313

Chao Liu's avatar
Chao Liu committed
314
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
315
316
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
317
        constexpr index_t Dim1V1Loop =
318
319
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
320

321
322
323
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
324
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
325
        {
Chao Liu's avatar
Chao Liu committed
326
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
327
328

            // v4
Chao Liu's avatar
Chao Liu committed
329
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
330
            {
Chao Liu's avatar
Chao Liu committed
331
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
332

333
334
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
335

Chao Liu's avatar
Chao Liu committed
336
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
337
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
338
339
340
            }

            // v2
Chao Liu's avatar
Chao Liu committed
341
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
342
            {
Chao Liu's avatar
Chao Liu committed
343
                index_t did1 =
344
345
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

346
347
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
348

Chao Liu's avatar
Chao Liu committed
349
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
350
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
351
352
353
            }

            // v1
Chao Liu's avatar
Chao Liu committed
354
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
355
            {
Chao Liu's avatar
Chao Liu committed
356
357
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
358

359
360
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
361
362
363
364
365
366
367

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
368
369
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
370
371
372

                if(did1 < L1)
                {
373
374
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
375
376
377
378
379
380
381
382
383

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
384
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
385
386
387
388
389

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
390
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
391
                {
Chao Liu's avatar
Chao Liu committed
392
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
393

394
395
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
396

Chao Liu's avatar
Chao Liu committed
397
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
398
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
399
400
401
                }

                // v2
Chao Liu's avatar
Chao Liu committed
402
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
403
                {
Chao Liu's avatar
Chao Liu committed
404
405
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
406

407
408
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
409

Chao Liu's avatar
Chao Liu committed
410
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
411
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
412
413
414
                }

                // v1
Chao Liu's avatar
Chao Liu committed
415
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
416
                {
Chao Liu's avatar
Chao Liu committed
417
418
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
419

420
421
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
422
423
424
425
426
427
428

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
429
430
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
431
432
433

                    if(did1 < L1)
                    {
434
435
                        const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
436
437
438
439
440
441
442
443

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
444

445
446
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
447
template <index_t BlockSize,
448
449
450
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
451
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
452
          index_t DataPerRead>
453
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
454
{
455
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
456

Chao Liu's avatar
Chao Liu committed
457
458
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
459

460
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
461
    {
462
463
464
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
465
466
467
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
468
469
470
471

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
472
473
474
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
475

Chao Liu's avatar
Chao Liu committed
476
        constexpr index_t L1 = CopyLengths{}.Get(I1);
477

Chao Liu's avatar
Chao Liu committed
478
479
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
480

Chao Liu's avatar
Chao Liu committed
481
482
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
483
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
484
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
485
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
486

Chao Liu's avatar
Chao Liu committed
487
488
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
489
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
490
491
492
493
494
495
496
497

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
498

Chao Liu's avatar
Chao Liu committed
499
500
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
501

502
503
504
505
        mSrcMyThreadOffset =
            SrcDesc{}.GetOffsetFromMultiIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset =
            DstDesc{}.GetOffsetFromMultiIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
506
507
    }

Chao Liu's avatar
Chao Liu committed
508
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
509
    {
510
511
512
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
513
514
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
515

Chao Liu's avatar
Chao Liu committed
516
517
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
518

Chao Liu's avatar
Chao Liu committed
519
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
520
521

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
522
        {
Chao Liu's avatar
Chao Liu committed
523
            if(get_thread_local_1d_id() >= num_active_thread)
524
525
526
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
527
528
        }

Chao Liu's avatar
Chao Liu committed
529
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
530

Chao Liu's avatar
Chao Liu committed
531
532
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
533

Chao Liu's avatar
Chao Liu committed
534
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
535
536
537
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
538
539
        };

Chao Liu's avatar
Chao Liu committed
540
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
541
542
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
543
        }
Chao Liu's avatar
Chao Liu committed
544

Chao Liu's avatar
Chao Liu committed
545
546
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
547
548
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
549
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
550
551
552

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
553
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
554
555
            }
        }
Chao Liu's avatar
Chao Liu committed
556
    }
557

Chao Liu's avatar
Chao Liu committed
558
    __device__ constexpr index_t GetRegisterClipboardSize() const
559
560
561
562
563
564
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
565
566
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
567

Chao Liu's avatar
Chao Liu committed
568
569
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
570
571
572
573
574

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
575
                                             Float* __restrict__ p_clipboard) const
576
577
578
579
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
580
581
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
582

Chao Liu's avatar
Chao Liu committed
583
584
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
585

Chao Liu's avatar
Chao Liu committed
586
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
587
588
589
590
591
592
593
594
595

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
596
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
597

Chao Liu's avatar
Chao Liu committed
598
599
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
600

Chao Liu's avatar
Chao Liu committed
601
        auto f_copy = [&](index_t iloop) {
602
603
604
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
605
606
        };

Chao Liu's avatar
Chao Liu committed
607
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
608
609
        {
            f_copy(iloop);
610
611
612
613
614
615
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
616
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
617
618
619

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
620
                f_copy(nloop_d0);
621
622
623
624
625
626
627
628
629
630
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
631
632
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
633

Chao Liu's avatar
Chao Liu committed
634
635
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
636

Chao Liu's avatar
Chao Liu committed
637
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
638
639
640
641
642
643
644
645
646

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
647
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
648

Chao Liu's avatar
Chao Liu committed
649
650
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
651

Chao Liu's avatar
Chao Liu committed
652
        auto f_copy = [&](index_t iloop) {
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

Chao Liu's avatar
Chao Liu committed
675
#if USE_AMD_INLINE_ASM
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
774
775
        };

Chao Liu's avatar
Chao Liu committed
776
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
777
778
        {
            f_copy(iloop);
779
780
781
782
783
784
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
785
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
786
787
788

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
789
                f_copy(nloop_d0);
790
791
792
            }
        }
    }
793
#endif
Chao Liu's avatar
Chao Liu committed
794
};