/****************************************************************************** * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_ #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_ #include #include #include "config.hpp" #include #define HIPCUB_WARP_THREADS ::rocprim::warp_size() #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size() #define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size() #define HIPCUB_ARCH 1 // ignored with rocPRIM backend BEGIN_HIPCUB_NAMESPACE // Missing compared to CUB: // * ThreadExit - not supported // * ThreadTrap - not supported // * FFMA_RZ, FMUL_RZ - not in CUB public API // * WARP_SYNC - not supported, not CUB public API // * CTA_SYNC_AND - not supported, not CUB public API // * MatchAny - not in CUB public API // // Differences: // * Warp thread masks (when used) are 64-bit unsigned integers // * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs // * Arguments first_lane, last_lane, and member_mask are ignored // in Shuffle* funcs // * count in BAR is ignored, BAR works like CTA_SYNC // ID functions etc. HIPCUB_DEVICE inline int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z) { return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) + ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) + threadIdx.x; } HIPCUB_DEVICE inline unsigned int LaneId() { return ::rocprim::lane_id(); } HIPCUB_DEVICE inline unsigned int WarpId() { return ::rocprim::warp_id(); } template HIPCUB_DEVICE inline uint64_t WarpMask(unsigned int warp_id) { constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS); constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == ::rocprim::device_warp_size(); uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS); if (is_pow_of_two && !is_arch_warp) { member_mask <<= warp_id * LOGICAL_WARP_THREADS; } return member_mask; } // Returns the warp lane mask of all lanes less than the calling thread HIPCUB_DEVICE inline uint64_t LaneMaskLt() { return (uint64_t(1) << LaneId()) - 1; } // Returns the warp lane mask of all lanes less than or equal to the calling thread HIPCUB_DEVICE inline uint64_t LaneMaskLe() { return ((uint64_t(1) << LaneId()) << 1) - 1; } // Returns the warp lane mask of all lanes greater than the calling thread HIPCUB_DEVICE inline uint64_t LaneMaskGt() { return uint64_t(-1)^LaneMaskLe(); } // Returns the warp lane mask of all lanes greater than or equal to the calling thread HIPCUB_DEVICE inline uint64_t LaneMaskGe() { return uint64_t(-1)^LaneMaskLt(); } // Shuffle funcs template < int LOGICAL_WARP_THREADS, typename T > HIPCUB_DEVICE inline T ShuffleUp(T input, int src_offset, int first_thread, unsigned int member_mask) { // Not supported in rocPRIM. (void) first_thread; // Member mask is not supported in rocPRIM, because it's // not supported in ROCm. (void) member_mask; return ::rocprim::warp_shuffle_up( input, src_offset, LOGICAL_WARP_THREADS ); } template < int LOGICAL_WARP_THREADS, typename T > HIPCUB_DEVICE inline T ShuffleDown(T input, int src_offset, int last_thread, unsigned int member_mask) { // Not supported in rocPRIM. (void) last_thread; // Member mask is not supported in rocPRIM, because it's // not supported in ROCm. (void) member_mask; return ::rocprim::warp_shuffle_down( input, src_offset, LOGICAL_WARP_THREADS ); } template < int LOGICAL_WARP_THREADS, typename T > HIPCUB_DEVICE inline T ShuffleIndex(T input, int src_lane, unsigned int member_mask) { // Member mask is not supported in rocPRIM, because it's // not supported in ROCm. (void) member_mask; return ::rocprim::warp_shuffle( input, src_lane, LOGICAL_WARP_THREADS ); } // Other HIPCUB_DEVICE inline unsigned int SHR_ADD(unsigned int x, unsigned int shift, unsigned int addend) { return (x >> shift) + addend; } HIPCUB_DEVICE inline unsigned int SHL_ADD(unsigned int x, unsigned int shift, unsigned int addend) { return (x << shift) + addend; } namespace detail { template HIPCUB_DEVICE inline auto unsigned_bit_extract(UnsignedBits source, unsigned int bit_start, unsigned int num_bits) -> typename std::enable_if::type { #ifdef __CUDACC__ return __bitextract_u64(source, bit_start, num_bits); #else return (source << (64 - bit_start - num_bits)) >> (64 - num_bits); #endif // __HIP_PLATFORM_AMD__ } template HIPCUB_DEVICE inline auto unsigned_bit_extract(UnsignedBits source, unsigned int bit_start, unsigned int num_bits) -> typename std::enable_if::type { #ifdef __CUDACC__ return __bitextract_u32(source, bit_start, num_bits); #else return (static_cast(source) << (32 - bit_start - num_bits)) >> (32 - num_bits); #endif // __HIP_PLATFORM_AMD__ } } // end namespace detail // Bitfield-extract. // Extracts \p num_bits from \p source starting at bit-offset \p bit_start. // The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type. template HIPCUB_DEVICE inline unsigned int BFE(UnsignedBits source, unsigned int bit_start, unsigned int num_bits) { static_assert(std::is_unsigned::value, "UnsignedBits must be unsigned"); return detail::unsigned_bit_extract(source, bit_start, num_bits); } // Bitfield insert. // Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start. HIPCUB_DEVICE inline void BFI(unsigned int &ret, unsigned int x, unsigned int y, unsigned int bit_start, unsigned int num_bits) { #ifdef __CUDACC__ ret = __bitinsert_u32(x, y, bit_start, num_bits); #else x <<= bit_start; unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start; unsigned int MASK_Y = ~MASK_X; ret = (y & MASK_Y) | (x & MASK_X); #endif // __HIP_PLATFORM_AMD__ } HIPCUB_DEVICE inline unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z) { return x + y + z; } HIPCUB_DEVICE inline int PRMT(unsigned int a, unsigned int b, unsigned int index) { return ::__byte_perm(a, b, index); } HIPCUB_DEVICE inline void BAR(int count) { (void) count; __syncthreads(); } HIPCUB_DEVICE inline void CTA_SYNC() { __syncthreads(); } HIPCUB_DEVICE inline void WARP_SYNC(unsigned int member_mask) { (void) member_mask; ::rocprim::wave_barrier(); } HIPCUB_DEVICE inline int WARP_ANY(int predicate, uint64_t member_mask) { (void) member_mask; return ::__any(predicate); } HIPCUB_DEVICE inline int WARP_ALL(int predicate, uint64_t member_mask) { (void) member_mask; return ::__all(predicate); } HIPCUB_DEVICE inline int64_t WARP_BALLOT(int predicate, uint64_t member_mask) { (void) member_mask; return __ballot(predicate); } END_HIPCUB_NAMESPACE #endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_