ConstantTensorDescriptor.hpp 18.2 KB
Newer Older
1
2
3
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_HPP
#define CK_CONSTANT_TENSOR_DESCRIPTOR_HPP

Chao Liu's avatar
Chao Liu committed
4
#include "common.hpp"
Chao Liu's avatar
Chao Liu committed
5

6
7
namespace ck {

8
template <class Lengths>
Chao Liu's avatar
Chao Liu committed
9
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
10
{
Chao Liu's avatar
Chao Liu committed
11
    return reverse_inclusive_scan_sequence(
12
               Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
13
        .PushBack(Number<1>{});
14
15
}

16
template <class Lengths, index_t Align>
Chao Liu's avatar
Chao Liu committed
17
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
Chao Liu's avatar
Chao Liu committed
18
{
19
    constexpr index_t L_back_align =
20
        Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
Chao Liu's avatar
Chao Liu committed
21

Chao Liu's avatar
Chao Liu committed
22
    return calculate_tensor_strides_packed(
23
        Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
24
25
}

Chao Liu's avatar
Chao Liu committed
26
template <class Lengths, class Strides>
Chao Liu's avatar
Chao Liu committed
27
28
struct ConstantTensorDescriptor
{
Chao Liu's avatar
Chao Liu committed
29
30
    using Type = ConstantTensorDescriptor;

31
    static constexpr index_t nDim = Lengths::GetSize();
Chao Liu's avatar
Chao Liu committed
32
33
34

    __host__ __device__ constexpr ConstantTensorDescriptor()
    {
Chao Liu's avatar
Chao Liu committed
35
        static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
Chao Liu's avatar
Chao Liu committed
36
37
    }

38
39
40
41
42
43
44
45
    __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }

    template <index_t IDim>
    __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
    {
        return Sequence<IDim>{};
    }

Chao Liu's avatar
Chao Liu committed
46
    __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
Chao Liu's avatar
Chao Liu committed
47

48
    __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
Chao Liu's avatar
Chao Liu committed
49

50
51
    __host__ __device__ static constexpr auto GetStrides() { return Strides{}; }

Chao Liu's avatar
Chao Liu committed
52
    template <index_t I>
53
    __host__ __device__ static constexpr index_t GetLength(Number<I>)
Chao Liu's avatar
Chao Liu committed
54
    {
Chao Liu's avatar
Chao Liu committed
55
        return Lengths::Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
56
57
    }

Chao Liu's avatar
Chao Liu committed
58
    template <index_t I>
59
    __host__ __device__ static constexpr index_t GetStride(Number<I>)
Chao Liu's avatar
Chao Liu committed
60
    {
Chao Liu's avatar
Chao Liu committed
61
        return Strides::Get(Number<I>{});
Chao Liu's avatar
Chao Liu committed
62
63
    }

Chao Liu's avatar
Chao Liu committed
64
    struct lambda_AreDimensionsContinuous
65
    {
Chao Liu's avatar
Chao Liu committed
66
        bool& is_continuous;
67

Chao Liu's avatar
Chao Liu committed
68
69
70
71
        __host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
            : is_continuous(is_continuous_)
        {
        }
72

Chao Liu's avatar
Chao Liu committed
73
74
        template <index_t IDim_>
        __host__ __device__ constexpr void operator()(Number<IDim_>) const
Chao Liu's avatar
Chao Liu committed
75
        {
Chao Liu's avatar
Chao Liu committed
76
77
            constexpr auto IDim    = Number<IDim_>{};
            constexpr auto IDim_p1 = Number<IDim_ + 1>{};
Chao Liu's avatar
Chao Liu committed
78
79
80
81
82
83

            is_continuous =
                is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
                                  GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
        }
    };
84

Chao Liu's avatar
Chao Liu committed
85
86
87
88
89
90
91
92
93
94
95
96
    __host__ __device__ static constexpr bool AreDimensionsContinuous()
    {
        bool is_continuous = true;

        static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));

        return is_continuous;
    }

    __host__ __device__ static constexpr bool IsPackedTensor()
    {
        return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
97
98
    }

Chao Liu's avatar
Chao Liu committed
99
100
101
102
103
104
    template <class T>
    __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
    {
        return false;
    }

105
    __host__ __device__ static constexpr index_t GetElementSize()
Chao Liu's avatar
Chao Liu committed
106
    {
107
        return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{});
108
    }
109

