threadwise_4d_tensor_op.hip.hpp 11.2 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
11

Chao Liu's avatar
Chao Liu committed
12
    constexpr auto desc = Desc{};
13
14

#if 0
15
    if(get_thread_local_1d_id() == 0)
16
    {
Chao Liu's avatar
Chao Liu committed
17
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
18
19
20
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
24
        {
Chao Liu's avatar
Chao Liu committed
25
            for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
26
            {
Chao Liu's avatar
Chao Liu committed
27
                for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
28
                {
Chao Liu's avatar
Chao Liu committed
29
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
30

Chao Liu's avatar
Chao Liu committed
31
                    f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
                }
            }
        }
    }
}

38
39
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
40
41
template <class SrcData,
          class DstData,
Chao Liu's avatar
Chao Liu committed
42
43
44
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
45
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
46
          class F>
47
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
Chao Liu's avatar
Chao Liu committed
48
    SrcDesc,
49
    const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
50
    DstDesc,
51
    DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
52
    SrcOpLengths,
53
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
54
    F f)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);
Chao Liu's avatar
Chao Liu committed
65

66
67
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
68
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
71
    {
Chao Liu's avatar
Chao Liu committed
72
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
73
        {
Chao Liu's avatar
Chao Liu committed
74
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
75
            {
Chao Liu's avatar
Chao Liu committed
76
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
77
                {
Chao Liu's avatar
Chao Liu committed
78
                    const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
79

Chao Liu's avatar
Chao Liu committed
80
                    const index_t did[4] = {did0, did1, did2, did3};
Chao Liu's avatar
Chao Liu committed
81

Chao Liu's avatar
Chao Liu committed
82
                    const index_t bindex =
Chao Liu's avatar
Chao Liu committed
83
                        dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
Chao Liu's avatar
Chao Liu committed
84

85
                    f(p_src[aindex], p_dst[bindex]);
86
87

#if 0
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                    if(get_block_1d_id() == 0)
                    {
                        printf("tid %5u, "
                               "src did %u %u %u %u, "
                               "dst did %u %u %u %u, "
                               "aindex %5u, "
                               "bindex %5u\n",
                               get_thread_local_1d_id(),
                               did0,
                               did1,
                               did2,
                               did3,
                               did[IR0],
                               did[IR1],
                               did[IR2],
                               did[IR3],
                               aindex,
                               bindex);
                    }
#endif
108
109
110
111
112
                }
            }
        }
    }
}
Chao Liu's avatar
Chao Liu committed
113

114
115
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
116
{
117
    auto f_set_zero = [](Data& v) { v = Data(0); };
Chao Liu's avatar
Chao Liu committed
118

119
    threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
120
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
121
}
Chao Liu's avatar
Chao Liu committed
122

123
124
125
126
127
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
128
          class MapDst2Src>
129
130
131
132
133
134
__device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
                                                                const SrcData* __restrict__ p_src,
                                                                DstDesc,
                                                                DstData* __restrict__ p_dst,
                                                                SrcOpLengths,
                                                                MapDst2Src)
135
{
136
    auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
137

138
    threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
139
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
140
141
}

142
#if 0 // replaced threadwise_nd_tensor_copy
143
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
144
__device__ void threadwise_4d_tensor_copy(
145
    SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
146
{
Chao Liu's avatar
Chao Liu committed
147
    auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
Chao Liu's avatar
Chao Liu committed
148

149
    threadwise_4d_tensor_copy_reorder_given_dst2src(
Chao Liu's avatar
Chao Liu committed
150
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
Chao Liu's avatar
Chao Liu committed
151
152
}

153
// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
154
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
155
156
157
158
159
160
161
162
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
                                             const Float* __restrict__ p_src,
                                             DstDesc,
                                             Float* __restrict__ p_dst,
                                             SrcOpLengths,
                                             Number<DataPerRead>)
{
    static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
163
                      SrcOpLengths::GetSize() == 4,
164
165
                  "wrong! should be 4 dimension");

166
167
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
                  "wrong! only support stride3 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I2) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
187
    constexpr index_t L3 = SrcOpLengths{}.Get(I3);
188
189
190

    static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
191
    constexpr index_t nloop_d3 = L3 / DataPerRead;
192

Chao Liu's avatar
Chao Liu committed
193
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
194
    {
Chao Liu's avatar
Chao Liu committed
195
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
196
        {
Chao Liu's avatar
Chao Liu committed
197
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
198
            {
Chao Liu's avatar
Chao Liu committed
199
                for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
200
                {
Chao Liu's avatar
Chao Liu committed
201
                    const index_t src_index =
202
203
                        src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
204
                    const index_t dst_index =
205
206
                        dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

207
208
                    *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
                        *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
209
210
211
212
213
                }
            }
        }
    }
}
214
#endif
215

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class MapDst2Src>
__device__ void
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
                                                   const SrcData* __restrict__ p_src,
                                                   DstDesc,
                                                   DstData* __restrict__ p_dst,
                                                   SrcOpLengths,
                                                   MapDst2Src)
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

    // ref_desc has dst_desc's ordering
    constexpr auto ref_desc =
        make_ConstantTensorDescriptor(SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{}));

    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
    {
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
        {
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
            {
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
                {
                    const auto dst_multi_id = Array<index_t, 4>{did0, did1, did2, did3};

                    const auto src_multi_id =
                        reorder_array_given_old2new(dst_multi_id, MapDst2Src{});

                    const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);

                    const index_t src_index = src_desc.Get1dIndex(src_multi_id);

                    p_dst[dst_index] = p_src[src_index];
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
271
272
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
273
274
275
276
277
278
279
280
281
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
282
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
283
284
285
286
287
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
288
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
289

Chao Liu's avatar
Chao Liu committed
290
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
291
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
292

Chao Liu's avatar
Chao Liu committed
293
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
294
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
295

Chao Liu's avatar
Chao Liu committed
296
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
297
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
298

Chao Liu's avatar
Chao Liu committed
299
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
300
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
301

Chao Liu's avatar
Chao Liu committed
302
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
303
    {
Chao Liu's avatar
Chao Liu committed
304
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
305
        {
Chao Liu's avatar
Chao Liu committed
306
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
307
            {
Chao Liu's avatar
Chao Liu committed
308
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
309
                {
Chao Liu's avatar
Chao Liu committed
310
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
311

Chao Liu's avatar
Chao Liu committed
312
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
313
314
315
316
317
318

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
319
}