reduction_operator.hpp 9.61 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
9
10
11
12
13
14
15
16

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
17
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
18
// operator, "identity element" is the unique
19
//                    element in the algebraic space that doesn't affect the value of other elements
20
21
22
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
23
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
24
25
26
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
//                  corresponding variable usually stores
27
28
29
30
31
32
33
34
35
36
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
37
38
39
40
41
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
42

43
    __host__ __device__ static constexpr bool
44
45
46
47
48
49
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

50
51
52
53
54
55
56
57
58
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
59
60
};

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
struct SquaredAdd
{
    template <class T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };

    __host__ __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

    template <class T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
82
                      "The data type is not supported by the SquaredAdd accumulator!");
83
84
85
86
87

        a = a + b * b;
    }
};

88
89
struct Mul
{
90
91
92
93
94
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
95

96
    __host__ __device__ static constexpr bool
97
98
99
100
101
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

102
103
104
105
106
107
108
109
110
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
111
112
113
114
};

struct Max
{
115
    template <typename T>
116
    __host__ __device__ static constexpr T GetIdentityValue()
117
118
119
    {
        return NumericLimits<T>::Lowest();
    };
120

121
    __host__ __device__ static constexpr bool
122
123
124
125
126
127
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

128
    template <typename T>
129
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
130
    {
131
132
133
134
135
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

136
137
138
139
        if(a < b)
            a = b;
    }

140
    template <typename T>
141
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
142
    {
143
144
145
146
147
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

148
149
150
151
152
153
154
155
156
157
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

struct Min
{
158
159
160
161
162
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return NumericLimits<T>::Max();
    };
163

164
    __host__ __device__ static constexpr bool
165
166
167
168
169
170
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

171
    template <typename T>
172
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
173
    {
174
175
176
177
178
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

179
180
181
182
        if(a > b)
            a = b;
    }

183
    template <typename T>
184
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
185
    {
186
187
188
189
190
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

191
192
193
194
195
196
197
198
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

199
struct AMax
200
{
201
202
203
204
205
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
206

207
    __host__ __device__ static constexpr bool
208
209
210
211
212
213
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

214
    template <typename T>
215
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
216
    {
217
218
219
220
221
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

222
223
224
225
        if(a < b)
            a = b;
    }

226
    template <typename T>
227
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
228
    {
229
230
231
232
233
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

234
235
236
237
238
239
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
240
241
};

242
template <typename T>
243
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
244
245
246
247
248
249
250
251
252
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

253
template <InMemoryDataOperationEnum Operation, typename DataType>
254
struct InMemoryDataOperationSupportedOnDataType
255
256
257
258
259
{
    static constexpr bool value = false;
};

template <typename DataType>
260
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
261
262
263
264
265
266
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
267
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
268
269
270
271
272
273
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
274
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
275
276
277
278
279
280
281
282
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
283
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
284
285
286
287
288
289
290
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

Chao Liu's avatar
Chao Liu committed
291
292
} // namespace reduce
} // namespace ck