util_ptx.cuh 9.18 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
/******************************************************************************
 * Copyright (c) 2010-2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
#define HIPCUB_ROCPRIM_UTIL_PTX_HPP_

#include <cstdint>
#include <type_traits>

#include "config.hpp"

#include <cub/rocprim/intrinsics/warp_shuffle.hpp>

#define HIPCUB_WARP_THREADS ::rocprim::warp_size()
#define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
#define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
#define HIPCUB_ARCH 1 // ignored with rocPRIM backend


BEGIN_HIPCUB_NAMESPACE

// Missing compared to CUB:
// * ThreadExit - not supported
// * ThreadTrap - not supported
// * FFMA_RZ, FMUL_RZ - not in CUB public API
// * WARP_SYNC - not supported, not CUB public API
// * CTA_SYNC_AND - not supported, not CUB public API
// * MatchAny - not in CUB public API
//
// Differences:
// * Warp thread masks (when used) are 64-bit unsigned integers
// * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs
// * Arguments first_lane, last_lane, and member_mask are ignored
// in Shuffle* funcs
// * count in BAR is ignored, BAR works like CTA_SYNC

// ID functions etc.

HIPCUB_DEVICE inline
int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
{
    return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y))
        + ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x))
        + threadIdx.x;
}

HIPCUB_DEVICE inline
unsigned int LaneId()
{
    return ::rocprim::lane_id();
}

HIPCUB_DEVICE inline
unsigned int WarpId()
{
    return ::rocprim::warp_id();
}

template <int LOGICAL_WARP_THREADS, int /* ARCH */ = 0>
HIPCUB_DEVICE inline 
uint64_t WarpMask(unsigned int warp_id) {
    constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
    constexpr bool is_arch_warp =
        LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();

    uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);

    if (is_pow_of_two && !is_arch_warp) {
        member_mask <<= warp_id * LOGICAL_WARP_THREADS;
    }

    return member_mask;
}

// Returns the warp lane mask of all lanes less than the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskLt()
{
    return (uint64_t(1) << LaneId()) - 1;
}

// Returns the warp lane mask of all lanes less than or equal to the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskLe()
{
    return ((uint64_t(1) << LaneId()) << 1) - 1;
}

// Returns the warp lane mask of all lanes greater than the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskGt()
{
    return uint64_t(-1)^LaneMaskLe();
}

// Returns the warp lane mask of all lanes greater than or equal to the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskGe()
{
    return uint64_t(-1)^LaneMaskLt();
}

// Shuffle funcs

template <
    int LOGICAL_WARP_THREADS,
    typename T
>
HIPCUB_DEVICE inline
T ShuffleUp(T input,
            int src_offset,
            int first_thread,
            unsigned int member_mask)
{
    // Not supported in rocPRIM.
    (void) first_thread;
    // Member mask is not supported in rocPRIM, because it's
    // not supported in ROCm.
    (void) member_mask;
    return ::rocprim::warp_shuffle_up(
        input, src_offset, LOGICAL_WARP_THREADS
    );
}

template <
    int LOGICAL_WARP_THREADS,
    typename T
>
HIPCUB_DEVICE inline
T ShuffleDown(T input,
              int src_offset,
              int last_thread,
              unsigned int member_mask)
{
    // Not supported in rocPRIM.
    (void) last_thread;
    // Member mask is not supported in rocPRIM, because it's
    // not supported in ROCm.
    (void) member_mask;
    return ::rocprim::warp_shuffle_down(
        input, src_offset, LOGICAL_WARP_THREADS
    );
}

template <
    int LOGICAL_WARP_THREADS,
    typename T
>
HIPCUB_DEVICE inline
T ShuffleIndex(T input,
               int src_lane,
               unsigned int member_mask)
{
    // Member mask is not supported in rocPRIM, because it's
    // not supported in ROCm.
    (void) member_mask;
    return ::rocprim::warp_shuffle(
        input, src_lane, LOGICAL_WARP_THREADS
    );
}

// Other

HIPCUB_DEVICE inline
unsigned int SHR_ADD(unsigned int x,
                     unsigned int shift,
                     unsigned int addend)
{
    return (x >> shift) + addend;
}

HIPCUB_DEVICE inline
unsigned int SHL_ADD(unsigned int x,
                     unsigned int shift,
                     unsigned int addend)
{
    return (x << shift) + addend;
}

namespace detail {

template <typename UnsignedBits>
HIPCUB_DEVICE inline
auto unsigned_bit_extract(UnsignedBits source,
                          unsigned int bit_start,
                          unsigned int num_bits)
    -> typename std::enable_if<sizeof(UnsignedBits) == 8, unsigned int>::type
{
    #ifdef __CUDACC__
        return __bitextract_u64(source, bit_start, num_bits);
    #else
        return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
    #endif // __HIP_PLATFORM_AMD__
}

template <typename UnsignedBits>
HIPCUB_DEVICE inline
auto unsigned_bit_extract(UnsignedBits source,
                          unsigned int bit_start,
                          unsigned int num_bits)
    -> typename std::enable_if<sizeof(UnsignedBits) < 8, unsigned int>::type
{
    #ifdef __CUDACC__
        return __bitextract_u32(source, bit_start, num_bits);
    #else
        return (static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
    #endif // __HIP_PLATFORM_AMD__
}

} // end namespace detail

// Bitfield-extract.
// Extracts \p num_bits from \p source starting at bit-offset \p bit_start.
// The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
template <typename UnsignedBits>
HIPCUB_DEVICE inline
unsigned int BFE(UnsignedBits source,
                 unsigned int bit_start,
                 unsigned int num_bits)
{
    static_assert(std::is_unsigned<UnsignedBits>::value, "UnsignedBits must be unsigned");
    return detail::unsigned_bit_extract(source, bit_start, num_bits);
}

// Bitfield insert.
// Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
HIPCUB_DEVICE inline
void BFI(unsigned int &ret,
         unsigned int x,
         unsigned int y,
         unsigned int bit_start,
         unsigned int num_bits)
{
    #ifdef __CUDACC__
        ret = __bitinsert_u32(x, y, bit_start, num_bits);
    #else
        x <<= bit_start;
        unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
        unsigned int MASK_Y = ~MASK_X;
        ret = (y & MASK_Y) | (x & MASK_X);
    #endif // __HIP_PLATFORM_AMD__
}

HIPCUB_DEVICE inline
unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
{
    return x + y + z;
}

HIPCUB_DEVICE inline
int PRMT(unsigned int a, unsigned int b, unsigned int index)
{
    return ::__byte_perm(a, b, index);
}

HIPCUB_DEVICE inline
void BAR(int count)
{
    (void) count;
    __syncthreads();
}

HIPCUB_DEVICE inline
void CTA_SYNC()
{
    __syncthreads();
}

HIPCUB_DEVICE inline
void WARP_SYNC(unsigned int member_mask)
{
    (void) member_mask;
    ::rocprim::wave_barrier();
}

HIPCUB_DEVICE inline
int WARP_ANY(int predicate, uint64_t member_mask)
{
    (void) member_mask;
    return ::__any(predicate);
}

HIPCUB_DEVICE inline
int WARP_ALL(int predicate, uint64_t member_mask)
{
    (void) member_mask;
    return ::__all(predicate);
}

HIPCUB_DEVICE inline
int64_t WARP_BALLOT(int predicate, uint64_t member_mask)
{
    (void) member_mask;
    return __ballot(predicate);
}

END_HIPCUB_NAMESPACE

#endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_