reduction_operator.hpp 9.65 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
#pragma once

#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
9
#include "ck/utility/type_convert.hpp"
10
11
12
13
14
15
16
17

namespace ck {

namespace reduce {

// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
18
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
19
// operator, "identity element" is the unique
20
//                    element in the algebraic space that doesn't affect the value of other elements
21
22
23
//                    when operated against them, and the concept is similar to zero vector in
//                    vector space
//                    (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
24
// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
25
26
27
// operator can use the InMemoryDataOperation to finalize, or else it return false
// 3) operator() -- the first argument of the operator must be both an input & output, and the
//                  corresponding variable usually stores
28
29
30
31
32
33
34
35
36
37
//                  the accumulated result of many operator() calls; the second argument is only an
//                  input. For indexable binary
//                  operator, the second version of operator() has third argument (which is an
//                  output) to indicate whether the
//                  accumulated value (the first argument) has changed, in which case the recorded
//                  accumulated index also need be
//                  changed.

struct Add
{
38
39
40
41
42
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
43

44
    __host__ __device__ static constexpr bool
45
46
47
48
49
50
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

51
52
53
54
55
56
57
58
59
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Add accumulator!");

        a = a + b;
    }
60
61
};

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
struct SquaredAdd
{
    template <class T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };

    __host__ __device__ static constexpr bool
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::AtomicAdd ||
               operation == InMemoryDataOperationEnum::Set;
    };

    template <class T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
83
                      "The data type is not supported by the SquaredAdd accumulator!");
84
85
86
87
88

        a = a + b * b;
    }
};

89
90
struct Mul
{
91
92
93
94
95
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(1.0f);
    };
96

97
    __host__ __device__ static constexpr bool
98
99
100
101
102
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        return operation == InMemoryDataOperationEnum::Set;
    };

103
104
105
106
107
108
109
110
111
    template <typename T>
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
    {
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, int32_t>::value,
                      "The data type is not supported by the Mul accumulator!");

        a = a * b;
    }
112
113
114
115
};

struct Max
{
116
    template <typename T>
117
    __host__ __device__ static constexpr T GetIdentityValue()
118
119
120
    {
        return NumericLimits<T>::Lowest();
    };
121

122
    __host__ __device__ static constexpr bool
123
124
125
126
127
128
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

129
    template <typename T>
130
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
131
    {
132
133
134
135
136
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

137
138
139
140
        if(a < b)
            a = b;
    }

141
    template <typename T>
142
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
143
    {
144
145
146
147
148
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Max accumulator!");

149
150
151
152
153
154
155
156
157
158
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
};

struct Min
{
159
160
161
162
163
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return NumericLimits<T>::Max();
    };
164

165
    __host__ __device__ static constexpr bool
166
167
168
169
170
171
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_min to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

172
    template <typename T>
173
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
174
    {
175
176
177
178
179
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

180
181
182
183
        if(a > b)
            a = b;
    }

184
    template <typename T>
185
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
186
    {
187
188
189
190
191
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the Min accumulator!");

192
193
194
195
196
197
198
199
        if(a > b)
        {
            a       = b;
            changed = true;
        }
    }
};

200
struct AMax
201
{
202
203
204
205
206
    template <typename T>
    __host__ __device__ static constexpr T GetIdentityValue()
    {
        return type_convert<T>(0.0f);
    };
207

208
    __host__ __device__ static constexpr bool
209
210
211
212
213
214
    IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
    {
        // ToChange: atomic_max to be added
        return operation == InMemoryDataOperationEnum::Set;
    };

215
    template <typename T>
216
    __host__ __device__ inline constexpr void operator()(T& a, T b) const
217
    {
218
219
220
221
222
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

223
224
225
226
        if(a < b)
            a = b;
    }

227
    template <typename T>
228
    __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
229
    {
230
231
232
233
234
        static_assert(is_same<T, float>::value || is_same<T, double>::value ||
                          is_same<T, half_t>::value || is_same<T, int32_t>::value ||
                          is_same<T, int8_t>::value,
                      "The data type is not supported by the AMax accumulator!");

235
236
237
238
239
240
        if(a < b)
        {
            a       = b;
            changed = true;
        }
    }
241
242
};

243
template <typename T>
244
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
245
246
247
248
249
250
251
252
253
{
    T result = ck::type_convert<T>(0.0f);

    if(operation == InMemoryDataOperationEnum::AtomicMax)
        result = ck::NumericLimits<T>::Lowest();

    return (result);
};

254
template <InMemoryDataOperationEnum Operation, typename DataType>
255
struct InMemoryDataOperationSupportedOnDataType
256
257
258
259
260
{
    static constexpr bool value = false;
};

template <typename DataType>
261
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
262
263
264
265
266
267
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
268
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
269
270
271
272
273
274
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value;
};

template <typename DataType>
275
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
276
277
278
279
280
281
282
283
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
        is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};

template <typename DataType>
284
struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
285
286
287
288
289
290
291
{
    static constexpr bool value =
        is_same<DataType, float>::value || is_same<DataType, double>::value ||
        is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
        is_same<DataType, int32_t>::value;
};

Chao Liu's avatar
Chao Liu committed
292
293
} // namespace reduce
} // namespace ck