blockwise_2d_tensor_op.hpp 27.2 KB
Newer Older
1
2
3
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
6

7
8
namespace ck {

Chao Liu's avatar
Chao Liu committed
9
template <index_t BlockSize, class Float, class DstDesc, class F>
10
__device__ void
Chao Liu's avatar
Chao Liu committed
11
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
12
{
Chao Liu's avatar
Chao Liu committed
13
14
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
15

16
17
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
18
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
19

20
#if 0
21
    if(get_thread_local_1d_id() == 0)
22
    {
Chao Liu's avatar
Chao Liu committed
23
24
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
25
26
27
    }
#endif

Chao Liu's avatar
Chao Liu committed
28
    constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
31
    {
32
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
33

Chao Liu's avatar
Chao Liu committed
34
        const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
35
36
37

        is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
38
        const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
39

40
        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
41

Chao Liu's avatar
Chao Liu committed
42
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
48
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
49
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
50
51
52

        if(is < desc.GetElementSize())
        {
Chao Liu's avatar
Chao Liu committed
53
            const index_t did0 = is / desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
54
55
56

            is -= did0 * desc.GetStride(I0);

Chao Liu's avatar
Chao Liu committed
57
            const index_t did1 = is / desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
58

59
            const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
62
63
64
        }
    }
}
Chao Liu's avatar
Chao Liu committed
65

66
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
67
68
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
69
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
70
          class Float,
71
72
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
73
          class SrcOpLengths,
74
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
75
          class F>
Chao Liu's avatar
Chao Liu committed
76
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
77
    SrcDesc,
Chao Liu's avatar
Chao Liu committed
78
    const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
79
80
81
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
82
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
83
    F f)
Chao Liu's avatar
Chao Liu committed
84
{
Chao Liu's avatar
Chao Liu committed
85
86
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
87

88
89
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
90

91
92
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
93
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
    constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
    for(index_t iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
98
    {
99
        index_t is = get_thread_local_1d_id() + iloop * BlockSize;
Chao Liu's avatar
Chao Liu committed
100

Chao Liu's avatar
Chao Liu committed
101
        index_t did[2];
Chao Liu's avatar
Chao Liu committed
102

103
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
104

105
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
106

107
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
108

109
        const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
110

111
        const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
112
113

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
114
115
    }

116
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
117
118
119

    if(has_tail)
    {
120
        index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
Chao Liu's avatar
Chao Liu committed
121

122
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
123
        {
Chao Liu's avatar
Chao Liu committed
124
            index_t did[2];
125
126

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
127

128
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
129

130
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
131

132
            const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
133

134
            const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
135

136
            f(p_src[aindex], p_dst[bindex]);
137
138
139
140
        }
    }
}

Chao Liu's avatar
Chao Liu committed
141
template <index_t BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
142
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
143
{
Chao Liu's avatar
Chao Liu committed
144
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
147
}
148

Chao Liu's avatar
Chao Liu committed
149
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
150
          class Float,
151
152
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
153
          class SrcOpLengths,
154
          class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
155
__device__ void
Chao Liu's avatar
Chao Liu committed
156
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
157
                                                     const Float* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
158
159
160
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
161
                                                     MapDst2Src)
162
{
Chao Liu's avatar
Chao Liu committed
163
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
164

Chao Liu's avatar
Chao Liu committed
165
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
166
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
167
168
}

169
170
171
172
173
174
template <index_t BlockSize,
          class Float,
          class SrcDesc,
          class DstDesc,
          class CopyLengths,
          index_t DataPerRead>
175
struct Blockwise2dTensorCopy1
Chao Liu's avatar
Chao Liu committed
176
{
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    __device__ constexpr Blockwise2dTensorCopy1()
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride2 should be multiple of DataPerRead to keep alignment");

        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride0 is big enough,
        //   so that the out-of-bound write won't contaminate next line in dst
        constexpr index_t L1          = CopyLengths{}.Get(I1);
199
        constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
200
201
202
203
204

        static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
                      "wrong! out-of-bound write will contaminate next line!\n");
    }

Chao Liu's avatar
Chao Liu committed
205
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
206
    {
207
208
209
210
211
212
213
214
215
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

216
        constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
217

218
        constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
219
220
221
222
223
224
225
226
227
228
229
230

        constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;

        auto f_copy = [&](index_t is) {
            index_t did[4];

            did[0] = is / ref_desc.GetStride(I0);

            is -= did[0] * ref_desc.GetStride(I0);

            did[1] = is / ref_desc.GetStride(I1);

231
232
233
234
            const index_t src_index =
                src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
            const index_t dst_index =
                dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                *(reinterpret_cast<const vector_t*>(p_src + src_index));
        };

        for(index_t iloop = 0; iloop < NLoop; ++iloop)
        {
            index_t is = get_thread_local_1d_id() + iloop * BlockSize;

            f_copy(is);
        }

        constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);

        if(has_tail)
        {
            index_t is = get_thread_local_1d_id() + NLoop * BlockSize;

            if(is < ref_desc.GetElementSize())
            {
                f_copy(is);
            }
        }
