threadwise_4d_tensor_op.hip.hpp 11.1 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
11

Chao Liu's avatar
Chao Liu committed
12
    constexpr auto desc = Desc{};
13
14

#if 0
15
    if(get_thread_local_1d_id() == 0)
16
    {
Chao Liu's avatar
Chao Liu committed
17
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
18
19
20
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
24
        {
Chao Liu's avatar
Chao Liu committed
25
            for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
26
            {
Chao Liu's avatar
Chao Liu committed
27
                for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
28
                {
Chao Liu's avatar
Chao Liu committed
29
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
30

Chao Liu's avatar
Chao Liu committed
31
                    f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
                }
            }
        }
    }
}

38
39
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
40
41
template <class SrcData,
          class DstData,
Chao Liu's avatar
Chao Liu committed
42
43
44
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
45
          class MapDst2Src,
Chao Liu's avatar
Chao Liu committed
46
          class F>
47
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
Chao Liu's avatar
Chao Liu committed
48
    SrcDesc,
49
    const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
50
    DstDesc,
51
    DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
52
    SrcOpLengths,
53
    MapDst2Src,
Chao Liu's avatar
Chao Liu committed
54
    F f)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
60

61
62
63
64
    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);
Chao Liu's avatar
Chao Liu committed
65

66
67
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
68
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
71
    {
Chao Liu's avatar
Chao Liu committed
72
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
73
        {
Chao Liu's avatar
Chao Liu committed
74
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
75
            {
Chao Liu's avatar
Chao Liu committed
76
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
77
                {
Chao Liu's avatar
Chao Liu committed
78
                    const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
79

Chao Liu's avatar
Chao Liu committed
80
                    const index_t did[4] = {did0, did1, did2, did3};
Chao Liu's avatar
Chao Liu committed
81

Chao Liu's avatar
Chao Liu committed
82
                    const index_t bindex =
Chao Liu's avatar
Chao Liu committed
83
                        dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
Chao Liu's avatar
Chao Liu committed
84

85
                    f(p_src[aindex], p_dst[bindex]);
86
87

#if 0
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                    if(get_block_1d_id() == 0)
                    {
                        printf("tid %5u, "
                               "src did %u %u %u %u, "
                               "dst did %u %u %u %u, "
                               "aindex %5u, "
                               "bindex %5u\n",
                               get_thread_local_1d_id(),
                               did0,
                               did1,
                               did2,
                               did3,
                               did[IR0],
                               did[IR1],
                               did[IR2],
                               did[IR3],
                               aindex,
                               bindex);
                    }
#endif
108
109
110
111
112
                }
            }
        }
    }
}
Chao Liu's avatar
Chao Liu committed
113

114
115
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
116
{
117
    auto f_set_zero = [](Data& v) { v = Data(0); };
Chao Liu's avatar
Chao Liu committed
118

119
    threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
120
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
121
}
Chao Liu's avatar
Chao Liu committed
122

123
124
125
126
127
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
128
          class MapDst2Src>
129
130
131
132
133
134
__device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
                                                                const SrcData* __restrict__ p_src,
                                                                DstDesc,
                                                                DstData* __restrict__ p_dst,
                                                                SrcOpLengths,
                                                                MapDst2Src)
135
{
136
    auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
137

138
    threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
139
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
140
141
}

