ConstantMergedTensorDescriptor.hpp 6.77 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
3
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
Chao Liu's avatar
Chao Liu committed
4

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
Chao Liu's avatar
Chao Liu committed
12
13
    using Type = ConstantMergedTensorDescriptor;

14
15
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
16
17
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
18
19
20

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
21
22
23
24
25
26
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
27
28

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
29
30
    }

Chao Liu's avatar
Chao Liu committed
31
32
33
34
35
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

36
37
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

38
39
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
40
    {
41
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
42
    }
43
44
45

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
46
    {
Chao Liu's avatar
Chao Liu committed
47
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
48
    }
Chao Liu's avatar
Chao Liu committed
49

50
51
52
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
53
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
54

55
56
57
58
59
60
61
62
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
63

Chao Liu's avatar
Chao Liu committed
64
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
65

66
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
67
68
    }

69
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
70
    {
Chao Liu's avatar
Chao Liu committed
71
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
72
73
    }

74
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
75
    {
76
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
77
78
    }

Chao Liu's avatar
Chao Liu committed
79
    template <class OriginalDimsPartial>
Chao Liu's avatar
Chao Liu committed
80
    struct lambda_1_GetOriginalMultiIndexFromMultiIndex
Chao Liu's avatar
Chao Liu committed
81
    {
Chao Liu's avatar
Chao Liu committed
82
83
84
85
86
87
88
89
        const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
        Array<index_t, nOriginalDim>& original_multi_id;

        __host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
            const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
            Array<index_t, nOriginalDim>& original_multi_id_)
            : original_multi_id_partial(original_multi_id_partial_),
              original_multi_id(original_multi_id_)
Chao Liu's avatar
Chao Liu committed
90
91
92
93
        {
        }

        template <index_t I>
Chao Liu's avatar
Chao Liu committed
94
        __host__ __device__ constexpr void operator()(Number<I>) const
Chao Liu's avatar
Chao Liu committed
95
96
97
        {
            constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});

Chao Liu's avatar
Chao Liu committed
98
            index_t itmp = original_multi_id_partial[I];
Chao Liu's avatar
Chao Liu committed
99

Chao Liu's avatar
Chao Liu committed
100
            original_multi_id.Set(Number<idim_original>{}, itmp);
Chao Liu's avatar
Chao Liu committed
101
102
103
        }
    };

Chao Liu's avatar
Chao Liu committed
104
    struct lambda_0_GetOriginalMultiIndexFromMultiIndex
Chao Liu's avatar
Chao Liu committed
105
    {
Chao Liu's avatar
Chao Liu committed
106
107
        const Array<index_t, nDim>& multi_id;
        Array<index_t, nOriginalDim>& original_multi_id;
Chao Liu's avatar
Chao Liu committed
108

Chao Liu's avatar
Chao Liu committed
109
110
111
        __host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
            const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
            : multi_id(multi_id_), original_multi_id(original_multi_id_)
Chao Liu's avatar
Chao Liu committed
112
113
114
115
        {
        }

        template <index_t IDim>
Chao Liu's avatar
Chao Liu committed
116
        __host__ __device__ constexpr void operator()(Number<IDim>) const
Chao Liu's avatar
Chao Liu committed
117
        {
Chao Liu's avatar
Chao Liu committed
118
            constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
119
120
121
122

            // get partial original-multi-id corresponding to this merged dimension
            const auto original_multi_id_partial =
                OriginalTensorDesc::Extract(original_dims_partial)
Chao Liu's avatar
Chao Liu committed
123
                    .GetMultiIndexFrom1dIndex(multi_id[IDim]);
Chao Liu's avatar
Chao Liu committed
124
125

            static_for<0, original_dims_partial.GetSize(), 1>{}(
Chao Liu's avatar
Chao Liu committed
126
127
                lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
                    original_multi_id_partial, original_multi_id));
Chao Liu's avatar
Chao Liu committed
128
129
130
        }
    };

Chao Liu's avatar
Chao Liu committed
131
    // return type is Array<...>
Chao Liu's avatar
Chao Liu committed
132
133
134
135
136
137
    __host__ __device__ static constexpr auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}(
Chao Liu's avatar
Chao Liu committed
138
            lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
Chao Liu's avatar
Chao Liu committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        return original_multi_id;
    }

    template <index_t... Is>
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
    {
        constexpr auto multi_id = sequence2array(Sequence<Is...>{});

        constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
    }

    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
155
    {
Chao Liu's avatar
Chao Liu committed
156
        auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
157

Chao Liu's avatar
Chao Liu committed
158
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
159
160
    }

Chao Liu's avatar
Chao Liu committed
161
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
162
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
163
    {
164
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
165
166
    }

Chao Liu's avatar
Chao Liu committed
167
    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
168
    {
Chao Liu's avatar
Chao Liu committed
169
        constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
170

Chao Liu's avatar
Chao Liu committed
171
        return packed_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
172
173
174
    }
};

175
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
176
177
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
178
{
179
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
180
}
Chao Liu's avatar
Chao Liu committed
181
182

template <class TDesc>
Chao Liu's avatar
Chao Liu committed
183
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
Chao Liu's avatar
Chao Liu committed
184
{
Chao Liu's avatar
Chao Liu committed
185
    print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
Chao Liu's avatar
Chao Liu committed
186
}