reduction_operator.hpp 9.76 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
Chao Liu's avatar
Chao Liu committed
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
6

Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
12
#include "ck/utility/type_convert.hpp"
13
14
15
16
17
18
19
20

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
21
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
22
// operator, "identity element" is the unique
23
//                    element in the algebraic space that doesn't affect the value of other elements
24
25
26
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
27
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
28
29
30
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
//                  corresponding variable usually stores
31
32
33
34
35
36
37
38
39
40
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
41
42
43
44
45
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
46

47
    __host__ __device__ static constexpr bool
48
49
50
51
52
53
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

54
55
56
57
58
59
60
61
62
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
63
64
};

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
struct SquaredAdd
{
    template <class T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };

    __host__ __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

    template <class T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
86
                      "The data type is not supported by the SquaredAdd accumulator!");
87
88
89
90
91

        a = a + b * b;
    }
};

92
93
struct Mul
{
94
95
96
97
98
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
99

100
    __host__ __device__ static constexpr bool
101
102
103
104
105
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

106
107
108
109
110
111
112
113
114
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
115
116
117
118
};

struct Max
{
119
    template <typename T>
120
    __host__ __device__ static constexpr T GetIdentityValue()
121
122
123
    {
        return NumericLimits<T>::Lowest();
    };
124

125
    __host__ __device__ static constexpr bool
126
127
128
129
130
131
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

132
    template <typename T>
133
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
134
    {
135
136
137
138
139
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

140
141
142
143
        if(a < b)
            a = b;
    }

144
    template <typename T>
145
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
146
    {
147
148
149
150
151
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

152
153
154
155
156
157
158
159
160
161
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

struct Min
{
162
163
164
165
166
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return NumericLimits<T>::Max();
    };
167

168
    __host__ __device__ static constexpr bool
169
170
171
172
173
174
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

175
    template <typename T>
176
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
177
    {
178
179
180
181
182
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

183
184
185
186
        if(a > b)
            a = b;
    }

187
    template <typename T>
188
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
189
    {
190
191
192
193
194
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

195
196
197
198
199
200
201
202
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

203
struct AMax
204
{
205
206
207
208
209
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
210

211
    __host__ __device__ static constexpr bool
212
213
214
215
216
217
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

218
    template <typename T>
219
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
220
    {
221
222
223
224
225
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

226
227
228
229
        if(a < b)
            a = b;
    }

230
    template <typename T>
231
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
232
    {
233
234
235
236
237
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

238
239
240
241
242
243
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
244
245
};

246
template <typename T>
247
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
248
249
250
251
252
253
254
255
256
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

257
template <InMemoryDataOperationEnum Operation, typename DataType>
258
struct InMemoryDataOperationSupportedOnDataType
259
260
261
262
263
{
    static constexpr bool value = false;
};

template <typename DataType>
264
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
265
266
267
268
269
270
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
271
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
272
273
274
275
276
277
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
278
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
279
280
281
282
283
284
285
286
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
287
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
288
289
290
291
292
293
294
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

Chao Liu's avatar
Chao Liu committed
295
296
} // namespace reduce
} // namespace ck
Umang Yadav's avatar
Umang Yadav committed
297
298

#pragma clang diagnostic pop