ConstantMergedTensorDescriptor.hip.hpp 3.25 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

// TensorDesc: ConstantTensorDescriptor<...>
// MergedDimRanges: Sequence<FirstMergedDim, LastMergedDim>
template <class TensorDesc, class... MergedDimRanges>
struct ConstantMergedTensorDescriptor
{
    static constexpr index_t nOriginalDim = GetNumOfOriginalDimension();
    static constexpr index_t nDim         = GetNumOfDimension();

    template <class... Is>
    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
        constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges{}...);

        static_for<0, sizeof...(MergedDimRanges), 1>{}([&](auto I) {
            constexpr index_t i             = I.Get();
            constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);

            static_assert(merged_dim_range.GetSize() == 2,
                          "wrong! should specify first and last dimension to be merged");
            static_assert(merged_dim_range.Get(Number<0>{}) < GetNumOfUnmergedDimension(),
                          "wrong!");
            static_assert(merged_dim_range.Get(Number<1>{}) < GetNumOfUnmergedDimension(),
                          "wrong!");
            static_assert(merged_dim_range.Get(Number<0>{}) <= merged_dim_range.Get(Number<1>{}),
                          "wrong!");
        });
    }

    __host__ __device__ static constexpr index_t GetNumOfDimension()
    {
        constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);

        struct f_calculate_num_of_lost_dim
        {
            __host__ __device__ constexpr index_t operator()(auto I) const
            {
                constexpr index_t i             = I.Get();
                constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);

                return merged_dim_range.Get(Number<1>{}) - merged_dim_range.Get(Number<0>{});
            }
        };

        constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
Chao Liu's avatar
Chao Liu committed
49
            f_calculate_num_of_lost_dim, std::plus<index_t>{});
Chao Liu's avatar
Chao Liu committed
50
51
52
53

        return TensorDesc::GetNumOfDimension() - num_lost_dim;
    }

Chao Liu's avatar
Chao Liu committed
54
55
56
57
58
    __host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
    {
        return TensorDesc::GetNumOfDimension();
    }

Chao Liu's avatar
Chao Liu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    template <index_t IDim>
    __host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
    {
        // not implemented
    }

    template <index_t IDim>
    __host__ __device__ static constexpr bool GetLength(Number<IDim>)
    {
        // not implemented
    }

    template <index_t IDim>
    __host__ __device__ static constexpr bool GetStride(Number<IDim>)
    {
Chao Liu's avatar
Chao Liu committed
74
        static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! stride of a merged dimension is undefined")
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        // not implemented
    }

    template <class... Is>
    __host__ __device__ auto MultiIndex2OriginalMultiIndex(Is... is) const
    {
        // not implemented
    }

    template <class... Is>
    __host__ __device__ auto OriginalMultiIndex2MultiIndex(Is... is) const
    {
        // not implemented
    }
};

template <class TensorDesc, class... MergedDimRanges>
constexpr auto make_ConstantMergedTensorDescriptor(TensorDesc, MergedDimRanges...)
{
    return ConstantMergedTensorDescriptor<TensorDesc, MergedDimRanges...>{};
}