blockwise_2d_tensor_op.hip.hpp 26.5 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "ConstantTensorDescriptor.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <index_t BlockSize, class Float, class DstDesc, class F>
6
__device__ void
Chao Liu's avatar
Chao Liu committed
7
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
8
{
Chao Liu's avatar
Chao Liu committed
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
11

12
13
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
14
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
15

16
#if 0
17
    if(get_thread_local_1d_id() == 0)
18
    {
Chao Liu's avatar
Chao Liu committed
19
20
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
21
22
23
    }
#endif

Chao Liu's avatar
Chao Liu committed
24
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
27
    {
28
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
31
32
33

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
34
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43
44
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
45
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
46
47
48

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
49
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
50
51
52

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
53
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
54

Chao Liu's avatar
Chao Liu committed
55
            const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
58
59
60
        }
    }
}
Chao Liu's avatar
Chao Liu committed
61

62
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
63
64
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
65
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
66
          class Float,
67
68
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
69
70
          class SrcOpLengths,
          class DstFromSrcReorder,
Chao Liu's avatar
Chao Liu committed
71
          class F>
Chao Liu's avatar
Chao Liu committed
72
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
73
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
74
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
83

Chao Liu's avatar
Chao Liu committed
84
85
    constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
86

87
88
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
89
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
94
    {
95
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
98

99
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
100

101
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
102

103
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
        const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
106

Chao Liu's avatar
Chao Liu committed
107
        const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
108
109

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
110
111
    }

112
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
113
114
115

    if(has_tail)
    {
116
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
117

118
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
119
        {
Chao Liu's avatar
Chao Liu committed
120
            index_t did[2];
121
122

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
123

124
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
125

126
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
            const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
129

Chao Liu's avatar
Chao Liu committed
130
            const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
131

132
            f(p_src[aindex], p_dst[bindex]);
133
134
135
136
        }
    }
}

Chao Liu's avatar
Chao Liu committed
137
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
138
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
139
{
Chao Liu's avatar
Chao Liu committed
140
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
143
}
144

Chao Liu's avatar
Chao Liu committed
145
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
146
          class Float,
147
148
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
149
150
151
          class SrcOpLengths,
          class DstFromSrcReorder>
__device__ void
Chao Liu's avatar
Chao Liu committed
152
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
153
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
154
155
156
157
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
                                                     DstFromSrcReorder)
158
{
Chao Liu's avatar
Chao Liu committed
159
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
160

Chao Liu's avatar
Chao Liu committed
161
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
162
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
163
164
}

165
166
167
168
169
170
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
171
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
172
{
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
        constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
201
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
202
    {
203
204
205
206
207
208
209
210
211
212
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
213

214
        constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
215
216
217
218
219
220
221
222
223
224
225
226

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

227
228
            const index_t src_index = src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
252
253
254
    }
};

