dynamic_buffer.hpp 10.6 KB
Newer Older
Chao Liu's avatar
rename  
Chao Liu committed
1
2
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
3

Chao Liu's avatar
tidy  
Chao Liu committed
4
#include "amd_buffer_addressing.hpp"
Chao Liu's avatar
Chao Liu committed
5
6
7
#include "c_style_pointer_cast.hpp"

namespace ck {
8

9
10
11
12
template <AddressSpaceEnum_t BufferAddressSpace,
          typename T,
          typename ElementSpaceSize,
          bool InvalidElementUseNumericalZeroValue>
13
14
15
16
17
18
struct DynamicBuffer
{
    using type = T;

    T* p_data_;
    ElementSpaceSize element_space_size_;
19
    T invalid_element_value_ = T{0};
20
21
22
23
24
25

    __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
        : p_data_{p_data}, element_space_size_{element_space_size}
    {
    }

26
27
28
29
30
31
32
33
34
    __host__ __device__ constexpr DynamicBuffer(T* p_data,
                                                ElementSpaceSize element_space_size,
                                                T invalid_element_value)
        : p_data_{p_data},
          element_space_size_{element_space_size},
          invalid_element_value_{invalid_element_value}
    {
    }

35
    __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
36
37
38
39
40
41
42
43
44
    {
        return BufferAddressSpace;
    }

    template <typename X,
              typename std::enable_if<
                  is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
                          typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
                  bool>::type = false>
45
    __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
46
47
48
49
50
51
52
53
54
55
56
57
    {
        // X contains multiple T
        constexpr index_t scalar_per_t_vector =
            scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;

        constexpr index_t scalar_per_x_vector =
            scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

#if CK_USE_AMD_BUFFER_ADDRESSING
58
        bool constexpr use_amd_buffer_addressing = true;
59
#else
60
        bool constexpr use_amd_buffer_addressing = false;
61
#endif
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
        {
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

            if constexpr(InvalidElementUseNumericalZeroValue)
            {
                return amd_buffer_load_invalid_element_return_return_zero<
                    remove_cv_t<remove_reference_t<T>>,
                    t_per_x>(p_data_, i, is_valid_element, element_space_size_);
            }
            else
            {
                return amd_buffer_load_invalid_element_return_customized_value<
                    remove_cv_t<remove_reference_t<T>>,
                    t_per_x>(
                    p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
            }
80
81
82
        }
        else
        {
83
84
85
86
87
88
89
90
91
            if constexpr(InvalidElementUseNumericalZeroValue)
            {
                return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
            }
            else
            {
                return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
                                        : X{invalid_element_value_};
            }
92
93
94
95
96
97
98
99
        }
    }

    template <typename X,
              typename std::enable_if<
                  is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
                          typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
                  bool>::type = false>
100
    __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
101
102
103
104
105
106
107
108
109
110
111
    {
        // X contains multiple T
        constexpr index_t scalar_per_t_vector =
            scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;

        constexpr index_t scalar_per_x_vector =
            scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

112
        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
113
114
        {
#if CK_USE_AMD_BUFFER_ADDRESSING
Chao Liu's avatar
Chao Liu committed
115
116
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

117
118
            amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
                x, p_data_, i, is_valid_element, element_space_size_);
119
#else
120
            if(is_valid_element)
121
            {
Chao Liu's avatar
Chao Liu committed
122
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
123
124
125
            }
#endif
        }
126
        else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
127
        {
128
            if(is_valid_element)
129
130
            {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
Chao Liu's avatar
Chao Liu committed
131
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
132
#else
133
134
                // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
                // inefficient
Chao Liu's avatar
Chao Liu committed
135
                // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
136
137
138
139
140
141
                // ds_write_b128
                // TODO: remove this after compiler fix
                if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
                                     int8_t>::value)
                {
                    static_assert(
Chao Liu's avatar
Chao Liu committed
142
143
144
145
146
147
148
149
                        (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                         is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
150
151
152
153
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
154
155
                        "wrong! not implemented for this combination, please add "
                        "implementation");
156

Chao Liu's avatar
Chao Liu committed
157
158
159
160
161
                    if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                 is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
162
163
                        *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int8_t*>(&x);
Chao Liu's avatar
Chao Liu committed
164
165
166
167
168
169
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
170
171
                        *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int16_t*>(&x);
Chao Liu's avatar
Chao Liu committed
172
173
174
175
176
177
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
178
179
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
Chao Liu's avatar
Chao Liu committed
180
181
182
183
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x4_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
184
185
186
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
187
188
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
189
                    }
Chao Liu's avatar
Chao Liu committed
190
191
192
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
193
194
195
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
196
197
                        *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x2_t*>(&x);
198
                    }
Chao Liu's avatar
Chao Liu committed
199
200
201
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x16_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
202
203
204
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
205
206
                        *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x4_t*>(&x);
207
208
209
210
                    }
                }
                else
                {
Chao Liu's avatar
Chao Liu committed
211
                    *c_style_pointer_cast<X*>(&p_data_[i]) = x;
212
213
214
215
216
217
                }
#endif
            }
        }
        else
        {
218
            if(is_valid_element)
219
            {
Chao Liu's avatar
Chao Liu committed
220
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
221
222
223
224
225
226
227
228
229
            }
        }
    }

    __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};

230
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
231
232
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
233
234
235
236
237
238
239
240
241
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}

template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
        p, element_space_size, invalid_element_value};
242
243
244
245
}

} // namespace ck
#endif