258
259
260
    }
};

261
262
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
263
template <index_t BlockSize,
264
265
266
267
          class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
Chao Liu's avatar
Chao Liu committed
268
269
          index_t ThreadPerDim0,
          index_t ThreadPerDim1>
270
struct Blockwise2dTensorCopy2
271
{
Chao Liu's avatar
Chao Liu committed
272
273
    index_t mThreadId0;
    index_t mThreadId1;
274

275
    __device__ Blockwise2dTensorCopy2()
276
    {
277
278
279
280
281
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
                      "wrong! stride is not 1!\n");
Chao Liu's avatar
Chao Liu committed
282

283
284
285
286
        mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
        mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
    }

Chao Liu's avatar
Chao Liu committed
287
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
288
    {
289
290
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

Chao Liu's avatar
Chao Liu committed
291
292
293
        using Float4 = float4;
        using Float2 = float2;

294
295
296
297
298
299
300
301
302
        if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
            return;

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr auto src_desc = SrcDesc{};
        constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
303
304
305
306
307
308
309
        // check alignment
        constexpr bool align_v4 =
            src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;

        constexpr bool align_v2 =
            src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;

Chao Liu's avatar
Chao Liu committed
310
311
        constexpr index_t L0 = SrcOpLengths{}.Get(I0);
        constexpr index_t L1 = SrcOpLengths{}.Get(I1);
312

Chao Liu's avatar
Chao Liu committed
313
314
        constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
        constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
315

Chao Liu's avatar
Chao Liu committed
316
        constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
Chao Liu's avatar
Chao Liu committed
317

Chao Liu's avatar
Chao Liu committed
318
        constexpr index_t Dim1V2Loop =
Chao Liu's avatar
Chao Liu committed
319
320
            align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;

Chao Liu's avatar
Chao Liu committed
321
        constexpr index_t Dim1V1Loop =
322
323
            (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
            ThreadPerDim1;
Chao Liu's avatar
Chao Liu committed
324

325
326
327
        constexpr bool d1_has_tail =
            (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));

Chao Liu's avatar
Chao Liu committed
328
        for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
329
        {
Chao Liu's avatar
Chao Liu committed
330
            index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
331
332

            // v4
Chao Liu's avatar
Chao Liu committed
333
            for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
334
            {
Chao Liu's avatar
Chao Liu committed
335
                index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
336

337
338
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
339

Chao Liu's avatar
Chao Liu committed
340
                *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
341
                    *(reinterpret_cast<const Float4*>(p_src + sindex));
342
343
344
            }

            // v2
Chao Liu's avatar
Chao Liu committed
345
            for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
346
            {
Chao Liu's avatar
Chao Liu committed
347
                index_t did1 =
348
349
                    Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;

350
351
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
352

Chao Liu's avatar
Chao Liu committed
353
                *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
354
                    *(reinterpret_cast<const Float2*>(p_src + sindex));
355
356
357
            }

            // v1
Chao Liu's avatar
Chao Liu committed
358
            for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
359
            {
Chao Liu's avatar
Chao Liu committed
360
361
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               d1v1loop * ThreadPerDim1 + mThreadId1;
362

363
364
                const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
365
366
367
368
369
370
371

                p_dst[dindex] = p_src[sindex];
            }

            // dim-1 tail
            if(d1_has_tail)
            {
Chao Liu's avatar
Chao Liu committed
372
373
                index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                               Dim1V1Loop * ThreadPerDim1 + mThreadId1;
374
375
376

                if(did1 < L1)
                {
377
378
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
379
380
381
382
383
384
385
386
387

                    p_dst[dindex] = p_src[sindex];
                }
            }
        }

        // dim-0 tail
        if(d0_has_tail)
        {
Chao Liu's avatar
Chao Liu committed
388
            index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
389
390
391
392
393

            if(did0 < L0)
            {

                // v4
Chao Liu's avatar
Chao Liu committed
394
                for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
395
                {
Chao Liu's avatar
Chao Liu committed
396
                    index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
397

398
399
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
400

Chao Liu's avatar
Chao Liu committed
401
                    *(reinterpret_cast<Float4*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
402
                        *(reinterpret_cast<const Float4*>(p_src + sindex));
403
404
405
                }

                // v2
Chao Liu's avatar
Chao Liu committed
406
                for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
407
                {
Chao Liu's avatar
Chao Liu committed
408
409
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
                                   2 * mThreadId1;
410

411
412
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
413

Chao Liu's avatar
Chao Liu committed
414
                    *(reinterpret_cast<Float2*>(p_dst + dindex)) =
Chao Liu's avatar
Chao Liu committed
415
                        *(reinterpret_cast<const Float2*>(p_src + sindex));
416
417
418
                }

                // v1
Chao Liu's avatar
Chao Liu committed
419
                for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
420
                {
Chao Liu's avatar
Chao Liu committed
421
422
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   d1v1loop * ThreadPerDim1 + mThreadId1;
423

424
425
                    const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                    const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
426
427
428
429
430
431
432

                    p_dst[dindex] = p_src[sindex];
                }

                // tail
                if(d1_has_tail)
                {
Chao Liu's avatar
Chao Liu committed
433
434
                    index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
                                   Dim1V1Loop * ThreadPerDim1 + mThreadId1;
435
436
437

                    if(did1 < L1)
                    {
438
439
                        const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
                        const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
440
441
442
443
444
445
446
447

                        p_dst[dindex] = p_src[sindex];
                    }
                }
            }
        }
    }
};
Chao Liu's avatar
Chao Liu committed
448

