blockwise_2d_tensor_op.hip.hpp 26.5 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "ConstantTensorDescriptor.hip.hpp"
4

Chao Liu's avatar
Chao Liu committed
5
template <index_t BlockSize, class Float, class DstDesc, class F>
6
__device__ void
Chao Liu's avatar
Chao Liu committed
7
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
8
{
Chao Liu's avatar
Chao Liu committed
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
11

12
13
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
14
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
15

16
#if 0
17
    if(get_thread_local_1d_id() == 0)
18
    {
Chao Liu's avatar
Chao Liu committed
19
20
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
21
22
23
    }
#endif

Chao Liu's avatar
Chao Liu committed
24
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
27
    {
28
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
31
32
33

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
34
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
39
40
41
42
43
44
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
45
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
46
47
48

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
49
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
50
51
52

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
53
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
54

Chao Liu's avatar
Chao Liu committed
55
            const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
58
59
60
        }
    }
}
Chao Liu's avatar
Chao Liu committed
61

62
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
63
64
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
65
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
66
          class Float,
67
68
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
69
          class SrcOpLengths,
70
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
71
          class F>
Chao Liu's avatar
Chao Liu committed
72
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
73
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
74
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
75
76
77
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
78
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
79
    F f)
Chao Liu's avatar
Chao Liu committed
80
{
Chao Liu's avatar
Chao Liu committed
81
82
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
83

84
85
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
86

87
88
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
89
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
94
    {
95
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
98

99
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
100

101
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
102

103
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
        const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
106

Chao Liu's avatar
Chao Liu committed
107
        const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
108
109

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
110
111
    }

112
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
113
114
115

    if(has_tail)
    {
116
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
117

118
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
119
        {
Chao Liu's avatar
Chao Liu committed
120
            index_t did[2];
121
122

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
123

124
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
125

126
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
            const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
129

Chao Liu's avatar
Chao Liu committed
130
            const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
131

132
            f(p_src[aindex], p_dst[bindex]);
133
134
135
136
        }
    }
}

Chao Liu's avatar
Chao Liu committed
137
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
138
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
139
{
Chao Liu's avatar
Chao Liu committed
140
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
143
}
144

Chao Liu's avatar
Chao Liu committed
145
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
146
          class Float,
147
148
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
149
          class SrcOpLengths,
150
          class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
151
__device__ void
Chao Liu's avatar
Chao Liu committed
152
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
153
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
154
155
156
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
157
                                                     MapDst2Src)
158
{
Chao Liu's avatar
Chao Liu committed
159
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
160

Chao Liu's avatar
Chao Liu committed
161
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
162
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
163
164
}

165
166
167
168
169
170
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
171
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
172
{
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
195
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
196
197
198
199
200

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
201
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
202
    {
203
204
205
206
207
208
209
210
211
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

212
        constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
213

214
        constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
215
216
217
218
219
220
221
222
223
224
225
226

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

227
228
            const index_t src_index = src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
252
253
254
    }
};

