ConstantMergedTensorDescriptor.hip.hpp 8.07 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
14
15
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
25
26

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

34
35
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

36
37
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
38
    {
39
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
40
    }
41
42
43

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
44
    {
Chao Liu's avatar
Chao Liu committed
45
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
46
    }
Chao Liu's avatar
Chao Liu committed
47

48
49
50
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
51
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
52

53
54
55
56
57
58
59
60
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
63

64
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
65
66
    }

67
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
68
    {
Chao Liu's avatar
Chao Liu committed
69
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
70
71
    }

72
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
73
    {
74
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
75
76
    }

Chao Liu's avatar
Chao Liu committed
77
78
#if 0
    __host__ __device__ static constexpr auto
79
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
80
    {
81
82
83
84
85
86
87
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
Chao Liu's avatar
Chao Liu committed
88
            const auto original_multi_id_partial =
89
90
91
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

Chao Liu's avatar
Chao Liu committed
92
93
94
            static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
                constexpr auto I                = decltype(I_){};
                constexpr index_t idim_original = original_dims_partial.Get(I);
95

Chao Liu's avatar
Chao Liu committed
96
97
                original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
            });
98
99
100
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
101
    }
Chao Liu's avatar
Chao Liu committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#else
    template <class OriginalDimsPartial>
    struct GetOriginalMultiIndexFromMultiIndex_impl1
    {
        const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_ref;
        Array<index_t, nOriginalDim>& original_multi_id_ref;

        __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl1(
            const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial,
            Array<index_t, nOriginalDim>& original_multi_id)
            : original_multi_id_partial_ref(original_multi_id_partial),
              original_multi_id_ref(original_multi_id)
        {
        }

        template <index_t I>
        constexpr __host__ __device__ bool operator()(Number<I>) const
        {
            constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});

            index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});

            original_multi_id_ref.Set(Number<idim_original>{}, itmp);

            return true;
        }
    };

    struct GetOriginalMultiIndexFromMultiIndex_impl0
    {
        const Array<index_t, nDim>& multi_id_ref;
        Array<index_t, nOriginalDim>& original_multi_id_ref;

        __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl0(
            const Array<index_t, nDim>& multi_id, Array<index_t, nOriginalDim>& original_multi_id)
            : multi_id_ref(multi_id), original_multi_id_ref(original_multi_id)
        {
        }

        template <index_t IDim>
        constexpr __host__ __device__ bool operator()(Number<IDim>) const
        {
            constexpr auto original_dims_partial =
                std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});

            // get partial original-multi-id corresponding to this merged dimension
            const auto original_multi_id_partial =
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id_ref[IDim]);

            static_for<0, original_dims_partial.GetSize(), 1>{}(
                GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
                    original_multi_id_partial, original_multi_id_ref));

            return true;
        }
    };

    __host__ __device__ static constexpr auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}(
            GetOriginalMultiIndexFromMultiIndex_impl0(multi_id, original_multi_id));

        return original_multi_id;
    }

    template <index_t... Is>
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
    {
        constexpr auto multi_id = sequence2array(Sequence<Is...>{});

        constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
    }
#endif

#if 0
    // return type is Sequence<...>
    template <index_t... Is>
    __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
    {
        // not implemented
        return Sequence<>{};
    }
#endif
Chao Liu's avatar
Chao Liu committed
191

Chao Liu's avatar
Chao Liu committed
192
193
    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
194
    {
Chao Liu's avatar
Chao Liu committed
195
        auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
196

Chao Liu's avatar
Chao Liu committed
197
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
198
199
    }

Chao Liu's avatar
Chao Liu committed
200
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
201
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
202
    {
203
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
204
205
    }

Chao Liu's avatar
Chao Liu committed
206
    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
207
    {
Chao Liu's avatar
Chao Liu committed
208
        constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths());
209
210

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
211
212
213
    }
};

214
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
215
216
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
217
{
218
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
219
}
Chao Liu's avatar
Chao Liu committed
220
221

template <class TDesc>
Chao Liu's avatar
Chao Liu committed
222
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
Chao Liu's avatar
Chao Liu committed
223
{
Chao Liu's avatar
Chao Liu committed
224
    print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
Chao Liu's avatar
Chao Liu committed
225
}