dynamic_buffer.hpp 14.7 KB
Newer Older
1
2
3
4
5
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP

#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
6
#include "config.hpp"
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include "enable_if.hpp"

namespace ck {

template <AddressSpaceEnum_t BufferAddressSpace,
          typename T,
          typename ElementSpaceSize,
          bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer
{
    using type = T;

    T* p_data_;
    ElementSpaceSize element_space_size_;
    T invalid_element_value_ = T{0};

    __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
        : p_data_{p_data}, element_space_size_{element_space_size}
    {
    }

    __host__ __device__ constexpr DynamicBuffer(T* p_data,
                                                ElementSpaceSize element_space_size,
                                                T invalid_element_value)
        : p_data_{p_data},
          element_space_size_{element_space_size},
          invalid_element_value_{invalid_element_value}
    {
    }

    __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
    {
        return BufferAddressSpace;
    }

42
43
44
45
    __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }

    __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }

46
    template <typename X,
Chao Liu's avatar
Chao Liu committed
47
48
49
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
50
51
52
    __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
    {
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
53
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
54

Chao Liu's avatar
Chao Liu committed
55
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
56
57
58
59

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

Jianfeng Yan's avatar
Jianfeng Yan committed
60
#if CK_USE_AMD_BUFFER_LOAD
61
62
63
64
65
66
67
68
69
70
71
        bool constexpr use_amd_buffer_addressing = true;
#else
        bool constexpr use_amd_buffer_addressing = false;
#endif

        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
        {
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

            if constexpr(InvalidElementUseNumericalZeroValue)
            {
Jianfeng Yan's avatar
Jianfeng Yan committed
72
                return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, t_per_x>(
Chao Liu's avatar
Chao Liu committed
73
                    p_data_, i, is_valid_element, element_space_size_);
74
75
76
            }
            else
            {
Chao Liu's avatar
Chao Liu committed
77
78
                return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
                                                                               t_per_x>(
79
80
81
82
83
84
85
                    p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
            }
        }
        else
        {
            if constexpr(InvalidElementUseNumericalZeroValue)
            {
86
87
88
89
90
91
92
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp;

                __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));

                return is_valid_element ? tmp : X{0};
#else
93
                return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
94
#endif
95
96
97
            }
            else
            {
98
99
100
101
102
103
104
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp;

                __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));

                return is_valid_element ? tmp : X{invalid_element_value_};
#else
105
106
                return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
                                        : X{invalid_element_value_};
107
#endif
108
109
110
111
            }
        }
    }

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    template <InMemoryDataOperationEnum_t Op,
              typename X,
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
    __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
    {
        if constexpr(Op == InMemoryDataOperationEnum_t::Set)
        {
            this->template Set<X>(i, is_valid_element, x);
        }
        else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd)
        {
            this->template AtomicAdd<X>(i, is_valid_element, x);
        }
        else if constexpr(Op == InMemoryDataOperationEnum_t::Add)
        {
            auto tmp = this->template Get<X>(i, is_valid_element);
            this->template Set<X>(i, is_valid_element, x + tmp);
            // tmp += x;
            // this->template Set<X>(i, is_valid_element, tmp);
        }
    }

136
    template <typename X,
