various.hpp 9.28 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef ROCPRIM_DETAIL_VARIOUS_HPP_
#define ROCPRIM_DETAIL_VARIOUS_HPP_

#include <type_traits>

#include "../config.hpp"
#include "../types.hpp"
#include "../type_traits.hpp"

// TODO: Refactor when it gets crowded

BEGIN_ROCPRIM_NAMESPACE
namespace detail
{

struct empty_storage_type
{

};

template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr bool is_power_of_two(const T x)
{
    static_assert(::rocprim::is_integral<T>::value, "T must be integer type");
    return (x > 0) && ((x & (x - 1)) == 0);
}

template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T next_power_of_two(const T x, const T acc = 1)
{
    static_assert(::rocprim::is_unsigned<T>::value, "T must be unsigned type");
    return acc >= x ? acc : next_power_of_two(x, 2 * acc);
}

template <
    typename T,
    typename U,
    std::enable_if_t<::rocprim::is_integral<T>::value && ::rocprim::is_unsigned<U>::value, int> = 0>
ROCPRIM_HOST_DEVICE inline constexpr auto ceiling_div(const T a, const U b)
{
    return a / b + (a % b > 0 ? 1 : 0);
}

ROCPRIM_HOST_DEVICE inline
size_t align_size(size_t size, size_t alignment = 256)
{
    return ceiling_div(size, alignment) * alignment;
}

// TOOD: Put the block algorithms with warp size variables at device side with macro.
// Temporary workaround
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T warp_size_in_class(const T warp_size)
{
    return warp_size;
}

// Select the minimal warp size for block of size block_size, it's
// useful for blocks smaller than maximal warp size.
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T get_min_warp_size(const T block_size, const T max_warp_size)
{
    static_assert(::rocprim::is_unsigned<T>::value, "T must be unsigned type");
    return block_size >= max_warp_size ? max_warp_size : next_power_of_two(block_size);
}

template<unsigned int WarpSize>
struct is_warpsize_shuffleable {
    static const bool value = detail::is_power_of_two(WarpSize);
};

// Selects an appropriate vector_type based on the input T and size N.
// The byte size is calculated and used to select an appropriate vector_type.
template<class T, unsigned int N>
struct match_vector_type
{
    static constexpr unsigned int size = sizeof(T) * N;
    using vector_base_type =
        typename std::conditional<
            sizeof(T) >= 4,
            int,
            typename std::conditional<
                sizeof(T) >= 2,
                short,
                char
            >::type
        >::type;

    using vector_4 = typename make_vector_type<vector_base_type, 4>::type;
    using vector_2 = typename make_vector_type<vector_base_type, 2>::type;
    using vector_1 = typename make_vector_type<vector_base_type, 1>::type;

    using type =
        typename std::conditional<
            size % sizeof(vector_4) == 0,
            vector_4,
            typename std::conditional<
                size % sizeof(vector_2) == 0,
                vector_2,
                vector_1
            >::type
        >::type;
};

// Checks if Items is odd and ensures that size of T is smaller than vector_type.
template<class T, unsigned int Items>
struct is_vectorizable : std::integral_constant<bool, (Items % 2 == 0) &&(sizeof(T) < sizeof(typename match_vector_type<T, Items>::type))> {};

// Returns the number of LDS (local data share) banks.
ROCPRIM_HOST_DEVICE
constexpr unsigned int get_lds_banks_no()
{
    // Currently all devices supported by ROCm have 32 banks (4 bytes each)
    return 32;
}

// Finds biggest fundamental type for type T that sizeof(T) is
// a multiple of that type's size.
template<class T>
struct match_fundamental_type
{
    using type =
        typename std::conditional<
            sizeof(T)%8 == 0,
            unsigned long long,
            typename std::conditional<
                sizeof(T)%4 == 0,
                unsigned int,
                typename std::conditional<
                    sizeof(T)%2 == 0,
                    unsigned short,
                    unsigned char
                >::type
            >::type
        >::type;
};

template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto store_volatile(T * output, T value)
    -> typename std::enable_if<std::is_fundamental<T>::value>::type
{
    // TODO: check GCC
    // error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers
#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__))
    *const_cast<volatile T*>(output) = value;
#else
    *output = value;
#endif
}

template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto store_volatile(T * output, T value)
    -> typename std::enable_if<!std::is_fundamental<T>::value>::type
{
    using fundamental_type = typename match_fundamental_type<T>::type;
    constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type);

    auto input_ptr = reinterpret_cast<volatile fundamental_type*>(&value);
    auto output_ptr = reinterpret_cast<volatile fundamental_type*>(output);

    ROCPRIM_UNROLL
    for(unsigned int i = 0; i < n; i++)
    {
        output_ptr[i] = input_ptr[i];
    }
}

template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto load_volatile(T * input)
    -> typename std::enable_if<std::is_fundamental<T>::value, T>::type
{
    // TODO: check GCC
    // error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers
#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__))
    T retval = *const_cast<volatile T*>(input);
    return retval;
#else
    return *input;
#endif
}

template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto load_volatile(T * input)
    -> typename std::enable_if<!std::is_fundamental<T>::value, T>::type
{
    using fundamental_type = typename match_fundamental_type<T>::type;
    constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type);

    T retval;
    auto output_ptr = reinterpret_cast<volatile fundamental_type*>(&retval);
    auto input_ptr = reinterpret_cast<volatile fundamental_type*>(input);

    ROCPRIM_UNROLL
    for(unsigned int i = 0; i < n; i++)
    {
        output_ptr[i] = input_ptr[i];
    }
    return retval;
}

// A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions
template <typename T>
struct raw_storage
{
    // Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T
    typedef typename detail::match_fundamental_type<T>::type device_word;

    // Backing storage
    device_word storage[sizeof(T) / sizeof(device_word)];

    // Alias
    ROCPRIM_HOST_DEVICE T& get()
    {
        return reinterpret_cast<T&>(*this);
    }
};

// Checks if two iterators have the same type and value
template<class Iterator1, class Iterator2>
inline
bool are_iterators_equal(Iterator1, Iterator2)
{
    return false;
}

template<class Iterator>
inline
bool are_iterators_equal(Iterator iter1, Iterator iter2)
{
    return iter1 == iter2;
}

template<class...>
using void_t = void;

template<typename T>
struct type_identity {
    using type = T;
};

template<class T, class = void>
struct extract_type_impl : type_identity<T> { };

template<class T>
struct extract_type_impl<T, void_t<typename T::type> > : extract_type_impl<typename T::type> { };

template <typename T>
using extract_type = typename extract_type_impl<T>::type;

template<bool Value, class T>
struct select_type_case
{
    static constexpr bool value = Value;
    using type = T;
};

template<class Case, class... OtherCases>
struct select_type_impl
    : std::conditional<
        Case::value,
        type_identity<extract_type<typename Case::type>>,
        select_type_impl<OtherCases...>
    >::type { };

template<class T>
struct select_type_impl<select_type_case<true, T>> : type_identity<extract_type<T>> { };

template<class T>
struct select_type_impl<select_type_case<false, T>>
{
    static_assert(
        sizeof(T) == 0,
        "Cannot select any case. "
        "The last case must have true condition or be a fallback type."
    );
};

template<class Fallback>
struct select_type_impl<Fallback> : type_identity<extract_type<Fallback>> { };

template <typename... Cases>
using select_type = typename select_type_impl<Cases...>::type;

template <bool Value>
using bool_constant = std::integral_constant<bool, Value>;

} // end namespace detail
END_ROCPRIM_NAMESPACE

#endif // ROCPRIM_DETAIL_VARIOUS_HPP_