Chao Liu's avatar
Chao Liu committed
110
    template <class Align = Number<1>>
111
    __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
Chao Liu's avatar
Chao Liu committed
112
    {
Chao Liu's avatar
Chao Liu committed
113
114
        // This is WRONG! align shouldbe applied to the last memory rank, not the last tensor
        // dimension
Chao Liu's avatar
Chao Liu committed
115
        constexpr index_t element_space_unaligned = accumulate_on_sequence(
116
            (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
117
118

        return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
Chao Liu's avatar
Chao Liu committed
119
    }
Chao Liu's avatar
Chao Liu committed
120

Chao Liu's avatar
Chao Liu committed
121
    // emulate constexpr lambda
122
    template <index_t NSize>
Chao Liu's avatar
Chao Liu committed
123
    struct lambda_GetOffsetFromMultiIndex
Chao Liu's avatar
Chao Liu committed
124
    {
Chao Liu's avatar
Chao Liu committed
125
126
        Array<index_t, NSize>& multi_id;
        index_t& offset;
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
129
130
131
        __host__
            __device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
                                                                index_t& offset_)
            : multi_id(multi_id_), offset(offset_)
Chao Liu's avatar
Chao Liu committed
132
133
134
        {
        }

Chao Liu's avatar
Chao Liu committed
135
136
        template <class X>
        __host__ __device__ constexpr void operator()(X IDim) const
Chao Liu's avatar
Chao Liu committed
137
        {
Chao Liu's avatar
Chao Liu committed
138
            offset += multi_id[IDim] * Type::GetStride(IDim);
Chao Liu's avatar
Chao Liu committed
139
140
141
142
143
144
145
146
147
148
149
        }
    };

    template <index_t NSize>
    __host__ __device__ static constexpr index_t
    GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
    {
        static_assert(NSize == nDim, "wrong! Dimension not consistent");

        index_t offset = 0;

Chao Liu's avatar
Chao Liu committed
150
        static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
Chao Liu's avatar
Chao Liu committed
151
152
153

        return offset;
    }
154

155
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
156
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
157
    {
158
        return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
159
160
    }

161
    template <index_t... Is>
162
    __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
163
164
165
    {
        static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");

Chao Liu's avatar
Chao Liu committed
166
167
        constexpr auto multi_id = Sequence<Is...>{};

168
        return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{});
169
170
    }

Chao Liu's avatar
Chao Liu committed
171
172
173
    // emulate constexpr lambda
    template <class PackedStrides>
    struct lambda_GetMultiIndexFrom1dIndex
Chao Liu's avatar
Chao Liu committed
174
    {
Chao Liu's avatar
Chao Liu committed
175
176
        index_t& id;
        Array<index_t, nDim>& multi_id;
Chao Liu's avatar
Chao Liu committed
177

Chao Liu's avatar
Chao Liu committed
178
179
180
181
        __host__
            __device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
                                                                 Array<index_t, nDim>& multi_id_)
            : id(id_), multi_id(multi_id_)
Chao Liu's avatar
Chao Liu committed
182
183
184
        {
        }

Chao Liu's avatar
Chao Liu committed
185
186
        template <class IDim_>
        __host__ __device__ constexpr void operator()(IDim_) const
Chao Liu's avatar
Chao Liu committed
187
        {
Chao Liu's avatar
Chao Liu committed
188
            constexpr auto IDim      = IDim_{};
Chao Liu's avatar
Chao Liu committed
189
190
191
            constexpr index_t stride = PackedStrides::Get(IDim);
            multi_id.Set(IDim, id / stride);
            id -= multi_id[IDim] * stride;
Chao Liu's avatar
Chao Liu committed
192
193
194
195
196
197
198
        }
    };

    __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
    {
        Array<index_t, nDim> multi_id;

Chao Liu's avatar
Chao Liu committed
199
        using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
Chao Liu's avatar
Chao Liu committed
200
201

        // calculate index in each of the dimensions in the order of their dimension
Chao Liu's avatar
Chao Liu committed
202
        static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
Chao Liu's avatar
Chao Liu committed
203

Chao Liu's avatar
Chao Liu committed
204
        multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{}));
Chao Liu's avatar
Chao Liu committed
205
206
207

        return multi_id;
    }
Chao Liu's avatar
Chao Liu committed
208

Chao Liu's avatar
Chao Liu committed
209
    __host__ __device__ static constexpr auto
