"include/vscode:/vscode.git/clone" did not exist on "919aeb1f52150737151f9271014025941125b56f"
ConstantMergedTensorDescriptor.hip.hpp 3.26 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"

// TensorDesc: ConstantTensorDescriptor<...>
// MergedDimRanges: Sequence<FirstMergedDim, LastMergedDim>
template <class TensorDesc, class... MergedDimRanges>
struct ConstantMergedTensorDescriptor
{
    static constexpr index_t nOriginalDim = GetNumOfOriginalDimension();
    static constexpr index_t nDim         = GetNumOfDimension();

    template <class... Is>
    __host__ __device__ constexpr ConstantMergedTensorDescriptor()
    {
        constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges{}...);

        static_for<0, sizeof...(MergedDimRanges), 1>{}([&](auto I) {
            constexpr index_t i             = I.Get();
            constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);

            static_assert(merged_dim_range.GetSize() == 2,
                          "wrong! should specify first and last dimension to be merged");
            static_assert(merged_dim_range.Get(Number<0>{}) < GetNumOfUnmergedDimension(),
                          "wrong!");
            static_assert(merged_dim_range.Get(Number<1>{}) < GetNumOfUnmergedDimension(),
                          "wrong!");
            static_assert(merged_dim_range.Get(Number<0>{}) <= merged_dim_range.Get(Number<1>{}),
                          "wrong!");
        });
    }

    __host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
    {
        return TensorDesc::GetNumOfDimension();
    }

    __host__ __device__ static constexpr index_t GetNumOfDimension()
    {
        constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);

        struct f_calculate_num_of_lost_dim
        {
            __host__ __device__ constexpr index_t operator()(auto I) const
            {
                constexpr index_t i             = I.Get();
                constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);

                return merged_dim_range.Get(Number<1>{}) - merged_dim_range.Get(Number<0>{});
            }
        };

        constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
            f_calculate_num_of_lost_dim, mod_conv::plus<index_t>{});

        return TensorDesc::GetNumOfDimension() - num_lost_dim;
    }

    template <index_t IDim>
    __host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
    {
        // not implemented
    }

    template <index_t IDim>
    __host__ __device__ static constexpr bool GetLength(Number<IDim>)
    {
        // not implemented
    }

    template <index_t IDim>
    __host__ __device__ static constexpr bool GetStride(Number<IDim>)
    {
        static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! A merged dimension does not have uniform stride")
        // not implemented
    }

    template <class... Is>
    __host__ __device__ auto MultiIndex2OriginalMultiIndex(Is... is) const
    {
        // not implemented
    }

    template <class... Is>
    __host__ __device__ auto OriginalMultiIndex2MultiIndex(Is... is) const
    {
        // not implemented
    }
};

template <class TensorDesc, class... MergedDimRanges>
constexpr auto make_ConstantMergedTensorDescriptor(TensorDesc, MergedDimRanges...)
{
    return ConstantMergedTensorDescriptor<TensorDesc, MergedDimRanges...>{};
}