ConstantMatrixDescriptor.hip.hpp 1.94 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "common.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_>
Chao Liu's avatar
Chao Liu committed
5
6
struct ConstantMatrixDescriptor
{
Chao Liu's avatar
Chao Liu committed
7
    __host__ __device__ constexpr ConstantMatrixDescriptor()
Chao Liu's avatar
Chao Liu committed
8
    {
Chao Liu's avatar
Chao Liu committed
9
        static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
Chao Liu's avatar
Chao Liu committed
10
11
    }

Chao Liu's avatar
Chao Liu committed
12
13
14
    __host__ __device__ constexpr unsigned NRow() const { return NRow_; }

    __host__ __device__ constexpr unsigned NCol() const { return NCol_; }
Chao Liu's avatar
Chao Liu committed
15

Chao Liu's avatar
Chao Liu committed
16
    __host__ __device__ constexpr unsigned RowStride() const { return RowStride_; }
Chao Liu's avatar
Chao Liu committed
17

Chao Liu's avatar
Chao Liu committed
18
    __host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
Chao Liu's avatar
Chao Liu committed
19

Chao Liu's avatar
Chao Liu committed
20
    __host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; }
Chao Liu's avatar
Chao Liu committed
21

Chao Liu's avatar
Chao Liu committed
22
    __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; }
Chao Liu's avatar
Chao Liu committed
23
24
25

    __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
    {
Chao Liu's avatar
Chao Liu committed
26
        return irow * RowStride_ + icol;
Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
32
    }

    template <unsigned SubNRow, unsigned SubNCol>
    __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
                                                               Number<SubNCol>) const
    {
Chao Liu's avatar
Chao Liu committed
33
        return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
Chao Liu's avatar
Chao Liu committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    }
};

template <unsigned NRow, unsigned NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
    return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}

template <unsigned NRow, unsigned NCol, unsigned RowStride>
__host__ __device__ constexpr auto
    make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
    return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}
Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
54
55
56
57
58
59

template <class TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{
    const auto desc = TDesc{};

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};

    printf("%s NRow %u NCol %u RowStride %u\n", s, desc.NRow(), desc.NCol(), desc.RowStride());
}