210
211
212
213
214
    GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
    {
        return multi_id;
    }

Chao Liu's avatar
Chao Liu committed
215
216
217
218
    // This function doesn't do carry check on the highest dimension for positive stepping (or
    // borrow check on the lowest dimension for negative stepping) , for performance reason. It is
    // the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
    // highest dimension for positive stepping (or on the lowest dimension for negative stepping)
219
    template <bool PositiveDirection>
220
221
    __host__ __device__ static Array<index_t, nDim>
    UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
222
223
                                           index_t step_size_of_1d_index,
                                           integral_constant<bool, PositiveDirection>)
224
    {
225
226
227
228
229
230
231
232
233
234
235
        Array<index_t, nDim> new_multi_id;

        const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);

        static_if<PositiveDirection>{}([&](auto) {
            new_multi_id = old_multi_id + step_sizes;

            bool carry = false;

            // do carry check in reversed order, starting from lowest dimension
            // don't check the highest dimension
Chao Liu's avatar
Chao Liu committed
236
            static_for<0, nDim, 1>{}([&](auto IDimReverse) {
237
238
239
240
241
                constexpr index_t idim = nDim - 1 - IDimReverse.Get();
                constexpr auto IDim    = Number<idim>{};

                if(carry)
                {
Chao Liu's avatar
Chao Liu committed
242
                    ++new_multi_id(idim);
243
244
245
246
247
248
                }

                carry = false;

                if(new_multi_id[idim] >= GetLength(IDim))
                {
Chao Liu's avatar
Chao Liu committed
249
                    new_multi_id(idim) -= GetLength(IDim);
250
251
252
253
254
255
256
257
258
259
260
261
                    carry = true;
                }
            });
        }).Else([&](auto) {
            // shift up multi-id to avoid unsigned integer underflow during intermediate
            // calculations. After the shift, should have new_multi_id[...] >= 1
            new_multi_id = old_multi_id + (GetLengths() - step_sizes);

            bool borrow = false;

            // do borrow check in reversed order, starting from lowest dimension
            // don't check the highest dimension
Chao Liu's avatar
Chao Liu committed
262
            static_for<0, nDim, 1>{}([&](auto IDimReverse) {
263
264
265
266
267
                constexpr index_t idim = nDim - 1 - IDimReverse.Get();
                constexpr auto IDim    = Number<idim>{};

                if(borrow)
                {
Chao Liu's avatar
Chao Liu committed
268
                    --new_multi_id(idim);
269
270
271
272
273
274
                }

                borrow = false;

                if(new_multi_id[idim] < GetLength(IDim))
                {
Chao Liu's avatar
Chao Liu committed
275
                    new_multi_id(idim) += GetLength(IDim);
276
277
278
279
280
281
282
                    borrow = true;
                }
            });

            // shift back down multi-id
            // here, should have new_multi_id[...] >= GetLengths()
            new_multi_id = new_multi_id - GetLengths();
283
284
285
286
287
        });

        return new_multi_id;
    }

Chao Liu's avatar
Chao Liu committed
288
    template <index_t... IDims>
Chao Liu's avatar
Chao Liu committed
289
    __host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
Chao Liu's avatar
Chao Liu committed
290
    {
Chao Liu's avatar
Chao Liu committed
291
292
        static_assert(sizeof...(IDims) <= GetNumOfDimension(),
                      "wrong! too many number of dimensions to be extracted");
Chao Liu's avatar
Chao Liu committed
293

Chao Liu's avatar
Chao Liu committed
294
295
        using extract_lengths = decltype(Lengths::Extract(extract_dims...));
        using extract_strides = decltype(Strides::Extract(extract_dims...));
296

Chao Liu's avatar
Chao Liu committed
297
        return ConstantTensorDescriptor<extract_lengths, extract_strides>{};
Chao Liu's avatar
Chao Liu committed
298
299
    }

Chao Liu's avatar
Chao Liu committed
300
301
302
303
304
305
    template <index_t... IDims>
    __host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
    {
        return Extract(Number<IDims>{}...);
    }

306
    template <class... Ts>
307
    __host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor<Ts...>)
308
309
310
311
    {
        using leaf_tensor = ConstantTensorDescriptor<Ts...>;

        return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
Chao Liu's avatar
Chao Liu committed
312
                                        decltype(GetStrides().Append(leaf_tensor::GetStrides()))>{};
313
314
    }