255
256
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
257
template <index_t BlockSize,
258
259
260
261
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
262
263
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
264
struct Blockwise2dTensorCopy2
265
{
Chao Liu's avatar
Chao Liu committed
266
267
    index_t mThreadId0;
    index_t mThreadId1;
268

269
    __device__ Blockwise2dTensorCopy2()
270
    {
271
272
273
274
275
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
276

277
278
279
280
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
281
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
282
    {
283
284
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
285
286
287
        using Float4 = float4;
        using Float2 = float2;

288
289
290
291
292
293
294
295
296
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
297
298
299
300
301
302
303
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
304
305
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
306

Chao Liu's avatar
Chao Liu committed
307
308
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
309

Chao Liu's avatar
Chao Liu committed
310
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
311

Chao Liu's avatar
Chao Liu committed
312
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
313
314
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
315
        constexpr index_t Dim1V1Loop =
316
317
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
318

319
320
321
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
322
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
323
        {
Chao Liu's avatar
Chao Liu committed
324
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
325
326

            // v4
Chao Liu's avatar
Chao Liu committed
327
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
328
            {
Chao Liu's avatar
Chao Liu committed
329
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
330

Chao Liu's avatar
Chao Liu committed
331
332
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
333

Chao Liu's avatar
Chao Liu committed
334
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
335
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
336
337
338
            }

            // v2
Chao Liu's avatar
Chao Liu committed
339
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
340
            {
Chao Liu's avatar
Chao Liu committed
341
                index_t did1 =
342
343
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
344
345
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
346

Chao Liu's avatar
Chao Liu committed
347
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
348
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
349
350
351
            }

            // v1
Chao Liu's avatar
Chao Liu committed
352
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
353
            {
Chao Liu's avatar
Chao Liu committed
354
355
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
356

Chao Liu's avatar
Chao Liu committed
357
358
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
359
360
361
362
363
364
365

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
366
367
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
368
369
370

                if(did1 < L1)
                {
Chao Liu's avatar
Chao Liu committed
371
372
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
373
374
375
376
377
378
379
380
381

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
382
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
383
384
385
386
387

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
388
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
389
                {
Chao Liu's avatar
Chao Liu committed
390
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
391

Chao Liu's avatar
Chao Liu committed
392
393
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
394

Chao Liu's avatar
Chao Liu committed
395
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
396
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
397
398
399
                }

                // v2
Chao Liu's avatar
Chao Liu committed
400
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
401
                {
Chao Liu's avatar
Chao Liu committed
402
403
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
404

Chao Liu's avatar
Chao Liu committed
405
406
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
407

Chao Liu's avatar
Chao Liu committed
408
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
409
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
410
411
412
                }

                // v1
Chao Liu's avatar
Chao Liu committed
413
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
414
                {
Chao Liu's avatar
Chao Liu committed
415
416
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
417

Chao Liu's avatar
Chao Liu committed
418
419
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
420
421
422
423
424
425
426

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
427
428
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
429
430
431

                    if(did1 < L1)
                    {
Chao Liu's avatar
Chao Liu committed
432
433
                        const index_t sindex = src_desc.Get1dIndex(did0, did1);
                        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
434
435
436
437
438
439
440
441

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
442

443
444
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
445
template <index_t BlockSize,
446
447
448
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
449
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
450
          index_t DataPerRead>
451
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
452
{
453
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
454

Chao Liu's avatar
Chao Liu committed
455
456
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
457

458
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
459
    {
460
461
462
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
463
464
465
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
466
467
468
469

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
470
471
472
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
473

Chao Liu's avatar
Chao Liu committed
474
475
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
476

Chao Liu's avatar
Chao Liu committed
477
478
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
479

Chao Liu's avatar
Chao Liu committed
480
481
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
482
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
483
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
484
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
485

Chao Liu's avatar
Chao Liu committed
486
487
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
488
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
489
490
491
492
493
494
495
496

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
497

Chao Liu's avatar
Chao Liu committed
498
499
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
500
501
502

        mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
503
504
    }

Chao Liu's avatar
Chao Liu committed
505
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
506
    {
507
508
509
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
510
511
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
512

Chao Liu's avatar
Chao Liu committed
513
514
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
515

Chao Liu's avatar
Chao Liu committed
516
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
517
518

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
519
        {
Chao Liu's avatar
Chao Liu committed
520
            if(get_thread_local_1d_id() >= num_active_thread)
521
522
523
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
524
525
        }

Chao Liu's avatar
Chao Liu committed
526
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
527

Chao Liu's avatar
Chao Liu committed
528
529
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
530

Chao Liu's avatar
Chao Liu committed
531
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
532
533
534
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
535
536
        };

Chao Liu's avatar
Chao Liu committed
537
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
538
539
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
540
        }
Chao Liu's avatar
Chao Liu committed
541

Chao Liu's avatar
Chao Liu committed
542
543
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
544
545
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
546
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
547
548
549

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
550
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
551
552
            }
        }
Chao Liu's avatar
Chao Liu committed
553
    }
554

Chao Liu's avatar
Chao Liu committed
555
    __device__ constexpr index_t GetRegisterClipboardSize() const
556
557
558
559
560
561
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
562
563
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
564

Chao Liu's avatar
Chao Liu committed
565
566
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
567
568
569
570
571

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
572
                                             Float* __restrict__ p_clipboard) const
573
574
575
576
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
577
578
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
579

Chao Liu's avatar
Chao Liu committed
580
581
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
582

Chao Liu's avatar
Chao Liu committed
583
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
584
585
586
587
588
589
590
591
592

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
593
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
594

Chao Liu's avatar
Chao Liu committed
595
596
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
597

Chao Liu's avatar
Chao Liu committed
598
        auto f_copy = [&](index_t iloop) {
599
600
601
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
602
603
        };

Chao Liu's avatar
Chao Liu committed
604
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
605
606
        {
            f_copy(iloop);
607
608
609
610
611
612
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
613
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
614
615
616

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
617
                f_copy(nloop_d0);
618
619
620
621
622
623
624
625
626
627
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
628
629
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
630

Chao Liu's avatar
Chao Liu committed
631
632
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
633

Chao Liu's avatar
Chao Liu committed
634
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
635
636
637
638
639
640
641
642
643

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
644
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
645

Chao Liu's avatar
Chao Liu committed
646
647
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
648

Chao Liu's avatar
Chao Liu committed
649
        auto f_copy = [&](index_t iloop) {
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

#if DEVICE_BACKEND_HIP
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
771
772
        };

Chao Liu's avatar
Chao Liu committed
773
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
774
775
        {
            f_copy(iloop);
776
777
778
779
780
781
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
782
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
783
784
785

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
786
                f_copy(nloop_d0);
787
788
789
            }
        }
    }
790
#endif
Chao Liu's avatar
Chao Liu committed
791
};