ConstantTensorDescriptor.hip.hpp 19.6 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "common.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3

4
template <class Lengths>
Chao Liu's avatar
Chao Liu committed
5
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
6
{
Chao Liu's avatar
Chao Liu committed
7
8
    return reverse_inclusive_scan_sequence(
               Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}, Number<1>{})
9
        .PushBack(Number<1>{});
10
11
}

12
template <class Lengths, index_t Align>
Chao Liu's avatar
Chao Liu committed
13
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
Chao Liu's avatar
Chao Liu committed
14
{
15
16
    constexpr index_t L_back_align =
        Align * mod_conv::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
Chao Liu's avatar
Chao Liu committed
17

Chao Liu's avatar
Chao Liu committed
18
    return calculate_tensor_strides_packed(
19
        Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
20
21
}

Chao Liu's avatar
Chao Liu committed
22
template <class Lengths, class Strides>
Chao Liu's avatar
Chao Liu committed
23
24
struct ConstantTensorDescriptor
{
Chao Liu's avatar
Chao Liu committed
25
26
    using Type = ConstantTensorDescriptor;

27
    static constexpr index_t nDim = Lengths::GetSize();
Chao Liu's avatar
Chao Liu committed
28
29
30

    __host__ __device__ constexpr ConstantTensorDescriptor()
    {
Chao Liu's avatar
Chao Liu committed
31
        static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
Chao Liu's avatar
Chao Liu committed
32
33
    }

34
35
36
37
38
39
40
41
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }

    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
    {
        return Sequence<IDim>{};
    }

Chao Liu's avatar
Chao Liu committed
42
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
Chao Liu's avatar
Chao Liu committed
43

44
    __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
Chao Liu's avatar
Chao Liu committed
45

46
47
    __host__ __device__ static constexpr auto GetStrides() { return Strides{}; }

Chao Liu's avatar
Chao Liu committed
48
    template <index_t I>
49
    __host__ __device__ static constexpr index_t GetLength(Number<I>)
Chao Liu's avatar
Chao Liu committed
50
    {
Chao Liu's avatar
Chao Liu committed
51
        return Lengths{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
52
53
    }

Chao Liu's avatar
Chao Liu committed
54
    template <index_t I>
55
    __host__ __device__ static constexpr index_t GetStride(Number<I>)
Chao Liu's avatar
Chao Liu committed
56
    {
Chao Liu's avatar
Chao Liu committed
57
        return Strides{}.Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
58
59
    }

60
61
62
63
64
65
66
67
68
69
70
71
72
    __host__ __device__ static constexpr bool AreStridesNonAscending()
    {
        bool flag = true;

        static_for<0, nDim - 1, 1>{}([&](auto IDim) {
            constexpr auto IDim_p1 = Number<IDim.Get() + 1>{};

            flag = flag && (GetLength(IDim) >= GetLength(IDim_p1));
        });

        return flag;
    }

Chao Liu's avatar
Chao Liu committed
73
74
75
76
77
78
    template <class T>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
    {
        return false;
    }

79
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
80
    {
81
        return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
82
    }
83

Chao Liu's avatar
Chao Liu committed
84
    template <class Align = Number<1>>
85
    __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
Chao Liu's avatar
Chao Liu committed
86
    {
Chao Liu's avatar
Chao Liu committed
87
88
        // This is WRONG! align shouldbe applied to the last memory rank, not the last tensor
        // dimension
Chao Liu's avatar
Chao Liu committed
89
        constexpr index_t element_space_unaligned = accumulate_on_sequence(
90
            (GetLengths() - Number<1>{}) * GetStrides(), mod_conv::plus<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
91
92

        return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
Chao Liu's avatar
Chao Liu committed
93
    }
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
#if 0
96
    template <index_t NSize>
Chao Liu's avatar
Chao Liu committed
97
98
    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
Chao Liu's avatar
Chao Liu committed
99
    {
100
        static_assert(NSize == nDim, "wrong! Dimension not consistent");
Chao Liu's avatar
Chao Liu committed
101

102
        index_t offset = 0;
Chao Liu's avatar
Chao Liu committed
103

104
        static_for<0, nDim, 1>{}([&](auto IDim) {
Chao Liu's avatar
Chao Liu committed
105
            constexpr index_t idim = IDim.Get();
106
            offset += multi_id[idim] * GetStride(IDim);
107
        });
Chao Liu's avatar
Chao Liu committed
108

109
        return offset;
110
    }
Chao Liu's avatar
Chao Liu committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#else
    template <index_t NSize>
    struct GetOffsetFromMultiIndex_impl
    {
        Array<index_t, NSize>& multi_id_ref;
        index_t& offset_ref;

        __host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
                                                                   index_t& offset)
            : multi_id_ref(multi_id), offset_ref(offset)
        {
        }

        template <index_t IDim>
        __host__ __device__ constexpr bool operator()(Number<IDim>) const
        {
            offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
            return true;
        }
    };

    template <index_t NSize>
    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
    {
        static_assert(NSize == nDim, "wrong! Dimension not consistent");

        index_t offset = 0;

        static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));

        return offset;
    }
