ConstantMatrixDescriptor.cuh 1.55 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
#include "common.cuh"

template <unsigned NRow, unsigned NCol, unsigned RowStride>
struct ConstantMatrixDescriptor
{
    __host__ __device__ ConstantMatrixDescriptor()
    {
        static_assert(NCol <= RowStride, "wrong! NCol > RowStride!");
    }

    __host__ __device__ constexpr unsigned GetNumberOfRow() const { return NRow; }

    __host__ __device__ constexpr unsigned GetNumberOfColumn() const { return NCol; }

    __host__ __device__ constexpr unsigned GetRowStride() const { return RowStride; }

    __host__ __device__ constexpr unsigned GetElementSize() const { return NRow * NCol; }

    __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow * RowStride; }

    __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
    {
        return irow * RowStride + icol;
    }

    template <unsigned SubNRow, unsigned SubNCol>
    __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
                                                               Number<SubNCol>) const
    {
        return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride>{};
    }
};

template <unsigned NRow, unsigned NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
    return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}

template <unsigned NRow, unsigned NCol, unsigned RowStride>
__host__ __device__ constexpr auto
    make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{
    return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
}