blockwise_2d_tensor_op.hip.hpp 26.6 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "ConstantTensorDescriptor.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <index_t BlockSize, class Float, class DstDesc, class F>
6
__device__ void
Chao Liu's avatar
Chao Liu committed
7
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
8
{
Chao Liu's avatar
Chao Liu committed
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
11

12
13
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
14
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
15

16
#if 0
17
    if(get_thread_local_1d_id() == 0)
18
    {
Chao Liu's avatar
Chao Liu committed
19
20
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
21
22
23
    }
#endif

Chao Liu's avatar
Chao Liu committed
24
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
27
    {
28
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
31
32
33

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
34
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43
44
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
45
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
46
47
48

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
49
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
50
51
52

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
53
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
54

Chao Liu's avatar
Chao Liu committed
55
            const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
58
59
60
        }
    }
}
Chao Liu's avatar
Chao Liu committed
61

62
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
63
64
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
65
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
66
          class Float,
67
68
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
69
70
          class SrcOpLengths,
          class DstFromSrcReorder,
Chao Liu's avatar
Chao Liu committed
71
          class F>
Chao Liu's avatar
Chao Liu committed
72
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
73
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
74
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
83

Chao Liu's avatar
Chao Liu committed
84
85
    constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
86

87
88
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
89
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
94
    {
95
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
98

99
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
100

101
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
102

103
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
        const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
106

Chao Liu's avatar
Chao Liu committed
107
        const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
108
109

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
110
111
    }

112
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
113
114
115

    if(has_tail)
    {
116
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
117

118
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
119
        {
Chao Liu's avatar
Chao Liu committed
120
            index_t did[2];
121
122

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
123

124
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
125

126
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
            const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
129

Chao Liu's avatar
Chao Liu committed
130
            const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
131

132
            f(p_src[aindex], p_dst[bindex]);
133
134
135
136
        }
    }
}

Chao Liu's avatar
Chao Liu committed
137
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
138
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
139
{
Chao Liu's avatar
Chao Liu committed
140
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
143
}
144

Chao Liu's avatar
Chao Liu committed
145
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
146
          class Float,
147
148
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
149
150
151
          class SrcOpLengths,
          class DstFromSrcReorder>
__device__ void
Chao Liu's avatar
Chao Liu committed
152
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
153
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
154
155
156
157
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
                                                     DstFromSrcReorder)
158
{
Chao Liu's avatar
Chao Liu committed
159
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
160

Chao Liu's avatar
Chao Liu committed
161
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
162
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
163
164
}

165
166
167
168
169
170
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
171
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
172
{
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
        constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
201
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
202
    {
203
204
205
206
207
208
209
210
211
212
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        constexpr auto ref_desc =
            make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

            const index_t src_index =
                src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index =
                dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
255
256
257
    }
};

