threadwise_tensor_slice_op.hip.hpp 9.81 KB
Newer Older
1
2
3
4
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
5
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
11
__device__ void threadwise_tensor_slice_copy(SrcDesc,
                                             const Float* __restrict__ p_src,
                                             DstDesc,
                                             Float* __restrict__ p_dst,
                                             SrcOpLengths,
                                             Number<DataPerRead>)
12
{
Chao Liu's avatar
Chao Liu committed
13
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
14

15
    constexpr index_t nDim = SrcOpLengths::GetSize();
16

Chao Liu's avatar
Chao Liu committed
17
    static_assert(SrcDesc{}.GetNumOfDimension() == nDim && DstDesc{}.GetNumOfDimension() == nDim,
18
                  "wrong! dimension not consistent");
19
20
21

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
22
    constexpr auto ref_desc = make_ConstantTensorDescriptor_default_rank_packed(SrcOpLengths{});
23

24
25
#if 0
    if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
26
    {
27
28
29
        print_ConstantTensorDescriptor(src_desc, "src_desc");
        print_ConstantTensorDescriptor(dst_desc, "dst_desc");
        print_ConstantTensorDescriptor(ref_desc, "ref_desc");
30
    }
31
#endif
32

33
34
35
    static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
                                       DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
                  "wrong! only support stride[nDim-1] == 1!\n");
36
37
38
39

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

40
41
42
43
    static_assert(
        SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
            DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
        "wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
44

45
    constexpr index_t L_Back = SrcOpLengths{}.Back();
46

47
48
    static_assert(L_Back % DataPerRead == 0,
                  "wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
49

50
    constexpr index_t nRead = L_Back / DataPerRead;
51

52
    static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
Chao Liu's avatar
Chao Liu committed
53
        static_for<0, nRead, 1>{}([&](auto IRead) {
54
            constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
55

56
            const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
57

58
            const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(multi_id);
59

60
61
62
63
            *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
                *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
        });
    });
64
}
Chao Liu's avatar
Chao Liu committed
65

Chao Liu's avatar
Chao Liu committed
66
// access in order of src
Chao Liu's avatar
Chao Liu committed
67
68
69
70
71
72
73
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class MapDst2Src>
__device__ void
Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
79
threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
                                                      const SrcData* __restrict__ p_src,
                                                      DstDesc,
                                                      DstData* __restrict__ p_dst,
                                                      SrcOpLengths,
                                                      MapDst2Src)
Chao Liu's avatar
Chao Liu committed
80
81
82
83
84
85
86
{
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

    ford<SrcOpLengths>{}([&](auto src_multi_id) {
        const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});

87
        const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
Chao Liu's avatar
Chao Liu committed
88

89
        const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
Chao Liu's avatar
Chao Liu committed
90
91
92
93
94

        p_dst[dst_index] = p_src[src_index];
    });
}

Chao Liu's avatar
Chao Liu committed
95
// access in order of dst
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
101
102
template <class SrcData,
          class DstData,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class MapDst2Src>
__device__ void
Chao Liu's avatar
Chao Liu committed
103
104
105
106
107
108
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
                                                      const SrcData* __restrict__ p_src,
                                                      DstDesc,
                                                      DstData* __restrict__ p_dst,
                                                      SrcOpLengths,
                                                      MapDst2Src)
Chao Liu's avatar
Chao Liu committed
109
110
111
112
113
114
115
116
117
{
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

    constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});

    ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
        const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});

118
        const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
Chao Liu's avatar
Chao Liu committed
119

120
        const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
Chao Liu's avatar
Chao Liu committed
121
122
123
124
125

        p_dst[dst_index] = p_src[src_index];
    });
}

Chao Liu's avatar
Chao Liu committed
126
127
// access in order of dst
// manually pack data into vector before write
Chao Liu's avatar
Chao Liu committed
128
129
130
131
132
133
template <class Float,
          class SrcDesc,
          class DstDesc,
          class SrcOpLengths,
          class MapDst2Src,
          index_t DstDataPerWrite>
Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
139
140
141
__device__ void
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
                                                      const Float* __restrict__ p_src,
                                                      DstDesc,
                                                      Float* __restrict__ p_dst,
                                                      SrcOpLengths,
                                                      MapDst2Src,
                                                      Number<DstDataPerWrite>)
