constant_tensor_descriptor.cuh 5.79 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "common.cuh"
Chao Liu's avatar
Chao Liu committed
3
4
5
6

template <class T, T N>
struct Constant
{
Chao Liu's avatar
Chao Liu committed
7
    static const T mValue = N;
Chao Liu's avatar
Chao Liu committed
8
9
};

Chao Liu's avatar
Chao Liu committed
10
11
template <unsigned N>
using Number = Constant<unsigned, N>;
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
19
20

template <unsigned... Is>
struct Sequence
{
    static constexpr unsigned nDim = sizeof...(Is);

    const unsigned mData[nDim] = {Is...};

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
21
    __host__ __device__ constexpr unsigned Get(Number<I>) const
Chao Liu's avatar
Chao Liu committed
22
23
24
    {
        return mData[I];
    }
25

Chao Liu's avatar
Chao Liu committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    template <unsigned I0, unsigned I1>
    __host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
    {
        constexpr unsigned IR0 = Get(Number<I0>{});
        constexpr unsigned IR1 = Get(Number<I1>{});

        return Sequence<IR0, IR1>{};
    }

    template <unsigned I0, unsigned I1, unsigned I2>
    __host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
    {
        constexpr unsigned IR0 = Get(Number<I0>{});
        constexpr unsigned IR1 = Get(Number<I1>{});
        constexpr unsigned IR2 = Get(Number<I2>{});

        return Sequence<IR0, IR1, IR2>{};
    }

    template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
    __host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
    {
        constexpr unsigned IR0 = Get(Number<I0>{});
        constexpr unsigned IR1 = Get(Number<I1>{});
        constexpr unsigned IR2 = Get(Number<I2>{});
        constexpr unsigned IR3 = Get(Number<I3>{});

        return Sequence<IR0, IR1, IR2, IR3>{};
    }

Chao Liu's avatar
Chao Liu committed
56
57
    template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
    __host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
Chao Liu's avatar
Chao Liu committed
58
59
60
61
62
63
    {
        constexpr unsigned IR0 = Get(Number<I0>{});
        constexpr unsigned IR1 = Get(Number<I1>{});
        constexpr unsigned IR2 = Get(Number<I2>{});
        constexpr unsigned IR3 = Get(Number<I3>{});

Chao Liu's avatar
Chao Liu committed
64
        return Sequence<IR0, IR1, IR2, IR3>{};
Chao Liu's avatar
Chao Liu committed
65
    }
Chao Liu's avatar
Chao Liu committed
66
67
68
69
70
71
};

template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
    static constexpr unsigned nDim = Lengths::nDim;
Chao Liu's avatar
Chao Liu committed
72
    using NDimConstant             = Number<nDim>;
Chao Liu's avatar
Chao Liu committed
73
74
75
76
77
78
79
80
81
82
83
84
85

    __host__ __device__ constexpr ConstantTensorDescriptor()
    {
        static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
    }

    __host__ __device__ constexpr unsigned GetDimension() const { return nDim; }

    __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }

    __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
86
    __host__ __device__ constexpr unsigned GetLength(Number<I>) const
Chao Liu's avatar
Chao Liu committed
87
    {
Chao Liu's avatar
Chao Liu committed
88
        return Lengths{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
89
90
91
    }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
92
    __host__ __device__ constexpr unsigned GetStride(Number<I>) const
Chao Liu's avatar
Chao Liu committed
93
    {
Chao Liu's avatar
Chao Liu committed
94
        return Strides{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
95
96
    }

Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
    // this is ugly, only for 4d
    __host__ __device__ constexpr unsigned GetElementSize() const
    {
        static_assert(nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
102
103
104
105
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
106
107
108
109
110
111

        return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
    }

    // this is ugly, only for 4d
    __host__ __device__ constexpr unsigned GetElementSpace() const
Chao Liu's avatar
Chao Liu committed
112
    {
Chao Liu's avatar
Chao Liu committed
113
114
        static_assert(nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
115
116
117
118
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
119
120
121

        return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
               (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + 1;
Chao Liu's avatar
Chao Liu committed
122
    }
Chao Liu's avatar
Chao Liu committed
123

Chao Liu's avatar
Chao Liu committed
124
    // this is ugly, only for 4d
Chao Liu's avatar
Chao Liu committed
125
126
    __host__ __device__ unsigned
    Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
Chao Liu's avatar
Chao Liu committed
127
    {
Chao Liu's avatar
Chao Liu committed
128
129
130
131
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
132

Chao Liu's avatar
Chao Liu committed
133
        static_assert(nDim == 4, "nDim is not 4");
Chao Liu's avatar
Chao Liu committed
134
        return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
Chao Liu's avatar
Chao Liu committed
135
    }
Chao Liu's avatar
Chao Liu committed
136
137
138
};

// this is ugly, only for 4d
Chao Liu's avatar
Chao Liu committed
139
140
141
142
143
144
145
146
147
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
    return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}

// this is ugly, only for 4d
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
Chao Liu's avatar
Chao Liu committed
148
{
Chao Liu's avatar
Chao Liu committed
149
150
151
    static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");

    return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
Chao Liu's avatar
Chao Liu committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
}

template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
{
    using Strides = decltype(calculate_default_strides(Lengths{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

// this is ugly, only for 4d
template <class TDesc>
Chao Liu's avatar
Chao Liu committed
169
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
Chao Liu's avatar
Chao Liu committed
170
171
172
{
    constexpr auto desc = TDesc{};

Chao Liu's avatar
Chao Liu committed
173
174
175
176
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
177
178
179

    static_assert(desc.GetDimension() == 4, "dim is not 4");

Chao Liu's avatar
Chao Liu committed
180
181
    printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
           s,
Chao Liu's avatar
Chao Liu committed
182
183
184
185
186
187
188
189
190
191
           desc.GetDimension(),
           desc.GetLength(I0),
           desc.GetLength(I1),
           desc.GetLength(I2),
           desc.GetLength(I3),
           desc.GetStride(I0),
           desc.GetStride(I1),
           desc.GetStride(I2),
           desc.GetStride(I3));
}