reduction_operator.hpp 11.6 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
9
#include "ck/utility/type_convert.hpp"
10
11
12
13
14
15
16
17

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
18
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
19
// operator, "identity element" is the unique
20
//                    element in the algebraic space that doesn't affect the value of other elements
21
22
23
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
24
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
25
26
27
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
//                  corresponding variable usually stores
28
29
30
31
32
33
34
35
36
37
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
38
39
40
41
42
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
43

44
    __host__ __device__ static constexpr bool
45
46
47
48
49
50
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

51
52
53
54
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
55
                          is_same<T, int32_t>::value || is_same<T, half_t>::value,
56
57
58
59
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
60
61
62
63
64
65
66
67

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        a = type_convert<bhalf_t>(a_ + b_);
    }
68
69
};

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
struct SquaredAdd
{
    template <class T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };

    __host__ __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

    template <class T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
91
                      "The data type is not supported by the SquaredAdd accumulator!");
92
93
94
95
96

        a = a + b * b;
    }
};

97
98
struct Mul
{
99
100
101
102
103
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
104

105
    __host__ __device__ static constexpr bool
106
107
108
109
110
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

111
112
113
114
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
115
                          is_same<T, int32_t>::value || is_same<T, half_t>::value,
116
117
118
119
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
120
121
122
123
124
125
126
127

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        a = type_convert<bhalf_t>(a_ * b_);
    }
128
129
130
131
};

struct Max
{
132
    template <typename T>
133
    __host__ __device__ static constexpr T GetIdentityValue()
134
    {
135
136
137
138
139
140
141
142
143
        if constexpr(is_same_v<T, bhalf_t>)
        {
            float val = NumericLimits<float>::Lowest();
            return type_convert<bhalf_t>(val);
        }
        else
        {
            return NumericLimits<T>::Lowest();
        }
144
    };
145

146
    __host__ __device__ static constexpr bool
147
148
149
150
151
152
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

153
    template <typename T>
154
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
155
    {
156
157
158
159
160
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

161
162
163
164
        if(a < b)
            a = b;
    }

165
166
167
168
169
170
171
172
173
    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ < b_)
            a = b;
    }

174
    template <typename T>
175
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
176
    {
177
178
179
180
181
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

182
183
184
185
186
187
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
188
189
190
191
192
193
194
195
196
197
198
199

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ < b_)
        {
            a       = b;
            changed = true;
        }
    }
200
201
202
203
};

struct Min
{
204
205
206
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
207
208
209
210
211
212
213
214
215
        if constexpr(is_same_v<T, bhalf_t>)
        {
            float val = NumericLimits<float>::Max();
            return type_convert<bhalf_t>(val);
        }
        else
        {
            return NumericLimits<T>::Max();
        }
216
217
        return NumericLimits<T>::Max();
    };
218

219
    __host__ __device__ static constexpr bool
220
221
222
223
224
225
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

226
    template <typename T>
227
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
228
    {
229
230
231
232
233
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

234
235
236
237
        if(a > b)
            a = b;
    }

238
239
240
241
242
243
244
245
246
    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ > b_)
            a = b;
    }

247
    template <typename T>
248
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
249
    {
250
251
252
253
254
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

255
256
257
258
259
260
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
261
262
263
264
265
266
267
268
269
270
271
272

    __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
    {
        float a_ = type_convert<float>(a);
        float b_ = type_convert<float>(b);

        if(a_ > b_)
        {
            a       = b;
            changed = true;
        }
    }
273
274
};

275
struct AMax
276
{
277
278
279
280
281
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
282

283
    __host__ __device__ static constexpr bool
284
285
286
287
288
289
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

290
    template <typename T>
291
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
292
    {
293
294
295
296
297
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

298
299
300
301
        if(a < b)
            a = b;
    }

302
    template <typename T>
303
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
304
    {
305
306
307
308
309
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

310
311
312
313
314
315
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
316
317
};

318
template <typename T>
319
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
320
321
322
323
324
325
326
327
328
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

329
template <InMemoryDataOperationEnum Operation, typename DataType>
330
struct InMemoryDataOperationSupportedOnDataType
331
332
333
334
335
{
    static constexpr bool value = false;
};

template <typename DataType>
336
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
337
338
339
340
341
342
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
343
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
344
345
346
347
348
349
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
350
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
351
352
353
354
355
356
357
358
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
359
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
360
361
362
363
364
365
366
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

Chao Liu's avatar
Chao Liu committed
367
368
} // namespace reduce
} // namespace ck