threadwise_nd_tensor_op.hip.hpp 12.5 KB
Newer Older
1
2
3
4
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
5
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
6
7
8
9
10
11
12
__device__ void threadwise_6d_tensor_copy(SrcDesc,
                                          const Float* __restrict__ p_src,
                                          DstDesc,
                                          Float* __restrict__ p_dst,
                                          SrcOpLengths,
                                          Number<DataPerRead>)
{
Chao Liu's avatar
Chao Liu committed
13
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
                      SrcOpLengths::nDim == 6,
                  "wrong! should be 6 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};
    constexpr auto I5 = Number<5>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1,
                  "wrong! only support stride5 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I4) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
40
    constexpr index_t L5 = SrcOpLengths{}.Get(I5);
41
42
43

    static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
44
    constexpr index_t nloop_d5 = L5 / DataPerRead;
45

Chao Liu's avatar
Chao Liu committed
46
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
47
    {
Chao Liu's avatar
Chao Liu committed
48
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
49
        {
Chao Liu's avatar
Chao Liu committed
50
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
51
            {
Chao Liu's avatar
Chao Liu committed
52
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
53
                {
Chao Liu's avatar
Chao Liu committed
54
                    for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
55
                    {
Chao Liu's avatar
Chao Liu committed
56
                        for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
57
                        {
Chao Liu's avatar
Chao Liu committed
58
                            const index_t src_index = src_desc.Get1dIndex(
59
60
                                did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
61
                            const index_t dst_index = dst_desc.Get1dIndex(
62
63
                                did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
64
65
                            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                                *(reinterpret_cast<const vector_t*>(p_src + src_index));
66
67
68
69
70
71
72
73
74
                        }
                    }
                }
            }
        }
    }
}

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
75
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
76
77
78
79
80
81
82
__device__ void threadwise_8d_tensor_copy(SrcDesc,
                                          const Float* __restrict__ p_src,
                                          DstDesc,
                                          Float* __restrict__ p_dst,
                                          SrcOpLengths,
                                          Number<DataPerRead>)
{
Chao Liu's avatar
Chao Liu committed
83
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
                      SrcOpLengths::nDim == 8,
                  "wrong! should be 8 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};
    constexpr auto I5 = Number<5>{};
    constexpr auto I6 = Number<6>{};
    constexpr auto I7 = Number<7>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1,
                  "wrong! only support stride7 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I6) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
112
    constexpr index_t L7 = SrcOpLengths{}.Get(I7);
113
114
115

    static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
116
    constexpr index_t nloop_d7 = L7 / DataPerRead;
117

Chao Liu's avatar
Chao Liu committed
118
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
119
    {
Chao Liu's avatar
Chao Liu committed
120
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
121
        {
Chao Liu's avatar
Chao Liu committed
122
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
123
            {
Chao Liu's avatar
Chao Liu committed
124
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
125
                {
Chao Liu's avatar
Chao Liu committed
126
                    for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
127
                    {
Chao Liu's avatar
Chao Liu committed
128
                        for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
129
                        {
Chao Liu's avatar
Chao Liu committed
130
                            for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
131
                            {
Chao Liu's avatar
Chao Liu committed
132
                                for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
133
                                {
Chao Liu's avatar
Chao Liu committed
134
                                    const index_t src_index =
135
136
137
138
139
140
141
142
143
                                        src_desc.Get1dIndex(did0,
                                                            did1,
                                                            did2,
                                                            did3,
                                                            did4,
                                                            did5,
                                                            did6,
                                                            iloop_d7 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
144
                                    const index_t dst_index =
145
146
147
148
149
150
151
152
153
                                        dst_desc.Get1dIndex(did0,
                                                            did1,
                                                            did2,
                                                            did3,
                                                            did4,
                                                            did5,
                                                            did6,
                                                            iloop_d7 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
154
155
                                    *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                                        *(reinterpret_cast<const vector_t*>(p_src + src_index));
156
157
158
159
160
161
162
163
164
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}
165
166
167
168
169
170
171
172
173
174
175
176
177

// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_10d_tensor_copy(SrcDesc,
                                           const Float* __restrict__ p_src,
                                           DstDesc,
                                           Float* __restrict__ p_dst,
                                           SrcOpLengths,
                                           Number<DataPerRead>)
{
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;

    static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 &&
178
                      SrcOpLengths::GetSize() == 10,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                  "wrong! should be 10 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};
    constexpr auto I5 = Number<5>{};
    constexpr auto I6 = Number<6>{};
    constexpr auto I7 = Number<7>{};
    constexpr auto I8 = Number<8>{};
    constexpr auto I9 = Number<9>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I9) == 1 && DstDesc{}.GetStride(I9) == 1,
                  "wrong! only support stride7 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I8) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I8) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

    constexpr index_t L9 = SrcOpLengths{}.Get(I9);

    static_assert(L9 % DataPerRead == 0, "wrong! L9 should be evenly divided by DataPerRead");

    constexpr index_t nloop_d9 = L9 / DataPerRead;

#pragma unroll
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
    {
#pragma unroll
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
        {
#pragma unroll
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
            {
#pragma unroll
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
                {
#pragma unroll
                    for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
                    {
#pragma unroll
                        for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
                        {
#pragma unroll
                            for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
                            {
#pragma unroll
                                for(index_t did7 = 0; did7 < ref_desc.GetLength(I7); ++did7)
                                {
#pragma unroll
                                    for(index_t did8 = 0; did8 < ref_desc.GetLength(I8); ++did8)
                                    {
#pragma unroll
                                        for(index_t iloop_d9 = 0; iloop_d9 < nloop_d9; ++iloop_d9)
                                        {
                                            const index_t src_index =
                                                src_desc.Get1dIndex(did0,
                                                                    did1,
                                                                    did2,
                                                                    did3,
                                                                    did4,
                                                                    did5,
                                                                    did6,
                                                                    did7,
                                                                    did8,
                                                                    iloop_d9 * DataPerRead);

                                            const index_t dst_index =
                                                dst_desc.Get1dIndex(did0,
                                                                    did1,
                                                                    did2,
                                                                    did3,
                                                                    did4,
                                                                    did5,
                                                                    did6,
                                                                    did7,
                                                                    did8,
                                                                    iloop_d9 * DataPerRead);

                                            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                                                *(reinterpret_cast<const vector_t*>(p_src +
                                                                                    src_index));
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}