255
256
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
257
template <index_t BlockSize,
258
259
260
261
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
262
263
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
264
struct Blockwise2dTensorCopy2
265
{
Chao Liu's avatar
Chao Liu committed
266
267
    index_t mThreadId0;
    index_t mThreadId1;
268

269
    __device__ Blockwise2dTensorCopy2()
270
    {
271
272
273
274
275
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
276

277
278
279
280
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
281
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
282
    {
283
284
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
285
286
287
        using Float4 = float4;
        using Float2 = float2;

288
289
290
291
292
293
294
295
296
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
297
298
299
300
301
302
303
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
304
305
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
306

Chao Liu's avatar
Chao Liu committed
307
308
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
309

Chao Liu's avatar
Chao Liu committed
310
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
311

Chao Liu's avatar
Chao Liu committed
312
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
313
314
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
315
        constexpr index_t Dim1V1Loop =
316
317
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
318

319
320
321
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
322
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
323
        {
Chao Liu's avatar
Chao Liu committed
324
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
325
326

            // v4
Chao Liu's avatar
Chao Liu committed
327
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
328
            {
Chao Liu's avatar
Chao Liu committed
329
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
330

Chao Liu's avatar
Chao Liu committed
331
332
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
333

Chao Liu's avatar
Chao Liu committed
334
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
335
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
336
337
338
            }

            // v2
Chao Liu's avatar
Chao Liu committed
339
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
340
            {
Chao Liu's avatar
Chao Liu committed
341
                index_t did1 =
342
343
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

Chao Liu's avatar
Chao Liu committed
344
345
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
346

Chao Liu's avatar
Chao Liu committed
347
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
348
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
349
350
351
            }

            // v1
Chao Liu's avatar
Chao Liu committed
352
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
353
            {
Chao Liu's avatar
Chao Liu committed
354
355
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
356

Chao Liu's avatar
Chao Liu committed
357
358
                const index_t sindex = src_desc.Get1dIndex(did0, did1);
                const index_t dindex = dst_desc.Get1dIndex(did0, did1);
359
360
361
362
363
364
365

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
366
367
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
368
369
370

                if(did1 < L1)
                {
Chao Liu's avatar
Chao Liu committed
371
372
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
373
374
375
376
377
378
379
380
381

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
382
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
383
384
385
386
387

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
388
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
389
                {
Chao Liu's avatar
Chao Liu committed
390
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
391

Chao Liu's avatar
Chao Liu committed
392
393
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
394

Chao Liu's avatar
Chao Liu committed
395
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
396
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
397
398
399
                }

                // v2
Chao Liu's avatar
Chao Liu committed
400
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
401
                {
Chao Liu's avatar
Chao Liu committed
402
403
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
404

Chao Liu's avatar
Chao Liu committed
405
406
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
407

Chao Liu's avatar
Chao Liu committed
408
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
409
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
410
411
412
                }

                // v1
Chao Liu's avatar
Chao Liu committed
413
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
414
                {
Chao Liu's avatar
Chao Liu committed
415
416
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
417

Chao Liu's avatar
Chao Liu committed
418
419
                    const index_t sindex = src_desc.Get1dIndex(did0, did1);
                    const index_t dindex = dst_desc.Get1dIndex(did0, did1);
420
421
422
423
424
425
426

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
427
428
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
429
430
431

                    if(did1 < L1)
                    {
Chao Liu's avatar
Chao Liu committed
432
433
                        const index_t sindex = src_desc.Get1dIndex(did0, did1);
                        const index_t dindex = dst_desc.Get1dIndex(did0, did1);
434
435
436
437
438
439
440
441

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
442

443
444
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
445
template <index_t BlockSize,
446
447
448
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
449
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
450
          index_t DataPerRead>
451
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
452
{
453
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
454

Chao Liu's avatar
Chao Liu committed
455
456
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
457

458
    __device__ Blockwise2dTensorCopy3()
Chao Liu's avatar
Chao Liu committed
459
    {
460
461
462
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
463
464
465
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
466
467
468
469

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
470
471
472
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
473

Chao Liu's avatar
Chao Liu committed
474
        constexpr index_t L1 = CopyLengths{}.Get(I1);
475

Chao Liu's avatar
Chao Liu committed
476
477
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
478

Chao Liu's avatar
Chao Liu committed
479
480
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
481
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
482
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
483
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
484

Chao Liu's avatar
Chao Liu committed
485
486
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
487
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
488
489
490
491
492
493
494
495

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
496

Chao Liu's avatar
Chao Liu committed
497
498
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
499
500
501

        mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
        mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
Chao Liu's avatar
Chao Liu committed
502
503
    }

Chao Liu's avatar
Chao Liu committed
504
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
505
    {
506
507
508
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
509
510
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
511

Chao Liu's avatar
Chao Liu committed
512
513
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
514

Chao Liu's avatar
Chao Liu committed
515
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
516
517

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
518
        {
Chao Liu's avatar
Chao Liu committed
519
            if(get_thread_local_1d_id() >= num_active_thread)
520
521
522
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
523
524
        }

Chao Liu's avatar
Chao Liu committed
525
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
526

Chao Liu's avatar
Chao Liu committed
527
528
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
529

Chao Liu's avatar
Chao Liu committed
530
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
531
532
533
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
534
535
        };

Chao Liu's avatar
Chao Liu committed
536
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
537
538
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
539
        }
Chao Liu's avatar
Chao Liu committed
540

Chao Liu's avatar
Chao Liu committed
541
542
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
543
544
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
545
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
546
547
548

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
549
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
550
551
            }
        }
Chao Liu's avatar
Chao Liu committed
552
    }
553

Chao Liu's avatar
Chao Liu committed
554
    __device__ constexpr index_t GetRegisterClipboardSize() const
555
556
557
558
559
560
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
561
562
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
563

Chao Liu's avatar
Chao Liu committed
564
565
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
566
567
568
569
570

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
571
                                             Float* __restrict__ p_clipboard) const
572
573
574
575
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
576
577
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
578

Chao Liu's avatar
Chao Liu committed
579
580
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
581

Chao Liu's avatar
Chao Liu committed
582
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
583
584
585
586
587
588
589
590
591

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
592
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
593

Chao Liu's avatar
Chao Liu committed
594
595
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
596

Chao Liu's avatar
Chao Liu committed
597
        auto f_copy = [&](index_t iloop) {
598
599
600
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
601
602
        };

Chao Liu's avatar
Chao Liu committed
603
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
604
605
        {
            f_copy(iloop);
606
607
608
609
610
611
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
612
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
613
614
615

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
616
                f_copy(nloop_d0);
617
618
619
620
621
622
623
624
625
626
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
627
628
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
629

Chao Liu's avatar
Chao Liu committed
630
631
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
632

Chao Liu's avatar
Chao Liu committed
633
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
634
635
636
637
638
639
640
641
642

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
643
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
644

Chao Liu's avatar
Chao Liu committed
645
646
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
647

Chao Liu's avatar
Chao Liu committed
648
        auto f_copy = [&](index_t iloop) {
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

#if DEVICE_BACKEND_HIP
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
770
771
        };

Chao Liu's avatar
Chao Liu committed
772
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
773
774
        {
            f_copy(iloop);
775
776
777
778
779
780
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
781
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
782
783
784

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
785
                f_copy(nloop_d0);
786
787
788
            }
        }
    }
789
#endif
Chao Liu's avatar
Chao Liu committed
790
};