#endif
145

146
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
147
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
148
    {
149
        return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
150
151
    }

152
    template <index_t... Is>
153
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
154
155
156
    {
        static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");

Chao Liu's avatar
Chao Liu committed
157
158
        constexpr auto multi_id = Sequence<Is...>{};

159
160
        return accumulate_on_sequence(
            multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
161
162
    }

Chao Liu's avatar
Chao Liu committed
163
164
#if 0
    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
165
    {
166
167
        Array<index_t, nDim> multi_id;

Chao Liu's avatar
Chao Liu committed
168
        constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
169

Chao Liu's avatar
Chao Liu committed
170
        // calculate index in each of the dimensions in the order of their dimension
171
172
173
174
175
176
177
178
179
180
        static_for<0, nDim - 1, 1>{}([&](auto IDim) {
            constexpr index_t idim   = IDim.Get();
            constexpr index_t stride = dummy_strides.Get(Number<idim>{});
            multi_id[idim]           = id / stride;
            id -= multi_id[idim] * stride;
        });

        multi_id[nDim - 1] = id / dummy_strides.Get(Number<nDim - 1>{});

        return multi_id;
181
    }
Chao Liu's avatar
Chao Liu committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#else
    struct GetMultiIndexFrom1dIndex_impl
    {
        using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));

        index_t& id_ref;
        Array<index_t, nDim>& multi_id_ref;

        __host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
                                                                    Array<index_t, nDim>& multi_id)
            : id_ref(id), multi_id_ref(multi_id)
        {
        }

        template <index_t IDim>
        __host__ __device__ constexpr bool operator()(Number<IDim>) const
        {
            constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
            multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
            id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;

            return true;
        }
    };

    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
    {
        Array<index_t, nDim> multi_id;

        constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());

        // calculate index in each of the dimensions in the order of their dimension
        static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));

        index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});

        multi_id.Set(Number<nDim - 1>{}, itmp);

        return multi_id;
    }
#endif

#if 0
    // return type is Sequence<...>
    template<index_t Id>
    __host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
    {
        return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
    }
#endif
Chao Liu's avatar
Chao Liu committed
232

Chao Liu's avatar
Chao Liu committed
233
    __host__ __device__ static constexpr auto
234
235
236
237
238
239
240
241
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        return multi_id;
    }

    // This function doesn't do carry check on the highest dimension, for performance reason.
    // It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
    // on the highest dimension
242
    template <bool PositiveDirection>
243
244
    __host__ __device__ static Array<index_t, nDim>
    UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
245
246
                                           index_t step_size_of_1d_index,
                                           integral_constant<bool, PositiveDirection>)
