ConstantMergedTensorDescriptor.hip.hpp 4.41 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

5
6
7
8
9
// OriginalTensorDesc : ConstantTensorDescriptor<...>
//     it's the tensor whose dimensions are to be merged
// OriginalDimMergeSeqs : Sequence<...>...
//     each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
Chao Liu's avatar
Chao Liu committed
10
11
struct ConstantMergedTensorDescriptor
{
12
13
14
15
    static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};

    static constexpr index_t nDim         = std::tuple_size<mOriginalDimMergeSeqs>::value;
    static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension();
Chao Liu's avatar
Chao Liu committed
16
17
18

    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
19
20
21
22
23
24
25
26
        static_assert(nDim <= nOriginalDim, "wrong!");

        // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
        // OriginalTensorDesc::nDim number of dimensions

        // TODO: check there is no duplication in OriginalDimMergeSeqs

        // TODO: check OriginalDimMergeSeqs contains all original dimensions
Chao Liu's avatar
Chao Liu committed
27
28
    }

29
30
31
32
33
34
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }

    __host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim }

    template <index_t IDim>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
Chao Liu's avatar
Chao Liu committed
35
    {
36
37
        return (std::Get<IDIM>(mOriginalDimMergeSeqs).GetSize() > 1);
    }
Chao Liu's avatar
Chao Liu committed
38

39
40
41
42
    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetLength(Number<IDim>)
    {
        constexpr auto original_dims_partial = std::Get<IDim>(mOriginalDimMergeSeqs);
Chao Liu's avatar
Chao Liu committed
43

44
45
46
47
48
49
50
51
        return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
    }

    template <index_t IDim>
    __host__ __device__ static constexpr index_t GetStride(Number<IDim>)
    {
        static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
                      "wrong! stride of a merged dimension is undefined");
Chao Liu's avatar
Chao Liu committed
52

53
        constexpr auto idim_original = std::Get<IDim>(mOriginalDimMergeSeqs).Front();
Chao Liu's avatar
Chao Liu committed
54

55
        return OriginalTensorDesc::GetStride(Number<idim_original>{});
Chao Liu's avatar
Chao Liu committed
56
57
    }

58
    __host__ __device__ static constexpr auto GetLengths()
Chao Liu's avatar
Chao Liu committed
59
    {
60
        return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs).GetElementSize()...>{};
Chao Liu's avatar
Chao Liu committed
61
62
    }

63
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
64
    {
65
        return OriginalTensorDesc::GetElementSize();
Chao Liu's avatar
Chao Liu committed
66
67
    }

68
69
    __host__ __device__ static auto
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
70
    {
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        Array<index_t, nOriginalDim> original_multi_id;

        static_for<0, nDim, 1>{}([&](auto IDim) {
            constexpr index_t idim               = IDim.Get();
            constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);

            // get partial original-multi-id corresponding to this merged dimension
            constexpr auto original_multi_id_partial =
                OriginalTensorDesc::Extract(original_dims_partial)
                    .GetMultiIndexFrom1dIndex(multi_id[idim]);

            // make sure compiler unroll this loop and propagate all the constants
            for(index_t i = 0; i < original_dims_partial.GetSize(); ++i)
            {
                index_t idim_original = original_dims_partial[i];

                original_multi_id[idim_original] = original_multi_id_partial[i]
            }
        });

        return original_multi_id;
Chao Liu's avatar
Chao Liu committed
92
93
    }

94
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
Chao Liu's avatar
Chao Liu committed
95
    {
96
97
98
        const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);

        return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id);
Chao Liu's avatar
Chao Liu committed
99
100
    }

101
102
    template <index_t... Is>
    __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
Chao Liu's avatar
Chao Liu committed
103
    {
104
        return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
Chao Liu's avatar
Chao Liu committed
105
106
    }

107
    __host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
Chao Liu's avatar
Chao Liu committed
108
    {
109
110
111
        constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths());

        return dummy_desc.GetMultiIndexFrom1dIndex(id);
Chao Liu's avatar
Chao Liu committed
112
113
114
    }
};

115
116
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...)
Chao Liu's avatar
Chao Liu committed
117
{
118
    return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
Chao Liu's avatar
Chao Liu committed
119
}