ConstantMatrixDescriptor.hpp 1.95 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hpp"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
template <index_t NRow_, index_t NCol_, index_t RowStride_>
Chao Liu's avatar
Chao Liu committed
5
6
struct ConstantMatrixDescriptor
{
Chao Liu's avatar
Chao Liu committed
7
    __host__ __device__ constexpr ConstantMatrixDescriptor()
Chao Liu's avatar
Chao Liu committed
8
    {
Chao Liu's avatar
Chao Liu committed
9
        static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
Chao Liu's avatar
Chao Liu committed
10
11
    }

Chao Liu's avatar
Chao Liu committed
12
    __host__ __device__ static constexpr index_t NRow() { return NRow_; }
Chao Liu's avatar
Chao Liu committed
13

Chao Liu's avatar
Chao Liu committed
14
    __host__ __device__ static constexpr index_t NCol() { return NCol_; }
Chao Liu's avatar
Chao Liu committed
15

Chao Liu's avatar
Chao Liu committed
16
    __host__ __device__ static constexpr index_t RowStride() { return RowStride_; }
Chao Liu's avatar
Chao Liu committed
17

Chao Liu's avatar
Chao Liu committed
18
    __host__ __device__ static constexpr auto GetLengths() { return Sequence<NRow_, NCol_>{}; }
Chao Liu's avatar
Chao Liu committed
19

Chao Liu's avatar
Chao Liu committed
20
    __host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; }
Chao Liu's avatar
Chao Liu committed
21

Chao Liu's avatar
Chao Liu committed
22
    __host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; }
Chao Liu's avatar
Chao Liu committed
23

Chao Liu's avatar
Chao Liu committed
24
    __host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol)
Chao Liu's avatar
Chao Liu committed
25
    {
Chao Liu's avatar
Chao Liu committed
26
        return irow * RowStride_ + icol;
Chao Liu's avatar
Chao Liu committed
27
28
    }

Chao Liu's avatar
Chao Liu committed
29
    template <index_t SubNRow, index_t SubNCol>
Chao Liu's avatar
Chao Liu committed
30
31
    __host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
                                                                      Number<SubNCol>)
Chao Liu's avatar
Chao Liu committed
32
    {
Chao Liu's avatar
Chao Liu committed
33
        return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
Chao Liu's avatar
Chao Liu committed
34
35
36
    }
};

Chao Liu's avatar
Chao Liu committed
37
template <index_t NRow, index_t NCol>
Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
    return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}

Chao Liu's avatar
Chao Liu committed
43
template <index_t NRow, index_t NCol, index_t RowStride>
Chao Liu's avatar
Chao Liu committed
44
45
46
47
48
__host__ __device__ constexpr auto
    make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
    return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
54
55
56
57
58
59

template <class TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{
    const auto desc = TDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};

    printf("%s NRow %u NCol %u RowStride %u\n", s, desc.NRow(), desc.NCol(), desc.RowStride());
}