ConstantMergedTensorDescriptor.hpp 6.89 KB
Newer Older
1
2
3
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "common.hpp"
#include "ConstantTensorDescriptor.hpp"
Chao Liu's avatar
Chao Liu committed
6

7
8
namespace ck {

9
10
11
12
13
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
14
15
struct ConstantMergedTensorDescriptor
{
Chao Liu's avatar
Chao Liu committed
16
17
    using Type = ConstantMergedTensorDescriptor;

18
19
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
20
21
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
22
23
24

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
25
26
27
28
29
30
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
31
32

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
33
34
    }

Chao Liu's avatar
Chao Liu committed
35
36
37
38
39
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

40
41
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

42
43
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
44
    {
45
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
46
    }
47
48
49

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
50
    {
Chao Liu's avatar
Chao Liu committed
51
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
52
    }
Chao Liu's avatar
Chao Liu committed
53

54
55
56
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
57
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
58

59
60
61
62
63
64
65
66
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
67

Chao Liu's avatar
Chao Liu committed
68
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
69

70
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
71
72
    }

73
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
74
    {
Chao Liu's avatar
Chao Liu committed
75
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
76
77
    }

78
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
79
    {
80
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
81
82
    }

Chao Liu's avatar
Chao Liu committed
83
    template <class OriginalDimsPartial>
Chao Liu's avatar
Chao Liu committed
84
    struct lambda_1_GetOriginalMultiIndexFromMultiIndex
Chao Liu's avatar
Chao Liu committed
85
    {
Chao Liu's avatar
Chao Liu committed
86
87
88
89
90
91
92
93
        const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
        Array<index_t, nOriginalDim>& original_multi_id;

        __host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
            const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
            Array<index_t, nOriginalDim>& original_multi_id_)
            : original_multi_id_partial(original_multi_id_partial_),
              original_multi_id(original_multi_id_)
Chao Liu's avatar
Chao Liu committed
94
95
96
97
        {
        }

        template <index_t I>
Chao Liu's avatar
Chao Liu committed
98
        __host__ __device__ constexpr void operator()(Number<I>) const
Chao Liu's avatar
Chao Liu committed
99
100
101
        {
            constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});

Chao Liu's avatar
Chao Liu committed
102
            index_t itmp = original_multi_id_partial[I];
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
            original_multi_id.Set(Number<idim_original>{}, itmp);
Chao Liu's avatar
Chao Liu committed
105
106
107
        }
    };

Chao Liu's avatar
Chao Liu committed
108
    struct lambda_0_GetOriginalMultiIndexFromMultiIndex
Chao Liu's avatar
Chao Liu committed
109
    {
Chao Liu's avatar
Chao Liu committed
110
111
        const Array<index_t, nDim>& multi_id;
        Array<index_t, nOriginalDim>& original_multi_id;
Chao Liu's avatar
Chao Liu committed
112

Chao Liu's avatar
Chao Liu committed
113
114
115
        __host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
            const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
            : multi_id(multi_id_), original_multi_id(original_multi_id_)
Chao Liu's avatar
Chao Liu committed
116
117
118
119
        {
        }

        template <index_t IDim>
Chao Liu's avatar
Chao Liu committed
120
        __host__ __device__ constexpr void operator()(Number<IDim>) const
Chao Liu's avatar
Chao Liu committed
121
        {
Chao Liu's avatar
Chao Liu committed
122
            constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
123
124
125
126

            // get partial original-multi-id corresponding to this merged dimension
            const auto original_multi_id_partial =
                OriginalTensorDesc::Extract(original_dims_partial)
Chao Liu's avatar
Chao Liu committed
127
                    .GetMultiIndexFrom1dIndex(multi_id[IDim]);
Chao Liu's avatar
Chao Liu committed
128
129

            static_for<0, original_dims_partial.GetSize(), 1>{}(
Chao Liu's avatar
Chao Liu committed
130
131
                lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
                    original_multi_id_partial, original_multi_id));
Chao Liu's avatar
Chao Liu committed
132
133
134
        }
    };

Chao Liu's avatar
Chao Liu committed
135
    // return type is Array<...>
Chao Liu's avatar
Chao Liu committed
136
137
138
139
140
141
    __host__ __device__ static constexpr auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}(
Chao Liu's avatar
Chao Liu committed
142
            lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
Chao Liu's avatar
Chao Liu committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        return original_multi_id;
    }

    template <index_t... Is>
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
    {
        constexpr auto multi_id = sequence2array(Sequence<Is...>{});

        constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
    }

    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
159
    {
Chao Liu's avatar
Chao Liu committed
160
        auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
161

Chao Liu's avatar
Chao Liu committed
162
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
163
164
    }

Chao Liu's avatar
Chao Liu committed
165
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
166
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
167
    {
168
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
169
170
    }

Chao Liu's avatar
Chao Liu committed
171
    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
172
    {
Chao Liu's avatar
Chao Liu committed
173
        constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
174

Chao Liu's avatar
Chao Liu committed
175
        return packed_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
176
177
178
    }
};

179
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
180
181
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
182
{
183
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
184
}
Chao Liu's avatar
Chao Liu committed
185
186

template <class TDesc>
Chao Liu's avatar
Chao Liu committed
187
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
Chao Liu's avatar
Chao Liu committed
188
{
Chao Liu's avatar
Chao Liu committed
189
    print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
Chao Liu's avatar
Chao Liu committed
190
}
191
192
193

} // namespace ck
#endif