ConstantMergedTensorDescriptor.hip.hpp 4.88 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

Chao Liu's avatar
Chao Liu committed
14
15
    static constexpr index_t nDim         = sizeof...(OriginalDimMergeSeqs);
    static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
25
26

        // TODO: check there is no duplication in OriginalDimMergeSeqs
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
30
31
32
33
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
    {
        return OriginalTensorDesc{};
    }

34
35
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

36
37
    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
38
    {
39
        return std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
40
    }
41
42
43

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
44
    {
Chao Liu's avatar
Chao Liu committed
45
        return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
46
    }
Chao Liu's avatar
Chao Liu committed
47

48
49
50
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
51
        constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
52

53
54
55
56
57
58
59
60
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
        constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
63

64
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
65
66
    }

67
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
68
    {
Chao Liu's avatar
Chao Liu committed
69
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
70
71
    }

72
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
73
    {
74
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
75
76
    }

77
78
    __host__ __device__ static auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
79
    {
80
81
82
83
84
85
86
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
Chao Liu's avatar
Chao Liu committed
87
            const auto original_multi_id_partial =
88
89
90
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

Chao Liu's avatar
Chao Liu committed
91
92
93
            static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
                constexpr auto I                = decltype(I_){};
                constexpr index_t idim_original = original_dims_partial.Get(I);
94

Chao Liu's avatar
Chao Liu committed
95
96
                original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
            });
97
98
99
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
100
101
    }

102
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
103
    {
104
105
        const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

Chao Liu's avatar
Chao Liu committed
106
        return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
Chao Liu's avatar
Chao Liu committed
107
108
    }

Chao Liu's avatar
Chao Liu committed
109
    template <class... Is>
110
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
111
    {
112
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
113
114
    }

115
    __host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
116
    {
Chao Liu's avatar
Chao Liu committed
117
        constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths());
118
119

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
120
121
122
    }
};

123
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
124
125
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
                                                                       OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
126
{
127
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
128
}
Chao Liu's avatar
Chao Liu committed
129
130

template <class TDesc>
Chao Liu's avatar
Chao Liu committed
131
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
Chao Liu's avatar
Chao Liu committed
132
{
Chao Liu's avatar
Chao Liu committed
133
    print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
Chao Liu's avatar
Chao Liu committed
134
}