"vscode:/vscode.git/clone" did not exist on "0b861c4879bfc511666421384709c58cc8450ea5"
threadwise_4d_tensor_op.hip.hpp 8.98 KB
Newer Older
1
#pragma once
2
#include "ConstantTensorDescriptor.hip.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
template <class Float, class Desc, class F>
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
9
10
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
11

Chao Liu's avatar
Chao Liu committed
12
    constexpr auto desc = Desc{};
13
14
15
16

#if 0
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
17
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
18
19
20
    }
#endif

Chao Liu's avatar
Chao Liu committed
21
    for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
22
    {
Chao Liu's avatar
Chao Liu committed
23
        for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
24
        {
Chao Liu's avatar
Chao Liu committed
25
            for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
26
            {
Chao Liu's avatar
Chao Liu committed
27
                for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
28
                {
Chao Liu's avatar
Chao Liu committed
29
                    const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
Chao Liu's avatar
Chao Liu committed
30

Chao Liu's avatar
Chao Liu committed
31
                    f(p[dindex]);
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
                }
            }
        }
    }
}

38
39
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
40
41
template <class SrcData,
          class DstData,
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
48
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class DstFromSrcReorder,
          class F>
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
    SrcDesc,
49
    const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
50
    DstDesc,
51
    DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
52
53
54
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
62
63
64
    constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
    constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
    constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
Chao Liu's avatar
Chao Liu committed
65

66
67
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
68
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
69

70
    for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
71
    {
72
        for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
73
        {
74
            for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
75
            {
76
                for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
77
                {
78
79
80
                    const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3);

                    const unsigned did[4] = {did0, did1, did2, did3};
Chao Liu's avatar
Chao Liu committed
81

82
                    const unsigned bindex =
Chao Liu's avatar
Chao Liu committed
83
                        dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
Chao Liu's avatar
Chao Liu committed
84

85
                    f(p_src[aindex], p_dst[bindex]);
86
87
88
89
90
                }
            }
        }
    }
}
Chao Liu's avatar
Chao Liu committed
91

92
93
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
Chao Liu's avatar
Chao Liu committed
94
{
95
    auto f_set_zero = [](Data& v) { v = Data(0); };
Chao Liu's avatar
Chao Liu committed
96

97
    threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Chao Liu's avatar
Chao Liu committed
98
        Desc{}, p, f_set_zero);
Chao Liu's avatar
Chao Liu committed
99
}
Chao Liu's avatar
Chao Liu committed
100

101
102
103
104
105
106
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class DstFromSrcReorder>
Chao Liu's avatar
Chao Liu committed
107
108
__device__ void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
109
                                                      const SrcData* __restrict__ p_src,
Chao Liu's avatar
Chao Liu committed
110
                                                      DstDesc,
111
                                                      DstData* __restrict__ p_dst,
Chao Liu's avatar
Chao Liu committed
112
113
                                                      SrcOpLengths,
                                                      DstFromSrcReorder)
114
{
115
    auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
116

Chao Liu's avatar
Chao Liu committed
117
118
    threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
119
120
}

121
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
122
__device__ void threadwise_4d_tensor_copy(
123
    SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
124
{
Chao Liu's avatar
Chao Liu committed
125
    auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
128
    threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
Chao Liu's avatar
Chao Liu committed
129
130
}

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
                                             const Float* __restrict__ p_src,
                                             DstDesc,
                                             Float* __restrict__ p_dst,
                                             SrcOpLengths,
                                             Number<DataPerRead>)
{
    using Float2 = float2;
    using Float4 = float4;

    static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
                      SrcOpLengths::nDim == 4,
                  "wrong! should be 4 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
                  "wrong! only support stride3 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I2) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

    constexpr unsigned L3 = SrcOpLengths{}.Get(I3);

    static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");

    constexpr unsigned nloop_d3 = L3 / DataPerRead;

    for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
    {
        for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
        {
            for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
            {
                for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
                {
                    const unsigned src_index =
                        src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

                    const unsigned dst_index =
                        dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);

                    if(DataPerRead == 1)
                    {
                        p_dst[dst_index] = p_src[src_index];
                    }
                    else if(DataPerRead == 2)
                    {
                        *(reinterpret_cast<Float2*>(p_dst + dst_index)) =
                            *(reinterpret_cast<const Float2*>(p_src + src_index));
                    }
                    else if(DataPerRead == 4)
                    {
                        *(reinterpret_cast<Float4*>(p_dst + dst_index)) =
                            *(reinterpret_cast<const Float4*>(p_src + src_index));
                    }
                    else
                    {
                        assert(false);
                    }
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
210
211
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
Chao Liu's avatar
Chao Liu committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
{
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};

    constexpr auto desc = Desc{};

#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
    }
#endif

Chao Liu's avatar
Chao Liu committed
227
228
229
230
    constexpr unsigned nshift = NShift::mValue;

    constexpr unsigned did0_end =
        is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
Chao Liu's avatar
Chao Liu committed
231

Chao Liu's avatar
Chao Liu committed
232
233
    constexpr unsigned did1_end =
        is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
234

Chao Liu's avatar
Chao Liu committed
235
236
    constexpr unsigned did2_end =
        is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
Chao Liu's avatar
Chao Liu committed
237

Chao Liu's avatar
Chao Liu committed
238
239
    constexpr unsigned did3_end =
        is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
240
241
242
243
244
245
246
247
248
249
250

    for(unsigned did0 = 0; did0 < did0_end; ++did0)
    {
        for(unsigned did1 = 0; did1 < did1_end; ++did1)
        {
            for(unsigned did2 = 0; did2 < did2_end; ++did2)
            {
                for(unsigned did3 = 0; did3 < did3_end; ++did3)
                {
                    const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);

Chao Liu's avatar
Chao Liu committed
251
                    const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
Chao Liu's avatar
Chao Liu committed
252
253
254
255
256
257

                    p[dindex] = p[sindex];
                }
            }
        }
    }
Chao Liu's avatar
Chao Liu committed
258
}