reduction_operator.hpp 6.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*******************************************************************************
 *
 * MIT License
 *
 * Copyright (c) 2020 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/
#ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP

29
30
#include "config.hpp"
#include "data_type.hpp"
31
32
33
34
35
36
37
38

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
39
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
40
// operator, "identity element" is the unique
41
//                    element in the algebraic space that doesn't affect the value of other elements
42
43
44
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
45
46
47
48
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
// the first argument of the operator must be both an input & output, and the corresponding variable
// usually stores
49
50
51
52
53
54
55
56
57
58
59
60
61
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

template <class T>
struct Add
{
    using dataType = T;

62
    __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
63

64
65
66
67
68
69
70
    __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

71
    __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
72
73
74
75
76
77
78
};

template <class T>
struct Mul
{
    using dataType = T;

79
    __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
80

81
82
83
84
85
86
    __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

87
    __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
88
89
90
91
92
93
94
};

template <class T>
struct Max
{
    using dataType = T;

95
    __host__ __device__ static constexpr T GetIdentityValue()
96
97
98
    {
        return NumericLimits<T>::Lowest();
    };
99

100
101
102
103
104
105
106
    __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

107
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
108
109
110
111
112
    {
        if(a < b)
            a = b;
    }

113
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    {
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

template <class T>
struct Min
{
    using dataType = T;

128
    __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
129

130
131
132
133
134
135
136
    __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

137
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
138
139
140
141
142
    {
        if(a > b)
            a = b;
    }

143
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
144
145
146
147
148
149
150
151
152
    {
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

153
154
template <class T>
struct AMax
155
{
156
    using dataType = T;
157

158
    __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
159

160
161
162
163
164
165
166
    __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

167
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
168
169
170
171
172
    {
        if(a < b)
            a = b;
    }

173
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
174
175
176
177
178
179
180
    {
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
181
182
};

183
template <typename T>
184
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
185
186
187
188
189
190
191
192
193
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

194
195
196
197
198
}; // end of namespace reduce

} // end of namespace ck

#endif