threadwise_4d_tensor_op.hip.hpp 9.14 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
11

Chao Liu's avatar
Chao Liu committed
12
    constexpr auto desc = Desc{};
13
14

#if 0
15
    if(get_thread_local_1d_id() == 0)
16
    {
Chao Liu's avatar
Chao Liu committed
17
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
18
19
20
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
24
        {
Chao Liu's avatar
Chao Liu committed
25
            for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
26
            {
Chao Liu's avatar
Chao Liu committed
27
                for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
28
                {
Chao Liu's avatar
Chao Liu committed
29
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
30

Chao Liu's avatar
Chao Liu committed
31
                    f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
                }
            }
        }
    }
}

38
39
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
40
41
template <class SrcData,
          class DstData,
Chao Liu's avatar
Chao Liu committed
42
43
44
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
45
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
46
47
48
          class F>
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
    SrcDesc,
49
    const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
50
    DstDesc,
51
    DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
52
    SrcOpLengths,
53
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
54
    F f)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);
Chao Liu's avatar
Chao Liu committed
65

66
67
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
68
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
71
    {
Chao Liu's avatar
Chao Liu committed
72
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
73
        {
Chao Liu's avatar
Chao Liu committed
74
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
75
            {
Chao Liu's avatar
Chao Liu committed
76
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
77
                {
Chao Liu's avatar
Chao Liu committed
78
                    const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
79

Chao Liu's avatar
Chao Liu committed
80
                    const index_t did[4] = {did0, did1, did2, did3};
Chao Liu's avatar
Chao Liu committed
81

Chao Liu's avatar
Chao Liu committed
82
                    const index_t bindex =
Chao Liu's avatar
Chao Liu committed
83
                        dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
Chao Liu's avatar
Chao Liu committed
84

85
#if 1
86
                    f(p_src[aindex], p_dst[bindex]);
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#else
                    if(get_block_1d_id() == 0)
                    {
                        printf("tid %5u, "
                               "src did %u %u %u %u, "
                               "dst did %u %u %u %u, "
                               "aindex %5u, "
                               "bindex %5u\n",
                               get_thread_local_1d_id(),
                               did0,
                               did1,
                               did2,
                               did3,
                               did[IR0],
                               did[IR1],
                               did[IR2],
                               did[IR3],
                               aindex,
                               bindex);
                    }
#endif
108
109
110
111
112
                }
            }
        }
    }
}
Chao Liu's avatar
Chao Liu committed
113

114
115
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
116
{
117
    auto f_set_zero = [](Data& v) { v = Data(0); };
Chao Liu's avatar
Chao Liu committed
118

119
    threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
120
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
121
}
Chao Liu's avatar
Chao Liu committed
122

123
124
125
126
127
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
128
          class MapDst2Src>
Chao Liu's avatar
Chao Liu committed
129
130
__device__ void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
131
                                                      const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
132
                                                      DstDesc,
133
                                                      DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
134
                                                      SrcOpLengths,
135
                                                      MapDst2Src)
136
{
137
    auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
138

Chao Liu's avatar
Chao Liu committed
139
    threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
140
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
141
142
}

143
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
144
__device__ void threadwise_4d_tensor_copy(
145
    SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
146
{
Chao Liu's avatar
Chao Liu committed
147
    auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
Chao Liu's avatar
Chao Liu committed
148

Chao Liu's avatar
Chao Liu committed
149
150
    threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
Chao Liu's avatar
Chao Liu committed
151
152
}

153
// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
154
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
155
156
157
158
159
160
161
162
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
                                             const Float* __restrict__ p_src,
                                             DstDesc,
                                             Float* __restrict__ p_dst,
                                             SrcOpLengths,
                                             Number<DataPerRead>)
{
    static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
163
                      SrcOpLengths::GetSize() == 4,
164
165
                  "wrong! should be 4 dimension");

166
167
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
                  "wrong! only support stride3 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I2) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
187
    constexpr index_t L3 = SrcOpLengths{}.Get(I3);
188
189
190

    static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
191
    constexpr index_t nloop_d3 = L3 / DataPerRead;
192

Chao Liu's avatar
Chao Liu committed
193
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
194
    {
Chao Liu's avatar
Chao Liu committed
195
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
196
        {
Chao Liu's avatar
Chao Liu committed
197
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
198
            {
Chao Liu's avatar
Chao Liu committed
199
                for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
200
                {
Chao Liu's avatar
Chao Liu committed
201
                    const index_t src_index =
202
203
                        src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
204
                    const index_t dst_index =
205
206
                        dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

207
208
                    *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
                        *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
209
210
211
212
213
214
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
215
216
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
217
218
219
220
221
222
223
224
225
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
226
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
227
228
229
230
231
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
232
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
233

Chao Liu's avatar
Chao Liu committed
234
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
235
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
236

Chao Liu's avatar
Chao Liu committed
237
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
238
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
239

Chao Liu's avatar
Chao Liu committed
240
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
241
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
242

Chao Liu's avatar
Chao Liu committed
243
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
244
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
245

Chao Liu's avatar
Chao Liu committed
246
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
247
    {
Chao Liu's avatar
Chao Liu committed
248
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
249
        {
Chao Liu's avatar
Chao Liu committed
250
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
251
            {
Chao Liu's avatar
Chao Liu committed
252
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
253
                {
Chao Liu's avatar
Chao Liu committed
254
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
255

Chao Liu's avatar
Chao Liu committed
256
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
257
258
259
260
261
262

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
263
}