threadwise_4d_tensor_op.hip.hpp 1.62 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
11
12
13
14
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
15
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
16
17
18
19
20
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
22

Chao Liu's avatar
Chao Liu committed
23
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
24
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
25

Chao Liu's avatar
Chao Liu committed
26
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
27
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
28

Chao Liu's avatar
Chao Liu committed
29
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
30
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
31

Chao Liu's avatar
Chao Liu committed
32
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
33
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
34

Chao Liu's avatar
Chao Liu committed
35
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
36
    {
Chao Liu's avatar
Chao Liu committed
37
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
38
        {
Chao Liu's avatar
Chao Liu committed
39
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
40
            {
Chao Liu's avatar
Chao Liu committed
41
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
42
                {
43
                    const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
44

Chao Liu's avatar
Chao Liu committed
45
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
46
47
48
49
50
51

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
52
}