Chao Liu's avatar
Chao Liu committed
315
316
317
    template <index_t IDim, index_t SliceLen>
    __host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
    {
318
319
        using slice_lengths = decltype(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}));

Chao Liu's avatar
Chao Liu committed
320
        return ConstantTensorDescriptor<slice_lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
321
322
    }

Chao Liu's avatar
Chao Liu committed
323
    template <index_t IDim, index_t... FoldIntervals>
Chao Liu's avatar
Chao Liu committed
324
    __host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
Chao Liu's avatar
Chao Liu committed
325
    {
Chao Liu's avatar
Chao Liu committed
326
327
        constexpr auto fold_intervals = Sequence<FoldIntervals...>{};

Chao Liu's avatar
Chao Liu committed
328
        constexpr index_t fold_intervals_product =
329
            accumulate_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
330
331
332
333
334
335

        constexpr auto unfold_length = GetLength(Number<IDim>{});
        constexpr auto unfold_stride = GetStride(Number<IDim>{});

        // length of the dimension to be folded needs to be dividable by fold_interval_product,
        // otherwise, folding is invalid
Chao Liu's avatar
Chao Liu committed
336
        static_assert(unfold_length % fold_intervals_product == 0,
Chao Liu's avatar
Chao Liu committed
337
338
339
340
                      "wrong! length on the dimension to be folded cannot be evenly divided!");

        // folded lengths
        constexpr auto fold_lengths =
Chao Liu's avatar
Chao Liu committed
341
            Sequence<unfold_length / fold_intervals_product>{}.Append(fold_intervals);
Chao Liu's avatar
Chao Liu committed
342
343

        // folded strides
Chao Liu's avatar
Chao Liu committed
344
345
        constexpr auto fold_strides =
            Number<unfold_stride>{} *
Chao Liu's avatar
Chao Liu committed
346
            reverse_inclusive_scan_sequence(
347
                fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
348

349
350
351
352
353
354
355
356
357
358
        // left and right
        constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
        constexpr auto right =
            typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::SeqType{};

        constexpr auto new_lengths =
            GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right));
        constexpr auto new_strides =
            GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right));

Chao Liu's avatar
Chao Liu committed
359
        return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
Chao Liu's avatar
Chao Liu committed
360
361
    }

Chao Liu's avatar
Chao Liu committed
362
    // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
Chao Liu's avatar
Chao Liu committed
363
364
365
    template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
    __host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
    {
Chao Liu's avatar
Chao Liu committed
366
367
368
369
        static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
                          FirstUnfoldDim <= LastUnfoldDim,
                      "wrong! should have FirstUnfoldDim <= LastUnfoldDim!");

Chao Liu's avatar
Chao Liu committed
370
        // left and right
371
372
373
374
375
376
        constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
        constexpr auto middle =
            typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::SeqType{};
        constexpr auto right =
            typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};

Chao Liu's avatar
Chao Liu committed
377
378
379
        // dimensions to be unfolded need to be continuous
        static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");

Chao Liu's avatar
Chao Liu committed
380
        // unfolded length, stride
Chao Liu's avatar
Chao Liu committed
381
        constexpr index_t unfold_length = accumulate_on_sequence(
382
            GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
Chao Liu's avatar
Chao Liu committed
383
384
385

        constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});

Chao Liu's avatar
Chao Liu committed
386
        // new lengths, strides
387
388
389
390
391
392
393
394
395
396
        constexpr auto new_lengths = GetLengths()
                                         .Extract(left)
                                         .PushBack(Number<unfold_length>{})
                                         .Append(GetLengths().Extract(right));

        constexpr auto new_strides = GetStrides()
                                         .Extract(left)
                                         .PushBack(Number<unfold_stride>{})
                                         .Append(GetStrides().Extract(right));

Chao Liu's avatar
Chao Liu committed
397
        return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
398
399
400
401
402
    }

    template <class MapNew2Old>
    __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
    {
Chao Liu's avatar
Chao Liu committed
403
404
        return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
                                        decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
Chao Liu's avatar
Chao Liu committed
405
406
    }

407
408
409
#if 0 // require sequence_sort, which is not implemented yet
    template <class MapOld2New>
    __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
Chao Liu's avatar
Chao Liu committed
410
    {
Chao Liu's avatar
Chao Liu committed
411
412
        return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
                                        decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{}
Chao Liu's avatar
Chao Liu committed
413
    }