247
    {
248
249
250
251
252
253
254
255
256
257
258
        Array<index_t, nDim> new_multi_id;

        const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);

        static_if<PositiveDirection>{}([&](auto) {
            new_multi_id = old_multi_id + step_sizes;

            bool carry = false;

            // do carry check in reversed order, starting from lowest dimension
            // don't check the highest dimension
Chao Liu's avatar
Chao Liu committed
259
            static_for<0, nDim, 1>{}([&](auto IDimReverse) {
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                constexpr index_t idim = nDim - 1 - IDimReverse.Get();
                constexpr auto IDim    = Number<idim>{};

                if(carry)
                {
                    ++new_multi_id[idim];
                }

                carry = false;

                if(new_multi_id[idim] >= GetLength(IDim))
                {
                    new_multi_id[idim] -= GetLength(IDim);
                    carry = true;
                }
            });
        }).Else([&](auto) {
            // shift up multi-id to avoid unsigned integer underflow during intermediate
            // calculations. After the shift, should have new_multi_id[...] >= 1
            new_multi_id = old_multi_id + (GetLengths() - step_sizes);

            bool borrow = false;

            // do borrow check in reversed order, starting from lowest dimension
            // don't check the highest dimension
Chao Liu's avatar
Chao Liu committed
285
            static_for<0, nDim, 1>{}([&](auto IDimReverse) {
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                constexpr index_t idim = nDim - 1 - IDimReverse.Get();
                constexpr auto IDim    = Number<idim>{};

                if(borrow)
                {
                    --new_multi_id[idim];
                }

                borrow = false;

                if(new_multi_id[idim] < GetLength(IDim))
                {
                    new_multi_id[idim] += GetLength(IDim);
                    borrow = true;
                }
            });

            // shift back down multi-id
            // here, should have new_multi_id[...] >= GetLengths()
            new_multi_id = new_multi_id - GetLengths();
306
307
308
309
310
        });

        return new_multi_id;
    }

Chao Liu's avatar
Chao Liu committed
311
    template <index_t... IDims>
Chao Liu's avatar
Chao Liu committed
312
    __host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
Chao Liu's avatar
Chao Liu committed
313
    {
Chao Liu's avatar
Chao Liu committed
314
315
        static_assert(sizeof...(IDims) <= GetNumOfDimension(),
                      "wrong! too many number of dimensions to be extracted");
Chao Liu's avatar
Chao Liu committed
316

Chao Liu's avatar
Chao Liu committed
317
318
        using extract_lengths = decltype(Lengths::Extract(extract_dims...));
        using extract_strides = decltype(Strides::Extract(extract_dims...));
319

Chao Liu's avatar
Chao Liu committed
320
        return ConstantTensorDescriptor<extract_lengths, extract_strides>{};
Chao Liu's avatar
Chao Liu committed
321
322
    }

Chao Liu's avatar
Chao Liu committed
323
324
325
326
327
328
    template <index_t... IDims>
    __host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
    {
        return Extract(Number<IDims>{}...);
    }

329
    template <class... Ts>
330
    __host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor<Ts...>)
331
332
333
334
    {
        using leaf_tensor = ConstantTensorDescriptor<Ts...>;

        return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
Chao Liu's avatar
Chao Liu committed
335
                                        decltype(GetStrides().Append(leaf_tensor::GetStrides()))>{};
336
337
    }

Chao Liu's avatar
Chao Liu committed
338
339
340
    template <index_t IDim, index_t SliceLen>
    __host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
    {
341
342
        using slice_lengths = decltype(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}));

Chao Liu's avatar
Chao Liu committed
343
        return ConstantTensorDescriptor<slice_lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
344
345
    }

Chao Liu's avatar
Chao Liu committed
346
    template <index_t IDim, index_t... FoldIntervals>
Chao Liu's avatar
Chao Liu committed
347
    __host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
Chao Liu's avatar
Chao Liu committed
348
    {
Chao Liu's avatar
Chao Liu committed
349
350
        constexpr auto fold_intervals = Sequence<FoldIntervals...>{};

Chao Liu's avatar
Chao Liu committed
351
        constexpr index_t fold_intervals_product =
352
            accumulate_on_sequence(fold_intervals, mod_conv::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
353
354
355
356
357
358

        constexpr auto unfold_length = GetLength(Number<IDim>{});
        constexpr auto unfold_stride = GetStride(Number<IDim>{});

        // length of the dimension to be folded needs to be dividable by fold_interval_product,
        // otherwise, folding is invalid
Chao Liu's avatar
Chao Liu committed
359
        static_assert(unfold_length % fold_intervals_product == 0,
Chao Liu's avatar
Chao Liu committed
360
361
362
363
                      "wrong! length on the dimension to be folded cannot be evenly divided!");

        // folded lengths
        constexpr auto fold_lengths =
Chao Liu's avatar
Chao Liu committed
364
            Sequence<unfold_length / fold_intervals_product>{}.Append(fold_intervals);
Chao Liu's avatar
Chao Liu committed
365
366

        // folded strides
Chao Liu's avatar
Chao Liu committed
367
368
        constexpr auto fold_strides =
            Number<unfold_stride>{} *
Chao Liu's avatar
Chao Liu committed
369
370
            reverse_inclusive_scan_sequence(
                fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
371

372
373
374
375
376
377
378
379
380
381
        // left and right
        constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
        constexpr auto right =
            typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::SeqType{};

        constexpr auto new_lengths =
            GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right));
        constexpr auto new_strides =
            GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right));

