reduce_example_common.hpp 1.47 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/ck.hpp"

8
9
10
template <int Rank, int NumReduceDim>
static inline std::array<int, Rank - NumReduceDim>
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
11
12
13
14
15
16
17
18
19
{
    int reduceFlag = 0;

    // flag the bits for the reduceDims
    for(int i = 0; i < NumReduceDim; i++)
    {
        reduceFlag |= 1 << reduceDims[i];
    };

20
    std::array<int, Rank - NumReduceDim> invariantDims;
21
22

    // collect invariant dimensions
23
    int dim = 0;
24
25
26
    for(int i = 0; i < Rank; i++)
        if((reduceFlag & (1 << i)) == 0)
        {
27
28
            invariantDims[dim] = i;
            dim++;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        };

    return invariantDims;
};

template <ck::index_t Rank, ck::index_t NumReduceDim>
struct ReduceShape
{
    static constexpr ck::index_t Rank_         = Rank;
    static constexpr ck::index_t NumReduceDim_ = NumReduceDim;
};

using reduce_shape_instances = std::tuple<ReduceShape<3, 1>,
                                          ReduceShape<3, 2>,
                                          ReduceShape<4, 1>,
                                          ReduceShape<4, 2>,
                                          ReduceShape<4, 3>,
                                          ReduceShape<5, 1>,
                                          ReduceShape<5, 2>,
                                          ReduceShape<5, 3>,
                                          ReduceShape<5, 4>>;