449
450
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
Chao Liu's avatar
Chao Liu committed
451
template <index_t BlockSize,
452
453
454
          class Float,
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
455
          class CopyLengths,
Chao Liu's avatar
Chao Liu committed
456
          index_t DataPerRead>
457
struct Blockwise2dTensorCopy3
Chao Liu's avatar
Chao Liu committed
458
{
459
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
Chao Liu's avatar
Chao Liu committed
460

Chao Liu's avatar
Chao Liu committed
461
462
    index_t mSrcMyThreadOffset;
    index_t mDstMyThreadOffset;
Chao Liu's avatar
Chao Liu committed
463

Chao Liu's avatar
Chao Liu committed
464
465
    __device__ Blockwise2dTensorCopy3(Array<index_t, 2> src_block_data_multi_id_begin,
                                      Array<index_t, 2> dst_block_data_multi_id_begin)
Chao Liu's avatar
Chao Liu committed
466
    {
467
468
469
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
470
471
472
        static_assert(DataPerRead == 1 ||
                          (SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
                      "wrong! only support stride1 == 1 if DataPerRead > 1!\n");
473
474
475
476

        static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                      "wrong! only support DataPerRead == 1, 2 or 4!\n");

Chao Liu's avatar
Chao Liu committed
477
478
479
        static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
                          DstDesc{}.GetStride(I0) % DataPerRead == 0,
                      "src and dst stride should be multiple of DataPerRead to keep alignment");
480

Chao Liu's avatar
Chao Liu committed
481
        constexpr index_t L1 = CopyLengths{}.Get(I1);
482

Chao Liu's avatar
Chao Liu committed
483
484
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
485

Chao Liu's avatar
Chao Liu committed
486
487
        // we allow out-of-bound read from src in D1 dimension,
        //   but we need to make sure dst stride is big enough,
Chao Liu's avatar
Chao Liu committed
488
        //   so that the out-of-bound write won't contaminate next line in dst
Chao Liu's avatar
Chao Liu committed
489
        static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
Chao Liu's avatar
Chao Liu committed
490
                      "wrong! out-of-bound write will contaminate next line!\n");
Chao Liu's avatar
Chao Liu committed
491

Chao Liu's avatar
Chao Liu committed
492
493
        static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");

Chao Liu's avatar
Chao Liu committed
494
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
Chao Liu's avatar
Chao Liu committed
495
496
497
498
499
500
501
502

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }
Chao Liu's avatar
Chao Liu committed
503

Chao Liu's avatar
Chao Liu committed
504
505
        const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
        const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
506

Chao Liu's avatar
Chao Liu committed
507
508
509
510
511
512
513
        mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
            src_block_data_multi_id_begin +
            Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});

        mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
            dst_block_data_multi_id_begin +
            Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
Chao Liu's avatar
Chao Liu committed
514
515
    }

Chao Liu's avatar
Chao Liu committed
516
    __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