Chao Liu's avatar
Chao Liu committed
382
        return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
Chao Liu's avatar
Chao Liu committed
383
384
    }

385
386
387
388
389
390
391
392
393
    template <index_t Threashold, index_t Delta>
    struct f_unfold_impl
    {
        __host__ __device__ constexpr index_t operator()(index_t x) const
        {
            return x > Threashold ? x - Delta : x;
        }
    };

Chao Liu's avatar
Chao Liu committed
394
395
396
    template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
    __host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
    {
Chao Liu's avatar
Chao Liu committed
397
398
399
400
        static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
                          FirstUnfoldDim <= LastUnfoldDim,
                      "wrong! should have FirstUnfoldDim <= LastUnfoldDim!");

Chao Liu's avatar
Chao Liu committed
401
#if 0 // cannot compile: compiler complain about constexpr
Chao Liu's avatar
Chao Liu committed
402
403
        // dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be
        // packed in memory, otherwise, unfolding is invalid
Chao Liu's avatar
Chao Liu committed
404
405
        static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim_) {
            constexpr auto IDim    = decltype(IDim_){};
406
407
408
            constexpr auto IDim_p1 = IDim + Number<1>{};

            // check stride
Chao Liu's avatar
Chao Liu committed
409
            static_assert(
410
                GetStride(IDim) >= GetStride(IDim_p1),
Chao Liu's avatar
Chao Liu committed
411
412
                "wrong! dimensions to be unfolded need to be in descending order w.r.t strides");

413
414
            // check if packed
            static_assert(GetStride(IDim_p1) * GetLength(IDim_p1) == GetStride(IDim),
Chao Liu's avatar
Chao Liu committed
415
416
                          "wrong! dimensions to be unfolded need to be packed");
        });
Chao Liu's avatar
Chao Liu committed
417
#endif
Chao Liu's avatar
Chao Liu committed
418

Chao Liu's avatar
Chao Liu committed
419
        // left and right
420
421
422
423
424
425
        constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
        constexpr auto middle =
            typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::SeqType{};
        constexpr auto right =
            typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};

Chao Liu's avatar
Chao Liu committed
426
        // unfolded length, stride
Chao Liu's avatar
Chao Liu committed
427
        constexpr index_t unfold_length = accumulate_on_sequence(
428
            GetLengths().Extract(middle), mod_conv::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
429
430
431

        constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});

Chao Liu's avatar
Chao Liu committed
432
        // new lengths, strides
433
434
435
436
437
438
439
440
441
442
        constexpr auto new_lengths = GetLengths()
                                         .Extract(left)
                                         .PushBack(Number<unfold_length>{})
                                         .Append(GetLengths().Extract(right));

        constexpr auto new_strides = GetStrides()
                                         .Extract(left)
                                         .PushBack(Number<unfold_stride>{})
                                         .Append(GetStrides().Extract(right));

Chao Liu's avatar
Chao Liu committed
443
        return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
444
445
446
447
448
449
    }

    template <class MapNew2Old>
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
    {
        return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenNew2Old(MapNew2Old{})),
Chao Liu's avatar
Chao Liu committed
450
                                        decltype(Strides{}.ReorderGivenNew2Old(MapNew2Old{}))>{};
Chao Liu's avatar
Chao Liu committed
451
452
    }

453
454
455
#if 0 // require sequence_sort, which is not implemented yet
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
Chao Liu's avatar
Chao Liu committed
456
    {
457
        return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenOld2New(MapOld2New{})),
