ConstantMergedTensorDescriptor.hip.hpp 6.39 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
14
15
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
25
26

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

34
35
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

36
37
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
38
    {
39
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
40
    }
41
42
43

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
44
    {
Chao Liu's avatar
Chao Liu committed
45
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
46
    }
Chao Liu's avatar
Chao Liu committed
47

48
49
50
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
51
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
52

53
54
55
56
57
58
59
60
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
63

64
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
65
66
    }

67
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
68
    {
Chao Liu's avatar
Chao Liu committed
69
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
70
71
    }

72
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
73
    {
74
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
75
76
    }

77
78
    __host__ __device__ static auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
79
    {
80
81
82
83
84
85
86
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
Chao Liu's avatar
Chao Liu committed
87
            const auto original_multi_id_partial =
88
89
90
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

Chao Liu's avatar
Chao Liu committed
91
92
93
            static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
                constexpr auto I                = decltype(I_){};
                constexpr index_t idim_original = original_dims_partial.Get(I);
94

Chao Liu's avatar
Chao Liu committed
95
96
                original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
            });
97
98
99
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
100
101
    }

102
103
104
105
106
107
108
109
110
#if 0 // not needed
    __host__ __device__ static index_t
    GetOffsetFromOriginalMultiIndex(Array<index_t, nOriginalDim> original_multi_id)
    {
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
    }
#endif

    __host__ __device__ static index_t GetOffsetFromMultiIndexA(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
111
    {
112
113
        const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

Chao Liu's avatar
Chao Liu committed
114
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
115
116
    }

Chao Liu's avatar
Chao Liu committed
117
    template <class... Is>
118
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
119
    {
120
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
121
122
    }

123
    __host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
124
    {
Chao Liu's avatar
Chao Liu committed
125
        constexpr auto dummy_desc = make_ConstantTensorDescriptor_default_rank_packed(GetLengths());
126
127

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
128
    }
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

#if 0 // not needed
    template <index_t IDim>
    __host__ __device__ static index_t GetNewOriginalMultiIndexAfterMovingAlongOneDimension(
        Array<index_t, nOriginalDim> old_original_multi_id, Number<IDim>, index_t step_size)
    {
        auto new_original_multi_id = old_original_multi_id;

        // get partial-original-multi-id corresponding to this merged dimension
        constexpr auto original_partial_dims = std::get<IDim>(mOriginalDimMergeSeqs);

        constexpr auto original_partial_tensor_desc =
            OriginalTensorDesc::Extract(original_partial_dims);

        auto old_original_partial_multi_id =
            extract_array(old_original_mutli_id, original_paritial_dims);

        auto new_original_partial_multi_id =
            original_partial_tensor_desc.GetNewMultiIndexGivenStepSizeOf1dIndex(
                old_original_partial_multi_id, step_size);

        // update original-mutli-id
        static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
            constexpr auto I                = decltype(I_){};
            constexpr index_t idim_original = original_dims_partial.Get(I);

            new_original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
        });

        return new_original_multi_id;
    }
#endif
Chao Liu's avatar
Chao Liu committed
161
162
};

163
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
164
165
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
166
{
167
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
168
}
Chao Liu's avatar
Chao Liu committed
169
170
171
172
173
174

template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(TDesc, const char* s)
{
    print_ConstantTensorDescriptor(TDesc::GetOriginalTensorDescriptor(), s);
}