Chao Liu's avatar
Chao Liu committed
517
    {
518
519
520
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
521
522
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
523

Chao Liu's avatar
Chao Liu committed
524
525
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
Chao Liu's avatar
Chao Liu committed
526

Chao Liu's avatar
Chao Liu committed
527
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
528
529

        if(BlockSize > num_active_thread)
Chao Liu's avatar
Chao Liu committed
530
        {
Chao Liu's avatar
Chao Liu committed
531
            if(get_thread_local_1d_id() >= num_active_thread)
532
533
534
            {
                return;
            }
Chao Liu's avatar
Chao Liu committed
535
536
        }

Chao Liu's avatar
Chao Liu committed
537
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
538

Chao Liu's avatar
Chao Liu committed
539
540
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
541

Chao Liu's avatar
Chao Liu committed
542
        auto f_copy = [&](index_t iloop) {
Chao Liu's avatar
Chao Liu committed
543
544
545
            *(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
                *(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
                                                    iloop * src_loop_stride));
Chao Liu's avatar
Chao Liu committed
546
547
        };

Chao Liu's avatar
Chao Liu committed
548
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
549
550
        {
            f_copy(iloop);
Chao Liu's avatar
Chao Liu committed
551
        }
Chao Liu's avatar
Chao Liu committed
552

Chao Liu's avatar
Chao Liu committed
553
554
        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

Chao Liu's avatar
Chao Liu committed
555
556
        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
557
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
Chao Liu's avatar
Chao Liu committed
558
559
560

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
561
                f_copy(nloop_d0);
Chao Liu's avatar
Chao Liu committed
562
563
            }
        }
Chao Liu's avatar
Chao Liu committed
564
    }
565

Chao Liu's avatar
Chao Liu committed
566
    __device__ constexpr index_t GetRegisterClipboardSize() const
567
568
569
570
571
572
    {
        static_assert(is_same<Float, float>::value, "wrong! only support float!\n");

        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
573
574
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
575

Chao Liu's avatar
Chao Liu committed
576
577
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
578
579
580
581
582

        return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
    }

    __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
583
                                             Float* __restrict__ p_clipboard) const
584
585
586
587
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
588
589
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
590

Chao Liu's avatar
Chao Liu committed
591
592
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
593

Chao Liu's avatar
Chao Liu committed
594
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
595
596
597
598
599
600
601
602
603

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
604
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
605

Chao Liu's avatar
Chao Liu committed
606
607
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
608

Chao Liu's avatar
Chao Liu committed
609
        auto f_copy = [&](index_t iloop) {
610
611
612
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(
                    &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
Chao Liu's avatar
Chao Liu committed
613
614
        };

Chao Liu's avatar
Chao Liu committed
615
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
616
617
        {
            f_copy(iloop);
618
619
620
621
622
623
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
624
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
625
626
627

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
628
                f_copy(nloop_d0);
629
630
631
632
633
634
635
636
637
638
            }
        }
    }

    __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
                                              Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
639
640
        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);
641

Chao Liu's avatar
Chao Liu committed
642
643
        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
644

Chao Liu's avatar
Chao Liu committed
645
        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
646
647
648
649
650
651
652
653
654

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

Chao Liu's avatar
Chao Liu committed
655
        constexpr index_t nloop_d0 = L0 / thread_per_d0;
656

Chao Liu's avatar
Chao Liu committed
657
658
        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
659

Chao Liu's avatar
Chao Liu committed
660
        auto f_copy = [&](index_t iloop) {
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

683
#if CK_USE_AMD_INLINE_ASM
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    __device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
                                                 Float* p_clipboard) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
                *(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
                                                    iloop * src_loop_stride]));
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "global_load is only for float4");

            global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
                        reinterpret_cast<const vector_t*>(
                            &p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
        };

        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
        {
            f_copy(iloop);
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
                f_copy(nloop_d0);
            }
        }
    }

    __device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
                                                  Float* __restrict__ p_dst) const
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        constexpr index_t L0 = CopyLengths{}.Get(I0);
        constexpr index_t L1 = CopyLengths{}.Get(I1);

        constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
        constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;

        constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;

        if(BlockSize > num_active_thread)
        {
            if(get_thread_local_1d_id() >= num_active_thread)
            {
                return;
            }
        }

        constexpr index_t nloop_d0 = L0 / thread_per_d0;

        constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
        constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;

        auto f_copy = [&](index_t iloop) {
#if 0
            *(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
                *(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
            static_assert(is_same<float, Float>::value && DataPerRead == 4,
                          "ds_write_b128 is only for float4");

            ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
                          &p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
Chao Liu's avatar
Chao Liu committed
782
783
        };

Chao Liu's avatar
Chao Liu committed
784
        for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
Chao Liu's avatar
Chao Liu committed
785
786
        {
            f_copy(iloop);
787
788
789
790
791
792
        }

        constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);

        if(has_tail_d0)
        {
Chao Liu's avatar
Chao Liu committed
793
            constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
794
795
796

            if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
            {
Chao Liu's avatar
Chao Liu committed
797
                f_copy(nloop_d0);
798
799
800
            }
        }
    }
801
#endif
Chao Liu's avatar
Chao Liu committed
802
};
803
804
805
806

} // namespace ck

#endif