reduction_operator.hpp 11.1 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
9
#include "ck/utility/type_convert.hpp"
10
11
12
13
14
15
16
17

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
18
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
19
// operator, "identity element" is the unique
20
//                    element in the algebraic space that doesn't affect the value of other elements
21
22
23
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
24
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
25
26
27
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
//                  corresponding variable usually stores
28
29
30
31
32
33
34
35
36
37
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
38
39
40
41
42
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
43

44
    __host__ __device__ static constexpr bool
45
46
47
48
49
50
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

51
52
53
54
55
56
57
58
59
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
60
61
};

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
struct SquaredAdd
{
    template <class T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };

    __host__ __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

    template <class T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
83
                      "The data type is not supported by the SquaredAdd accumulator!");
84
85
86
87
88

        a = a + b * b;
    }
};

89
90
struct Mul
{
91
92
93
94
95
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
96

97
    __host__ __device__ static constexpr bool
98
99
100
101
102
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

103
104
105
106
107
108
109
110
111
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
112
113
114
115
};

struct Max
{
116
    template <typename T>
117
    __host__ __device__ static constexpr T GetIdentityValue()
118
    {
119
120
121
122
123
124
125
126
127
        if constexpr(is_same_v<T, bhalf_t>)
        {
            float val = NumericLimits<float>::Lowest();
            return type_convert<bhalf_t>(val);
        }
        else
        {
            return NumericLimits<T>::Lowest();
        }
128
    };
129

130
    __host__ __device__ static constexpr bool
131
132
133
134
135
136
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

137
    template <typename T>
138
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
139
    {
140
141
142
143
144
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

145
146
147
148
        if(a < b)
            a = b;
    }

149
150
151
152
153
154
155
156
157
    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ < b_)
            a = b;
    }

158
    template <typename T>
159
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
160
    {
161
162
163
164
165
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

166
167
168
169
170
171
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
172
173
174
175
176
177
178
179
180
181
182
183

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ < b_)
        {
            a       = b;
            changed = true;
        }
    }
184
185
186
187
};

struct Min
{
188
189
190
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
191
192
193
194
195
196
197
198
199
        if constexpr(is_same_v<T, bhalf_t>)
        {
            float val = NumericLimits<float>::Max();
            return type_convert<bhalf_t>(val);
        }
        else
        {
            return NumericLimits<T>::Max();
        }
200
201
        return NumericLimits<T>::Max();
    };
202

203
    __host__ __device__ static constexpr bool
204
205
206
207
208
209
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

210
    template <typename T>
211
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
212
    {
213
214
215
216
217
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

218
219
220
221
        if(a > b)
            a = b;
    }

222
223
224
225
226
227
228
229
230
    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ > b_)
            a = b;
    }

231
    template <typename T>
232
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
233
    {
234
235
236
237
238
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

239
240
241
242
243
244
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
245
246
247
248
249
250
251
252
253
254
255
256

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ > b_)
        {
            a       = b;
            changed = true;
        }
    }
257
258
};

259
struct AMax
260
{
261
262
263
264
265
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
266

267
    __host__ __device__ static constexpr bool
268
269
270
271
272
273
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

274
    template <typename T>
275
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
276
    {
277
278
279
280
281
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

282
283
284
285
        if(a < b)
            a = b;
    }

286
    template <typename T>
287
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
288
    {
289
290
291
292
293
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

294
295
296
297
298
299
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
300
301
};

302
template <typename T>
303
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
304
305
306
307
308
309
310
311
312
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

313
template <InMemoryDataOperationEnum Operation, typename DataType>
314
struct InMemoryDataOperationSupportedOnDataType
315
316
317
318
319
{
    static constexpr bool value = false;
};

template <typename DataType>
320
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
321
322
323
324
325
326
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
327
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
328
329
330
331
332
333
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
334
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
335
336
337
338
339
340
341
342
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
343
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
344
345
346
347
348
349
350
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

Chao Liu's avatar
Chao Liu committed
351
352
} // namespace reduce
} // namespace ck