"vscode:/vscode.git/clone" did not exist on "f8d693195e9d178ae12e7f39aba9d95007454041"
dynamic_buffer.hpp 14.8 KB
Newer Older
1
#pragma once
2
#include "config.hpp"
3
#include "enable_if.hpp"
4
5
6
#include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic_add.hpp"
7

8
#ifndef CK_NOGPU
9
10
namespace ck {

11
12
13
14
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
15
template <AddressSpaceEnum BufferAddressSpace,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
          typename T,
          typename ElementSpaceSize,
          bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer
{
    using type = T;

    T* p_data_;
    ElementSpaceSize element_space_size_;
    T invalid_element_value_ = T{0};

    __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
        : p_data_{p_data}, element_space_size_{element_space_size}
    {
    }

    __host__ __device__ constexpr DynamicBuffer(T* p_data,
                                                ElementSpaceSize element_space_size,
                                                T invalid_element_value)
        : p_data_{p_data},
          element_space_size_{element_space_size},
          invalid_element_value_{invalid_element_value}
    {
    }

41
    __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
42
43
44
45
    {
        return BufferAddressSpace;
    }

46
47
48
49
    __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }

    __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }

50
    template <typename X,
Chao Liu's avatar
Chao Liu committed
51
52
53
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
54
55
56
    __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
    {
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
57
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
58

Chao Liu's avatar
Chao Liu committed
59
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
60
61

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
62
                      "wrong! X should contain multiple T");
63

Jianfeng Yan's avatar
Jianfeng Yan committed
64
#if CK_USE_AMD_BUFFER_LOAD
65
66
67
68
69
        bool constexpr use_amd_buffer_addressing = true;
#else
        bool constexpr use_amd_buffer_addressing = false;
#endif

70
        if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
71
72
73
74
75
        {
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

            if constexpr(InvalidElementUseNumericalZeroValue)
            {
Jianfeng Yan's avatar
Jianfeng Yan committed
76
                return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, t_per_x>(
Chao Liu's avatar
Chao Liu committed
77
                    p_data_, i, is_valid_element, element_space_size_);
78
79
80
            }
            else
            {
Chao Liu's avatar
Chao Liu committed
81
82
                return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
                                                                               t_per_x>(
83
84
85
86
87
                    p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
            }
        }
        else
        {
88
            if(is_valid_element)
89
            {
90
91
92
93
94
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp;

                __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));

95
                return tmp;
96
#else
97
                return *c_style_pointer_cast<const X*>(&p_data_[i]);
98
#endif
99
100
101
            }
            else
            {
102
103
104
105
106
107
108
109
                if constexpr(InvalidElementUseNumericalZeroValue)
                {
                    return X{0};
                }
                else
                {
                    return X{invalid_element_value_};
                }
110
111
112
113
            }
        }
    }

114
    template <InMemoryDataOperationEnum Op,
115
116
117
118
119
120
              typename X,
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
    __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
    {
121
        if constexpr(Op == InMemoryDataOperationEnum::Set)
122
123
124
        {
            this->template Set<X>(i, is_valid_element, x);
        }
125
        else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
126
127
128
        {
            this->template AtomicAdd<X>(i, is_valid_element, x);
        }
129
        else if constexpr(Op == InMemoryDataOperationEnum::Add)
130
131
132
133
134
135
136
137
        {
            auto tmp = this->template Get<X>(i, is_valid_element);
            this->template Set<X>(i, is_valid_element, x + tmp);
            // tmp += x;
            // this->template Set<X>(i, is_valid_element, tmp);
        }
    }

138
    template <typename X,
Chao Liu's avatar
Chao Liu committed
139
140
141
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
142
143
144
    __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
    {
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
145
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
146

Chao Liu's avatar
Chao Liu committed
147
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
148
149

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
150
                      "wrong! X should contain multiple T");
151

Jianfeng Yan's avatar
Jianfeng Yan committed
152
#if CK_USE_AMD_BUFFER_STORE
153
        bool constexpr use_amd_buffer_addressing = true;
