"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "14f45075d8a55f613b8169414b8453fae6651199"
ConstantTensorDescriptor.cuh 4.17 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "common.cuh"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
    return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}

// this is ugly, only for 4d
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
{
    static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");

    return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
}

Chao Liu's avatar
Chao Liu committed
20
21
22
23
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
    static constexpr unsigned nDim = Lengths::nDim;
Chao Liu's avatar
Chao Liu committed
24
    using NDimConstant             = Number<nDim>;
Chao Liu's avatar
Chao Liu committed
25
26
27
28
29
30
31
32
33
34
35
36
37

    __host__ __device__ constexpr ConstantTensorDescriptor()
    {
        static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
    }

    __host__ __device__ constexpr unsigned GetDimension() const { return nDim; }

    __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }

    __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
38
    __host__ __device__ constexpr unsigned GetLength(Number<I>) const
Chao Liu's avatar
Chao Liu committed
39
    {
Chao Liu's avatar
Chao Liu committed
40
        return Lengths{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
41
42
43
    }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
44
    __host__ __device__ constexpr unsigned GetStride(Number<I>) const
Chao Liu's avatar
Chao Liu committed
45
    {
Chao Liu's avatar
Chao Liu committed
46
        return Strides{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
47
48
    }

Chao Liu's avatar
Chao Liu committed
49
50
51
52
53
    // this is ugly, only for 4d
    __host__ __device__ constexpr unsigned GetElementSize() const
    {
        static_assert(nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
54
55
56
57
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
58
59
60
61
62
63

        return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
    }

    // this is ugly, only for 4d
    __host__ __device__ constexpr unsigned GetElementSpace() const
Chao Liu's avatar
Chao Liu committed
64
    {
Chao Liu's avatar
Chao Liu committed
65
66
        static_assert(nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
67
68
69
70
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
71
72
73

        return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
               (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + 1;
Chao Liu's avatar
Chao Liu committed
74
    }
Chao Liu's avatar
Chao Liu committed
75

Chao Liu's avatar
Chao Liu committed
76
    // this is ugly, only for 4d
Chao Liu's avatar
Chao Liu committed
77
78
    __host__ __device__ unsigned
    Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
Chao Liu's avatar
Chao Liu committed
79
    {
Chao Liu's avatar
Chao Liu committed
80
81
82
83
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
84

Chao Liu's avatar
Chao Liu committed
85
        static_assert(nDim == 4, "nDim is not 4");
Chao Liu's avatar
Chao Liu committed
86
        return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
Chao Liu's avatar
Chao Liu committed
87
    }
Chao Liu's avatar
Chao Liu committed
88

Chao Liu's avatar
Chao Liu committed
89
90
91
92
93
94
    __host__ __device__ constexpr auto Condense() const
    {
        constexpr auto default_strides = calculate_default_strides(Lengths{});
        return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
    }
};
Chao Liu's avatar
Chao Liu committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
{
    using Strides = decltype(calculate_default_strides(Lengths{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

// this is ugly, only for 4d
template <class TDesc>
Chao Liu's avatar
Chao Liu committed
111
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
Chao Liu's avatar
Chao Liu committed
112
113
114
{
    constexpr auto desc = TDesc{};

Chao Liu's avatar
Chao Liu committed
115
116
117
118
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
119
120
121

    static_assert(desc.GetDimension() == 4, "dim is not 4");

Chao Liu's avatar
Chao Liu committed
122
123
    printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
           s,
Chao Liu's avatar
Chao Liu committed
124
125
126
127
128
129
130
131
132
           desc.GetDimension(),
           desc.GetLength(I0),
           desc.GetLength(I1),
           desc.GetLength(I2),
           desc.GetLength(I3),
           desc.GetStride(I0),
           desc.GetStride(I1),
           desc.GetStride(I2),
           desc.GetStride(I3));
Chao Liu's avatar
Chao Liu committed
133
}