258
259
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
260
template <index_t BlockSize,
261
262
263
264
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
265
266
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
267
struct Blockwise2dTensorCopy2
268
{
Chao Liu's avatar
Chao Liu committed
269
270
    index_t mThreadId0;
    index_t mThreadId1;
271

272
    __device__ Blockwise2dTensorCopy2()
273
    {
274
275
276
277
278
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
279

280
281
282
283
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
284
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
285
    {
286
287
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
288
289
290
        using Float4 = float4;
        using Float2 = float2;

291
292
293
294
295
296
297
298
299
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
300
301
302
303
304
305
306
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
307
308
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
309

Chao Liu's avatar
Chao Liu committed
310
311
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
312

Chao Liu's avatar
Chao Liu committed
313
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
314

Chao Liu's avatar
Chao Liu committed
315
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
316
317
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
318
        constexpr index_t Dim1V1Loop =
319
320
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
321

322
323
324
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
325
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
326
        {
Chao Liu's avatar
Chao Liu committed
327
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
328
329

            // v4
Chao Liu's avatar
Chao Liu committed
330
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
331
            {
Chao Liu's avatar
Chao Liu committed
332
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
333

Chao Liu's avatar
Chao Liu committed
334
335
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
336

Chao Liu's avatar
Chao Liu committed
337
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
338
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
339
340
341
            }

            // v2
Chao Liu's avatar
Chao Liu committed
342
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
343
            {
Chao Liu's avatar
Chao Liu committed
344
                index_t did1 =
345
346
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
347
348
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
349

Chao Liu's avatar
Chao Liu committed
350
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
351
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
352
353
354
            }

            // v1
Chao Liu's avatar
Chao Liu committed
355
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
356
            {
Chao Liu's avatar
Chao Liu committed
357
358
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
359

Chao Liu's avatar
Chao Liu committed
360
361
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
362
363
364
365
366
367
368

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
369
370
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
371
372
373

                if(did1 < L1)
                {
Chao Liu's avatar
Chao Liu committed
374
375
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
376
377
378
379
380
381
382
383
384

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
385
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
386
387
388
389
390

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
391
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
392
                {
Chao Liu's avatar
Chao Liu committed
393
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
394

Chao Liu's avatar
Chao Liu committed
395
396
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
397

Chao Liu's avatar
Chao Liu committed
398
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
399
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
400
401
402
                }

                // v2
Chao Liu's avatar
Chao Liu committed
403
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
404
                {
Chao Liu's avatar
Chao Liu committed
405
406
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
407

Chao Liu's avatar
Chao Liu committed
408
409
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
410

Chao Liu's avatar
Chao Liu committed
411
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
412
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
413
414
415
                }

                // v1
Chao Liu's avatar
Chao Liu committed
416
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
417
                {
Chao Liu's avatar
Chao Liu committed
418
419
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
420

Chao Liu's avatar
Chao Liu committed
421
422
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
423
424
425
426
427
428
429

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
430
431
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
432
433
434

                    if(did1 < L1)
                    {
Chao Liu's avatar
Chao Liu committed
435
436
                        const index_t sindex = src_desc.Get1dIndex(did0, did1);
                        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
437
438
439
440
441
442
443
444

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
445

446
447
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
448
template <index_t BlockSize,
449
450
451
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
452
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
453
          index_t DataPerRead>
454
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
455
{
456
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
457

Chao Liu's avatar
Chao Liu committed
458
459
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
460

461
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
462
    {
463
464
465
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
466
467
468
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
469
470
471
472

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
473
474
475
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
476

Chao Liu's avatar
Chao Liu committed
477
478
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
479

Chao Liu's avatar
Chao Liu committed
480
481
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
482

Chao Liu's avatar
Chao Liu committed
483
484
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
485
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
486
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
487
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
488

Chao Liu's avatar
Chao Liu committed
489
490
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
491
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
492
493
494
495
496
497
498
499

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
500

Chao Liu's avatar
Chao Liu committed
501
502
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
503
504
505

        mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
506
507
    }

Chao Liu's avatar
Chao Liu committed
508
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
509
    {
510
511
512
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
513
514
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
515

Chao Liu's avatar
Chao Liu committed
516
517
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
518

Chao Liu's avatar
Chao Liu committed
519
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
520
521

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
522
        {
Chao Liu's avatar
Chao Liu committed
523
            if(get_thread_local_1d_id() >= num_active_thread)
524
525
526
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
527
528
        }

Chao Liu's avatar
Chao Liu committed
529
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
530

Chao Liu's avatar
Chao Liu committed
531
532
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
533

Chao Liu's avatar
Chao Liu committed
534
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
535
536
537
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
538
539
        };

Chao Liu's avatar
Chao Liu committed
540
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
541
542
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
543
        }
Chao Liu's avatar
Chao Liu committed
544

Chao Liu's avatar
Chao Liu committed
545
546
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
547
548
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
549
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
550
551
552

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
553
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
554
555
            }
        }
Chao Liu's avatar
Chao Liu committed
556
    }
557

Chao Liu's avatar
Chao Liu committed
558
    __device__ constexpr index_t GetRegisterClipboardSize() const
559
560
561
562
563
564
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
565
566
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
567

Chao Liu's avatar
Chao Liu committed
568
569
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
570
571
572
573
574

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
575
                                             Float* __restrict__ p_clipboard) const
576
577
578
579
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
580
581
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
582

Chao Liu's avatar
Chao Liu committed
583
584
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
585

Chao Liu's avatar
Chao Liu committed
586
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
587
588
589
590
591
592
593
594
595

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
596
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
597

Chao Liu's avatar
Chao Liu committed
598
599
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
600

Chao Liu's avatar
Chao Liu committed
601
        auto f_copy = [&](index_t iloop) {
602
603
604
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
605
606
        };

Chao Liu's avatar
Chao Liu committed
607
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
608
609
        {
            f_copy(iloop);
610
611
612
613
614
615
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
616
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
617
618
619

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
620
                f_copy(nloop_d0);
621
622
623
624
625
626
627
628
629
630
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
631
632
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
633

Chao Liu's avatar
Chao Liu committed
634
635
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
636

Chao Liu's avatar
Chao Liu committed
637
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
638
639
640
641
642
643
644
645
646

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
647
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
648

Chao Liu's avatar
Chao Liu committed
649
650
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
651

Chao Liu's avatar
Chao Liu committed
652
        auto f_copy = [&](index_t iloop) {
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

#if DEVICE_BACKEND_HIP
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
774
775
        };

Chao Liu's avatar
Chao Liu committed
776
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
777
778
        {
            f_copy(iloop);
779
780
781
782
783
784
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
785
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
786
787
788

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
789
                f_copy(nloop_d0);
790
791
792
            }
        }
    }
793
#endif
Chao Liu's avatar
Chao Liu committed
794
};