Chao Liu's avatar
Chao Liu committed
137
138
139
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
140
141
142
    __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
    {
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
143
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
144

Chao Liu's avatar
Chao Liu committed
145
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
146
147
148
149
150
151

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
        {
Jianfeng Yan's avatar
Jianfeng Yan committed
152
#if CK_USE_AMD_BUFFER_STORE
153
154
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

Chao Liu's avatar
Chao Liu committed
155
            amd_buffer_store<remove_cvref_t<T>, t_per_x>(
156
157
158
159
                x, p_data_, i, is_valid_element, element_space_size_);
#else
            if(is_valid_element)
            {
160
161
162
163
164
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp = x;

                __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
165
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
166
#endif
167
168
169
170
171
172
173
174
            }
#endif
        }
        else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
        {
            if(is_valid_element)
            {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
175
176
177
178
179
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp = x;

                __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
180
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
181
#endif
182
183
184
185
186
187
#else
                // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
                // inefficient
                // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
                // ds_write_b128
                // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
188
                if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
189
                {
Chao Liu's avatar
Chao Liu committed
190
191
192
193
194
195
                    static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8_t>::value) ||
                                      (is_same<remove_cvref_t<T>, int8_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x2_t>::value) ||
                                      (is_same<remove_cvref_t<T>, int8_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x4_t>::value) ||
196
197
                                      (is_same<remove_cvref_t<T>, int8_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x8_t>::value) ||
198
199
                                      (is_same<remove_cvref_t<T>, int8_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x16_t>::value) ||
Chao Liu's avatar
Chao Liu committed
200
201
202
203
204
205
206
207
                                      (is_same<remove_cvref_t<T>, int8x4_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x4_t>::value) ||
                                      (is_same<remove_cvref_t<T>, int8x8_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x8_t>::value) ||
                                      (is_same<remove_cvref_t<T>, int8x16_t>::value &&
                                       is_same<remove_cvref_t<X>, int8x16_t>::value),
                                  "wrong! not implemented for this combination, please add "
                                  "implementation");
208

Chao Liu's avatar
Chao Liu committed
209
210
                    if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                 is_same<remove_cvref_t<X>, int8_t>::value)
211
212
213
214
215
216
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int8_t*>(&x);
                    }
Chao Liu's avatar
Chao Liu committed
217
218
                    else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x2_t>::value)
219
220
221
222
223
224
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int16_t*>(&x);
                    }
Chao Liu's avatar
Chao Liu committed
225
226
                    else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x4_t>::value)
227
228
229
230
231
232
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
                    }
233
234
235
236
237
238
239
240
                    else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x8_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x2_t*>(&x);
                    }
241
242
243
244
245
246
247
248
                    else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x16_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x4_t*>(&x);
                    }
Chao Liu's avatar
Chao Liu committed
249
250
                    else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x4_t>::value)
251
252
253
254
255
256
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
                    }
Chao Liu's avatar
Chao Liu committed
257
258
                    else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x8_t>::value)
259
260
261
262
263
264
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x2_t*>(&x);
                    }
Chao Liu's avatar
Chao Liu committed
265
266
                    else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
                                      is_same<remove_cvref_t<X>, int8x16_t>::value)
267
268
269
270
271
272
273
274
275
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
                        *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x4_t*>(&x);
                    }
                }
                else
                {
276
277
278
279
280
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                    X tmp = x;

                    __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
281
                    *c_style_pointer_cast<X*>(&p_data_[i]) = x;
282
#endif
283
284
285
286
287
288
289
290
                }
#endif
            }
        }
        else
        {
            if(is_valid_element)
            {
291
292
293
294
295
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp = x;

                __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
296
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
297
#endif
298
299
300
301
            }
        }
    }

zjing14's avatar
zjing14 committed
302
    template <typename X,
Chao Liu's avatar
Chao Liu committed
303
304
305
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
zjing14's avatar
zjing14 committed
306
307
308
    __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
    {
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
309
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
zjing14's avatar
zjing14 committed
310

Chao Liu's avatar
Chao Liu committed
311
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
zjing14's avatar
zjing14 committed
312
313
314
315
316
317

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

        static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");

Jianfeng Yan's avatar
Jianfeng Yan committed
318
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
zjing14's avatar
zjing14 committed
319
320
        constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

Chao Liu's avatar
Chao Liu committed
321
        amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
zjing14's avatar
zjing14 committed
322
323
324
325
326
327
328
329
330
            x, p_data_, i, is_valid_element, element_space_size_);
#else
        if(is_valid_element)
        {
            atomicAdd(&p_data_[i], x);
        }
#endif
    }

331
332
333
334
335
336
337
338
339
340
341
    __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};

template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}

342
343
344
345
346
347
template <
    AddressSpaceEnum_t BufferAddressSpace,
    typename T,
    typename ElementSpaceSize,
    typename X,
    typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
348
__host__ __device__ constexpr auto
349
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
350
351
352
353
354
355
356
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
        p, element_space_size, invalid_element_value};
}

} // namespace ck
#endif