Chao Liu's avatar
Chao Liu committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
{
    using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;

    constexpr index_t nDim = SrcOpLengths::GetSize();

    static_assert(DstDataPerWrite == 1 || DstDesc{}.GetStride(Number<nDim - 1>{}) == 1,
                  "wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1");

    static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
                  "wrong! only support DstDataPerWrite == 1, 2 or 4");

    static_assert(
        DstDesc{}.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
        "wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment");

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

    constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});

    constexpr index_t L_Dst_Back = dst_op_lengths.Back();

    static_assert(L_Dst_Back % DstDataPerWrite == 0,
                  "wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite");

    constexpr index_t nWrite = L_Dst_Back / DstDataPerWrite;

    ford<decltype(dst_op_lengths.PopBack())>{}([&](auto ids) {
        static_for<0, nWrite, 1>{}([&](auto IWrite) {
            vector_t dst_vec_data;

            // pack data
            static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
                const auto dst_multi_id =
                    ids.PushBack(IWrite.Get() * DstDataPerWrite + IDstData.Get());

                const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});

180
                const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
Chao Liu's avatar
Chao Liu committed
181
182
183
184
185
186
187
188

                vector_type<Float, DstDataPerWrite>::SetScalar(
                    dst_vec_data, p_src[src_index], IDstData);
            });

            // write data
            const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite);

189
            const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
Chao Liu's avatar
Chao Liu committed
190
191
192
193
194

            *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
        });
    });
}
Chao Liu's avatar
Chao Liu committed
195
196

template <class Float, class SrcDesc, class DstDesc, class SliceLengths, class DimAccessOrder>
Chao Liu's avatar
Chao Liu committed
197
__device__ void threadwise_generic_tensor_slice_copy(
Chao Liu's avatar
Chao Liu committed
198
199
200
201
202
203
204
205
    SrcDesc,
    const Float* __restrict__ p_src,
    Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
    DstDesc,
    Float* __restrict__ p_dst,
    Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_id_begin,
    SliceLengths,
    DimAccessOrder)
Chao Liu's avatar
Chao Liu committed
206
{
Chao Liu's avatar
Chao Liu committed
207
208
209
210
    constexpr index_t nDim = SrcDesc::GetNumOfDimension();

    static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
                      nDim == SliceLengths::GetSize() && nDim == DimAccessOrder::GetSize(),
Chao Liu's avatar
Chao Liu committed
211
212
                  "wrong! # of dimensions not the same");

Chao Liu's avatar
Chao Liu committed
213
    static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
214
215

    constexpr auto slice_lengths_in_access_order =
Chao Liu's avatar
Chao Liu committed
216
        SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
217

Chao Liu's avatar
Chao Liu committed
218
#if 1
219
220
221
222
223
    ford<decltype(slice_lengths_in_access_order)>{}([&](auto data_multi_id_in_access_order) {
        const auto data_multi_id =
            reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});

        const index_t src_index =
Chao Liu's avatar
Chao Liu committed
224
225
226
227
            SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);

        const index_t dst_index =
            DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
228
229
230

        p_dst[dst_index] = p_src[src_index];
    });
Chao Liu's avatar
Chao Liu committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#else
    static_ford<decltype(slice_lengths_in_access_order)>{}(
        [&](auto data_multi_id_in_access_order_) {
            constexpr auto data_multi_id_in_access_order =
                sequence2array(decltype(data_multi_id_in_access_order_){});

            const auto data_multi_id =
                reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});

            const index_t src_index =
                SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);

            const index_t dst_index =
                DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);

            p_dst[dst_index] = p_src[src_index];
        });
#endif
Chao Liu's avatar
Chao Liu committed
249
}