ConstantTensorDescriptor.cuh 7.47 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "common.cuh"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
10
// this is ugly, only for 2d
template <unsigned L0, unsigned L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{
    return Sequence<L1, 1>{};
}

Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
17
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
    return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}

Chao Liu's avatar
Chao Liu committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
                                                                     Number<Align>)
{
    constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
    return Sequence<L1_align, 1>{};
}

// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
                                                                     Number<Align>)
{
    constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
    return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
}

Chao Liu's avatar
Chao Liu committed
36
37
38
39
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
    static constexpr unsigned nDim = Lengths::nDim;
Chao Liu's avatar
Chao Liu committed
40
    using NDimConstant             = Number<nDim>;
Chao Liu's avatar
Chao Liu committed
41
42
43
44
45
46
47
48
49
50
51
52
53

    __host__ __device__ constexpr ConstantTensorDescriptor()
    {
        static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
    }

    __host__ __device__ constexpr unsigned GetDimension() const { return nDim; }

    __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }

    __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
54
    __host__ __device__ constexpr unsigned GetLength(Number<I>) const
Chao Liu's avatar
Chao Liu committed
55
    {
Chao Liu's avatar
Chao Liu committed
56
        return Lengths{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
57
58
59
    }

    template <unsigned I>
Chao Liu's avatar
Chao Liu committed
60
    __host__ __device__ constexpr unsigned GetStride(Number<I>) const
Chao Liu's avatar
Chao Liu committed
61
    {
Chao Liu's avatar
Chao Liu committed
62
        return Strides{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
63
64
    }

Chao Liu's avatar
Chao Liu committed
65
66
    __host__ __device__ constexpr unsigned GetElementSize() const
    {
Chao Liu's avatar
Chao Liu committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        static_assert(nDim >= 2 && nDim <= 4, "nDim");

        if(nDim == 2)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};

            return GetLength(I0) * GetLength(I1);
        }
        else if(nDim == 3)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};
            constexpr auto I2 = Number<2>{};

            return GetLength(I0) * GetLength(I1) * GetLength(I2);
        }
        else if(nDim == 4)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};
            constexpr auto I2 = Number<2>{};
            constexpr auto I3 = Number<3>{};

            return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
        }
    }

Chao Liu's avatar
Chao Liu committed
95
96
    template <class Align = Number<1>>
    __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
Chao Liu's avatar
Chao Liu committed
97
98
99
    {
        static_assert(nDim >= 2 && nDim <= 4, "nDim");

Chao Liu's avatar
Chao Liu committed
100
101
        constexpr unsigned align_size = align.Get();

Chao Liu's avatar
Chao Liu committed
102
103
104
105
106
        if(nDim == 2)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};

Chao Liu's avatar
Chao Liu committed
107
108
            return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
                   align_size;
Chao Liu's avatar
Chao Liu committed
109
110
111
112
113
114
115
116
        }
        else if(nDim == 3)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};
            constexpr auto I2 = Number<2>{};

            return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
Chao Liu's avatar
Chao Liu committed
117
                   (GetLength(I2) - 1) * GetStride(I2) + align_size;
Chao Liu's avatar
Chao Liu committed
118
119
120
121
122
123
124
125
126
        }
        else if(nDim == 4)
        {
            constexpr auto I0 = Number<0>{};
            constexpr auto I1 = Number<1>{};
            constexpr auto I2 = Number<2>{};
            constexpr auto I3 = Number<3>{};

            return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
Chao Liu's avatar
Chao Liu committed
127
128
                   (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
                   align_size;
Chao Liu's avatar
Chao Liu committed
129
130
        }
    }
Chao Liu's avatar
Chao Liu committed
131

Chao Liu's avatar
Chao Liu committed
132
133
134
    // this is ugly, only for 2d
    __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const
    {
Chao Liu's avatar
Chao Liu committed
135
136
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
137

Chao Liu's avatar
Chao Liu committed
138
139
        static_assert(nDim == 2, "nDim is not 2");
        return i0 * GetStride(I0) + i1 * GetStride(I1);
Chao Liu's avatar
Chao Liu committed
140
141
    }

Chao Liu's avatar
Chao Liu committed
142
143
    // this is ugly, only for 3d
    __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const
Chao Liu's avatar
Chao Liu committed
144
    {
Chao Liu's avatar
Chao Liu committed
145
146
147
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
Chao Liu's avatar
Chao Liu committed
148

Chao Liu's avatar
Chao Liu committed
149
150
        static_assert(nDim == 3, "nDim is not 3");
        return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2);
Chao Liu's avatar
Chao Liu committed
151
    }
Chao Liu's avatar
Chao Liu committed
152

Chao Liu's avatar
Chao Liu committed
153
    // this is ugly, only for 4d
Chao Liu's avatar
Chao Liu committed
154
155
    __host__ __device__ unsigned
    Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
Chao Liu's avatar
Chao Liu committed
156
    {
Chao Liu's avatar
Chao Liu committed
157
158
159
160
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
161

Chao Liu's avatar
Chao Liu committed
162
        static_assert(nDim == 4, "nDim is not 4");
Chao Liu's avatar
Chao Liu committed
163
        return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
Chao Liu's avatar
Chao Liu committed
164
    }
Chao Liu's avatar
Chao Liu committed
165

Chao Liu's avatar
Chao Liu committed
166
167
168
169
170
171
    __host__ __device__ constexpr auto Condense() const
    {
        constexpr auto default_strides = calculate_default_strides(Lengths{});
        return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
    }
};
Chao Liu's avatar
Chao Liu committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185

template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
{
    using Strides = decltype(calculate_default_strides(Lengths{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
191
192
template <class Lengths, unsigned Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
    using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
}

Chao Liu's avatar
Chao Liu committed
193
template <class TDesc>
Chao Liu's avatar
Chao Liu committed
194
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
Chao Liu's avatar
Chao Liu committed
195
{
Chao Liu's avatar
Chao Liu committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    constexpr auto desc     = TDesc{};
    constexpr unsigned ndim = desc.GetDimension();

    static_assert(ndim >= 2 && ndim <= 4, "wrong!");

    if(ndim == 2)
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};

        printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
               s,
               desc.GetDimension(),
               desc.GetLength(I0),
               desc.GetLength(I1),
               desc.GetStride(I0),
               desc.GetStride(I1));
    }
    else if(ndim == 4)
    {
        constexpr auto I0 = Number<0>{};
        constexpr auto I1 = Number<1>{};
        constexpr auto I2 = Number<2>{};
        constexpr auto I3 = Number<3>{};

        printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
               s,
               desc.GetDimension(),
               desc.GetLength(I0),
               desc.GetLength(I1),
               desc.GetLength(I2),
               desc.GetLength(I3),
               desc.GetStride(I0),
               desc.GetStride(I1),
               desc.GetStride(I2),
               desc.GetStride(I3));
    }
}