142
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
143
__device__ void threadwise_4d_tensor_copy(
144
    SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
145
{
Chao Liu's avatar
Chao Liu committed
146
    auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
Chao Liu's avatar
Chao Liu committed
147

148
    threadwise_4d_tensor_copy_reorder_given_dst2src(
Chao Liu's avatar
Chao Liu committed
149
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
Chao Liu's avatar
Chao Liu committed
150
151
}

152
// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
153
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
154
155
156
157
158
159
160
161
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
                                             const Float* __restrict__ p_src,
                                             DstDesc,
                                             Float* __restrict__ p_dst,
                                             SrcOpLengths,
                                             Number<DataPerRead>)
{
    static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
162
                      SrcOpLengths::GetSize() == 4,
163
164
                  "wrong! should be 4 dimension");

165
166
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
                  "wrong! only support stride3 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I2) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
186
    constexpr index_t L3 = SrcOpLengths{}.Get(I3);
187
188
189

    static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
190
    constexpr index_t nloop_d3 = L3 / DataPerRead;
191

Chao Liu's avatar
Chao Liu committed
192
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
193
    {
Chao Liu's avatar
Chao Liu committed
194
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
195
        {
Chao Liu's avatar
Chao Liu committed
196
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
197
            {
Chao Liu's avatar
Chao Liu committed
198
                for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
199
                {
Chao Liu's avatar
Chao Liu committed
200
                    const index_t src_index =
201
202
                        src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
203
                    const index_t dst_index =
204
205
                        dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

206
207
                    *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
                        *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
208
209
210
211
212
213
                }
            }
        }
    }
}

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class MapDst2Src>
__device__ void
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
                                                   const SrcData* __restrict__ p_src,
                                                   DstDesc,
                                                   DstData* __restrict__ p_dst,
                                                   SrcOpLengths,
                                                   MapDst2Src)
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr index_t IR0 = MapDst2Src{}.Get(I0);
    constexpr index_t IR1 = MapDst2Src{}.Get(I1);
    constexpr index_t IR2 = MapDst2Src{}.Get(I2);
    constexpr index_t IR3 = MapDst2Src{}.Get(I3);

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

    // ref_desc has dst_desc's ordering
    constexpr auto ref_desc =
        make_ConstantTensorDescriptor(SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{}));

    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
    {
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
        {
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
            {
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
                {
                    const auto dst_multi_id = Array<index_t, 4>{did0, did1, did2, did3};

                    const auto src_multi_id =
                        reorder_array_given_old2new(dst_multi_id, MapDst2Src{});

                    const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);

                    const index_t src_index = src_desc.Get1dIndex(src_multi_id);

                    p_dst[dst_index] = p_src[src_index];
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
269
270
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
271
272
273
274
275
276
277
278
279
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
280
    if(get_thread_local_1d_id() == 0)
Chao Liu's avatar
Chao Liu committed
281
282
283
284
285
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
286
    constexpr index_t nshift = NShift::mValue;
Chao Liu's avatar
Chao Liu committed
287

Chao Liu's avatar
Chao Liu committed
288
    constexpr index_t did0_end =
Chao Liu's avatar
Chao Liu committed
289
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
290

Chao Liu's avatar
Chao Liu committed
291
    constexpr index_t did1_end =
Chao Liu's avatar
Chao Liu committed
292
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
293

Chao Liu's avatar
Chao Liu committed
294
    constexpr index_t did2_end =
Chao Liu's avatar
Chao Liu committed
295
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
296

Chao Liu's avatar
Chao Liu committed
297
    constexpr index_t did3_end =
Chao Liu's avatar
Chao Liu committed
298
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
299

Chao Liu's avatar
Chao Liu committed
300
    for(index_t did0 = 0; did0 < did0_end; ++did0)
Chao Liu's avatar
Chao Liu committed
301
    {
Chao Liu's avatar
Chao Liu committed
302
        for(index_t did1 = 0; did1 < did1_end; ++did1)
Chao Liu's avatar
Chao Liu committed
303
        {
Chao Liu's avatar
Chao Liu committed
304
            for(index_t did2 = 0; did2 < did2_end; ++did2)
Chao Liu's avatar
Chao Liu committed
305
            {
Chao Liu's avatar
Chao Liu committed
306
                for(index_t did3 = 0; did3 < did3_end; ++did3)
Chao Liu's avatar
Chao Liu committed
307
                {
Chao Liu's avatar
Chao Liu committed
308
                    const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
309

Chao Liu's avatar
Chao Liu committed
310
                    const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
311
312
313
314
315
316

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
317
}