reduction_operator.hpp 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*******************************************************************************
 *
 * MIT License
 *
 * Copyright (c) 2020 Advanced Micro Devices, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 *******************************************************************************/
#ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP

29
30
#include "config.hpp"
#include "data_type.hpp"
31
#include "type.hpp"
32
33
34
35
36
37
38
39

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
40
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
41
// operator, "identity element" is the unique
42
//                    element in the algebraic space that doesn't affect the value of other elements
43
44
45
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
46
47
48
49
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
// the first argument of the operator must be both an input & output, and the corresponding variable
// usually stores
50
51
52
53
54
55
56
57
58
59
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
60
61
62
63
64
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
65

66
    __host__ __device__ static constexpr bool
67
68
69
70
71
72
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

73
74
75
76
77
78
79
80
81
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
82
83
};

84
85
struct SquaredAdd
{
86
    template <class T>
87
    __host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); };
88

89
    __host__ __device__ static constexpr bool
Anthony Chang's avatar
amend  
Anthony Chang committed
90
91
92
93
94
95
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

96
    template <class T>
97
98
99
100
101
102
103
104
105
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

        a = a + b * b;
    }
106
107
};

108
109
struct Mul
{
110
111
112
113
114
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
115

116
    __host__ __device__ static constexpr bool
117
118
119
120
121
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

122
123
124
125
126
127
128
129
130
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
131
132
133
134
};

struct Max
{
135
    template <typename T>
136
    __host__ __device__ static constexpr T GetIdentityValue()
137
138
139
    {
        return NumericLimits<T>::Lowest();
    };
140

141
    __host__ __device__ static constexpr bool
142
143
144
145
146
147
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

148
    template <typename T>
149
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
150
    {
151
152
153
154
155
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

156
157
158
159
        if(a < b)
            a = b;
    }

160
    template <typename T>
161
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
162
    {
163
164
165
166
167
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

168
169
170
171
172
173
174
175
176
177
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

struct Min
{
178
179
180
181
182
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return NumericLimits<T>::Max();
    };
183

184
    __host__ __device__ static constexpr bool
185
186
187
188
189
190
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

191
    template <typename T>
192
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
193
    {
194
195
196
197
198
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

199
200
201
202
        if(a > b)
            a = b;
    }

203
    template <typename T>
204
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
205
    {
206
207
208
209
210
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

211
212
213
214
215
216
217
218
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

219
struct AMax
220
{
221
222
223
224
225
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
226

227
    __host__ __device__ static constexpr bool
228
229
230
231
232
233
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

234
    template <typename T>
235
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
236
    {
237
238
239
240
241
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

242
243
244
245
        if(a < b)
            a = b;
    }

246
    template <typename T>
247
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
248
    {
249
250
251
252
253
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

254
255
256
257
258
259
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
260
261
};

262
template <typename T>
263
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
264
265
266
267
268
269
270
271
272
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
    static constexpr bool value = false;
};

template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

311
312
313
314
315
}; // end of namespace reduce

} // end of namespace ck

#endif