reduction_operator.hpp 4.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*******************************************************************************
 *
 * MIT License
 *
 * Copyright (c) 2020 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/
#ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP

29
#include "common_header.hpp"
30
31
32
33
34
35
36
37

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
38
39
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique
40
//                    element in the algebraic space that doesn't affect the value of other elements
41
42
43
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// 2) indexable -- boolean value indicating whether indices of the operated elements could be
// recorded. Usually, Min/Max operator could
//                 need to record the indices of elements. For operator like Add/Mul, no need to
//                 record the indices.
// 3) operator() -- the first argument of the operator must be both an input & output, and the
// corresponding variable usually stores
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

template <class T>
struct Add
{
    using dataType = T;

63
    __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
64

65
    __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
66
67
68
69
70
71
72
};

template <class T>
struct Mul
{
    using dataType = T;

73
    __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
74

75
    __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
76
77
78
79
80
81
82
};

template <class T>
struct Max
{
    using dataType = T;

83
84
85
86
    __host__ __device__ static constexpr T GetReductionZeroVal()
    {
        return NumericLimits<T>::Lowest();
    };
87

88
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
89
90
91
92
93
    {
        if(a < b)
            a = b;
    }

94
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    {
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

template <class T>
struct Min
{
    using dataType = T;

109
110
111
112
    __host__ __device__ static constexpr T GetReductionZeroVal()
    {
        return NumericLimits<T>::Max();
    };
113

114
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
115
116
117
118
119
    {
        if(a > b)
            a = b;
    }

120
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
121
122
123
124
125
126
127
128
129
    {
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

130
131
template <class T>
struct AMax
132
{
133
    using dataType = T;
134

135
    __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
136

137
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
138
139
140
141
142
    {
        if(a < b)
            a = b;
    }

143
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
144
145
146
147
148
149
150
    {
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
151
152
153
154
155
156
157
};

}; // end of namespace reduce

} // end of namespace ck

#endif