154
#else
155
156
        bool constexpr use_amd_buffer_addressing      = false;
#endif
157

158
159
#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
        bool constexpr workaround_int8_ds_write_issue = true;
160
#else
161
        bool constexpr workaround_int8_ds_write_issue = false;
162
#endif
163
164
165
166
167
168
169

        if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
        {
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

            amd_buffer_store<remove_cvref_t<T>, t_per_x>(
                x, p_data_, i, is_valid_element, element_space_size_);
170
        }
171
172
173
        else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
                          is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
                          workaround_int8_ds_write_issue)
174
175
176
        {
            if(is_valid_element)
            {
177
                // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
178
179
180
                // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
                // ds_write_b128
                // TODO: remove this after compiler fix
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
                               is_same<remove_cvref_t<X>, int8_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x2_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x4_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x8_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x16_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8x4_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x4_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8x8_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x8_t>::value) ||
                                  (is_same<remove_cvref_t<T>, int8x16_t>::value &&
                                   is_same<remove_cvref_t<X>, int8x16_t>::value),
                              "wrong! not implemented for this combination, please add "
                              "implementation");

                if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                             is_same<remove_cvref_t<X>, int8_t>::value)
202
                {
203
204
205
206
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int8_t*>(&x);
207
                }
208
209
                else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x2_t>::value)
210
                {
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int16_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x4_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x8_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32x2_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x16_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32x4_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x4_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x8_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32x2_t*>(&x);
                }
                else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
                                  is_same<remove_cvref_t<X>, int8x16_t>::value)
                {
                    // HACK: cast pointer of x is bad
                    // TODO: remove this after compiler fix
                    *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                        *c_style_pointer_cast<const int32x4_t*>(&x);
263
264
265
266
267
268
269
                }
            }
        }
        else
        {
            if(is_valid_element)
            {
270
271
272
273
274
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
                X tmp = x;

                __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
275
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
276
#endif
277
278
279
280
            }
        }
    }

zjing14's avatar
zjing14 committed
281
    template <typename X,
Chao Liu's avatar
Chao Liu committed
282
283
284
              typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
                                         typename scalar_type<remove_cvref_t<T>>::type>::value,
                                 bool>::type = false>
zjing14's avatar
zjing14 committed
285
286
    __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
    {
287
288
        using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;

zjing14's avatar
zjing14 committed
289
        // X contains multiple T
Chao Liu's avatar
Chao Liu committed
290
        constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
zjing14's avatar
zjing14 committed
291

Chao Liu's avatar
Chao Liu committed
292
        constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
zjing14's avatar
zjing14 committed
293
294

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                      "wrong! X should contain multiple T");

        static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");

#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
        bool constexpr use_amd_buffer_addressing =
            is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
            is_same_v<remove_cvref_t<scalar_t>, float> ||
            (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
        bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
        bool constexpr use_amd_buffer_addressing =
            is_same_v<remove_cvref_t<scalar_t>, float> ||
            (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#else
        bool constexpr use_amd_buffer_addressing = false;
#endif
zjing14's avatar
zjing14 committed
313

314
315
316
        if constexpr(use_amd_buffer_addressing)
        {
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
zjing14's avatar
zjing14 committed
317

318
319
320
321
            amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
                x, p_data_, i, is_valid_element, element_space_size_);
        }
        else
zjing14's avatar
zjing14 committed
322
        {
323
324
            if(is_valid_element)
            {
325
                atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
326
            }
zjing14's avatar
zjing14 committed
327
328
329
        }
    }

330
331
332
333
334
    __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};

335
template <AddressSpaceEnum BufferAddressSpace, typename T, typename ElementSpaceSize>
336
337
338
339
340
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}

341
template <
342
    AddressSpaceEnum BufferAddressSpace,
343
344
345
346
    typename T,
    typename ElementSpaceSize,
    typename X,
    typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
347
__host__ __device__ constexpr auto
348
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
349
350
351
352
353
354
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
        p, element_space_size, invalid_element_value};
}

} // namespace ck
355
#endif