dynamic_buffer.hpp 9.26 KB
Newer Older
Chao Liu's avatar
rename  
Chao Liu committed
1
2
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
3

Chao Liu's avatar
tidy  
Chao Liu committed
4
#include "amd_buffer_addressing.hpp"
Chao Liu's avatar
Chao Liu committed
5
6
7
#include "c_style_pointer_cast.hpp"

namespace ck {
8

9
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
10
11
12
13
14
15
16
17
18
19
20
21
struct DynamicBuffer
{
    using type = T;

    T* p_data_;
    ElementSpaceSize element_space_size_;

    __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
        : p_data_{p_data}, element_space_size_{element_space_size}
    {
    }

22
    __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
23
24
25
26
27
28
29
30
31
32
33
34
35
    {
        return BufferAddressSpace;
    }

    __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }

    __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }

    template <typename X,
              typename std::enable_if<
                  is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
                          typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
                  bool>::type = false>
36
    __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
37
38
39
40
41
42
43
44
45
46
47
    {
        // X contains multiple T
        constexpr index_t scalar_per_t_vector =
            scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;

        constexpr index_t scalar_per_x_vector =
            scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

48
        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
49
50
        {
#if CK_USE_AMD_BUFFER_ADDRESSING
Chao Liu's avatar
Chao Liu committed
51
52
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

53
54
55
            return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
                p_data_, i, is_valid_offset, element_space_size_);
#else
Chao Liu's avatar
Chao Liu committed
56
            return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
57
58
59
60
#endif
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
61
            return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        }
    }

    template <typename X,
              typename std::enable_if<
                  is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
                          typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
                  bool>::type = false>
    __host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
    {
        // X contains multiple T
        constexpr index_t scalar_per_t_vector =
            scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;

        constexpr index_t scalar_per_x_vector =
            scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;

        static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
                      "wrong! X need to be multiple T");

82
        if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
83
84
        {
#if CK_USE_AMD_BUFFER_ADDRESSING
Chao Liu's avatar
Chao Liu committed
85
86
            constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

87
88
89
90
91
            amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
                x, p_data_, i, is_valid_offset, element_space_size_);
#else
            if(is_valid_offset)
            {
Chao Liu's avatar
Chao Liu committed
92
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
93
94
95
            }
#endif
        }
96
        else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
97
98
99
100
        {
            if(is_valid_offset)
            {
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
Chao Liu's avatar
Chao Liu committed
101
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
102
#else
103
104
                // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
                // inefficient
Chao Liu's avatar
Chao Liu committed
105
                // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
106
107
108
109
110
111
                // ds_write_b128
                // TODO: remove this after compiler fix
                if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
                                     int8_t>::value)
                {
                    static_assert(
Chao Liu's avatar
Chao Liu committed
112
113
114
115
116
117
118
119
                        (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                         is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
120
121
122
123
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
                            (is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
                             is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
124
125
                        "wrong! not implemented for this combination, please add "
                        "implementation");
126

Chao Liu's avatar
Chao Liu committed
127
128
129
130
131
                    if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                 is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
132
133
                        *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int8_t*>(&x);
Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
139
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
140
141
                        *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int16_t*>(&x);
Chao Liu's avatar
Chao Liu committed
142
143
144
145
146
147
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
148
149
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
Chao Liu's avatar
Chao Liu committed
150
151
152
153
                    }
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x4_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
154
155
156
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
157
158
                        *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32_t*>(&x);
159
                    }
Chao Liu's avatar
Chao Liu committed
160
161
162
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x8_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
163
164
165
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
166
167
                        *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x2_t*>(&x);
168
                    }
Chao Liu's avatar
Chao Liu committed
169
170
171
                    else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
                                              int8x16_t>::value &&
                                      is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
172
173
174
                    {
                        // HACK: cast pointer of x is bad
                        // TODO: remove this after compiler fix
Chao Liu's avatar
Chao Liu committed
175
176
                        *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
                            *c_style_pointer_cast<const int32x4_t*>(&x);
177
178
179
180
                    }
                }
                else
                {
Chao Liu's avatar
Chao Liu committed
181
                    *c_style_pointer_cast<X*>(&p_data_[i]) = x;
182
183
184
185
186
187
188
189
                }
#endif
            }
        }
        else
        {
            if(is_valid_offset)
            {
Chao Liu's avatar
Chao Liu committed
190
                *c_style_pointer_cast<X*>(&p_data_[i]) = x;
191
192
193
194
195
196
197
198
199
            }
        }
    }

    __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }

    __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};

200
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
201
202
203
204
205
206
207
208
209
          typename T,
          typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
    return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
}

} // namespace ck
#endif