ConstantMergedTensorDescriptor.hip.hpp 7.81 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
14
15
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
25
26

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

34
35
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

36
37
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
38
    {
39
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
40
    }
41
42
43

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
44
    {
Chao Liu's avatar
Chao Liu committed
45
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
46
    }
Chao Liu's avatar
Chao Liu committed
47

48
49
50
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
51
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
52

53
54
55
56
57
58
59
60
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
63

64
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
65
66
    }

67
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
68
    {
Chao Liu's avatar
Chao Liu committed
69
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
70
71
    }

72
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
73
    {
74
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
75
76
    }

Chao Liu's avatar
Chao Liu committed
77
78
#if 0
    __host__ __device__ static constexpr auto
79
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
80
    {
81
82
83
84
85
86
87
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
Chao Liu's avatar
Chao Liu committed
88
            const auto original_multi_id_partial =
89
90
91
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

Chao Liu's avatar
Chao Liu committed
92
93
94
            static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
                constexpr auto I                = decltype(I_){};
                constexpr index_t idim_original = original_dims_partial.Get(I);
95

Chao Liu's avatar
Chao Liu committed
96
97
                original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
            });
98
99
100
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
101
    }
Chao Liu's avatar
Chao Liu committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#else
    template <class OriginalDimsPartial>
    struct GetOriginalMultiIndexFromMultiIndex_impl1
    {
        const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_ref;
        Array<index_t, nOriginalDim>& original_multi_id_ref;

        __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl1(
            const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial,
            Array<index_t, nOriginalDim>& original_multi_id)
            : original_multi_id_partial_ref(original_multi_id_partial),
              original_multi_id_ref(original_multi_id)
        {
        }

        template <index_t I>
Chao Liu's avatar
Chao Liu committed
118
        __host__ __device__ constexpr void operator()(Number<I>) const
Chao Liu's avatar
Chao Liu committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        {
            constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});

            index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});

            original_multi_id_ref.Set(Number<idim_original>{}, itmp);
        }
    };

    struct GetOriginalMultiIndexFromMultiIndex_impl0
    {
        const Array<index_t, nDim>& multi_id_ref;
        Array<index_t, nOriginalDim>& original_multi_id_ref;

        __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl0(
            const Array<index_t, nDim>& multi_id, Array<index_t, nOriginalDim>& original_multi_id)
            : multi_id_ref(multi_id), original_multi_id_ref(original_multi_id)
        {
        }

        template <index_t IDim>
Chao Liu's avatar
Chao Liu committed
140
        __host__ __device__ constexpr void operator()(Number<IDim>) const
Chao Liu's avatar
Chao Liu committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        {
            constexpr auto original_dims_partial =
                std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});

            // get partial original-multi-id corresponding to this merged dimension
            const auto original_multi_id_partial =
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id_ref[IDim]);

            static_for<0, original_dims_partial.GetSize(), 1>{}(
                GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
                    original_multi_id_partial, original_multi_id_ref));
        }
    };

Chao Liu's avatar
Chao Liu committed
156
    // return type is Array<...>
Chao Liu's avatar
Chao Liu committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    __host__ __device__ static constexpr auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}(
            GetOriginalMultiIndexFromMultiIndex_impl0(multi_id, original_multi_id));

        return original_multi_id;
    }

    template <index_t... Is>
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
    {
        constexpr auto multi_id = sequence2array(Sequence<Is...>{});

        constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
    }
#endif

    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
181
    {
Chao Liu's avatar
Chao Liu committed
182
        auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
183

Chao Liu's avatar
Chao Liu committed
184
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
185
186
    }

Chao Liu's avatar
Chao Liu committed
187
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
188
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
189
    {
190
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
191
192
    }

Chao Liu's avatar
Chao Liu committed
193
    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
194
    {
Chao Liu's avatar
Chao Liu committed
195
        constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths());
196
197

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
198
199
200
    }
};

201
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
202
203
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
204
{
205
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
206
}
Chao Liu's avatar
Chao Liu committed
207
208

template <class TDesc>
Chao Liu's avatar
Chao Liu committed
209
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
Chao Liu's avatar
Chao Liu committed
210
{
Chao Liu's avatar
Chao Liu committed
211
    print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
Chao Liu's avatar
Chao Liu committed
212
}