threadwise_4d_tensor_op.hip.hpp 6.19 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
11

Chao Liu's avatar
Chao Liu committed
12
    constexpr auto desc = Desc{};
13
14

#if 0
15
    if(get_thread_local_1d_id() == 0)
16
    {
Chao Liu's avatar
Chao Liu committed
17
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
18
19
20
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
24
        {
Chao Liu's avatar
Chao Liu committed
25
            for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
26
            {
Chao Liu's avatar
Chao Liu committed
27
                for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
28
                {
Chao Liu's avatar
Chao Liu committed
29
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
30

Chao Liu's avatar
Chao Liu committed
31
                    f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
                }
            }
        }
    }
}

38
39
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
40
41
template <class SrcData,
          class DstData,
Chao Liu's avatar
Chao Liu committed
42
43
44
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
45
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
46
          class F>
47
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
Chao Liu's avatar
Chao Liu committed
48
    SrcDesc,
49
    const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
50
    DstDesc,
51
    DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
52
    SrcOpLengths,
53
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
54
    F f)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);
Chao Liu's avatar
Chao Liu committed
65

66
67
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
68
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
71
    {
Chao Liu's avatar
Chao Liu committed
72
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
73
        {
Chao Liu's avatar
Chao Liu committed
74
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
75
            {
Chao Liu's avatar
Chao Liu committed
76
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
77
                {
Chao Liu's avatar
Chao Liu committed
78
                    const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
79

Chao Liu's avatar
Chao Liu committed
80
                    const index_t did[4] = {did0, did1, did2, did3};
Chao Liu's avatar
Chao Liu committed
81

Chao Liu's avatar
Chao Liu committed
82
                    const index_t bindex =
Chao Liu's avatar
Chao Liu committed
83
                        dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
Chao Liu's avatar
Chao Liu committed
84

85
                    f(p_src[aindex], p_dst[bindex]);
86
87

#if 0
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                    if(get_block_1d_id() == 0)
                    {
                        printf("tid %5u, "
                               "src did %u %u %u %u, "
                               "dst did %u %u %u %u, "
                               "aindex %5u, "
                               "bindex %5u\n",
                               get_thread_local_1d_id(),
                               did0,
                               did1,
                               did2,
                               did3,
                               did[IR0],
                               did[IR1],
                               did[IR2],
                               did[IR3],
                               aindex,
                               bindex);
                    }
#endif
108
109
110
111
112
                }
            }
        }
    }
}
Chao Liu's avatar
Chao Liu committed
113

114
115
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
116
{
117
    auto f_set_zero = [](Data& v) { v = Data(0); };
Chao Liu's avatar
Chao Liu committed
118

119
    threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
120
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
121
}
Chao Liu's avatar
Chao Liu committed
122

123
124
125
126
127
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
128
          class MapDst2Src>
129
130
131
132
133
134
__device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
                                                                const SrcData* __restrict__ p_src,
                                                                DstDesc,
                                                                DstData* __restrict__ p_dst,
                                                                SrcOpLengths,
                                                                MapDst2Src)
135
{
136
    auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
137

138
    threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
139
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
140
141
}

Chao Liu's avatar
Chao Liu committed
142
143
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
144
145
146
147
148
149
150
151
152
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
153
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
154
155
156
157
158
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
159
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
160

Chao Liu's avatar
Chao Liu committed
161
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
162
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
163

Chao Liu's avatar
Chao Liu committed
164
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
165
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
166

Chao Liu's avatar
Chao Liu committed
167
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
168
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
169

Chao Liu's avatar
Chao Liu committed
170
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
171
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
172

Chao Liu's avatar
Chao Liu committed
173
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
174
    {
Chao Liu's avatar
Chao Liu committed
175
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
176
        {
Chao Liu's avatar
Chao Liu committed
177
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
178
            {
Chao Liu's avatar
Chao Liu committed
179
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
180
                {
Chao Liu's avatar
Chao Liu committed
181
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
182

Chao Liu's avatar
Chao Liu committed
183
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
184
185
186
187
188
189

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
190
}