Chao Liu's avatar
Chao Liu committed
458
                                        decltype(Strides{}.ReorderGivenOld2New(MapOld2New{}))>{}
Chao Liu's avatar
Chao Liu committed
459
    }
460
#endif
Chao Liu's avatar
Chao Liu committed
461
};
Chao Liu's avatar
Chao Liu committed
462
463

template <class Lengths>
Chao Liu's avatar
Chao Liu committed
464
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
Chao Liu's avatar
Chao Liu committed
465
{
Chao Liu's avatar
Chao Liu committed
466
467
    using Strides = decltype(calculate_tensor_strides_packed(Lengths{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
468
469
470
}

template <class Lengths, class Strides>
Chao Liu's avatar
Chao Liu committed
471
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
Chao Liu's avatar
Chao Liu committed
472
{
Chao Liu's avatar
Chao Liu committed
473
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
474
475
}

Chao Liu's avatar
Chao Liu committed
476
template <class Lengths, index_t Align>
Chao Liu's avatar
Chao Liu committed
477
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
Chao Liu's avatar
Chao Liu committed
478
{
Chao Liu's avatar
Chao Liu committed
479
480
    using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
481
482
}

Chao Liu's avatar
Chao Liu committed
483
484
485
486
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_ConstantTensorDescriptor(const char* s,
                               ConstantTensorDescriptor<Sequence<Lengths...>, Sequence<Strides...>>)
Chao Liu's avatar
Chao Liu committed
487
{
Chao Liu's avatar
Chao Liu committed
488
    constexpr index_t ndim = sizeof...(Lengths);
489

Chao Liu's avatar
Chao Liu committed
490
    static_assert(ndim > 0 && ndim <= 10, "wrong!");
491

Chao Liu's avatar
Chao Liu committed
492
493
    static_if<ndim == 1>{}([&](auto) {
        printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
494
    });
Chao Liu's avatar
Chao Liu committed
495

Chao Liu's avatar
Chao Liu committed
496
497
    static_if<ndim == 2>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
Chao Liu's avatar
Chao Liu committed
498
499
    });

Chao Liu's avatar
Chao Liu committed
500
501
502
    static_if<ndim == 3>{}([&](auto) {
        printf(
            "%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
Chao Liu's avatar
Chao Liu committed
503
504
    });

Chao Liu's avatar
Chao Liu committed
505
506
    static_if<ndim == 4>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
Chao Liu's avatar
Chao Liu committed
507
               s,
Chao Liu's avatar
Chao Liu committed
508
509
510
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
511
512
    });

Chao Liu's avatar
Chao Liu committed
513
514
    static_if<ndim == 5>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
515
               s,
Chao Liu's avatar
Chao Liu committed
516
517
518
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
519
520
    });

Chao Liu's avatar
Chao Liu committed
521
522
    static_if<ndim == 6>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
523
               s,
Chao Liu's avatar
Chao Liu committed
524
525
526
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
527
528
    });

Chao Liu's avatar
Chao Liu committed
529
530
    static_if<ndim == 7>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
531
               s,
Chao Liu's avatar
Chao Liu committed
532
533
534
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
535
536
    });

Chao Liu's avatar
Chao Liu committed
537
538
    static_if<ndim == 8>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
539
               s,
Chao Liu's avatar
Chao Liu committed
540
541
542
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
543
544
    });

Chao Liu's avatar
Chao Liu committed
545
    static_if<ndim == 9>{}([&](auto) {
Chao Liu's avatar
tidy yp  
Chao Liu committed
546
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
Chao Liu's avatar
Chao Liu committed
547
               "%u}\n",
Chao Liu's avatar
Chao Liu committed
548
               s,
Chao Liu's avatar
Chao Liu committed
549
550
551
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
552
553
    });

Chao Liu's avatar
Chao Liu committed
554
    static_if<ndim == 10>{}([&](auto) {
Chao Liu's avatar
tidy yp  
Chao Liu committed
555
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
Chao Liu's avatar
Chao Liu committed
556
               "%u %u %u}\n",
Chao Liu's avatar
Chao Liu committed
557
               s,
Chao Liu's avatar
Chao Liu committed
558
559
560
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
561
    });
Chao Liu's avatar
Chao Liu committed
562
}