414
#endif
Chao Liu's avatar
Chao Liu committed
415
};
Chao Liu's avatar
Chao Liu committed
416
417

template <class Lengths>
Chao Liu's avatar
Chao Liu committed
418
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
Chao Liu's avatar
Chao Liu committed
419
{
Chao Liu's avatar
Chao Liu committed
420
421
    using Strides = decltype(calculate_tensor_strides_packed(Lengths{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
422
423
424
}

template <class Lengths, class Strides>
Chao Liu's avatar
Chao Liu committed
425
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
Chao Liu's avatar
Chao Liu committed
426
{
Chao Liu's avatar
Chao Liu committed
427
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
428
429
}

Chao Liu's avatar
Chao Liu committed
430
template <class Lengths, index_t Align>
Chao Liu's avatar
Chao Liu committed
431
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
Chao Liu's avatar
Chao Liu committed
432
{
Chao Liu's avatar
Chao Liu committed
433
434
    using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{}));
    return ConstantTensorDescriptor<Lengths, Strides>{};
Chao Liu's avatar
Chao Liu committed
435
436
}

Chao Liu's avatar
Chao Liu committed
437
438
439
440
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_ConstantTensorDescriptor(const char* s,
                               ConstantTensorDescriptor<Sequence<Lengths...>, Sequence<Strides...>>)
Chao Liu's avatar
Chao Liu committed
441
{
Chao Liu's avatar
Chao Liu committed
442
    constexpr index_t ndim = sizeof...(Lengths);
443

Chao Liu's avatar
Chao Liu committed
444
    static_assert(ndim > 0 && ndim <= 10, "wrong!");
445

Chao Liu's avatar
Chao Liu committed
446
447
    static_if<ndim == 1>{}([&](auto) {
        printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
448
    });
Chao Liu's avatar
Chao Liu committed
449

Chao Liu's avatar
Chao Liu committed
450
451
    static_if<ndim == 2>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
Chao Liu's avatar
Chao Liu committed
452
453
    });

Chao Liu's avatar
Chao Liu committed
454
455
456
    static_if<ndim == 3>{}([&](auto) {
        printf(
            "%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
Chao Liu's avatar
Chao Liu committed
457
458
    });

Chao Liu's avatar
Chao Liu committed
459
460
    static_if<ndim == 4>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
Chao Liu's avatar
Chao Liu committed
461
               s,
Chao Liu's avatar
Chao Liu committed
462
463
464
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
465
466
    });

Chao Liu's avatar
Chao Liu committed
467
468
    static_if<ndim == 5>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
469
               s,
Chao Liu's avatar
Chao Liu committed
470
471
472
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
473
474
    });

Chao Liu's avatar
Chao Liu committed
475
476
    static_if<ndim == 6>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
477
               s,
Chao Liu's avatar
Chao Liu committed
478
479
480
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
481
482
    });

Chao Liu's avatar
Chao Liu committed
483
484
    static_if<ndim == 7>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
485
               s,
Chao Liu's avatar
Chao Liu committed
486
487
488
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
489
490
    });

Chao Liu's avatar
Chao Liu committed
491
492
    static_if<ndim == 8>{}([&](auto) {
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
493
               s,
Chao Liu's avatar
Chao Liu committed
494
495
496
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
497
498
    });

Chao Liu's avatar
Chao Liu committed
499
    static_if<ndim == 9>{}([&](auto) {
Chao Liu's avatar
tidy yp  
Chao Liu committed
500
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
Chao Liu's avatar
Chao Liu committed
501
               "%u}\n",
Chao Liu's avatar
Chao Liu committed
502
               s,
Chao Liu's avatar
Chao Liu committed
503
504
505
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
506
507
    });

Chao Liu's avatar
Chao Liu committed
508
    static_if<ndim == 10>{}([&](auto) {
Chao Liu's avatar
tidy yp  
Chao Liu committed
509
        printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
Chao Liu's avatar
Chao Liu committed
510
               "%u %u %u}\n",
Chao Liu's avatar
Chao Liu committed
511
               s,
Chao Liu's avatar
Chao Liu committed
512
513
514
               ndim,
               Lengths...,
               Strides...);
Chao Liu's avatar
Chao Liu committed
515
    });
Chao Liu's avatar
Chao Liu committed
516
}
517
518
519

} // namespace ck
#endif