threadwise_4d_tensor_op.hpp 1.72 KB
Newer Older
1
2
3
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP

Chao Liu's avatar
Chao Liu committed
4
#include "ConstantTensorDescriptor.hpp"
5

6
7
namespace ck {

Chao Liu's avatar
Chao Liu committed
8
9
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
10
11
12
13
14
15
16
17
18
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
19
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
20
21
22
23
24
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
25
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
26

Chao Liu's avatar
Chao Liu committed
27
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
28
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
29

Chao Liu's avatar
Chao Liu committed
30
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
31
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
32

Chao Liu's avatar
Chao Liu committed
33
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
34
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
37
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
38

Chao Liu's avatar
Chao Liu committed
39
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
40
    {
Chao Liu's avatar
Chao Liu committed
41
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
42
        {
Chao Liu's avatar
Chao Liu committed
43
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
44
            {
Chao Liu's avatar
Chao Liu committed
45
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
46
                {
47
                    const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
48

Chao Liu's avatar
Chao Liu committed
49
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
50
51
52
53
54
55

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
56
}
57
58
59

} // namespace ck
#endif