diff --git a/3rdparty/cub b/3rdparty/cub deleted file mode 100644 index c915163a332209baa4de9255160071dd6d6629b6..0000000000000000000000000000000000000000 --- a/3rdparty/cub +++ /dev/null @@ -1 +0,0 @@ -/opt/dtk-23.04/cuda/include/cub \ No newline at end of file diff --git a/3rdparty/cub/block/block_adjacent_difference.cuh b/3rdparty/cub/block/block_adjacent_difference.cuh new file mode 100644 index 0000000000000000000000000000000000000000..337033dba0c666fc536fb5bd0331707cb9d71331 --- /dev/null +++ b/3rdparty/cub/block/block_adjacent_difference.cuh @@ -0,0 +1,303 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + // Trait checks if FlagOp can be called with 3 arguments (a, b, b_index) + template + struct WithBIndexArg + : std::false_type + { }; + + template + struct WithBIndexArg< + T, FlagOp, + typename std::conditional< + true, + void, + decltype(std::declval()(std::declval(), std::declval(), 0)) + >::type + > : std::true_type + { }; + +} + +template< + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockAdjacentDifference + : private ::rocprim::block_adjacent_difference< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_adjacent_difference< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockAdjacentDifference() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockAdjacentDifference(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads(head_flags, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_predecessor_item) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads(head_flags, tile_predecessor_item, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_tails(tail_flags, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_successor_item) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_tails(tail_flags, tile_successor_item, input, flag_op, temp_storage_); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads_and_tails( + head_flags, tail_flags, input, + flag_op, temp_storage_ + ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads_and_tails( + head_flags, tail_flags, tile_successor_item, input, + flag_op, temp_storage_ + ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, input, + flag_op, temp_storage_ + ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + [[deprecated("The Flags API of BlockAdjacentDifference is deprecated.")]] + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") + base_type::flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input, + flag_op, temp_storage_ + ); + HIPCUB_CLANG_SUPPRESS_WARNING_POP + } + + template + HIPCUB_DEVICE inline + void SubtractLeft(T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op) + { + base_type::subtract_left( + input, output, difference_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractLeft(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_predecessor_item) + { + base_type::subtract_left( + input, output, difference_op, tile_predecessor_item, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractLeftPartialTile(T (&input)[ITEMS_PER_THREAD], + OutputType (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items) + { + base_type::subtract_left_partial( + input, output, difference_op, valid_items, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRight(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op) + { + base_type::subtract_right( + input, output, difference_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRight(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + T tile_successor_item) + { + base_type::subtract_right( + input, output, difference_op, tile_successor_item, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void SubtractRightPartialTile(T (&input)[ITEMS_PER_THREAD], + OutputT (&output)[ITEMS_PER_THREAD], + DifferenceOpT difference_op, + int valid_items) + { + base_type::subtract_right_partial( + input, output, difference_op, valid_items, temp_storage_ + ); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ diff --git a/3rdparty/cub/block/block_discontinuity.cuh b/3rdparty/cub/block/block_discontinuity.cuh new file mode 100644 index 0000000000000000000000000000000000000000..33e9ef8ca3cb876d6b14ad2f090d495d5773be5d --- /dev/null +++ b/3rdparty/cub/block/block_discontinuity.cuh @@ -0,0 +1,188 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockDiscontinuity + : private ::rocprim::block_discontinuity< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_discontinuity< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockDiscontinuity() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockDiscontinuity(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + HIPCUB_DEVICE inline + void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_heads(head_flags, input, flag_op, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void FlagHeads(FlagT (&head_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_predecessor_item) + { + base_type::flag_heads(head_flags, tile_predecessor_item, input, flag_op, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_tails(tail_flags, input, flag_op, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void FlagTails(FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op, + T tile_successor_item) + { + base_type::flag_tails(tail_flags, tile_successor_item, input, flag_op, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_heads_and_tails( + head_flags, tail_flags, input, + flag_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_heads_and_tails( + head_flags, tail_flags, tile_successor_item, input, + flag_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, input, + flag_op, temp_storage_ + ); + } + + template + HIPCUB_DEVICE inline + void FlagHeadsAndTails(FlagT (&head_flags)[ITEMS_PER_THREAD], + T tile_predecessor_item, + FlagT (&tail_flags)[ITEMS_PER_THREAD], + T tile_successor_item, + T (&input)[ITEMS_PER_THREAD], + FlagOp flag_op) + { + base_type::flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, tile_successor_item, input, + flag_op, temp_storage_ + ); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ diff --git a/3rdparty/cub/block/block_exchange.cuh b/3rdparty/cub/block/block_exchange.cuh new file mode 100644 index 0000000000000000000000000000000000000000..38792d6ae83b0fcbe0506f04f2e470b0f937a158 --- /dev/null +++ b/3rdparty/cub/block/block_exchange.cuh @@ -0,0 +1,229 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename InputT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + bool WARP_TIME_SLICING = false, /* ignored */ + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockExchange + : private ::rocprim::block_exchange< + InputT, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_exchange< + InputT, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockExchange() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockExchange(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + HIPCUB_DEVICE inline + void StripedToBlocked(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type::striped_to_blocked(input_items, output_items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void BlockedToStriped(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type::blocked_to_striped(input_items, output_items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void WarpStripedToBlocked(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type::warp_striped_to_blocked(input_items, output_items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void BlockedToWarpStriped(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type::blocked_to_warp_striped(input_items, output_items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ScatterToBlocked(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD]) + { + base_type::scatter_to_blocked(input_items, output_items, ranks, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ScatterToStriped(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD]) + { + base_type::scatter_to_striped(input_items, output_items, ranks, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ScatterToStripedGuarded(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD]) + { + base_type::scatter_to_striped_guarded(input_items, output_items, ranks, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ScatterToStripedFlagged(InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD], + ValidFlag (&is_valid)[ITEMS_PER_THREAD]) + { + base_type::scatter_to_striped_flagged(input_items, output_items, ranks, is_valid, temp_storage_); + } + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + + HIPCUB_DEVICE inline void StripedToBlocked( + InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + StripedToBlocked(items, items); + } + + HIPCUB_DEVICE inline void BlockedToStriped( + InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + BlockedToStriped(items, items); + } + + HIPCUB_DEVICE inline void WarpStripedToBlocked( + InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + WarpStripedToBlocked(items, items); + } + + HIPCUB_DEVICE inline void BlockedToWarpStriped( + InputT (&items)[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + { + BlockedToWarpStriped(items, items); + } + + template + HIPCUB_DEVICE inline void ScatterToBlocked( + InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToBlocked(items, items, ranks); + } + + template + HIPCUB_DEVICE inline void ScatterToStriped( + InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToStriped(items, items, ranks); + } + + template + HIPCUB_DEVICE inline void ScatterToStripedGuarded( + InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT (&ranks)[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks + { + ScatterToStripedGuarded(items, items, ranks); + } + + template + HIPCUB_DEVICE inline void ScatterToStripedFlagged( + InputT (&items)[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between striped and blocked arrangements. + OffsetT (&ranks)[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks + ValidFlag (&is_valid)[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity + { + ScatterToStriped(items, items, ranks, is_valid); + } + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ diff --git a/3rdparty/cub/block/block_histogram.cuh b/3rdparty/cub/block/block_histogram.cuh new file mode 100644 index 0000000000000000000000000000000000000000..293dfa6454f8e316c37bdd7a095b49e3ad981388 --- /dev/null +++ b/3rdparty/cub/block/block_histogram.cuh @@ -0,0 +1,147 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ + +#include + +#include + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + inline constexpr + typename std::underlying_type<::rocprim::block_histogram_algorithm>::type + to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm v) + { + using utype = std::underlying_type<::rocprim::block_histogram_algorithm>::type; + return static_cast(v); + } +} + +enum BlockHistogramAlgorithm +{ + BLOCK_HISTO_ATOMIC + = detail::to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm::using_atomic), + BLOCK_HISTO_SORT + = detail::to_BlockHistogramAlgorithm_enum(::rocprim::block_histogram_algorithm::using_sort) +}; + +template< + typename T, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + int BINS, + BlockHistogramAlgorithm ALGORITHM = BLOCK_HISTO_SORT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockHistogram + : private ::rocprim::block_histogram< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + BINS, + static_cast<::rocprim::block_histogram_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_histogram< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + BINS, + static_cast<::rocprim::block_histogram_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockHistogram() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockHistogram(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + HIPCUB_DEVICE inline + void InitHistogram(CounterT histogram[BINS]) + { + base_type::init_histogram(histogram); + } + + template + HIPCUB_DEVICE inline + void Composite(T (&items)[ITEMS_PER_THREAD], + CounterT histogram[BINS]) + { + base_type::composite(items, histogram, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void Histogram(T (&items)[ITEMS_PER_THREAD], + CounterT histogram[BINS]) + { + base_type::init_histogram(histogram); + CTA_SYNC(); + base_type::composite(items, histogram, temp_storage_); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ diff --git a/3rdparty/cub/block/block_load.cuh b/3rdparty/cub/block/block_load.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6c81163010139e211778f951f2eccb5882701167 --- /dev/null +++ b/3rdparty/cub/block/block_load.cuh @@ -0,0 +1,161 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ + +#include + +#include "../config.hpp" + +#include + +#include "block_load_func.cuh" + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + inline constexpr + typename std::underlying_type<::rocprim::block_load_method>::type + to_BlockLoadAlgorithm_enum(::rocprim::block_load_method v) + { + using utype = std::underlying_type<::rocprim::block_load_method>::type; + return static_cast(v); + } +} + +enum BlockLoadAlgorithm +{ + BLOCK_LOAD_DIRECT + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_direct), + BLOCK_LOAD_STRIPED + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_striped), + BLOCK_LOAD_VECTORIZE + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_vectorize), + BLOCK_LOAD_TRANSPOSE + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_transpose), + BLOCK_LOAD_WARP_TRANSPOSE + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_warp_transpose), + BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED + = detail::to_BlockLoadAlgorithm_enum(::rocprim::block_load_method::block_load_warp_transpose) +}; + +template< + typename T, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + BlockLoadAlgorithm ALGORITHM = BLOCK_LOAD_DIRECT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockLoad + : private ::rocprim::block_load< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + static_cast<::rocprim::block_load_method>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_load< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + static_cast<::rocprim::block_load_method>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockLoad() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockLoad(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + HIPCUB_DEVICE inline + void Load(InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) + { + base_type::load(block_iter, items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void Load(InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + base_type::load(block_iter, items, valid_items, temp_storage_); + } + + template< + class InputIteratorT, + class Default + > + HIPCUB_DEVICE inline + void Load(InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items, + Default oob_default) + { + base_type::load(block_iter, items, valid_items, oob_default, temp_storage_); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ diff --git a/3rdparty/cub/block/block_load_func.cuh b/3rdparty/cub/block/block_load_func.cuh new file mode 100644 index 0000000000000000000000000000000000000000..273cf8f2dba3aca00b90b73789441a36eb022280 --- /dev/null +++ b/3rdparty/cub/block/block_load_func.cuh @@ -0,0 +1,205 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectBlocked(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_load_direct_blocked( + linear_id, block_iter, items + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectBlocked(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_load_direct_blocked( + linear_id, block_iter, items, valid_items + ); +} + +template< + typename T, + typename Default, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectBlocked(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items, + Default oob_default) +{ + ::rocprim::block_load_direct_blocked( + linear_id, block_iter, items, valid_items, oob_default + ); +} + +template < + typename T, + int ITEMS_PER_THREAD +> +HIPCUB_DEVICE inline +void LoadDirectBlockedVectorized(int linear_id, + T* block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_load_direct_blocked_vectorized( + linear_id, block_iter, items + ); +} + +template< + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_load_direct_striped( + linear_id, block_iter, items + ); +} + +template< + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_load_direct_striped( + linear_id, block_iter, items, valid_items + ); +} + +template< + int BLOCK_THREADS, + typename T, + typename Default, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items, + Default oob_default) +{ + ::rocprim::block_load_direct_striped( + linear_id, block_iter, items, valid_items, oob_default + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectWarpStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_load_direct_warp_striped( + linear_id, block_iter, items + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectWarpStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_load_direct_warp_striped( + linear_id, block_iter, items, valid_items + ); +} + +template< + typename T, + typename Default, + int ITEMS_PER_THREAD, + typename InputIteratorT +> +HIPCUB_DEVICE inline +void LoadDirectWarpStriped(int linear_id, + InputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items, + Default oob_default) +{ + ::rocprim::block_load_direct_warp_striped( + linear_id, block_iter, items, valid_items, oob_default + ); +} + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ diff --git a/3rdparty/cub/block/block_merge_sort.hpp b/3rdparty/cub/block/block_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..50f101bfa6e898aa401c27919712ce103ba29f68 --- /dev/null +++ b/3rdparty/cub/block/block_merge_sort.hpp @@ -0,0 +1,808 @@ +/****************************************************************************** +* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. +* Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in the +* documentation and/or other materials provided with the distribution. +* * Neither the name of the NVIDIA CORPORATION nor the +* names of its contributors may be used to endorse or promote products +* derived from this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_ + +#include "../thread/thread_sort.hpp" +#include "../util_math.cuh" +#include "../util_type.cuh" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + + +// Additional details of the Merge-Path Algorithm can be found in: +// S. Odeh, O. Green, Z. Mwassi, O. Shmueli, Y. Birk, " Merge Path - Parallel +// Merging Made Simple", Multithreaded Architectures and Applications (MTAAP) +// Workshop, IEEE 26th International Parallel & Distributed Processing +// Symposium (IPDPS), 2012 +template +HIPCUB_DEVICE __forceinline__ OffsetT MergePath(KeyIteratorT keys1, + KeyIteratorT keys2, + OffsetT keys1_count, + OffsetT keys2_count, + OffsetT diag, + BinaryPred binary_pred) +{ + OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count; + OffsetT keys1_end = (::rocprim::min)(diag, keys1_count); + + while (keys1_begin < keys1_end) + { + OffsetT mid = cub::MidPoint(keys1_begin, keys1_end); + KeyT key1 = keys1[mid]; + KeyT key2 = keys2[diag - 1 - mid]; + bool pred = binary_pred(key2, key1); + + if (pred) + { + keys1_end = mid; + } + else + { + keys1_begin = mid + 1; + } + } + return keys1_begin; +} + +template +HIPCUB_DEVICE __forceinline__ void SerialMerge(KeyT *keys_shared, + int keys1_beg, + int keys2_beg, + int keys1_count, + int keys2_count, + KeyT (&output)[ITEMS_PER_THREAD], + int (&indices)[ITEMS_PER_THREAD], + CompareOp compare_op) +{ + int keys1_end = keys1_beg + keys1_count; + int keys2_end = keys2_beg + keys2_count; + + KeyT key1 = keys_shared[keys1_beg]; + KeyT key2 = keys_shared[keys2_beg]; + +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + bool p = (keys2_beg < keys2_end) && + ((keys1_beg >= keys1_end) + || compare_op(key2, key1)); + + output[item] = p ? key2 : key1; + indices[item] = p ? keys2_beg++ : keys1_beg++; + + if (p) + { + key2 = keys_shared[keys2_beg]; + } + else + { + key1 = keys_shared[keys1_beg]; + } + } +} + +/** + * @brief Generalized merge sort algorithm + * + * This class is used to reduce code duplication. Warp and Block merge sort + * differ only in how they compute thread index and how they synchronize + * threads. Since synchronization might require access to custom data + * (like member mask), CRTP is used. + * + * @par + * The code snippet below illustrates the way this class can be used. + * @par + * @code + * #include // or equivalently + * + * constexpr int BLOCK_THREADS = 256; + * constexpr int ITEMS_PER_THREAD = 9; + * + * class BlockMergeSort : public BlockMergeSortStrategy + * { + * using BlockMergeSortStrategyT = + * BlockMergeSortStrategy; + * public: + * __device__ __forceinline__ explicit BlockMergeSort( + * typename BlockMergeSortStrategyT::TempStorage &temp_storage) + * : BlockMergeSortStrategyT(temp_storage, threadIdx.x) + * {} + * + * __device__ __forceinline__ void SyncImplementation() const + * { + * __syncthreads(); + * } + * }; + * @endcode + * + * @tparam KeyT + * KeyT type + * + * @tparam ValueT + * ValueT type. cub::NullType indicates a keys-only sort + * + * @tparam SynchronizationPolicy + * Provides a way of synchronizing threads. Should be derived from + * `BlockMergeSortStrategy`. + */ +template +class BlockMergeSortStrategy +{ + static_assert(PowerOfTwo::VALUE, + "NUM_THREADS must be a power of two"); + +private: + + static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS; + + // Whether or not there are values to be trucked along with keys + static constexpr bool KEYS_ONLY = ::rocprim::Equals::VALUE; + + /// Shared memory type required by this thread block + union _TempStorage + { + KeyT keys_shared[ITEMS_PER_TILE + 1]; + ValueT items_shared[ITEMS_PER_TILE + 1]; + }; // union TempStorage + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Internal storage allocator + HIPCUB_DEVICE __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + const unsigned int linear_tid; + +public: + /// \smemstorage{BlockMergeSort} + struct TempStorage : Uninitialized<_TempStorage> {}; + + BlockMergeSortStrategy() = delete; + explicit HIPCUB_DEVICE __forceinline__ + BlockMergeSortStrategy(unsigned int linear_tid) + : temp_storage(PrivateStorage()) + , linear_tid(linear_tid) + {} + + HIPCUB_DEVICE __forceinline__ BlockMergeSortStrategy(TempStorage &temp_storage, + unsigned int linear_tid) + : temp_storage(temp_storage.Alias()) + , linear_tid(linear_tid) + {} + + HIPCUB_DEVICE __forceinline__ unsigned int get_linear_tid() const + { + return linear_tid; + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * Sort is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + CompareOp compare_op) + { + ValueT items[ITEMS_PER_THREAD]; + Sort(keys, items, compare_op, ITEMS_PER_TILE, keys[0]); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * - Sort is not guaranteed to be stable. That is, suppose that `i` and `j` + * are equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - The value of `oob_default` is assigned to all elements that are out of + * `valid_items` boundaries. It's expected that `oob_default` is ordered + * after any value in the `valid_items` boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to + * `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered + * after `oob_default`, it won't be placed within `valid_items` boundaries. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * @param[in] valid_items + * Number of valid items to sort + * + * @param[in] oob_default + * Default value to assign out-of-bound items + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + CompareOp compare_op, + int valid_items, + KeyT oob_default) + { + ValueT items[ITEMS_PER_THREAD]; + Sort(keys, items, compare_op, valid_items, oob_default); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * @par + * Sort is not guaranteed to be stable. That is, suppose that `i` and `j` are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in,out] items + * Values to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op) + { + Sort(keys, items, compare_op, ITEMS_PER_TILE, keys[0]); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * - Sort is not guaranteed to be stable. That is, suppose that `i` and `j` + * are equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - The value of `oob_default` is assigned to all elements that are out of + * `valid_items` boundaries. It's expected that `oob_default` is ordered + * after any value in the `valid_items` boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to + * `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered + * after `oob_default`, it won't be placed within `valid_items` boundaries. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)` + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @tparam IS_LAST_TILE + * True if `valid_items` isn't equal to the `ITEMS_PER_TILE` + * + * @param[in,out] keys + * Keys to sort + * + * @param[in,out] items + * Values to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * @param[in] valid_items + * Number of valid items to sort + * + * @param[in] oob_default + * Default value to assign out-of-bound items + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op, + int valid_items, + KeyT oob_default) + { + if (IS_LAST_TILE) + { + // if last tile, find valid max_key + // and fill the remaining keys with it + // + KeyT max_key = oob_default; + + #pragma unroll + for (int item = 1; item < ITEMS_PER_THREAD; ++item) + { + if (ITEMS_PER_THREAD * static_cast(linear_tid) + item < valid_items) + { + max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key; + } + else + { + keys[item] = max_key; + } + } + } + + // if first element of thread is in input range, stable sort items + // + if (!IS_LAST_TILE || ITEMS_PER_THREAD * static_cast(linear_tid) < valid_items) + { + StableOddEvenSort(keys, items, compare_op); + } + + // each thread has sorted keys + // merge sort keys in shared memory + // + #pragma unroll + for (int target_merged_threads_number = 2; + target_merged_threads_number <= NUM_THREADS; + target_merged_threads_number *= 2) + { + int merged_threads_number = target_merged_threads_number / 2; + int mask = target_merged_threads_number - 1; + + Sync(); + + // store keys in shmem + // + #pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = ITEMS_PER_THREAD * linear_tid + item; + temp_storage.keys_shared[idx] = keys[item]; + } + + Sync(); + + int indices[ITEMS_PER_THREAD]; + + int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid; + int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged; + int size = ITEMS_PER_THREAD * merged_threads_number; + + int thread_idx_in_thread_group_being_merged = mask & linear_tid; + + int diag = + (::rocprim::min)(valid_items, + ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); + + int keys1_beg = (::rocprim::min)(valid_items, start); + int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size); + int keys2_beg = keys1_end; + int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size); + + int keys1_count = keys1_end - keys1_beg; + int keys2_count = keys2_end - keys2_beg; + + int partition_diag = MergePath(&temp_storage.keys_shared[keys1_beg], + &temp_storage.keys_shared[keys2_beg], + keys1_count, + keys2_count, + diag, + compare_op); + + int keys1_beg_loc = keys1_beg + partition_diag; + int keys1_end_loc = keys1_end; + int keys2_beg_loc = keys2_beg + diag - partition_diag; + int keys2_end_loc = keys2_end; + int keys1_count_loc = keys1_end_loc - keys1_beg_loc; + int keys2_count_loc = keys2_end_loc - keys2_beg_loc; + SerialMerge(&temp_storage.keys_shared[0], + keys1_beg_loc, + keys2_beg_loc, + keys1_count_loc, + keys2_count_loc, + keys, + indices, + compare_op); + + if (!KEYS_ONLY) + { + Sync(); + + // store keys in shmem + // + #pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = ITEMS_PER_THREAD * linear_tid + item; + temp_storage.items_shared[idx] = items[item]; + } + + Sync(); + + // gather items from shmem + // + #pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + items[item] = temp_storage.items_shared[indices[item]]; + } + } + } + } // func block_merge_sort + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if `x` and `y` are elements such that `x` precedes `y`, + * and if the two elements are equivalent (neither `x < y` nor `y < x`) then + * a postcondition of StableSort is that `x` still precedes `y`. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD], + CompareOp compare_op) + { + Sort(keys, compare_op); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if `x` and `y` are elements such that `x` precedes `y`, + * and if the two elements are equivalent (neither `x < y` nor `y < x`) then + * a postcondition of StableSort is that `x` still precedes `y`. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in,out] items + * Values to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op) + { + Sort(keys, items, compare_op); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if `x` and `y` are elements such that `x` precedes + * `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) + * then a postcondition of StableSort is that `x` still precedes `y`. + * - The value of `oob_default` is assigned to all elements that are out of + * `valid_items` boundaries. It's expected that `oob_default` is ordered + * after any value in the `valid_items` boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to + * `ITEMS_PER_THREAD * BLOCK_THREADS`. + * If there is a value that is ordered after `oob_default`, it won't be + * placed within `valid_items` boundaries. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @param[in,out] keys + * Keys to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * @param[in] valid_items + * Number of valid items to sort + * + * @param[in] oob_default + * Default value to assign out-of-bound items + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD], + CompareOp compare_op, + int valid_items, + KeyT oob_default) + { + Sort(keys, compare_op, valid_items, oob_default); + } + + /** + * @brief Sorts items partitioned across a CUDA thread block using + * a merge sorting method. + * + * @par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if `x` and `y` are elements such that `x` precedes + * `y`, and if the two elements are equivalent (neither `x < y` nor `y < x`) + * then a postcondition of StableSort is that `x` still precedes `y`. + * - The value of `oob_default` is assigned to all elements that are out of + * `valid_items` boundaries. It's expected that `oob_default` is ordered + * after any value in the `valid_items` boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to + * `ITEMS_PER_THREAD * BLOCK_THREADS`. If there is a value that is ordered + * after `oob_default`, it won't be placed within `valid_items` boundaries. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)`. + * `CompareOp` is a model of [Strict Weak Ordering]. + * + * @tparam IS_LAST_TILE + * True if `valid_items` isn't equal to the `ITEMS_PER_TILE` + * + * @param[in,out] keys + * Keys to sort + * + * @param[in,out] items + * Values to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + * + * @param[in] valid_items + * Number of valid items to sort + * + * @param[in] oob_default + * Default value to assign out-of-bound items + * + * [Strict Weak Ordering]: https://en.cppreference.com/w/cpp/concepts/strict_weak_order + */ + template + HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op, + int valid_items, + KeyT oob_default) + { + Sort(keys, + items, + compare_op, + valid_items, + oob_default); + } + +private: + HIPCUB_DEVICE __forceinline__ void Sync() const + { + static_cast(this)->SyncImplementation(); + } +}; + + +/** + * @brief The BlockMergeSort class provides methods for sorting items + * partitioned across a CUDA thread block using a merge sorting method. + * @ingroup BlockModule + * + * @tparam KeyT + * KeyT type + * + * @tparam BLOCK_DIM_X + * The thread block length in threads along the X dimension + * + * @tparam ITEMS_PER_THREAD + * The number of items per thread + * + * @tparam ValueT + * **[optional]** ValueT type (default: `cub::NullType`, which indicates + * a keys-only sort) + * + * @tparam BLOCK_DIM_Y + * **[optional]** The thread block length in threads along the Y dimension + * (default: 1) + * + * @tparam BLOCK_DIM_Z + * **[optional]** The thread block length in threads along the Z dimension + * (default: 1) + * + * @par Overview + * BlockMergeSort arranges items into ascending order using a comparison + * functor with less-than semantics. Merge sort can handle arbitrary types + * and comparison functors, but is slower than BlockRadixSort when sorting + * arithmetic types into ascending/descending order. + * + * @par A Simple Example + * @blockcollective{BlockMergeSort} + * @par + * The code snippet below illustrates a sort of 512 integer keys that are + * partitioned across 128 threads * where each thread owns 4 consecutive items. + * @par + * @code + * #include // or equivalently + * + * struct CustomLess + * { + * template + * __device__ bool operator()(const DataType &lhs, const DataType &rhs) + * { + * return lhs < rhs; + * } + * }; + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockMergeSort for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockMergeSort BlockMergeSort; + * + * // Allocate shared memory for BlockMergeSort + * __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * BlockMergeSort(temp_storage_shuffle).Sort(thread_keys, CustomLess()); + * ... + * } + * @endcode + * @par + * Suppose the set of input `thread_keys` across the block of threads is + * `{ [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }`. + * The corresponding output `thread_keys` in those threads will be + * `{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }`. + * + * @par Re-using dynamically allocating shared memory + * The following example under the examples/block folder illustrates usage of + * dynamically shared memory with BlockReduce and how to re-purpose + * the same memory region: + * example_block_reduce_dyn_smem.cu + * + * This example can be easily adapted to the storage required by BlockMergeSort. + */ +template +class BlockMergeSort + : public BlockMergeSortStrategy> +{ +private: + // The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z; + static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS; + + using BlockMergeSortStrategyT = + BlockMergeSortStrategy; + +public: + HIPCUB_DEVICE __forceinline__ BlockMergeSort() + : BlockMergeSortStrategyT( + RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + HIPCUB_DEVICE __forceinline__ explicit BlockMergeSort( + typename BlockMergeSortStrategyT::TempStorage &temp_storage) + : BlockMergeSortStrategyT( + temp_storage, + RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + +private: + HIPCUB_DEVICE __forceinline__ void SyncImplementation() const + { + CTA_SYNC(); + } + + friend BlockMergeSortStrategyT; +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_ diff --git a/3rdparty/cub/block/block_radix_rank.cuh b/3rdparty/cub/block/block_radix_rank.cuh new file mode 100644 index 0000000000000000000000000000000000000000..99001958d045f10593aaf8a79d7a6298bd7a0c7a --- /dev/null +++ b/3rdparty/cub/block/block_radix_rank.cuh @@ -0,0 +1,703 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block + */ + + #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_ + #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_ + +#include + +#include "../config.hpp" +#include "../util_type.cuh" +#include "../util_ptx.cuh" + +#include "../thread/thread_reduce.cuh" +#include "../thread/thread_scan.cuh" +#include "../block/block_scan.cuh" +#include "../block/radix_rank_sort_operations.hpp" + +BEGIN_HIPCUB_NAMESPACE + + + +/** + * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. + * \ingroup BlockModule + * + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam RADIX_BITS The number of radix bits per digit place + * \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low + * \tparam MEMOIZE_OUTER_SCAN [optional] Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details. + * \tparam INNER_SCAN_ALGORITHM [optional] The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) + * \tparam SMEM_CONFIG [optional] Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * \tparam ARCH [optional] \ptxversion + * + * \par Overview + * Blah... + * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits). + * - \blocked + * + * \par Performance Considerations + * - \granularity + * + * \par Examples + * \par + * - Example 1: Simple radix rank of 32-bit integer keys + * \code + * #include + * + * template + * __global__ void ExampleKernel(...) + * { + * + * \endcode + */ +template < + int BLOCK_DIM_X, + int RADIX_BITS, + bool IS_DESCENDING, + bool MEMOIZE_OUTER_SCAN = false, + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, + cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */> +class BlockRadixRank +{ +private: + + /****************************************************************************** + * Type definitions and constants + ******************************************************************************/ + + // Integer type for digit counters (to be packed into words of type PackedCounters) + typedef unsigned short DigitCounter; + + // Integer type for packing DigitCounters into columns of shared memory banks + typedef typename std::conditional<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte), + unsigned long long, + unsigned int>::type PackedCounter; + + enum + { + // The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + RADIX_DIGITS = 1 << RADIX_BITS, + + LOG_WARP_THREADS = Log2::VALUE, + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + BYTES_PER_COUNTER = sizeof(DigitCounter), + LOG_BYTES_PER_COUNTER = Log2::VALUE, + + PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter), + LOG_PACKING_RATIO = Log2::VALUE, + + LOG_COUNTER_LANES = rocprim::maximum()((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0), // Always at least one lane + COUNTER_LANES = 1 << LOG_COUNTER_LANES, + + // The number of packed counters per thread (plus one for padding) + PADDED_COUNTER_LANES = COUNTER_LANES + 1, + RAKING_SEGMENT = PADDED_COUNTER_LANES, + }; + +public: + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = rocprim::maximum()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), + }; + +private: + + + /// BlockScan type + typedef BlockScan< + PackedCounter, + BLOCK_DIM_X, + INNER_SCAN_ALGORITHM, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + ARCH> + BlockScan; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// Shared memory storage layout type for BlockRadixRank + struct __align__(16) _TempStorage + { + union Aliasable + { + DigitCounter digit_counters[PADDED_COUNTER_LANES * BLOCK_THREADS * PACKING_RATIO]; + PackedCounter raking_grid[BLOCK_THREADS * RAKING_SEGMENT]; + + } aliasable; + + // Storage for scanning local ranks + typename BlockScan::TempStorage block_scan; + }; + +#endif + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + /// Copy of raking segment, promoted to registers + PackedCounter cached_segment[RAKING_SEGMENT]; + + + /****************************************************************************** + * Utility methods + ******************************************************************************/ + + /** + * Internal storage allocator + */ + HIPCUB_DEVICE inline _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + + /** + * Performs upsweep raking reduction, returning the aggregate + */ + HIPCUB_DEVICE inline PackedCounter Upsweep() + { + PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT]; + PackedCounter *raking_ptr; + + if (MEMOIZE_OUTER_SCAN) + { + // Copy data into registers + #pragma unroll + for (int i = 0; i < RAKING_SEGMENT; i++) + { + cached_segment[i] = smem_raking_ptr[i]; + } + raking_ptr = cached_segment; + } + else + { + raking_ptr = smem_raking_ptr; + } + + return internal::ThreadReduce(raking_ptr, Sum()); + } + + + /// Performs exclusive downsweep raking scan + HIPCUB_DEVICE inline void ExclusiveDownsweep( + PackedCounter raking_partial) + { + PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT]; + + PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ? + cached_segment : + smem_raking_ptr; + + // Exclusive raking downsweep scan + internal::ThreadScanExclusive(raking_ptr, raking_ptr, Sum(), raking_partial); + + if (MEMOIZE_OUTER_SCAN) + { + // Copy data back to smem + #pragma unroll + for (int i = 0; i < RAKING_SEGMENT; i++) + { + smem_raking_ptr[i] = cached_segment[i]; + } + } + } + + + /** + * Reset shared memory digit counters + */ + HIPCUB_DEVICE inline void ResetCounters() + { + // Reset shared memory digit counters + #pragma unroll + for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++) + { + #pragma unroll + for (int SUB_COUNTER = 0; SUB_COUNTER < PACKING_RATIO; SUB_COUNTER++) + { + temp_storage.aliasable.digit_counters[(LANE * BLOCK_THREADS + linear_tid) * PACKING_RATIO + SUB_COUNTER] = 0; + } + } + } + + + /** + * Block-scan prefix callback + */ + struct PrefixCallBack + { + HIPCUB_DEVICE inline PackedCounter operator()(PackedCounter block_aggregate) + { + PackedCounter block_prefix = 0; + + // Propagate totals in packed fields + #pragma unroll + for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++) + { + block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED); + } + + return block_prefix; + } + }; + + + /** + * Scan shared memory digit counters. + */ + HIPCUB_DEVICE inline void ScanCounters() + { + // Upsweep scan + PackedCounter raking_partial = Upsweep(); + + // Compute exclusive sum + PackedCounter exclusive_partial; + PrefixCallBack prefix_call_back; + BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back); + + // Downsweep scan with exclusive partial + ExclusiveDownsweep(exclusive_partial); + } + +public: + + /// \smemstorage{BlockScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + /** + * \brief Collective constructor using a private static allocation of shared memory as temporary storage. + */ + HIPCUB_DEVICE inline BlockRadixRank() + : + temp_storage(PrivateStorage()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + HIPCUB_DEVICE inline BlockRadixRank( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Raking + *********************************************************************/ + //@{ + + /** + * \brief Rank keys. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD, + typename DigitExtractorT> + HIPCUB_DEVICE inline void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile + DigitExtractorT digit_extractor) ///< [in] The digit extractor + { + DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit + DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem + + // Reset shared memory digit counters + ResetCounters(); + + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // Get digit + unsigned int digit = digit_extractor.Digit(keys[ITEM]); + + // Get sub-counter + unsigned int sub_counter = digit >> LOG_COUNTER_LANES; + + // Get counter lane + unsigned int counter_lane = digit & (COUNTER_LANES - 1); + + if (IS_DESCENDING) + { + sub_counter = PACKING_RATIO - 1 - sub_counter; + counter_lane = COUNTER_LANES - 1 - counter_lane; + } + + // Pointer to smem digit counter + digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane * BLOCK_THREADS * PACKING_RATIO + linear_tid * PACKING_RATIO + sub_counter]; + + // Load thread-exclusive prefix + thread_prefixes[ITEM] = *digit_counters[ITEM]; + + // Store inclusive prefix + *digit_counters[ITEM] = thread_prefixes[ITEM] + 1; + } + + ::rocprim::syncthreads(); + + // Scan shared memory counters + ScanCounters(); + + ::rocprim::syncthreads(); + + // Extract the local ranks of each key + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // Add in thread block exclusive prefix + ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM]; + } + } + + + /** + * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD, + typename DigitExtractorT> + HIPCUB_DEVICE inline void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) + DigitExtractorT digit_extractor, ///< [in] The digit extractor + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + { + // Rank keys + RankKeys(keys, ranks, digit_extractor); + + // Get the inclusive and exclusive digit totals corresponding to the calling thread. + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + // Obtain ex/inclusive digit counts. (Unfortunately these all reside in the + // first counter column, resulting in unavoidable bank conflicts.) + unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1)); + unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES); + + exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counter[counter_lane * BLOCK_THREADS * PACKING_RATIO + sub_counter]; + } + } + } +}; + + + + + +/** + * Radix-rank using match.any + */ +template < + int BLOCK_DIM_X, + int RADIX_BITS, + bool IS_DESCENDING, + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH> +class BlockRadixRankMatch +{ +private: + + /****************************************************************************** + * Type definitions and constants + ******************************************************************************/ + + typedef int32_t RankT; + typedef int32_t DigitCounterT; + + enum + { + // The thread block size in threads + BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, + + RADIX_DIGITS = 1 << RADIX_BITS, + + LOG_WARP_THREADS = Log2::VALUE, + WARP_THREADS = 1 << LOG_WARP_THREADS, + WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, + + PADDED_WARPS = ((WARPS & 0x1) == 0) ? + WARPS + 1 : + WARPS, + + COUNTERS = PADDED_WARPS * RADIX_DIGITS, + RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS, + PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? + RAKING_SEGMENT + 1 : + RAKING_SEGMENT, + }; + +public: + + enum + { + /// Number of bin-starting offsets tracked per thread + BINS_TRACKED_PER_THREAD = rocprim::maximum()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), + }; + +private: + + /// BlockScan type + typedef BlockScan< + DigitCounterT, + BLOCK_THREADS, + INNER_SCAN_ALGORITHM, + BLOCK_DIM_Y, + BLOCK_DIM_Z, + ARCH> + BlockScanT; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + /// Shared memory storage layout type for BlockRadixRank + struct __align__(16) _TempStorage + { + typename BlockScanT::TempStorage block_scan; + + union __align__(16) Aliasable + { + volatile DigitCounterT warp_digit_counters[RADIX_DIGITS * PADDED_WARPS]; + DigitCounterT raking_grid[BLOCK_THREADS * PADDED_RAKING_SEGMENT]; + + } aliasable; + }; +#endif + + /****************************************************************************** + * Thread fields + ******************************************************************************/ + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + + + +public: + + /// \smemstorage{BlockScan} + struct TempStorage : Uninitialized<_TempStorage> {}; + + + /******************************************************************//** + * \name Collective constructors + *********************************************************************/ + //@{ + + + /** + * \brief Collective constructor using the specified memory allocation as temporary storage. + */ + HIPCUB_DEVICE inline BlockRadixRankMatch( + TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : + temp_storage(temp_storage.Alias()), + linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + + //@} end member group + /******************************************************************//** + * \name Raking + *********************************************************************/ + //@{ + + /** + * \brief Rank keys. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD, + typename DigitExtractorT> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile + DigitExtractorT digit_extractor) ///< [in] The digit extractor + { + // Initialize shared digit counters + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = 0; + + ::rocprim::syncthreads(); + + // Each warp will strip-mine its section of input, one strip at a time + + volatile DigitCounterT *digit_counters[KEYS_PER_THREAD]; + uint32_t warp_id = linear_tid >> LOG_WARP_THREADS; + uint32_t lane_mask_lt = LaneMaskLt(); + + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + { + // My digit + uint32_t digit = digit_extractor.Digit(keys[ITEM]); + + if (IS_DESCENDING) + digit = RADIX_DIGITS - digit - 1; + + // Mask of peers who have same digit as me + uint32_t peer_mask = rocprim::MatchAny(digit); + + // Pointer to smem digit counter for this key + digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit * PADDED_WARPS + warp_id]; + + // Number of occurrences in previous strips + DigitCounterT warp_digit_prefix = *digit_counters[ITEM]; + + // Warp-sync + WARP_SYNC(0xFFFFFFFF); + + // Number of peers having same digit as me + int32_t digit_count = __popc(peer_mask); + + // Number of lower-ranked peers having same digit seen so far + int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt); + + if (peer_digit_prefix == 0) + { + // First thread for each digit updates the shared warp counter + *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count); + } + + // Warp-sync + WARP_SYNC(0xFFFFFFFF); + + // Number of prior keys having same digit + ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix); + } + + ::rocprim::syncthreads(); + + // Scan warp counters + + DigitCounterT scan_counters[PADDED_RAKING_SEGMENT]; + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM]; + + BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters); + + #pragma unroll + for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) + temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = scan_counters[ITEM]; + + ::rocprim::syncthreads(); + + // Seed ranks with counter values from previous warps + #pragma unroll + for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) + ranks[ITEM] += *digit_counters[ITEM]; + } + + + /** + * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. + */ + template < + typename UnsignedBits, + int KEYS_PER_THREAD, + typename DigitExtractorT> + __device__ __forceinline__ void RankKeys( + UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile + int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) + DigitExtractorT digit_extractor, ///< [in] The digit extractor + int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] + { + RankKeys(keys, ranks, digit_extractor); + + // Get exclusive count for each digit + #pragma unroll + for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) + { + int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; + + if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) + { + if (IS_DESCENDING) + bin_idx = RADIX_DIGITS - bin_idx - 1; + + exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx * PADDED_WARPS]; + } + } + } +}; + + + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_ diff --git a/3rdparty/cub/block/block_radix_sort.cuh b/3rdparty/cub/block/block_radix_sort.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e58257068517d13b5968bf6ca3483e05b28d28fd --- /dev/null +++ b/3rdparty/cub/block/block_radix_sort.cuh @@ -0,0 +1,177 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include +#include + +#include "block_scan.cuh" + +BEGIN_HIPCUB_NAMESPACE + +template< + typename KeyT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + typename ValueT = NullType, + int RADIX_BITS = 4, /* ignored */ + bool MEMOIZE_OUTER_SCAN = true, /* ignored */ + BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, /* ignored */ + cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, /* ignored */ + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int PTX_ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockRadixSort + : private ::rocprim::block_radix_sort< + KeyT, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + ValueT, + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_radix_sort< + KeyT, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + ValueT, + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockRadixSort() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockRadixSort(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + HIPCUB_DEVICE inline + void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort(keys, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void Sort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&values)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort(keys, values, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_desc(keys, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&values)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_desc(keys, values, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_to_striped(keys, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&values)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_to_striped(keys, values, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_desc_to_striped(keys, temp_storage_, begin_bit, end_bit); + } + + HIPCUB_DEVICE inline + void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&values)[ITEMS_PER_THREAD], + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8) + { + base_type::sort_desc_to_striped(keys, values, temp_storage_, begin_bit, end_bit); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/block/block_raking_layout.cuh b/3rdparty/cub/block/block_raking_layout.cuh new file mode 100644 index 0000000000000000000000000000000000000000..40892b669078ac9f040239a074aea10959978c5b --- /dev/null +++ b/3rdparty/cub/block/block_raking_layout.cuh @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::BlockRakingLayout provides a conflict-free shared memory layout abstraction for warp-raking across thread block data. + */ + + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_ + +#include + +#include "../config.hpp" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +/** + * \brief BlockRakingLayout provides a conflict-free shared memory layout abstraction for 1D raking across thread block data. + * \ingroup BlockModule + * + * \par Overview + * This type facilitates a shared memory usage pattern where a block of CUDA + * threads places elements into shared memory and then reduces the active + * parallelism to one "raking" warp of threads for serially aggregating consecutive + * sequences of shared items. Padding is inserted to eliminate bank conflicts + * (for most data types). + * + * \tparam T The data type to be exchanged. + * \tparam BLOCK_THREADS The thread block size in threads. + * \tparam PTX_ARCH [optional] \ptxversion + */ +template < + typename T, + int BLOCK_THREADS, + int ARCH = HIPCUB_ARCH /* ignored */ +> +struct block_raking_layout +{ + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + enum + { + /// The total number of elements that need to be cooperatively reduced + SHARED_ELEMENTS = BLOCK_THREADS, + + /// Maximum number of warp-synchronous raking threads + MAX_RAKING_THREADS = ::rocprim::detail::get_min_warp_size(BLOCK_THREADS, HIPCUB_DEVICE_WARP_THREADS), + + /// Number of raking elements per warp-synchronous raking thread (rounded up) + SEGMENT_LENGTH = (SHARED_ELEMENTS + MAX_RAKING_THREADS - 1) / MAX_RAKING_THREADS, + + /// Never use a raking thread that will have no valid data (e.g., when BLOCK_THREADS is 62 and SEGMENT_LENGTH is 2, we should only use 31 raking threads) + RAKING_THREADS = (SHARED_ELEMENTS + SEGMENT_LENGTH - 1) / SEGMENT_LENGTH, + + /// Pad each segment length with one element if segment length is not relatively prime to warp size and can't be optimized as a vector load + USE_SEGMENT_PADDING = ((SEGMENT_LENGTH & 1) == 0) && (SEGMENT_LENGTH > 2), + + /// Total number of elements in the raking grid + GRID_ELEMENTS = RAKING_THREADS * (SEGMENT_LENGTH + USE_SEGMENT_PADDING), + + /// Whether or not we need bounds checking during raking (the number of reduction elements is not a multiple of the number of raking threads) + UNGUARDED = (SHARED_ELEMENTS % RAKING_THREADS == 0), + }; + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + /** + * \brief Shared memory storage type + */ + struct __align__(16) _TempStorage + { + T buff[BlockRakingLayout::GRID_ELEMENTS]; + }; + +#endif + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + /** + * \brief Returns the location for the calling thread to place data into the grid + */ + static HIPCUB_DEVICE inline T* PlacementPtr( + TempStorage &temp_storage, + unsigned int linear_tid) + { + // Offset for partial + unsigned int offset = linear_tid; + + // Add in one padding element for every segment + if (USE_SEGMENT_PADDING > 0) + { + offset += offset / SEGMENT_LENGTH; + } + + // Incorporating a block of padding partials every shared memory segment + return temp_storage.Alias().buff + offset; + } + + /** + * \brief Returns the location for the calling thread to begin sequential raking + */ + static HIPCUB_DEVICE inline T* RakingPtr( + TempStorage &temp_storage, + unsigned int linear_tid) + { + return temp_storage.Alias().buff + (linear_tid * (SEGMENT_LENGTH + USE_SEGMENT_PADDING)); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RAKING_LAYOUT_HPP_ diff --git a/3rdparty/cub/block/block_reduce.cuh b/3rdparty/cub/block/block_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..27ae1c6522a636ad20699c80c4b95bbff291ad94 --- /dev/null +++ b/3rdparty/cub/block/block_reduce.cuh @@ -0,0 +1,166 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ + +#include + +#include + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + inline constexpr + typename std::underlying_type<::rocprim::block_reduce_algorithm>::type + to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm v) + { + using utype = std::underlying_type<::rocprim::block_reduce_algorithm>::type; + return static_cast(v); + } +} + +enum BlockReduceAlgorithm +{ + BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + = detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::raking_reduce_commutative_only), + BLOCK_REDUCE_RAKING + = detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::raking_reduce), + BLOCK_REDUCE_WARP_REDUCTIONS + = detail::to_BlockReduceAlgorithm_enum(::rocprim::block_reduce_algorithm::using_warp_reduce) +}; + +template< + typename T, + int BLOCK_DIM_X, + BlockReduceAlgorithm ALGORITHM = BLOCK_REDUCE_WARP_REDUCTIONS, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockReduce + : private ::rocprim::block_reduce< + T, + BLOCK_DIM_X, + static_cast<::rocprim::block_reduce_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_reduce< + T, + BLOCK_DIM_X, + static_cast<::rocprim::block_reduce_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockReduce() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockReduce(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + HIPCUB_DEVICE inline + T Sum(T input) + { + base_type::reduce(input, input, temp_storage_); + return input; + } + + HIPCUB_DEVICE inline + T Sum(T input, int valid_items) + { + base_type::reduce(input, input, valid_items, temp_storage_); + return input; + } + + template + HIPCUB_DEVICE inline + T Sum(T(&input)[ITEMS_PER_THREAD]) + { + T output; + base_type::reduce(input, output, temp_storage_); + return output; + } + + template + HIPCUB_DEVICE inline + T Reduce(T input, ReduceOp reduce_op) + { + base_type::reduce(input, input, temp_storage_, reduce_op); + return input; + } + + template + HIPCUB_DEVICE inline + T Reduce(T input, ReduceOp reduce_op, int valid_items) + { + base_type::reduce(input, input, valid_items, temp_storage_, reduce_op); + return input; + } + + template + HIPCUB_DEVICE inline + T Reduce(T(&input)[ITEMS_PER_THREAD], ReduceOp reduce_op) + { + T output; + base_type::reduce(input, output, temp_storage_, reduce_op); + return output; + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ diff --git a/3rdparty/cub/block/block_run_length_decode.hpp b/3rdparty/cub/block/block_run_length_decode.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb485b79f3b031b0016f938cb24e82bc67aaf269 --- /dev/null +++ b/3rdparty/cub/block/block_run_length_decode.hpp @@ -0,0 +1,393 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RUN_LENGTH_DECODE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_RUN_LENGTH_DECODE_HPP_ + +#include "../config.hpp" +#include "../thread/thread_search.cuh" +#include "../util_math.cuh" +#include "../util_ptx.cuh" +#include "../util_type.cuh" +#include "block_scan.cuh" +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +/** + * \brief The BlockRunLengthDecode class supports decoding a run-length encoded array of items. That is, given + * the two arrays run_value[N] and run_lengths[N], run_value[i] is repeated run_lengths[i] many times in the output + * array. + * Due to the nature of the run-length decoding algorithm ("decompression"), the output size of the run-length decoded + * array is runtime-dependent and potentially without any upper bound. To address this, BlockRunLengthDecode allows + * retrieving a "window" from the run-length decoded array. The window's offset can be specified and BLOCK_THREADS * + * DECODED_ITEMS_PER_THREAD (i.e., referred to as window_size) decoded items from the specified window will be returned. + * + * \note: Trailing runs of length 0 are supported (i.e., they may only appear at the end of the run_lengths array). + * A run of length zero may not be followed by a run length that is not zero. + * + * \par + * \code + * __global__ void ExampleKernel(...) + * { + * // Specialising BlockRunLengthDecode to run-length decode items of type uint64_t + * using RunItemT = uint64_t; + * // Type large enough to index into the run-length decoded array + * using RunLengthT = uint32_t; + * + * // Specialising BlockRunLengthDecode for a 1D block of 128 threads + * constexpr int BLOCK_DIM_X = 128; + * // Specialising BlockRunLengthDecode to have each thread contribute 2 run-length encoded runs + * constexpr int RUNS_PER_THREAD = 2; + * // Specialising BlockRunLengthDecode to have each thread hold 4 run-length decoded items + * constexpr int DECODED_ITEMS_PER_THREAD = 4; + * + * // Specialize BlockRadixSort for a 1D block of 128 threads owning 4 integer items each + * using BlockRunLengthDecodeT = + * cub::BlockRunLengthDecode; + * + * // Allocate shared memory for BlockRunLengthDecode + * __shared__ typename BlockRunLengthDecodeT::TempStorage temp_storage; + * + * // The run-length encoded items and how often they shall be repeated in the run-length decoded output + * RunItemT run_values[RUNS_PER_THREAD]; + * RunLengthT run_lengths[RUNS_PER_THREAD]; + * ... + * + * // Initialize the BlockRunLengthDecode with the runs that we want to run-length decode + * uint32_t total_decoded_size = 0; + * BlockRunLengthDecodeT block_rld(temp_storage, run_values, run_lengths, total_decoded_size); + * + * // Run-length decode ("decompress") the runs into a window buffer of limited size. This is repeated until all runs + * // have been decoded. + * uint32_t decoded_window_offset = 0U; + * while (decoded_window_offset < total_decoded_size) + * { + * RunLengthT relative_offsets[DECODED_ITEMS_PER_THREAD]; + * RunItemT decoded_items[DECODED_ITEMS_PER_THREAD]; + * + * // The number of decoded items that are valid within this window (aka pass) of run-length decoding + * uint32_t num_valid_items = total_decoded_size - decoded_window_offset; + * block_rld.RunLengthDecode(decoded_items, relative_offsets, decoded_window_offset); + * + * decoded_window_offset += BLOCK_DIM_X * DECODED_ITEMS_PER_THREAD; + * + * ... + * } + * } + * \endcode + * \par + * Suppose the set of input \p run_values across the block of threads is + * { [0, 1], [2, 3], [4, 5], [6, 7], ..., [254, 255] } and + * \p run_lengths is { [1, 2], [3, 4], [5, 1], [2, 3], ..., [5, 1] }. + * The corresponding output \p decoded_items in those threads will be { [0, 1, 1, 2], [2, 2, 3, 3], [3, 3, 4, 4], + * [4, 4, 4, 5], ..., [169, 169, 170, 171] } and \p relative_offsets will be { [0, 0, 1, 0], [1, 2, 0, 1], [2, + * 3, 0, 1], [2, 3, 4, 0], ..., [3, 4, 0, 0] } during the first iteration of the while loop. + * + * \tparam ItemT The data type of the items being run-length decoded + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam RUNS_PER_THREAD The number of consecutive runs that each thread contributes + * \tparam DECODED_ITEMS_PER_THREAD The maximum number of decoded items that each thread holds + * \tparam DecodedOffsetT Type used to index into the block's decoded items (large enough to hold the sum over all the + * runs' lengths) + * \tparam BLOCK_DIM_Y The thread block length in threads along the Y dimension + * \tparam BLOCK_DIM_Z The thread block length in threads along the Z dimension + */ +template +class BlockRunLengthDecode +{ + //--------------------------------------------------------------------- + // CONFIGS & TYPE ALIASES + //--------------------------------------------------------------------- +private: + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z; + + /// The number of runs that the block decodes (out-of-bounds items may be padded with run lengths of '0') + static constexpr int BLOCK_RUNS = BLOCK_THREADS * RUNS_PER_THREAD; + + /// BlockScan used to determine the beginning of each run (i.e., prefix sum over the runs' length) + using RunOffsetScanT = BlockScan; + + /// Type used to index into the block's runs + using RunOffsetT = uint32_t; + + /// Shared memory type required by this thread block + union _TempStorage + { + typename RunOffsetScanT::TempStorage offset_scan; + struct + { + ItemT run_values[BLOCK_RUNS]; + DecodedOffsetT run_offsets[BLOCK_RUNS]; + } runs; + }; // union TempStorage + + /// Internal storage allocator (used when the user does not provide pre-allocated shared memory) + HIPCUB_DEVICE __forceinline__ _TempStorage &PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + uint32_t linear_tid; + +public: + struct TempStorage : Uninitialized<_TempStorage> + { + }; + + //--------------------------------------------------------------------- + // CONSTRUCTOR + //--------------------------------------------------------------------- + + /** + * \brief Constructor specialised for user-provided temporary storage, initializing using the runs' lengths. The + * algorithm's temporary storage may not be repurposed between the constructor call and subsequent + * RunLengthDecode calls. + */ + template + HIPCUB_DEVICE __forceinline__ BlockRunLengthDecode(TempStorage &temp_storage, + ItemT (&run_values)[RUNS_PER_THREAD], + RunLengthT (&run_lengths)[RUNS_PER_THREAD], + TotalDecodedSizeT &total_decoded_size) + : temp_storage(temp_storage.Alias()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + { + InitWithRunLengths(run_values, run_lengths, total_decoded_size); + } + + /** + * \brief Constructor specialised for user-provided temporary storage, initializing using the runs' offsets. The + * algorithm's temporary storage may not be repurposed between the constructor call and subsequent + * RunLengthDecode calls. + */ + template + HIPCUB_DEVICE __forceinline__ BlockRunLengthDecode(TempStorage &temp_storage, + ItemT (&run_values)[RUNS_PER_THREAD], + UserRunOffsetT (&run_offsets)[RUNS_PER_THREAD]) + : temp_storage(temp_storage.Alias()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + { + InitWithRunOffsets(run_values, run_offsets); + } + + /** + * \brief Constructor specialised for static temporary storage, initializing using the runs' lengths. + */ + template + HIPCUB_DEVICE __forceinline__ BlockRunLengthDecode(ItemT (&run_values)[RUNS_PER_THREAD], + RunLengthT (&run_lengths)[RUNS_PER_THREAD], + TotalDecodedSizeT &total_decoded_size) + : temp_storage(PrivateStorage()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + { + InitWithRunLengths(run_values, run_lengths, total_decoded_size); + } + + /** + * \brief Constructor specialised for static temporary storage, initializing using the runs' offsets. + */ + template + HIPCUB_DEVICE __forceinline__ BlockRunLengthDecode(ItemT (&run_values)[RUNS_PER_THREAD], + UserRunOffsetT (&run_offsets)[RUNS_PER_THREAD]) + : temp_storage(PrivateStorage()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + { + InitWithRunOffsets(run_values, run_offsets); + } + +private: + /** + * \brief Returns the offset of the first value within \p input which compares greater than \p val. This version takes + * \p MAX_NUM_ITEMS, an upper bound of the array size, which will be used to determine the number of binary search + * iterations at compile time. + */ + template + HIPCUB_DEVICE __forceinline__ OffsetT StaticUpperBound(InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key + { + OffsetT lower_bound = 0; + OffsetT upper_bound = num_items; + #pragma unroll + for (int i = 0; i <= Log2::VALUE; i++) + { + OffsetT mid = cub::MidPoint(lower_bound, upper_bound); + mid = (rocprim::min)(mid, num_items - 1); + + if (val < input[mid]) + { + upper_bound = mid; + } + else + { + lower_bound = mid + 1; + } + } + + return lower_bound; + } + + template + HIPCUB_DEVICE __forceinline__ void InitWithRunOffsets(ItemT (&run_values)[RUNS_PER_THREAD], + RunOffsetT (&run_offsets)[RUNS_PER_THREAD]) + { + // Keep the runs' items and the offsets of each run's beginning in the temporary storage + RunOffsetT thread_dst_offset = static_cast(linear_tid) * static_cast(RUNS_PER_THREAD); + #pragma unroll + for (int i = 0; i < RUNS_PER_THREAD; i++) + { + temp_storage.runs.run_values[thread_dst_offset] = run_values[i]; + temp_storage.runs.run_offsets[thread_dst_offset] = run_offsets[i]; + thread_dst_offset++; + } + + // Ensure run offsets and run values have been writen to shared memory + CTA_SYNC(); + } + + template + HIPCUB_DEVICE __forceinline__ void InitWithRunLengths(ItemT (&run_values)[RUNS_PER_THREAD], + RunLengthT (&run_lengths)[RUNS_PER_THREAD], + TotalDecodedSizeT &total_decoded_size) + { + // Compute the offset for the beginning of each run + DecodedOffsetT run_offsets[RUNS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < RUNS_PER_THREAD; i++) + { + run_offsets[i] = static_cast(run_lengths[i]); + } + DecodedOffsetT decoded_size_aggregate; + RunOffsetScanT(this->temp_storage.offset_scan).ExclusiveSum(run_offsets, run_offsets, decoded_size_aggregate); + total_decoded_size = static_cast(decoded_size_aggregate); + + // Ensure the prefix scan's temporary storage can be reused (may be superfluous, but depends on scan implementation) + CTA_SYNC(); + + InitWithRunOffsets(run_values, run_offsets); + } + +public: + /** + * \brief Run-length decodes the runs previously passed via a call to Init(...) and returns the run-length decoded + * items in a blocked arrangement to \p decoded_items. If the number of run-length decoded items exceeds the + * run-length decode buffer (i.e., DECODED_ITEMS_PER_THREAD * BLOCK_THREADS), only the items that fit within + * the buffer are returned. Subsequent calls to RunLengthDecode adjusting \p from_decoded_offset can be + * used to retrieve the remaining run-length decoded items. Calling __syncthreads() between any two calls to + * RunLengthDecode is not required. + * \p item_offsets can be used to retrieve each run-length decoded item's relative index within its run. E.g., the + * run-length encoded array of `3, 1, 4` with the respective run lengths of `2, 1, 3` would yield the run-length + * decoded array of `3, 3, 1, 4, 4, 4` with the relative offsets of `0, 1, 0, 0, 1, 2`. + * \smemreuse + * + * \param[out] decoded_items The run-length decoded items to be returned in a blocked arrangement + * \param[out] item_offsets The run-length decoded items' relative offset within the run they belong to + * \param[in] from_decoded_offset If invoked with from_decoded_offset that is larger than total_decoded_size results + * in undefined behavior. + */ + template + HIPCUB_DEVICE __forceinline__ void RunLengthDecode(ItemT (&decoded_items)[DECODED_ITEMS_PER_THREAD], + RelativeOffsetT (&item_offsets)[DECODED_ITEMS_PER_THREAD], + DecodedOffsetT from_decoded_offset = 0) + { + // The (global) offset of the first item decoded by this thread + DecodedOffsetT thread_decoded_offset = from_decoded_offset + linear_tid * DECODED_ITEMS_PER_THREAD; + + // The run that the first decoded item of this thread belongs to + // If this thread's is already beyond the total decoded size, it will be assigned to the + // last run + RunOffsetT assigned_run = + StaticUpperBound(temp_storage.runs.run_offsets, BLOCK_RUNS, thread_decoded_offset) - + static_cast(1U); + + DecodedOffsetT assigned_run_begin = temp_storage.runs.run_offsets[assigned_run]; + + // If this thread is getting assigned the last run, we make sure it will not fetch any other run after this + DecodedOffsetT assigned_run_end = (assigned_run == BLOCK_RUNS - 1) + ? thread_decoded_offset + DECODED_ITEMS_PER_THREAD + : temp_storage.runs.run_offsets[assigned_run + 1]; + + ItemT val = temp_storage.runs.run_values[assigned_run]; + + #pragma unroll + for (DecodedOffsetT i = 0; i < DECODED_ITEMS_PER_THREAD; i++) + { + decoded_items[i] = val; + item_offsets[i] = thread_decoded_offset - assigned_run_begin; + if (thread_decoded_offset == assigned_run_end - 1) + { + // We make sure that a thread is not re-entering this conditional when being assigned to the last run already by + // extending the last run's length to all the thread's item + assigned_run++; + assigned_run_begin = temp_storage.runs.run_offsets[assigned_run]; + + // If this thread is getting assigned the last run, we make sure it will not fetch any other run after this + assigned_run_end = (assigned_run == BLOCK_RUNS - 1) ? thread_decoded_offset + DECODED_ITEMS_PER_THREAD + : temp_storage.runs.run_offsets[assigned_run + 1]; + val = temp_storage.runs.run_values[assigned_run]; + } + thread_decoded_offset++; + } + } + + /** + * \brief Run-length decodes the runs previously passed via a call to Init(...) and returns the run-length decoded + * items in a blocked arrangement to \p decoded_items. If the number of run-length decoded items exceeds the + * run-length decode buffer (i.e., DECODED_ITEMS_PER_THREAD * BLOCK_THREADS), only the items that fit within + * the buffer are returned. Subsequent calls to RunLengthDecode adjusting \p from_decoded_offset can be + * used to retrieve the remaining run-length decoded items. Calling __syncthreads() between any two calls to + * RunLengthDecode is not required. + * + * \param[out] decoded_items The run-length decoded items to be returned in a blocked arrangement + * \param[in] from_decoded_offset If invoked with from_decoded_offset that is larger than total_decoded_size results + * in undefined behavior. + */ + HIPCUB_DEVICE __forceinline__ void RunLengthDecode(ItemT (&decoded_items)[DECODED_ITEMS_PER_THREAD], + DecodedOffsetT from_decoded_offset = 0) + { + DecodedOffsetT item_offsets[DECODED_ITEMS_PER_THREAD]; + RunLengthDecode(decoded_items, item_offsets, from_decoded_offset); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RUN_LENGTH_DECODE_HPP_ diff --git a/3rdparty/cub/block/block_scan.cuh b/3rdparty/cub/block/block_scan.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d0677170f4467754350859d864a073d461517a3c --- /dev/null +++ b/3rdparty/cub/block/block_scan.cuh @@ -0,0 +1,317 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ + +#include + +#include "../config.hpp" + +#include "../thread/thread_operators.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + inline constexpr + typename std::underlying_type<::rocprim::block_scan_algorithm>::type + to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm v) + { + using utype = std::underlying_type<::rocprim::block_scan_algorithm>::type; + return static_cast(v); + } +} + +enum BlockScanAlgorithm +{ + BLOCK_SCAN_RAKING + = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::reduce_then_scan), + BLOCK_SCAN_RAKING_MEMOIZE + = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::reduce_then_scan), + BLOCK_SCAN_WARP_SCANS + = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::using_warp_scan) +}; + +template< + typename T, + int BLOCK_DIM_X, + BlockScanAlgorithm ALGORITHM = BLOCK_SCAN_RAKING, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockScan + : private ::rocprim::block_scan< + T, + BLOCK_DIM_X, + static_cast<::rocprim::block_scan_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_scan< + T, + BLOCK_DIM_X, + static_cast<::rocprim::block_scan_algorithm>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockScan() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockScan(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + HIPCUB_DEVICE inline + void InclusiveSum(T input, T& output) + { + base_type::inclusive_scan(input, output, temp_storage_); + } + + HIPCUB_DEVICE inline + void InclusiveSum(T input, T& output, T& block_aggregate) + { + base_type::inclusive_scan(input, output, block_aggregate, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void InclusiveSum(T input, T& output, BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::inclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, ::cub::Sum() + ); + } + + template + HIPCUB_DEVICE inline + void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD]) + { + base_type::inclusive_scan(input, output, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + T& block_aggregate) + { + base_type::inclusive_scan(input, output, block_aggregate, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::inclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, ::cub::Sum() + ); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T input, T& output, ScanOp scan_op) + { + base_type::inclusive_scan(input, output, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T input, T& output, ScanOp scan_op, T& block_aggregate) + { + base_type::inclusive_scan(input, output, block_aggregate, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T input, T& output, ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::inclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], ScanOp scan_op) + { + base_type::inclusive_scan(input, output, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + ScanOp scan_op, T& block_aggregate) + { + base_type::inclusive_scan(input, output, block_aggregate, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::inclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, scan_op + ); + } + + HIPCUB_DEVICE inline + void ExclusiveSum(T input, T& output) + { + base_type::exclusive_scan(input, output, T(0), temp_storage_); + } + + HIPCUB_DEVICE inline + void ExclusiveSum(T input, T& output, T& block_aggregate) + { + base_type::exclusive_scan(input, output, T(0), block_aggregate, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ExclusiveSum(T input, T& output, BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::exclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, ::cub::Sum() + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD]) + { + base_type::exclusive_scan(input, output, T(0), temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + T& block_aggregate) + { + base_type::exclusive_scan(input, output, T(0), block_aggregate, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::exclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, ::cub::Sum() + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& output, T initial_value, ScanOp scan_op) + { + base_type::exclusive_scan(input, output, initial_value, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& output, T initial_value, + ScanOp scan_op, T& block_aggregate) + { + base_type::exclusive_scan( + input, output, initial_value, block_aggregate, temp_storage_, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& output, ScanOp scan_op, + BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::exclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + T initial_value, ScanOp scan_op) + { + base_type::exclusive_scan(input, output, initial_value, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + T initial_value, ScanOp scan_op, T& block_aggregate) + { + base_type::exclusive_scan( + input, output, initial_value, block_aggregate, temp_storage_, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], + ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op) + { + base_type::exclusive_scan( + input, output, temp_storage_, block_prefix_callback_op, scan_op + ); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ diff --git a/3rdparty/cub/block/block_shuffle.cuh b/3rdparty/cub/block/block_shuffle.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e3613761240f5e4bcf5b6c621248ca64c156c3d3 --- /dev/null +++ b/3rdparty/cub/block/block_shuffle.cuh @@ -0,0 +1,191 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ + +#include + +#include "../config.hpp" + +#include "../thread/thread_operators.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + + + +template < + typename T, + int BLOCK_DIM_X, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH> +class BlockShuffle : public ::rocprim::block_shuffle< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z> +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_shuffle< + T, + BLOCK_DIM_X, + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockShuffle() : temp_storage_(private_storage()) + {} + + + HIPCUB_DEVICE inline + BlockShuffle(TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage + : temp_storage_(temp_storage) + {} + + /** + * \brief Each threadi obtains the \p input provided by threadi+distance. The offset \p distance may be negative. + * + * \par + * - \smemreuse + */ + HIPCUB_DEVICE inline void Offset( + T input, ///< [in] The input item from the calling thread (threadi) + T& output, ///< [out] The \p input item from the successor (or predecessor) thread threadi+distance (may be aliased to \p input). This value is only updated for for threadi when 0 <= (i + \p distance) < BLOCK_THREADS-1 + int distance = 1) ///< [in] Offset distance (may be negative) + { + base_type::offset(input,output,distance); + } + + /** + * \brief Each threadi obtains the \p input provided by threadi+distance. + * + * \par + * - \smemreuse + */ + HIPCUB_DEVICE inline void Rotate( + T input, ///< [in] The calling thread's input item + T& output, ///< [out] The \p input item from thread thread(i+distance>)% (may be aliased to \p input). This value is not updated for threadBLOCK_THREADS-1 + unsigned int distance = 1) ///< [in] Offset distance (0 < \p distance < BLOCK_THREADS) + { + base_type::rotate(input,output,distance); + } + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it up by one item + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + HIPCUB_DEVICE inline void Up( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for thread0. + { + base_type::up(input,prev); + } + + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it up by one item. All threads receive the \p input provided by threadBLOCK_THREADS-1. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + HIPCUB_DEVICE inline void Up( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&prev)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for thread0. + T &block_suffix) ///< [out] The item \p input[ITEMS_PER_THREAD-1] from threadBLOCK_THREADS-1, provided to all threads + { + base_type::up(input,prev,block_suffix); + } + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of \p input items, shifting it down by one item + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + HIPCUB_DEVICE inline void Down( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&next)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p next[0] is not updated for threadBLOCK_THREADS-1. + { + base_type::down(input,next); + } + + /** + * \brief The thread block rotates its [blocked arrangement](index.html#sec5sec3) of input items, shifting it down by one item. All threads receive \p input[0] provided by thread0. + * + * \par + * - \blocked + * - \granularity + * - \smemreuse + */ + template + HIPCUB_DEVICE inline void Down( + T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items + T (&next)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p next[0] is not updated for threadBLOCK_THREADS-1. + T &block_prefix) ///< [out] The item \p input[0] from thread0, provided to all threads + { + base_type::down(input,next,block_prefix); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ diff --git a/3rdparty/cub/block/block_store.cuh b/3rdparty/cub/block/block_store.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b763a3aa1358a216c219feba49fbf7a297550f93 --- /dev/null +++ b/3rdparty/cub/block/block_store.cuh @@ -0,0 +1,148 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_HPP_ + +#include + +#include "../config.hpp" + +#include "block_store_func.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +namespace detail +{ + inline constexpr + typename std::underlying_type<::rocprim::block_store_method>::type + to_BlockStoreAlgorithm_enum(::rocprim::block_store_method v) + { + using utype = std::underlying_type<::rocprim::block_store_method>::type; + return static_cast(v); + } +} + +enum BlockStoreAlgorithm +{ + BLOCK_STORE_DIRECT + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_direct), + BLOCK_STORE_STRIPED + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_striped), + BLOCK_STORE_VECTORIZE + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_vectorize), + BLOCK_STORE_TRANSPOSE + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_transpose), + BLOCK_STORE_WARP_TRANSPOSE + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_warp_transpose), + BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED + = detail::to_BlockStoreAlgorithm_enum(::rocprim::block_store_method::block_store_warp_transpose) +}; + +template< + typename T, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + BlockStoreAlgorithm ALGORITHM = BLOCK_STORE_DIRECT, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1, + int ARCH = HIPCUB_ARCH /* ignored */ +> +class BlockStore + : private ::rocprim::block_store< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + static_cast<::rocprim::block_store_method>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + > +{ + static_assert( + BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0, + "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0" + ); + + using base_type = + typename ::rocprim::block_store< + T, + BLOCK_DIM_X, + ITEMS_PER_THREAD, + static_cast<::rocprim::block_store_method>(ALGORITHM), + BLOCK_DIM_Y, + BLOCK_DIM_Z + >; + + // Reference to temporary storage (usually shared memory) + typename base_type::storage_type& temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + BlockStore() : temp_storage_(private_storage()) + { + } + + HIPCUB_DEVICE inline + BlockStore(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + template + HIPCUB_DEVICE inline + void Store(OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) + { + base_type::store(block_iter, items, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void Store(OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + base_type::store(block_iter, items, valid_items, temp_storage_); + } + +private: + HIPCUB_DEVICE inline + TempStorage& private_storage() + { + HIPCUB_SHARED_MEMORY TempStorage private_storage; + return private_storage; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_HPP_ diff --git a/3rdparty/cub/block/block_store_func.hpp b/3rdparty/cub/block/block_store_func.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6dbae952a8f0a7c6d280f928031fdda0cc0aa8a --- /dev/null +++ b/3rdparty/cub/block/block_store_func.hpp @@ -0,0 +1,150 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ +#define HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectBlocked(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_store_direct_blocked( + linear_id, block_iter, items + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectBlocked(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_store_direct_blocked( + linear_id, block_iter, items, valid_items + ); +} + +template < + typename T, + int ITEMS_PER_THREAD +> +HIPCUB_DEVICE inline +void StoreDirectBlockedVectorized(int linear_id, + T* block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_store_direct_blocked_vectorized( + linear_id, block_iter, items + ); +} + +template< + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectStriped(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_store_direct_striped( + linear_id, block_iter, items + ); +} + +template< + int BLOCK_THREADS, + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectStriped(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_store_direct_striped( + linear_id, block_iter, items, valid_items + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectWarpStriped(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD]) +{ + ::rocprim::block_store_direct_warp_striped( + linear_id, block_iter, items + ); +} + +template< + typename T, + int ITEMS_PER_THREAD, + typename OutputIteratorT +> +HIPCUB_DEVICE inline +void StoreDirectWarpStriped(int linear_id, + OutputIteratorT block_iter, + T (&items)[ITEMS_PER_THREAD], + int valid_items) +{ + ::rocprim::block_store_direct_warp_striped( + linear_id, block_iter, items, valid_items + ); +} + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ diff --git a/3rdparty/cub/block/radix_rank_sort_operations.hpp b/3rdparty/cub/block/radix_rank_sort_operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c48f515b3828eb103343c017ba50a5481e5c57d --- /dev/null +++ b/3rdparty/cub/block/radix_rank_sort_operations.hpp @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * radix_rank_sort_operations.cuh contains common abstractions, definitions and + * operations used for radix sorting and ranking. + */ + + #ifndef HIPCUB_ROCPRIM_BLOCK_RADIX_RANK_SORT_OPERATIONS_HPP_ + #define HIPCUB_ROCPRIM_BLOCK_RADIX_RANK_SORT_OPERATIONS_HPP_ + +#include + +#include "../config.hpp" + + #include + #include + #include + +BEGIN_HIPCUB_NAMESPACE + +/** \brief Twiddling keys for radix sort. */ +template +struct RadixSortTwiddle +{ + typedef Traits TraitsT; + typedef typename TraitsT::UnsignedBits UnsignedBits; + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits In(UnsignedBits key) + { + key = TraitsT::TwiddleIn(key); + if (IS_DESCENDING) key = ~key; + return key; + } + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits Out(UnsignedBits key) + { + if (IS_DESCENDING) key = ~key; + key = TraitsT::TwiddleOut(key); + return key; + } + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits DefaultKey() + { + return Out(~UnsignedBits(0)); + } +}; + +/** \brief Base struct for digit extractor. Contains common code to provide + special handling for floating-point -0.0. + + \note This handles correctly both the case when the keys are + bitwise-complemented after twiddling for descending sort (in onesweep) as + well as when the keys are not bit-negated, but the implementation handles + descending sort separately (in other implementations in CUB). Twiddling + alone maps -0.0f to 0x7fffffff and +0.0f to 0x80000000 for float, which are + subsequent bit patterns and bitwise complements of each other. For onesweep, + both -0.0f and +0.0f are mapped to the bit pattern of +0.0f (0x80000000) for + ascending sort, and to the pattern of -0.0f (0x7fffffff) for descending + sort. For all other sorting implementations in CUB, both are always mapped + to +0.0f. Since bit patterns for both -0.0f and +0.0f are next to each other + and only one of them is used, the sorting works correctly. For double, the + same applies, but with 64-bit patterns. +*/ + template + struct BaseDigitExtractor + { + typedef Traits TraitsT; + typedef typename TraitsT::UnsignedBits UnsignedBits; + + enum + { + FLOAT_KEY = TraitsT::CATEGORY == FLOATING_POINT, + }; + + static __device__ __forceinline__ UnsignedBits ProcessFloatMinusZero(UnsignedBits key) + { + if (!FLOAT_KEY) { + return key; + } else { + UnsignedBits TWIDDLED_MINUS_ZERO_BITS = + TraitsT::TwiddleIn(UnsignedBits(1) << UnsignedBits(8 * sizeof(UnsignedBits) - 1)); + UnsignedBits TWIDDLED_ZERO_BITS = TraitsT::TwiddleIn(0); + return key == TWIDDLED_MINUS_ZERO_BITS ? TWIDDLED_ZERO_BITS : key; + } + } + }; + +/** \brief A wrapper type to extract digits. Uses the BFE intrinsic to extract a + * key from a digit. */ + template + struct BFEDigitExtractor : BaseDigitExtractor + { + using typename BaseDigitExtractor::UnsignedBits; + + uint32_t bit_start, num_bits; + explicit __device__ __forceinline__ BFEDigitExtractor( + uint32_t bit_start = 0, uint32_t num_bits = 0) + : bit_start(bit_start), num_bits(num_bits) + { } + + __device__ __forceinline__ uint32_t Digit(UnsignedBits key) + { + return BFE(this->ProcessFloatMinusZero(key), bit_start, num_bits); + } + }; + +/** \brief A wrapper type to extract digits. Uses a combination of shift and + * bitwise and to extract digits. */ + template + struct ShiftDigitExtractor : BaseDigitExtractor + { + using typename BaseDigitExtractor::UnsignedBits; + + uint32_t bit_start, mask; + explicit __device__ __forceinline__ ShiftDigitExtractor( + uint32_t bit_start = 0, uint32_t num_bits = 0) + : bit_start(bit_start), mask((1 << num_bits) - 1) + { } + + __device__ __forceinline__ uint32_t Digit(UnsignedBits key) + { + return uint32_t(this->ProcessFloatMinusZero(key) >> UnsignedBits(bit_start)) & mask; + } + }; + +END_HIPCUB_NAMESPACE + +#endif //HIPCUB_ROCPRIM_BLOCK_RADIX_RANK_SORT_OPERATIONS_HPP_ diff --git a/3rdparty/cub/config.hpp b/3rdparty/cub/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f38e518f751c5d820fb6daf952149a621e6816d5 --- /dev/null +++ b/3rdparty/cub/config.hpp @@ -0,0 +1,122 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_CONFIG_HPP_ +#define HIPCUB_CONFIG_HPP_ + +#include + +#define HIPCUB_NAMESPACE cub + +#define BEGIN_HIPCUB_NAMESPACE \ + namespace cub { + +#define END_HIPCUB_NAMESPACE \ + } /* hipcub */ + +#ifndef HIPCUB_ARCH +#define HIPCUB_ARCH 1 +#endif + +#define CUB_DEVICE_WARP_THREADS 64 + +#ifdef __CUDACC__ + #define HIPCUB_ROCPRIM_API 1 + #define HIPCUB_RUNTIME_FUNCTION __host__ +#elif defined(__HIP_PLATFORM_NVIDIA__) + #define HIPCUB_CUB_API 1 + #define HIPCUB_RUNTIME_FUNCTION CUB_RUNTIME_FUNCTION + + #include + #define HIPCUB_WARP_THREADS CUB_PTX_WARP_THREADS + #define HIPCUB_DEVICE_WARP_THREADS CUB_PTX_WARP_THREADS + #define HIPCUB_HOST_WARP_THREADS CUB_PTX_WARP_THREADS + #define HIPCUB_ARCH CUB_PTX_ARCH + BEGIN_HIPCUB_NAMESPACE + using namespace cub; + END_HIPCUB_NAMESPACE +#endif + +/// Supported warp sizes +#define HIPCUB_WARP_SIZE_32 32u +#define HIPCUB_WARP_SIZE_64 64u +#define HIPCUB_MAX_WARP_SIZE HIPCUB_WARP_SIZE_64 + +#define HIPCUB_HOST __host__ +#define HIPCUB_DEVICE __device__ +#define HIPCUB_HOST_DEVICE __host__ __device__ +#define HIPCUB_SHARED_MEMORY __shared__ + +// Helper macros to disable warnings in clang +#ifdef __clang__ +#define HIPCUB_PRAGMA_TO_STR(x) _Pragma(#x) +#define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define HIPCUB_CLANG_SUPPRESS_WARNING(w) HIPCUB_PRAGMA_TO_STR(clang diagnostic ignored w) +#define HIPCUB_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + HIPCUB_CLANG_SUPPRESS_WARNING_PUSH HIPCUB_CLANG_SUPPRESS_WARNING(w) +#else // __clang__ +#define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH +#define HIPCUB_CLANG_SUPPRESS_WARNING(w) +#define HIPCUB_CLANG_SUPPRESS_WARNING_POP +#define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // __clang__ + +BEGIN_HIPCUB_NAMESPACE + +/// hipCUB error reporting macro (prints error messages to stderr) +#if (defined(DEBUG) || defined(_DEBUG)) && !defined(HIPCUB_STDERR) + #define HIPCUB_STDERR +#endif + +inline +cudaError_t Debug( + cudaError_t error, + const char* filename, + int line) +{ + (void)filename; + (void)line; +#ifdef HIPCUB_STDERR + if (error) + { + fprintf(stderr, "cuda error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error)); + fflush(stderr); + } +#endif + return error; +} + +#ifndef cubDebug + #define cubDebug(e) cub::Debug((cudaError_t) (e), __FILE__, __LINE__) +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_CONFIG_HPP_ diff --git a/3rdparty/cub/cub.cuh b/3rdparty/cub/cub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9df829f139d1118379ad264714b5471a494cc473 --- /dev/null +++ b/3rdparty/cub/cub.cuh @@ -0,0 +1,92 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_HIPCUB_HPP_ +#define HIPCUB_ROCPRIM_HIPCUB_HPP_ + +#include "config.hpp" +#include "version.cuh" + +#include "util_allocator.cuh" +#include "util_type.cuh" +#include "util_ptx.cuh" +#include "thread/thread_operators.cuh" + +// Iterator +#include "iterator/arg_index_input_iterator.cuh" +#include "iterator/cache_modified_input_iterator.cuh" +#include "iterator/cache_modified_output_iterator.cuh" +#include "iterator/constant_input_iterator.cuh" +#include "iterator/counting_input_iterator.cuh" +#include "iterator/discard_output_iterator.cuh" +#include "iterator/tex_obj_input_iterator.cuh" +#include "iterator/tex_ref_input_iterator.cuh" +#include "iterator/transform_input_iterator.cuh" + +// Warp +#include "warp/warp_exchange.hpp" +#include "warp/warp_load.hpp" +#include "warp/warp_merge_sort.hpp" +#include "warp/warp_reduce.cuh" +#include "warp/warp_scan.cuh" +#include "warp/warp_store.hpp" + +// Thread +#include "thread/thread_load.cuh" +#include "thread/thread_operators.cuh" +#include "thread/thread_reduce.cuh" +#include "thread/thread_scan.cuh" +#include "thread/thread_search.cuh" +#include "thread/thread_sort.hpp" +#include "thread/thread_store.cuh" + +// Block +#include "block/block_discontinuity.cuh" +#include "block/block_exchange.cuh" +#include "block/block_histogram.cuh" +#include "block/block_load.cuh" +#include "block/block_radix_sort.cuh" +#include "block/block_reduce.cuh" +#include "block/block_scan.cuh" +#include "block/block_store.cuh" + +// Device +#include "device/device_adjacent_difference.hpp" +#include "device/device_histogram.cuh" +#include "device/device_radix_sort.cuh" +#include "device/device_reduce.cuh" +#include "device/device_run_length_encode.cuh" +#include "device/device_scan.cuh" +#include "device/device_segmented_radix_sort.cuh" +#include "device/device_segmented_reduce.cuh" +#include "device/device_segmented_sort.hpp" +#include "device/device_select.cuh" +#include "device/device_partition.cuh" + +#endif // HIPCUB_ROCPRIM_HIPCUB_HPP_ diff --git a/3rdparty/cub/device/device_adjacent_difference.hpp b/3rdparty/cub/device/device_adjacent_difference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc8d374ba32c8eed67c1649f0c2fbfbd7fd3304d --- /dev/null +++ b/3rdparty/cub/device/device_adjacent_difference.hpp @@ -0,0 +1,116 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#include "../config.hpp" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceAdjacentDifference +{ + template + static HIPCUB_RUNTIME_FUNCTION cudaError_t + SubtractLeftCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::adjacent_difference( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION cudaError_t + SubtractLeft(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::adjacent_difference_inplace( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION cudaError_t + SubtractRightCopy(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_input, + OutputIteratorT d_output, + std::size_t num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::adjacent_difference_right( + d_temp_storage, temp_storage_bytes, d_input, d_output, + num_items, difference_op, stream, debug_synchronous + ); + } + + template + static HIPCUB_RUNTIME_FUNCTION cudaError_t + SubtractRight(void *d_temp_storage, + std::size_t &temp_storage_bytes, + RandomAccessIteratorT d_input, + std::size_t num_items, + DifferenceOpT difference_op = {}, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::adjacent_difference_right_inplace( + d_temp_storage, temp_storage_bytes, d_input, + num_items, difference_op, stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ diff --git a/3rdparty/cub/device/device_histogram.cuh b/3rdparty/cub/device/device_histogram.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9fa74838eac2b74586d27382562780bd8a3e6bd5 --- /dev/null +++ b/3rdparty/cub/device/device_histogram.cuh @@ -0,0 +1,294 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceHistogram +{ + template< + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t HistogramEven(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram, + int num_levels, + LevelT lower_level, + LevelT upper_level, + OffsetT num_samples, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::histogram_even( + d_temp_storage, temp_storage_bytes, + d_samples, num_samples, + d_histogram, + num_levels, lower_level, upper_level, + stream, debug_synchronous + ); + } + + template< + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t HistogramEven(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram, + int num_levels, + LevelT lower_level, + LevelT upper_level, + OffsetT num_row_samples, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::histogram_even( + d_temp_storage, temp_storage_bytes, + d_samples, num_row_samples, num_rows, row_stride_bytes, + d_histogram, + num_levels, lower_level, upper_level, + stream, debug_synchronous + ); + } + + template< + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t MultiHistogramEven(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram[NUM_ACTIVE_CHANNELS], + int num_levels[NUM_ACTIVE_CHANNELS], + LevelT lower_level[NUM_ACTIVE_CHANNELS], + LevelT upper_level[NUM_ACTIVE_CHANNELS], + OffsetT num_pixels, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + unsigned int levels[NUM_ACTIVE_CHANNELS]; + for(unsigned int channel = 0; channel < NUM_ACTIVE_CHANNELS; channel++) + { + levels[channel] = num_levels[channel]; + } + return (cudaError_t)::rocprim::multi_histogram_even( + d_temp_storage, temp_storage_bytes, + d_samples, num_pixels, + d_histogram, + levels, lower_level, upper_level, + stream, debug_synchronous + ); + } + + template< + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t MultiHistogramEven(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram[NUM_ACTIVE_CHANNELS], + int num_levels[NUM_ACTIVE_CHANNELS], + LevelT lower_level[NUM_ACTIVE_CHANNELS], + LevelT upper_level[NUM_ACTIVE_CHANNELS], + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + unsigned int levels[NUM_ACTIVE_CHANNELS]; + for(unsigned int channel = 0; channel < NUM_ACTIVE_CHANNELS; channel++) + { + levels[channel] = num_levels[channel]; + } + return (cudaError_t)::rocprim::multi_histogram_even( + d_temp_storage, temp_storage_bytes, + d_samples, num_row_pixels, num_rows, row_stride_bytes, + d_histogram, + levels, lower_level, upper_level, + stream, debug_synchronous + ); + } + + template< + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t HistogramRange(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram, + int num_levels, + LevelT * d_levels, + OffsetT num_samples, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::histogram_range( + d_temp_storage, temp_storage_bytes, + d_samples, num_samples, + d_histogram, + num_levels, d_levels, + stream, debug_synchronous + ); + } + + template< + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t HistogramRange(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram, + int num_levels, + LevelT * d_levels, + OffsetT num_row_samples, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::histogram_range( + d_temp_storage, temp_storage_bytes, + d_samples, num_row_samples, num_rows, row_stride_bytes, + d_histogram, + num_levels, d_levels, + stream, debug_synchronous + ); + } + + template< + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t MultiHistogramRange(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram[NUM_ACTIVE_CHANNELS], + int num_levels[NUM_ACTIVE_CHANNELS], + LevelT * d_levels[NUM_ACTIVE_CHANNELS], + OffsetT num_pixels, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + unsigned int levels[NUM_ACTIVE_CHANNELS]; + for(unsigned int channel = 0; channel < NUM_ACTIVE_CHANNELS; channel++) + { + levels[channel] = num_levels[channel]; + } + return (cudaError_t)::rocprim::multi_histogram_range( + d_temp_storage, temp_storage_bytes, + d_samples, num_pixels, + d_histogram, + levels, d_levels, + stream, debug_synchronous + ); + } + + template< + int NUM_CHANNELS, + int NUM_ACTIVE_CHANNELS, + typename SampleIteratorT, + typename CounterT, + typename LevelT, + typename OffsetT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t MultiHistogramRange(void * d_temp_storage, + size_t& temp_storage_bytes, + SampleIteratorT d_samples, + CounterT * d_histogram[NUM_ACTIVE_CHANNELS], + int num_levels[NUM_ACTIVE_CHANNELS], + LevelT * d_levels[NUM_ACTIVE_CHANNELS], + OffsetT num_row_pixels, + OffsetT num_rows, + size_t row_stride_bytes, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + unsigned int levels[NUM_ACTIVE_CHANNELS]; + for(unsigned int channel = 0; channel < NUM_ACTIVE_CHANNELS; channel++) + { + levels[channel] = num_levels[channel]; + } + return (cudaError_t)::rocprim::multi_histogram_range( + d_temp_storage, temp_storage_bytes, + d_samples, num_row_pixels, num_rows, row_stride_bytes, + d_histogram, + levels, d_levels, + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ diff --git a/3rdparty/cub/device/device_merge_sort.hpp b/3rdparty/cub/device/device_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5934d58cdffc3464a6369bc1caef7dc1580a30a --- /dev/null +++ b/3rdparty/cub/device/device_merge_sort.hpp @@ -0,0 +1,176 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_MERGE_SORT_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_MERGE_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceMergeSort +{ + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t SortPairs(void * d_temp_storage, + std::size_t & temp_storage_bytes, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::merge_sort(d_temp_storage, + temp_storage_bytes, + d_keys, + d_keys, + d_items, + d_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t SortPairsCopy(void * d_temp_storage, + std::size_t & temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::merge_sort(d_temp_storage, + temp_storage_bytes, + d_input_keys, + d_output_keys, + d_input_items, + d_output_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t SortKeys(void * d_temp_storage, + std::size_t & temp_storage_bytes, + KeyIteratorT d_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::merge_sort( + d_temp_storage, temp_storage_bytes, + d_keys, d_keys, num_items, + compare_op, stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t SortKeysCopy(void * d_temp_storage, + std::size_t & temp_storage_bytes, + KeyInputIteratorT d_input_keys, + KeyIteratorT d_output_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + + { + return (cudaError_t)::rocprim::merge_sort( + d_temp_storage, temp_storage_bytes, + d_input_keys, d_output_keys, num_items, + compare_op, stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t + StableSortPairs(void *d_temp_storage, + std::size_t &temp_storage_bytes, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::merge_sort(d_temp_storage, + temp_storage_bytes, + d_keys, + d_keys, + d_items, + d_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + template + HIPCUB_RUNTIME_FUNCTION static cudaError_t StableSortKeys(void * d_temp_storage, + std::size_t & temp_storage_bytes, + KeyIteratorT d_keys, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::merge_sort( + d_temp_storage, temp_storage_bytes, + d_keys, d_keys, num_items, + compare_op, stream, debug_synchronous + ); + } + +}; +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_MERGE_SORT_HPP_ diff --git a/3rdparty/cub/device/device_partition.cuh b/3rdparty/cub/device/device_partition.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f6cb63351a6cf36619b60c0103186d24f55dcd53 --- /dev/null +++ b/3rdparty/cub/device/device_partition.cuh @@ -0,0 +1,139 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_PARTITION_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_PARTITION_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DevicePartition +{ + template < + typename InputIteratorT, + typename FlagIterator, + typename OutputIteratorT, + typename NumSelectedIteratorT> + HIPCUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t Flagged( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + FlagIterator d_flags, ///< [in] Pointer to the input sequence of selection flags + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of partitioned data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + int num_items, ///< [in] Total number of items to select from + cudaStream_t stream = 0, ///< [in] [optional] hip stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + return (cudaError_t)rocprim::partition( + d_temp_storage, + temp_storage_bytes, + d_in, + d_flags, + d_out, + d_num_selected_out, + num_items, + stream, + debug_synchronous); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT, + typename SelectOp> + HIPCUB_RUNTIME_FUNCTION __forceinline__ + static cudaError_t If( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items + OutputIteratorT d_out, ///< [out] Pointer to the output sequence of partitioned data items + NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the output total number of items selected (i.e., the offset of the unselected partition) + int num_items, ///< [in] Total number of items to select from + SelectOp select_op, ///< [in] Unary selection operator + cudaStream_t stream = 0, ///< [in] [optional] hip stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + return (cudaError_t)rocprim::partition( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + d_num_selected_out, + num_items, + select_op, + stream, + debug_synchronous); + } + + template + HIPCUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + If(void *d_temp_storage, + std::size_t &temp_storage_bytes, + InputIteratorT d_in, + FirstOutputIteratorT d_first_part_out, + SecondOutputIteratorT d_second_part_out, + UnselectedOutputIteratorT d_unselected_out, + NumSelectedIteratorT d_num_selected_out, + int num_items, + SelectFirstPartOp select_first_part_op, + SelectSecondPartOp select_second_part_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)rocprim::partition_three_way( + d_temp_storage, + temp_storage_bytes, + d_in, + d_first_part_out, + d_second_part_out, + d_unselected_out, + d_num_selected_out, + num_items, + select_first_part_op, + select_second_part_op, + stream, + debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif diff --git a/3rdparty/cub/device/device_radix_sort.cuh b/3rdparty/cub/device/device_radix_sort.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6abd09f8bfea3eded523f856e64fe834d0264a93 --- /dev/null +++ b/3rdparty/cub/device/device_radix_sort.cuh @@ -0,0 +1,224 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceRadixSort +{ + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + NumItemsT num_items, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/device/device_reduce.cuh b/3rdparty/cub/device/device_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..17a85909254949d427b90e0100346e78670efbbd --- /dev/null +++ b/3rdparty/cub/device/device_reduce.cuh @@ -0,0 +1,297 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ + +#include +#include + +#include // __half +#include // hip_bfloat16 + +#include "../config.hpp" +#include "../iterator/arg_index_input_iterator.cuh" +#include "../thread/thread_operators.cuh" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE +namespace detail +{ + +template +inline +T get_lowest_value() +{ + return std::numeric_limits::lowest(); +} + +template<> +inline +__half get_lowest_value<__half>() +{ + unsigned short lowest_half = 0xfbff; + __half lowest_value = *reinterpret_cast<__half*>(&lowest_half); + return lowest_value; +} + +template<> +inline +cuda_bfloat16 get_lowest_value() +{ + return cuda_bfloat16(-3.38953138925e+38f); +} + +template +inline +T get_max_value() +{ + return std::numeric_limits::max(); +} + +template<> +inline +__half get_max_value<__half>() +{ + unsigned short max_half = 0x7bff; + __half max_value = *reinterpret_cast<__half*>(&max_half); + return max_value; +} + +template<> +inline +cuda_bfloat16 get_max_value() +{ + return cuda_bfloat16(3.38953138925e+38f); +} + +} // end detail namespace + +class DeviceReduce +{ +public: + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ReduceOpT, + typename T + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Reduce(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + ReduceOpT reduction_op, + T init, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, init, num_items, + ::cub::detail::convert_result_type(reduction_op), + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Sum(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using T = typename std::iterator_traits::value_type; + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, ::cub::Sum(), T(0), + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Min(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using T = typename std::iterator_traits::value_type; + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, ::cub::Min(), detail::get_max_value(), + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ArgMin(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = + typename std::conditional< + std::is_same::value, + KeyValuePair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = ArgIndexInputIterator; + + IteratorT d_indexed_in(d_in); + OutputTupleT init(1, detail::get_max_value()); + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_indexed_in, d_out, num_items, ::cub::ArgMin(), init, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Max(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using T = typename std::iterator_traits::value_type; + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, ::cub::Max(), detail::get_lowest_value(), + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ArgMax(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = + typename std::conditional< + std::is_same::value, + KeyValuePair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = ArgIndexInputIterator; + + IteratorT d_indexed_in(d_in); + OutputTupleT init(1, detail::get_lowest_value()); + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_indexed_in, d_out, num_items, ::cub::ArgMax(), init, + stream, debug_synchronous + ); + } + + template< + typename KeysInputIteratorT, + typename UniqueOutputIteratorT, + typename ValuesInputIteratorT, + typename AggregatesOutputIteratorT, + typename NumRunsOutputIteratorT, + typename ReductionOpT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ReduceByKey(void * d_temp_storage, + size_t& temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + ReductionOpT reduction_op, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using key_compare_op = + ::rocprim::equal_to::value_type>; + return (cudaError_t)::rocprim::reduce_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, num_items, + d_unique_out, d_aggregates_out, d_num_runs_out, + ::cub::detail::convert_result_type(reduction_op), + key_compare_op(), + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ diff --git a/3rdparty/cub/device/device_run_length_encode.cuh b/3rdparty/cub/device/device_run_length_encode.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e4692e4242659b18219fbd5cd3008f5fabffe438 --- /dev/null +++ b/3rdparty/cub/device/device_run_length_encode.cuh @@ -0,0 +1,95 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ + +#include "../config.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +class DeviceRunLengthEncode +{ +public: + template< + typename InputIteratorT, + typename UniqueOutputIteratorT, + typename LengthsOutputIteratorT, + typename NumRunsOutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Encode(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + UniqueOutputIteratorT d_unique_out, + LengthsOutputIteratorT d_counts_out, + NumRunsOutputIteratorT d_num_runs_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::run_length_encode( + d_temp_storage, temp_storage_bytes, + d_in, num_items, + d_unique_out, d_counts_out, d_num_runs_out, + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OffsetsOutputIteratorT, + typename LengthsOutputIteratorT, + typename NumRunsOutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t NonTrivialRuns(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OffsetsOutputIteratorT d_offsets_out, + LengthsOutputIteratorT d_lengths_out, + NumRunsOutputIteratorT d_num_runs_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::run_length_encode_non_trivial_runs( + d_temp_storage, temp_storage_bytes, + d_in, num_items, + d_offsets_out, d_lengths_out, d_num_runs_out, + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ diff --git a/3rdparty/cub/device/device_scan.cuh b/3rdparty/cub/device/device_scan.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d948b56805c3fcfc72ccf24124affbbb41aa3c6f --- /dev/null +++ b/3rdparty/cub/device/device_scan.cuh @@ -0,0 +1,272 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ + +#include +#include "../config.hpp" + +#include "../thread/thread_operators.cuh" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +class DeviceScan +{ +public: + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t InclusiveSum(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + size_t num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return InclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, ::cub::Sum(), num_items, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ScanOpT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t InclusiveScan(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + size_t num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::inclusive_scan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, num_items, + scan_op, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ExclusiveSum(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + size_t num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using T = typename std::iterator_traits::value_type; + return ExclusiveScan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, ::cub::Sum(), T(0), num_items, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ScanOpT, + typename InitValueT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ExclusiveScan(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + size_t num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::exclusive_scan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, init_value, num_items, + scan_op, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename ScanOpT, + typename InitValueT, + typename InitValueIterT = InitValueT* + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ExclusiveScan(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + FutureValue init_value, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::exclusive_scan( + d_temp_storage, temp_storage_bytes, + d_in, d_out, init_value, num_items, + scan_op, + stream, debug_synchronous + ); + } + + template < + typename KeysInputIteratorT, + typename ValuesInputIteratorT, + typename ValuesOutputIteratorT, + typename EqualityOpT = ::cub::Equality + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ExclusiveSumByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + int num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using in_value_type = typename std::iterator_traits::value_type; + + return (cudaError_t)::rocprim::exclusive_scan_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, + static_cast(0), static_cast(num_items), + ::cub::Sum(), equality_op, stream, debug_synchronous + ); + } + + template < + typename KeysInputIteratorT, + typename ValuesInputIteratorT, + typename ValuesOutputIteratorT, + typename ScanOpT, + typename InitValueT, + typename EqualityOpT = ::cub::Equality + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ExclusiveScanByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanOpT scan_op, + InitValueT init_value, + int num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::exclusive_scan_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, + init_value, static_cast(num_items), + scan_op, equality_op, stream, debug_synchronous + ); + } + + template < + typename KeysInputIteratorT, + typename ValuesInputIteratorT, + typename ValuesOutputIteratorT, + typename EqualityOpT = ::cub::Equality + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t InclusiveSumByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + int num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::inclusive_scan_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, + static_cast(num_items), ::cub::Sum(), + equality_op, stream, debug_synchronous + ); + } + + template < + typename KeysInputIteratorT, + typename ValuesInputIteratorT, + typename ValuesOutputIteratorT, + typename ScanOpT, + typename EqualityOpT = ::cub::Equality + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t InclusiveScanByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanOpT scan_op, + int num_items, + EqualityOpT equality_op = EqualityOpT(), + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::inclusive_scan_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_values_in, d_values_out, + static_cast(num_items), scan_op, + equality_op, stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ diff --git a/3rdparty/cub/device/device_segmented_radix_sort.cuh b/3rdparty/cub/device/device_segmented_radix_sort.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9fd67580bcbf91636186f7bca93753c32eddeaef --- /dev/null +++ b/3rdparty/cub/device/device_segmented_radix_sort.cuh @@ -0,0 +1,256 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceSegmentedRadixSort +{ + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + int begin_bit = 0, + int end_bit = sizeof(KeyT) * 8, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/device/device_segmented_reduce.cuh b/3rdparty/cub/device/device_segmented_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cc16847b1df32339056657ef9126226a9c341d61 --- /dev/null +++ b/3rdparty/cub/device/device_segmented_reduce.cuh @@ -0,0 +1,241 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ + +#include +#include + +#include "../config.hpp" + +#include "../thread/thread_operators.cuh" +#include "../iterator/arg_index_input_iterator.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceSegmentedReduce +{ + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT, + typename ReductionOp, + typename T + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Reduce(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + ReductionOp reduction_op, + T initial_value, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::detail::convert_result_type(reduction_op), + initial_value, + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Sum(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using input_type = typename std::iterator_traits::value_type; + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::Sum(), input_type(), + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Min(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using input_type = typename std::iterator_traits::value_type; + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::Min(), std::numeric_limits::max(), + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ArgMin(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = typename std::conditional< + std::is_same::value, + KeyValuePair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = ArgIndexInputIterator; + + IteratorT d_indexed_in(d_in); + const OutputTupleT init(1, std::numeric_limits::max()); + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_indexed_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::ArgMin(), init, + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Max(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using input_type = typename std::iterator_traits::value_type; + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::Max(), std::numeric_limits::lowest(), + stream, debug_synchronous + ); + } + + template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t ArgMax(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = typename std::conditional< + std::is_same::value, + KeyValuePair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = ArgIndexInputIterator; + + IteratorT d_indexed_in(d_in); + const OutputTupleT init(1, std::numeric_limits::lowest()); + + return Reduce( + d_temp_storage, temp_storage_bytes, + d_indexed_in, d_out, + num_segments, d_begin_offsets, d_end_offsets, + ::cub::ArgMax(), init, + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ diff --git a/3rdparty/cub/device/device_segmented_sort.hpp b/3rdparty/cub/device/device_segmented_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7770fef2bee4eee766518dd75ab793da3d24bbf --- /dev/null +++ b/3rdparty/cub/device/device_segmented_sort.hpp @@ -0,0 +1,410 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_SORT_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +struct DeviceSegmentedSort +{ + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_pairs( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + ::rocprim::double_buffer d_values_db = detail::to_double_buffer(d_values); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_pairs_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, d_values_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + detail::update_double_buffer(d_values, d_values_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_keys( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::segmented_radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t SortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + ::rocprim::double_buffer d_keys_db = detail::to_double_buffer(d_keys); + cudaError_t error = (cudaError_t)::rocprim::segmented_radix_sort_keys_desc( + d_temp_storage, temp_storage_bytes, + d_keys_db, num_items, + num_segments, d_begin_offsets, d_end_offsets, + 0, sizeof(KeyT) * 8, + stream, debug_synchronous + ); + detail::update_double_buffer(d_keys, d_keys_db); + return error; + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortPairs(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortPairs( + d_temp_storage, temp_storage_bytes, + d_keys, d_values, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + const ValueT * d_values_in, + ValueT * d_values_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortPairsDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + DoubleBuffer& d_values, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortPairsDescending( + d_temp_storage, temp_storage_bytes, + d_keys, d_values, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortKeys( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortKeys(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortKeys( + d_temp_storage, temp_storage_bytes, + d_keys, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + const KeyT * d_keys_in, + KeyT * d_keys_out, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortKeysDescending( + d_temp_storage, temp_storage_bytes, + d_keys_in, d_keys_out, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } + + template + HIPCUB_RUNTIME_FUNCTION static + cudaError_t StableSortKeysDescending(void * d_temp_storage, + size_t& temp_storage_bytes, + DoubleBuffer& d_keys, + int num_items, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return SortKeysDescending( + d_temp_storage, temp_storage_bytes, + d_keys, num_items, + num_segments, d_begin_offsets, d_end_offsets, + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_SEGMENTED_SORT_HPP_ diff --git a/3rdparty/cub/device/device_select.cuh b/3rdparty/cub/device/device_select.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0ba9e3c55313e8eff03e66964d38509b24016277 --- /dev/null +++ b/3rdparty/cub/device/device_select.cuh @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ + +#include "../config.hpp" + +#include "../thread/thread_operators.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +class DeviceSelect +{ +public: + template < + typename InputIteratorT, + typename FlagIterator, + typename OutputIteratorT, + typename NumSelectedIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Flagged(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagIterator d_flags, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::select( + d_temp_storage, temp_storage_bytes, + d_in, d_flags, d_out, d_num_selected_out, num_items, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT, + typename SelectOp + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t If(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + int num_items, + SelectOp select_op, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::select( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items, select_op, + stream, debug_synchronous + ); + } + + template < + typename InputIteratorT, + typename OutputIteratorT, + typename NumSelectedIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t Unique(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + NumSelectedIteratorT d_num_selected_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::unique( + d_temp_storage, temp_storage_bytes, + d_in, d_out, d_num_selected_out, num_items, cub::Equality(), + stream, debug_synchronous + ); + } + + template < + typename KeyIteratorT, + typename ValueIteratorT, + typename OutputKeyIteratorT, + typename OutputValueIteratorT, + typename NumSelectedIteratorT + > + HIPCUB_RUNTIME_FUNCTION static + cudaError_t UniqueByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeyIteratorT d_keys_input, + ValueIteratorT d_values_input, + OutputKeyIteratorT d_keys_output, + OutputValueIteratorT d_values_output, + NumSelectedIteratorT d_num_selected_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) + { + return (cudaError_t)::rocprim::unique_by_key( + d_temp_storage, temp_storage_bytes, + d_keys_input, d_values_input, + d_keys_output, d_values_output, + d_num_selected_out, num_items, cub::Equality(), + stream, debug_synchronous + ); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ diff --git a/3rdparty/cub/device/device_spmv.cuh b/3rdparty/cub/device/device_spmv.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b5be1aef3b0571729bd9e2b157fe6a0960234505 --- /dev/null +++ b/3rdparty/cub/device/device_spmv.cuh @@ -0,0 +1,153 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_ +#define HIPCUB_ROCPRIM_DEVICE_DEVICE_SPMV_HPP_ + +#include "../config.hpp" + +#include "../iterator/tex_ref_input_iterator.cuh" + +BEGIN_HIPCUB_NAMESPACE + +class DeviceSpmv +{ + +public: + +template < + typename ValueT, ///< Matrix and vector value type + typename OffsetT> ///< Signed integer type for sequence offsets +struct SpmvParams +{ + ValueT* d_values; ///< Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix A. + OffsetT* d_row_end_offsets; ///< Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values + OffsetT* d_column_indices; ///< Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix A. (Indices are zero-valued.) + ValueT* d_vector_x; ///< Pointer to the array of \p num_cols values corresponding to the dense input vector x + ValueT* d_vector_y; ///< Pointer to the array of \p num_rows values corresponding to the dense output vector y + int num_rows; ///< Number of rows of matrix A. + int num_cols; ///< Number of columns of matrix A. + int num_nonzeros; ///< Number of nonzero elements of matrix A. + ValueT alpha; ///< Alpha multiplicand + ValueT beta; ///< Beta addend-multiplicand + + ::cub::TexRefInputIterator t_vector_x; +}; + +static constexpr uint32_t CsrMVKernel_MaxThreads = 256; + +template +static __global__ void +CsrMVKernel(SpmvParams spmv_params) +{ + __shared__ ValueT partial; + + const int32_t row_id = hipBlockIdx_x; + + if(threadIdx.x == 0) + { + partial = spmv_params.beta * spmv_params.d_vector_y[row_id]; + } + __syncthreads(); + + int32_t row_offset = (row_id == 0) ? (0) : (spmv_params.d_row_end_offsets[row_id - 1]); + for(uint32_t thread_offset = 0; thread_offset < spmv_params.num_cols / blockDim.x; thread_offset++) + { + int32_t offset = row_offset + thread_offset * blockDim.x + threadIdx.x; + + if(offset < spmv_params.d_row_end_offsets[row_id]) + { + ValueT t_value = + spmv_params.alpha * + spmv_params.d_values[offset] * + spmv_params.d_vector_x[spmv_params.d_column_indices[offset]]; + + atomicAdd(&partial, t_value); + + __syncthreads(); + + iif(threadIdx.x == 0) + { + spmv_params.d_vector_y[row_id] = partial; + } + } + } +} + +template + HIPCUB_RUNTIME_FUNCTION + static cudaError_t CsrMV( + void* d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + ValueT* d_values, ///< [in] Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix A. + int* d_row_offsets, ///< [in] Pointer to the array of \p m + 1 offsets demarcating the start of every row in \p d_column_indices and \p d_values (with the final entry being equal to \p num_nonzeros) + int* d_column_indices, ///< [in] Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix A. (Indices are zero-valued.) + ValueT* d_vector_x, ///< [in] Pointer to the array of \p num_cols values corresponding to the dense input vector x + ValueT* d_vector_y, ///< [out] Pointer to the array of \p num_rows values corresponding to the dense output vector y + int num_rows, ///< [in] number of rows of matrix A. + int num_cols, ///< [in] number of columns of matrix A. + int num_nonzeros, ///< [in] number of nonzero elements of matrix A. + cudaStream_t stream = 0, ///< [in] [optional] hip stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. May cause significant slowdown. Default is \p false. + { + SpmvParams spmv_params; + spmv_params.d_values = d_values; + spmv_params.d_row_end_offsets = d_row_offsets + 1; + spmv_params.d_column_indices = d_column_indices; + spmv_params.d_vector_x = d_vector_x; + spmv_params.d_vector_y = d_vector_y; + spmv_params.num_rows = num_rows; + spmv_params.num_cols = num_cols; + spmv_params.num_nonzeros = num_nonzeros; + spmv_params.alpha = 1.0; + spmv_params.beta = 0.0; + + cudaError_t status; + if(d_temp_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, because + // hipMalloc will return nullptr when size is zero. + temp_storage_bytes = 4; + return cudaError_t(0); + } + else + { + size_t block_size = min(num_cols, DeviceSpmv::CsrMVKernel_MaxThreads); + size_t grid_size = num_rows; + CsrMVKernel<<>>(spmv_params); + status = hipGetLastError(); + } + return status; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_CUB_DEVICE_DEVICE_SELECT_HPP_ + diff --git a/3rdparty/cub/grid/grid_barrier.cuh b/3rdparty/cub/grid/grid_barrier.cuh new file mode 100644 index 0000000000000000000000000000000000000000..79b031fc360953caa98f97f19ef744b3c84f4e4d --- /dev/null +++ b/3rdparty/cub/grid/grid_barrier.cuh @@ -0,0 +1,202 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_GRID_GRID_BARRIER_HPP_ +#define HIPCUB_ROCPRIM_GRID_GRID_BARRIER_HPP_ + +#include + +#include "../config.hpp" + +#include "../thread/thread_load.cuh" + +BEGIN_HIPCUB_NAMESPACE + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridBarrier implements a software global barrier among thread blocks within a cuda grid + */ +class GridBarrier +{ +protected : + + typedef unsigned int SyncFlag; + + // Counters in global device memory + SyncFlag* d_sync; + +public: + + /** + * Constructor + */ + GridBarrier() : d_sync(NULL) {} + + /** + * @typedef SyncFlag + * @brief Synchronize + */ + __device__ __forceinline__ void Sync() const + { + volatile SyncFlag *d_vol_sync = d_sync; + + // Threadfence and syncthreads to make sure global writes are visible before + // thread-0 reports in with its sync counter + __threadfence(); + __syncthreads(); + + if (blockIdx.x == 0) + { + // Report in ourselves + if (threadIdx.x == 0) + { + d_vol_sync[blockIdx.x] = 1; + } + + __syncthreads(); + + // Wait for everyone else to report in + for (uint32_t peer_block = threadIdx.x; peer_block < gridDim.x; peer_block += blockDim.x) + { + while (ThreadLoad(d_sync + peer_block) == 0) + { + __threadfence_block(); + } + } + + __syncthreads(); + + // Let everyone know it's safe to proceed + for (uint32_t peer_block = threadIdx.x; peer_block < gridDim.x; peer_block += blockDim.x) + { + d_vol_sync[peer_block] = 0; + } + } + else + { + if (threadIdx.x == 0) + { + // Report in + d_vol_sync[blockIdx.x] = 1; + + // Wait for acknowledgment + while (ThreadLoad(d_sync + blockIdx.x) == 1) + { + __threadfence_block(); + } + } + + __syncthreads(); + } + } +}; + + +/** + * \brief GridBarrierLifetime extends GridBarrier to provide lifetime management of the temporary device storage needed for cooperation. + * + * Uses RAII for lifetime, i.e., device resources are reclaimed when + * the destructor is called. + */ +class GridBarrierLifetime : public GridBarrier +{ +protected: + + // Number of bytes backed by d_sync + size_t sync_bytes; + +public: + + /** + * Constructor + */ + GridBarrierLifetime() : GridBarrier(), sync_bytes(0) {} + + + /** + * DeviceFrees and resets the progress counters + */ + cudaError_t HostReset() + { + cudaError_t retval = cudaSuccess; + if (d_sync) + { + retval = cudaFree(d_sync); + d_sync = NULL; + } + sync_bytes = 0; + return retval; + } + + + /** + * Destructor + */ + virtual ~GridBarrierLifetime() + { + HostReset(); + } + + + /** + * Sets up the progress counters for the next kernel launch (lazily + * allocating and initializing them if necessary) + */ + cudaError_t Setup(int sweep_grid_size) + { + cudaError_t retval = cudaSuccess; + do { + size_t new_sync_bytes = sweep_grid_size * sizeof(SyncFlag); + if (new_sync_bytes > sync_bytes) + { + if (d_sync) + { + if ((retval = cudaFree(d_sync))) break; + } + + sync_bytes = new_sync_bytes; + + // Allocate and initialize to zero + if ((retval = cudaMalloc((void**) &d_sync, sync_bytes))) break; + if ((retval = cudaMemset(d_sync, 0, new_sync_bytes))) break; + } + } while (0); + + return retval; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_GRID_GRID_BARRIER_HPP_ diff --git a/3rdparty/cub/grid/grid_even_share.cuh b/3rdparty/cub/grid/grid_even_share.cuh new file mode 100644 index 0000000000000000000000000000000000000000..76b9d9f14d91080a90ee014a4dae9059a719961b --- /dev/null +++ b/3rdparty/cub/grid/grid_even_share.cuh @@ -0,0 +1,214 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_GRID_GRID_EVEN_SHARE_HPP_ +#define HIPCUB_ROCPRIM_GRID_GRID_EVEN_SHARE_HPP_ + +#include + +#include "../config.hpp" +#include "grid_mapping.cuh" +#include "../util_type.cuh" + +BEGIN_HIPCUB_NAMESPACE + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridEvenShare is a descriptor utility for distributing input among + * CUDA thread blocks in an "even-share" fashion. Each thread block gets roughly + * the same number of input tiles. + * + * \par Overview + * Each thread block is assigned a consecutive sequence of input tiles. To help + * preserve alignment and eliminate the overhead of guarded loads for all but the + * last thread block, to GridEvenShare assigns one of three different amounts of + * work to a given thread block: "big", "normal", or "last". The "big" workloads + * are one scheduling grain larger than "normal". The "last" work unit for the + * last thread block may be partially-full if the input is not an even multiple of + * the scheduling grain size. + * + * \par + * Before invoking a child grid, a parent thread will typically construct an + * instance of GridEvenShare. The instance can be passed to child thread blocks + * which can initialize their per-thread block offsets using \p BlockInit(). + */ +template +struct GridEvenShare +{ +private: + + int total_tiles; + int big_shares; + OffsetT big_share_items; + OffsetT normal_share_items; + OffsetT normal_base_offset; + +public: + + /// Total number of input items + OffsetT num_items; + + /// Grid size in thread blocks + int grid_size; + + /// OffsetT into input marking the beginning of the owning thread block's segment of input tiles + OffsetT block_offset; + + /// OffsetT into input of marking the end (one-past) of the owning thread block's segment of input tiles + OffsetT block_end; + + /// Stride between input tiles + OffsetT block_stride; + + + /** + * \brief Constructor. + */ + __host__ __device__ __forceinline__ GridEvenShare() : + total_tiles(0), + big_shares(0), + big_share_items(0), + normal_share_items(0), + normal_base_offset(0), + num_items(0), + grid_size(0), + block_offset(0), + block_end(0), + block_stride(0) + {} + + + /** + * \brief Dispatch initializer. To be called prior to kernel launch. + */ + __host__ __device__ __forceinline__ void DispatchInit( + OffsetT num_items_, ///< Total number of input items + int max_grid_size, ///< Maximum grid size allowable (actual grid size may be less if not warranted by the the number of input items) + int tile_items) ///< Number of data items per input tile + { + this->block_offset = num_items_; // Initialize past-the-end + this->block_end = num_items_; // Initialize past-the-end + this->num_items = num_items_; + this->total_tiles = static_cast(cub::DivideAndRoundUp(num_items_, tile_items)); + this->grid_size = min(total_tiles, max_grid_size); + int avg_tiles_per_block = total_tiles / grid_size; + // leftover grains go to big blocks: + this->big_shares = total_tiles - (avg_tiles_per_block * grid_size); + this->normal_share_items = avg_tiles_per_block * tile_items; + this->normal_base_offset = big_shares * tile_items; + this->big_share_items = normal_share_items + tile_items; + } + + + /** + * \brief Initializes ranges for the specified thread block index. Specialized + * for a "raking" access pattern in which each thread block is assigned a + * consecutive sequence of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + int block_id, + Int2Type /*strategy_tag*/) + { + block_stride = TILE_ITEMS; + if (block_id < big_shares) + { + // This thread block gets a big share of grains (avg_tiles_per_block + 1) + block_offset = (block_id * big_share_items); + block_end = block_offset + big_share_items; + } + else if (block_id < total_tiles) + { + // This thread block gets a normal share of grains (avg_tiles_per_block) + block_offset = normal_base_offset + (block_id * normal_share_items); + block_end = min(num_items, block_offset + normal_share_items); + } + // Else default past-the-end + } + + + /** + * \brief Block-initialization, specialized for a "raking" access + * pattern in which each thread block is assigned a consecutive sequence + * of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + int block_id, + Int2Type /*strategy_tag*/) + { + block_stride = grid_size * TILE_ITEMS; + block_offset = (block_id * TILE_ITEMS); + block_end = num_items; + } + + + /** + * \brief Block-initialization, specialized for "strip mining" access + * pattern in which the input tiles assigned to each thread block are + * separated by a stride equal to the the extent of the grid. + */ + template < + int TILE_ITEMS, + GridMappingStrategy STRATEGY> + __device__ __forceinline__ void BlockInit() + { + BlockInit(blockIdx.x, Int2Type()); + } + + + /** + * \brief Block-initialization, specialized for a "raking" access + * pattern in which each thread block is assigned a consecutive sequence + * of input tiles. + */ + template + __device__ __forceinline__ void BlockInit( + OffsetT block_offset, ///< [in] Threadblock begin offset (inclusive) + OffsetT block_end) ///< [in] Threadblock end offset (exclusive) + { + this->block_offset = block_offset; + this->block_end = block_end; + this->block_stride = TILE_ITEMS; + } + + +}; + + +/** @} */ // end group GridModule + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_GRID_GRID_EVEN_SHARE_HPP_ diff --git a/3rdparty/cub/grid/grid_mapping.cuh b/3rdparty/cub/grid/grid_mapping.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0c3409f5e08036212b699f4853f48a3d8b45f718 --- /dev/null +++ b/3rdparty/cub/grid/grid_mapping.cuh @@ -0,0 +1,108 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_GRID_GRID_MAPPING_HPP_ +#define HIPCUB_ROCPRIM_GRID_GRID_MAPPING_HPP_ + +#include + +#include "../config.hpp" +#include "../thread/thread_load.cuh" + +BEGIN_HIPCUB_NAMESPACE + + +/** + * \addtogroup GridModule + * @{ + */ + + +/****************************************************************************** + * Mapping policies + *****************************************************************************/ + + +/** + * \brief cub::GridMappingStrategy enumerates alternative strategies for mapping constant-sized tiles of device-wide data onto a grid of CUDA thread blocks. + */ +enum GridMappingStrategy +{ + /** + * \brief An a "raking" access pattern in which each thread block is + * assigned a consecutive sequence of input tiles + * + * \par Overview + * The input is evenly partitioned into \p p segments, where \p p is + * constant and corresponds loosely to the number of thread blocks that may + * actively reside on the target device. Each segment is comprised of + * consecutive tiles, where a tile is a small, constant-sized unit of input + * to be processed to completion before the thread block terminates or + * obtains more work. The kernel invokes \p p thread blocks, each + * of which iteratively consumes a segment of n/p elements + * in tile-size increments. + */ + GRID_MAPPING_RAKE, + + /** + * \brief An a "strip mining" access pattern in which the input tiles assigned + * to each thread block are separated by a stride equal to the the extent of + * the grid. + * + * \par Overview + * The input is evenly partitioned into \p p sets, where \p p is + * constant and corresponds loosely to the number of thread blocks that may + * actively reside on the target device. Each set is comprised of + * data tiles separated by stride \p tiles, where a tile is a small, + * constant-sized unit of input to be processed to completion before the + * thread block terminates or obtains more work. The kernel invokes \p p + * thread blocks, each of which iteratively consumes a segment of + * n/p elements in tile-size increments. + */ + GRID_MAPPING_STRIP_MINE, + + /** + * \brief A dynamic "queue-based" strategy for assigning input tiles to thread blocks. + * + * \par Overview + * The input is treated as a queue to be dynamically consumed by a grid of + * thread blocks. Work is atomically dequeued in tiles, where a tile is a + * unit of input to be processed to completion before the thread block + * terminates or obtains more work. The grid size \p p is constant, + * loosely corresponding to the number of thread blocks that may actively + * reside on the target device. + */ + GRID_MAPPING_DYNAMIC, +}; + + +/** @} */ // end group GridModule + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_GRID_GRID_MAPPING_HPP_ diff --git a/3rdparty/cub/grid/grid_queue.cuh b/3rdparty/cub/grid/grid_queue.cuh new file mode 100644 index 0000000000000000000000000000000000000000..68235b99ac9dcc9f0a453632caaff94c09ce8745 --- /dev/null +++ b/3rdparty/cub/grid/grid_queue.cuh @@ -0,0 +1,235 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_GRID_GRID_QUEUE_HPP_ +#define HIPCUB_ROCPRIM_GRID_GRID_QUEUE_HPP_ + +#include + +#include "../config.hpp" + +BEGIN_HIPCUB_NAMESPACE + +/** + * \addtogroup GridModule + * @{ + */ + + +/** + * \brief GridQueue is a descriptor utility for dynamic queue management. + * + * \par Overview + * GridQueue descriptors provides abstractions for "filling" or + * "draining" globally-shared vectors. + * + * \par + * A "filling" GridQueue works by atomically-adding to a zero-initialized counter, + * returning a unique offset for the calling thread to write its items. + * The GridQueue maintains the total "fill-size". The fill counter must be reset + * using GridQueue::ResetFill by the host or kernel instance prior to the kernel instance that + * will be filling. + * + * \par + * Similarly, a "draining" GridQueue works by works by atomically-incrementing a + * zero-initialized counter, returning a unique offset for the calling thread to + * read its items. Threads can safely drain until the array's logical fill-size is + * exceeded. The drain counter must be reset using GridQueue::ResetDrain or + * GridQueue::FillAndResetDrain by the host or kernel instance prior to the kernel instance that + * will be filling. (For dynamic work distribution of existing data, the corresponding fill-size + * is simply the number of elements in the array.) + * + * \par + * Iterative work management can be implemented simply with a pair of flip-flopping + * work buffers, each with an associated set of fill and drain GridQueue descriptors. + * + * \tparam OffsetT Signed integer type for global offsets + */ +template +class GridQueue +{ +private: + + /// Counter indices + enum + { + FILL = 0, + DRAIN = 1, + }; + + /// Pair of counters + OffsetT *d_counters; + +public: + + /// Returns the device allocation size in bytes needed to construct a GridQueue instance + __host__ __device__ __forceinline__ + static size_t AllocationSize() + { + return sizeof(OffsetT) * 2; + } + + + /// Constructs an invalid GridQueue descriptor + __host__ __device__ __forceinline__ GridQueue() + : + d_counters(NULL) + {} + + + /// Constructs a GridQueue descriptor around the device storage allocation + __host__ __device__ __forceinline__ GridQueue( + void *d_storage) ///< Device allocation to back the GridQueue. Must be at least as big as AllocationSize(). + : + d_counters((OffsetT*) d_storage) + {} + + + /// This operation sets the fill-size and resets the drain counter, preparing the GridQueue for draining in the next kernel instance. To be called by the host or by a kernel prior to that which will be draining. + HIPCUB_DEVICE cudaError_t FillAndResetDrain( + OffsetT fill_size, + cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + (void)stream; + d_counters[FILL] = fill_size; + d_counters[DRAIN] = 0; + result = cudaSuccess; + return result; + } + + HIPCUB_HOST cudaError_t FillAndResetDrain( + OffsetT fill_size, + cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + OffsetT counters[2]; + counters[FILL] = fill_size; + counters[DRAIN] = 0; + result = CubDebug(cudaMemcpyAsync(d_counters, counters, sizeof(OffsetT) * 2, cudaMemcpyHostToDevice, stream)); + return result; + } + + /// This operation resets the drain so that it may advance to meet the existing fill-size. To be called by the host or by a kernel prior to that which will be draining. + HIPCUB_DEVICE cudaError_t ResetDrain(cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + (void)stream; + d_counters[DRAIN] = 0; + result = cudaSuccess; + return result; + } + + HIPCUB_HOST cudaError_t ResetDrain(cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + result = CubDebug(cudaMemsetAsync(d_counters + DRAIN, 0, sizeof(OffsetT), stream)); + return result; + } + + + /// This operation resets the fill counter. To be called by the host or by a kernel prior to that which will be filling. + HIPCUB_DEVICE cudaError_t ResetFill(cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + (void)stream; + d_counters[FILL] = 0; + result = cudaSuccess; + return result; + } + + HIPCUB_HOST cudaError_t ResetFill(cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + result = CubDebug(cudaMemsetAsync(d_counters + FILL, 0, sizeof(OffsetT), stream)); + return result; + } + + + /// Returns the fill-size established by the parent or by the previous kernel. + HIPCUB_DEVICE cudaError_t FillSize( + OffsetT &fill_size, + cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + (void)stream; + fill_size = d_counters[FILL]; + result = cudaSuccess; + return result; + } + + HIPCUB_HOST cudaError_t FillSize( + OffsetT &fill_size, + cudaStream_t stream = 0) + { + cudaError_t result = cudaErrorUnknown; + result = CubDebug(cudaMemcpyAsync(&fill_size, d_counters + FILL, sizeof(OffsetT), cudaMemcpyDeviceToHost, stream)); + return result; + } + + + /// Drain \p num_items from the queue. Returns offset from which to read items. To be called from cuda kernel. + HIPCUB_DEVICE OffsetT Drain(OffsetT num_items) + { + return atomicAdd(d_counters + DRAIN, num_items); + } + + + /// Fill \p num_items into the queue. Returns offset from which to write items. To be called from cuda kernel. + HIPCUB_DEVICE OffsetT Fill(OffsetT num_items) + { + return atomicAdd(d_counters + FILL, num_items); + } +}; + + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + +/** + * Reset grid queue (call with 1 block of 1 thread) + */ +template +__global__ void FillAndResetDrainKernel( + GridQueue grid_queue, + OffsetT num_items) +{ + grid_queue.FillAndResetDrain(num_items); +} + + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + +/** @} */ // end group GridModule + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_GRID_GRID_QUEUE_HPP_ diff --git a/3rdparty/cub/iterator/arg_index_input_iterator.cuh b/3rdparty/cub/iterator/arg_index_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7762d64b39068b076395a4f5bc34e06e74c41320 --- /dev/null +++ b/3rdparty/cub/iterator/arg_index_input_iterator.cuh @@ -0,0 +1,61 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_ARG_INDEX_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_ARG_INDEX_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#include + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template< + typename InputIterator, + typename Difference = std::ptrdiff_t, + typename Value = typename std::iterator_traits::value_type +> +using ArgIndexInputIterator = ::rocprim::arg_index_iterator; + +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_ARG_INDEX_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/cache_modified_input_iterator.cuh b/3rdparty/cub/iterator/cache_modified_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ab9fca56a49993367356eb4ca15673332dec5f56 --- /dev/null +++ b/3rdparty/cub/iterator/cache_modified_input_iterator.cuh @@ -0,0 +1,173 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_CACHE_MODIFIED_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_CACHE_MODIFIED_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../util_type.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +BEGIN_HIPCUB_NAMESPACE + +template < + CacheLoadModifier MODIFIER, + typename ValueType, + typename OffsetT = ptrdiff_t> +class CacheModifiedInputIterator +{ +public: + + // Required iterator traits + typedef CacheModifiedInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef ValueType value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the iterator can point to + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category + +public: + + /// Wrapped native pointer + ValueType* ptr; + + /// Constructor + __host__ __device__ __forceinline__ CacheModifiedInputIterator( + ValueType* ptr) ///< Native pointer to wrap + : + ptr(const_cast::type *>(ptr)) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + ptr++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + ptr++; + return *this; + } + + /// Indirection + __device__ __forceinline__ reference operator*() const + { + return ThreadLoad(ptr); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(ptr + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + ptr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(ptr - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + ptr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return ptr - other.ptr; + } + + /// Array subscript + template + __device__ __forceinline__ reference operator[](Distance n) const + { + return ThreadLoad(ptr + n); + } + + /// Structure dereference + __device__ __forceinline__ pointer operator->() + { + return &ThreadLoad(ptr); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (ptr == rhs.ptr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (ptr != rhs.ptr); + } + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& /*itr*/) + { + return os; + } + +#endif + +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_CACHE_MODIFIED_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/cache_modified_output_iterator.cuh b/3rdparty/cub/iterator/cache_modified_output_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6be4e10aeea01a9dff0b5913c155fc61e6fb0086 --- /dev/null +++ b/3rdparty/cub/iterator/cache_modified_output_iterator.cuh @@ -0,0 +1,190 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_CACHE_MODIFIED_OUTPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_CACHE_MODIFIED_OUTPUT_ITERATOR_HPP_ + +#include +#include + +#include "../thread/thread_load.cuh" +#include "../thread/thread_store.cuh" +#include "../util_type.cuh" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +BEGIN_HIPCUB_NAMESPACE + +template < + CacheStoreModifier MODIFIER, + typename ValueType, + typename OffsetT = ptrdiff_t> +class CacheModifiedOutputIterator +{ +private: + + // Proxy object + struct Reference + { + ValueType* ptr; + + /// Constructor + __host__ __device__ __forceinline__ Reference(ValueType* ptr) : ptr(ptr) {} + + /// Assignment + __device__ __forceinline__ ValueType operator =(ValueType val) + { + ThreadStore(ptr, val); + return val; + } + }; + +public: + + // Required iterator traits + typedef CacheModifiedOutputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef void value_type; ///< The type of the element the iterator can point to + typedef void pointer; ///< The type of a pointer to an element the iterator can point to + typedef Reference reference; ///< The type of a reference to an element the iterator can point to + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category + +private: + + ValueType* ptr; + +public: + + /// Constructor + template + __host__ __device__ __forceinline__ CacheModifiedOutputIterator( + QualifiedValueType* ptr) ///< Native pointer to wrap + : + ptr(const_cast::type *>(ptr)) + {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + ptr++; + return retval; + } + + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() + { + ptr++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const + { + return Reference(ptr); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(ptr + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + ptr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(ptr - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + ptr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return ptr - other.ptr; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const + { + return Reference(ptr + n); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (ptr == rhs.ptr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (ptr != rhs.ptr); + } + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + (void)itr; + return os; + } + +#endif +}; + +END_HIPCUB_NAMESPACE + +#endif diff --git a/3rdparty/cub/iterator/constant_input_iterator.cuh b/3rdparty/cub/iterator/constant_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3619594b262137f08d15ac5be931780a53af18dc --- /dev/null +++ b/3rdparty/cub/iterator/constant_input_iterator.cuh @@ -0,0 +1,60 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_CONSTANT_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_CONSTANT_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#include + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template< + typename ValueType, + typename OffsetT = std::ptrdiff_t +> +using ConstantInputIterator = ::rocprim::constant_iterator; + +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_CONSTANT_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/counting_input_iterator.cuh b/3rdparty/cub/iterator/counting_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b625c533b94fd55b73ef31eb67c76677cce7f730 --- /dev/null +++ b/3rdparty/cub/iterator/counting_input_iterator.cuh @@ -0,0 +1,60 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_COUNTING_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_COUNTING_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#include + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template< + typename ValueType, + typename OffsetT = std::ptrdiff_t +> +using CountingInputIterator = ::rocprim::counting_iterator; + +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_COUNTING_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/discard_output_iterator.cuh b/3rdparty/cub/iterator/discard_output_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e444925dac4ee3e6678b58656bf50baab4d77ee3 --- /dev/null +++ b/3rdparty/cub/iterator/discard_output_iterator.cuh @@ -0,0 +1,231 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_DISCARD_OUTPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_DISCARD_OUTPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +BEGIN_HIPCUB_NAMESPACE +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +/** + * \addtogroup UtilIterator + * @{ + */ + + +/** + * \brief A discard iterator + */ +template +class DiscardOutputIterator +{ +public: + + // Required iterator traits + typedef DiscardOutputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of subtracting one iterator from another + typedef void value_type; ///< The type of the element the iterator can point to + typedef void pointer; ///< The type of a pointer to an element the iterator can point to + typedef void reference; ///< The type of a reference to an element the iterator can point to + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::any_system_tag, + thrust::random_access_traversal_tag, + value_type, + reference + >::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag iterator_category; ///< The iterator category +#endif // THRUST_VERSION + +private: + + OffsetT offset; + +public: + + /// Constructor + __host__ __device__ __forceinline__ DiscardOutputIterator( + OffsetT offset = 0) ///< Base offset + : + offset(offset) + {} + + /** + * @typedef self_type + * @brief Postfix increment + */ + __host__ __device__ __forceinline__ self_type operator++(int) + { + self_type retval = *this; + offset++; + return retval; + } + + /** + * @typedef self_type + * @brief Postfix increment + */ + __host__ __device__ __forceinline__ self_type operator++() + { + offset++; + return *this; + } + + /** + * @typedef self_type + * @brief Indirection + */ + __host__ __device__ __forceinline__ self_type& operator*() + { + // return self reference, which can be assigned to anything + return *this; + } + + /** + * @typedef self_type + * @brief Addition + */ + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const + { + self_type retval(offset + n); + return retval; + } + + /** + * @typedef self_type + * @brief Addition assignment + */ + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) + { + offset += n; + return *this; + } + + /** + * @typedef self_type + * @brief Subtraction assignment + */ + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const + { + self_type retval(offset - n); + return retval; + } + + /** + * @typedef self_type + * @brief Subtraction assignment + */ + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) + { + offset -= n; + return *this; + } + + /** + * @typedef self_type + * @brief Distance + */ + __host__ __device__ __forceinline__ difference_type operator-(self_type other) const + { + return offset - other.offset; + } + + /** + * @typedef self_type + * @brief Array subscript + */ + template + __host__ __device__ __forceinline__ self_type& operator[](Distance) + { + // return self reference, which can be assigned to anything + return *this; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() + { + return; + } + + /// Assignment to anything else (no-op) + template + __host__ __device__ __forceinline__ void operator=(T const&) + {} + + /// Cast to void* operator + __host__ __device__ __forceinline__ operator void*() const { return NULL; } + + /** + * @typedef self_type + * @brief Equal to + */ + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) + { + return (offset == rhs.offset); + } + + /** + * @typedef self_type + * @brief Not equal to + */ + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) + { + return (offset != rhs.offset); + } + + /** + * @typedef self_type + * @brief ostream operator + */ + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) + { + os << "[" << itr.offset << "]"; + return os; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_DISCARD_OUTPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/tex_obj_input_iterator.cuh b/3rdparty/cub/iterator/tex_obj_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a87a16c517bb9cd7363c6865ecfa5b6f9fa72b32 --- /dev/null +++ b/3rdparty/cub/iterator/tex_obj_input_iterator.cuh @@ -0,0 +1,88 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_TEX_OBJ_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_TEX_OBJ_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + typename OffsetT = std::ptrdiff_t +> +class TexObjInputIterator : public ::rocprim::texture_cache_iterator +{ + public: + template + inline + cudaError_t BindTexture(Qualified* ptr, + size_t bytes = size_t(-1), + size_t texture_offset = 0) + { + return (cudaError_t)::rocprim::texture_cache_iterator::bind_texture(ptr, bytes, texture_offset); + } + + inline cudaError_t UnbindTexture() + { + return (cudaError_t)::rocprim::texture_cache_iterator::unbind_texture(); + } + + HIPCUB_HOST_DEVICE inline + ~TexObjInputIterator() = default; + + HIPCUB_HOST_DEVICE inline + TexObjInputIterator() : ::rocprim::texture_cache_iterator() + { + } + + HIPCUB_HOST_DEVICE inline + TexObjInputIterator(const ::rocprim::texture_cache_iterator other) + : ::rocprim::texture_cache_iterator(other) + { + } + +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_TEX_OBJ_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/tex_ref_input_iterator.cuh b/3rdparty/cub/iterator/tex_ref_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4c3aeb5227848f87cdd626925380a7e240009fc7 --- /dev/null +++ b/3rdparty/cub/iterator/tex_ref_input_iterator.cuh @@ -0,0 +1,87 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_TEX_REF_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_TEX_REF_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#if (THRUST_VERSION >= 100700) // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int UNIQUE_ID, // Unused parameter for compatibility with original definition in cub + typename OffsetT = std::ptrdiff_t +> +class TexRefInputIterator : public ::rocprim::texture_cache_iterator +{ + public: + template + inline + cudaError_t BindTexture(Qualified* ptr, + size_t bytes = size_t(-1), + size_t texture_offset = 0) + { + return (cudaError_t)::rocprim::texture_cache_iterator::bind_texture(ptr, bytes, texture_offset); + } + + inline cudaError_t UnbindTexture() + { + return (cudaError_t)::rocprim::texture_cache_iterator::unbind_texture(); + } + + HIPCUB_HOST_DEVICE inline + ~TexRefInputIterator() = default; + + HIPCUB_HOST_DEVICE inline + TexRefInputIterator() : ::rocprim::texture_cache_iterator() + { + } + + HIPCUB_HOST_DEVICE inline + TexRefInputIterator(const ::rocprim::texture_cache_iterator other) + : ::rocprim::texture_cache_iterator(other) + { + } + +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_TEX_OBJ_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/iterator/transform_input_iterator.cuh b/3rdparty/cub/iterator/transform_input_iterator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..344cd62edc62a87e47e94adf6c5325576be4aaee --- /dev/null +++ b/3rdparty/cub/iterator/transform_input_iterator.cuh @@ -0,0 +1,63 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_ITERATOR_TRANSFORM_INPUT_ITERATOR_HPP_ +#define HIPCUB_ROCPRIM_ITERATOR_TRANSFORM_INPUT_ITERATOR_HPP_ + +#include +#include + +#include "../config.hpp" + +#include + +#if (THRUST_VERSION >= 100700) + // This iterator is compatible with Thrust API 1.7 and newer + #include + #include +#endif // THRUST_VERSION + + +BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template< + typename ValueType, + typename ConversionOp, + typename InputIteratorT, + typename OffsetT = std::ptrdiff_t // ignored +> +using TransformInputIterator = ::rocprim::transform_iterator; + +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_ITERATOR_TRANSFORM_INPUT_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_adjacent_difference.hpp b/3rdparty/cub/rocprim/block/block_adjacent_difference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..173a73a2d2f3d218b3434a7611992c87b66868b0 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_adjacent_difference.hpp @@ -0,0 +1,1155 @@ +/****************************************************************************** +* Copyright (c) 2011, Duane Merrill. All rights reserved. +* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. +* Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in the +* documentation and/or other materials provided with the distribution. +* * Neither the name of the NVIDIA CORPORATION nor the +* names of its contributors may be used to endorse or promote products +* derived from this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +******************************************************************************/ + +#ifndef ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ +#define ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ + + +#include "detail/block_adjacent_difference_impl.hpp" + +#include "../config.hpp" +#include "../detail/various.hpp" + + + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The \p block_adjacent_difference class is a block level parallel primitive which provides +/// methods for applying binary functions for pairs of consecutive items partition across a thread +/// block. +/// +/// \tparam T - the input type. +/// \tparam BlockSize - the number of threads in a block. +/// +/// \par Overview +/// * There are two types of flags: +/// * Head flags. +/// * Tail flags. +/// * The above flags are used to differentiate items from their predecessors or successors. +/// * E.g. Head flags are convenient for differentiating disjoint data segments as part of a +/// segmented reduction/scan. +/// +/// \par Examples +/// \parblock +/// In the examples discontinuity operation is performed on block of 128 threads, using type +/// \p int. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize discontinuity for int and a block of 128 threads +/// using block_adjacent_difference_int = rocprim::block_adjacent_difference; +/// // allocate storage in shared memory +/// __shared__ block_adjacent_difference_int::storage_type storage; +/// +/// // segment of consecutive items to be used +/// int input[8]; +/// ... +/// int head_flags[8]; +/// block_adjacent_difference_int b_discontinuity; +/// using flag_op_type = typename rocprim::greater; +/// b_discontinuity.flag_heads(head_flags, input, flag_op_type(), storage); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_adjacent_difference +#ifndef DOXYGEN_SHOULD_SKIP_THIS // hide implementation detail from documentation + : private detail::block_adjacent_difference_impl +#endif // DOXYGEN_SHOULD_SKIP_THIS +{ + using base_type = detail::block_adjacent_difference_impl; + + static constexpr unsigned BlockSize = base_type::BlockSize; + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + typename base_type::storage_type left; + typename base_type::storage_type right; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; + #endif + + /// \brief Tags \p head_flags that indicate discontinuities between items partitioned + /// across the thread block, where the first item has no reference and is always + /// flagged. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_left() or block_discontinuity::flag_heads() instead. + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int head_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads(head_flags, input, flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_left or block_discontinuity.flag_heads instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = false; + base_type::template apply_left( + input, head_flags, flag_op, input[0] /* predecessor */, storage.get().left); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_left() or block_discontinuity::flag_heads() instead. + /// This overload does not take a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_left or block_discontinuity.flag_heads instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads(head_flags, input, flag_op, storage); + } + + /// \brief Tags \p head_flags that indicate discontinuities between items partitioned + /// across the thread block, where the first item of the first thread is compared against + /// a \p tile_predecessor_item. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_left() or block_discontinuity::flag_heads() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads(head_flags, tile_item, input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_left or block_discontinuity.flag_heads instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = true; + base_type::template apply_left( + input, head_flags, flag_op, tile_predecessor_item, storage.get().left); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_left() or block_discontinuity::flag_heads() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_left or block_discontinuity.flag_heads instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads(head_flags, tile_predecessor_item, input, flag_op, storage); + } + + /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned + /// across the thread block, where the last item has no reference and is always + /// flagged. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_right() or block_discontinuity::flag_tails() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_tails(tail_flags, input, flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_right or block_discontinuity.flag_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_successor = false; + base_type::template apply_right( + input, tail_flags, flag_op, input[0] /* successor */, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_right() or block_discontinuity::flag_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_right or block_discontinuity.flag_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_tails(tail_flags, input, flag_op, storage); + } + + /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned + /// across the thread block, where the last item of the last thread is compared against + /// a \p tile_successor_item. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_right() or block_discontinuity::flag_tails() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_tails(tail_flags, tile_item, input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_right or block_discontinuity.flag_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_successor = true; + base_type::template apply_right( + input, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use subtract_right() or block_discontinuity::flag_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use subtract_right or block_discontinuity.flag_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_tails(tail_flags, tile_successor_item, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, input, + /// flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = false; + static constexpr auto with_successor = false; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tail_flags, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the last item of the + /// last thread is compared against a \p tile_successor_item. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, tile_item, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = false; + static constexpr auto with_successor = true; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tail_flags, tile_successor_item, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the first item of the + /// first thread is compared against a \p tile_predecessor_item. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tile_item, tail_flags, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = true; + static constexpr auto with_successor = false; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, tile_predecessor_item, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the first and last items of + /// the first and last thread is compared against a \p tile_predecessor_item and + /// a \p tile_successor_item. + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reuse + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; + /// // allocate storage in shared memory + /// __shared__ block_adjacent_difference_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_predecessor_item = 0; + /// int tile_successor_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_predecessor_item = ... + /// tile_successor_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_adjacent_difference_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, + /// tail_flags, tile_successor_item, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = true; + static constexpr auto with_successor = true; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, tile_predecessor_item, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// \deprecated The flags API of block_adjacent_difference is deprecated, + /// use block_discontinuity::flag_heads_and_tails() instead. + /// + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + [[deprecated("The flags API of block_adjacent_difference is deprecated." + "Use block_discontinuity.flag_heads_and_tails instead.")]] + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, tile_successor_item, + input, flag_op, storage + ); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the left item. + /// + /// The first item in the first thread is copied from the input then for the rest the following + /// code applies. + /// \code + /// // For each i in [1, block_size * ItemsPerThread) across threads in a block + /// output[i] = op(input[i], input[i-1]); + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param storage reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_left(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = false; + + base_type::template apply_left( + input, output, op, input[0] /* predecessor */, storage.get().left); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the left item, with an explicit item before + /// the tile. + /// + /// \code + /// // For the first item on the first thread use the tile predecessor + /// output[0] = op(input[0], tile_predecessor) + /// // For other items, i in [1, block_size * ItemsPerThread) across threads in a block + /// output[i] = op(input[i], input[i-1]); + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param [in] tile_predecessor - the item before the tile, will be used as the input + /// of the first application of `op` + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_left(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const T tile_predecessor, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = true; + + base_type::template apply_left( + input, output, op, tile_predecessor, storage.get().left); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the left item, in a partial tile. + /// + /// \code + /// output[0] = input[0] + /// // For each item i in [1, valid_items) across threads in a block + /// output[i] = op(input[i], input[i-1]); + /// // Just copy "invalid" items in [valid_items, block_size * ItemsPerThread) + /// output[i] = input[i] + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param [in] valid_items - number of items in the block which are considered "valid" and will + /// be used. Must be less or equal to `BlockSize` * `ItemsPerThread` + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_left_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const unsigned int valid_items, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = false; + + base_type::template apply_left_partial( + input, output, op, input[0] /* predecessor */, valid_items, storage.get().left); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the left item, in a partial tile with a + /// predecessor. + /// + /// This combines subtract_left_partial() with a tile predecessor. + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param [in] tile_predecessor - the item before the tile, will be used as the input + /// of the first application of `op` + /// \param [in] valid_items - number of items in the block which are considered "valid" and will + /// be used. Must be less or equal to `BlockSize` * `ItemsPerThread` + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_left_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const T tile_predecessor, + const unsigned int valid_items, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = true; + static constexpr auto with_predecessor = true; + + base_type::template apply_left_partial( + input, output, op, tile_predecessor, valid_items, storage.get().left); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the right item. + /// + /// The last item in the last thread is copied from the input then for the rest the following + /// code applies. + /// \code + /// // For each i in [0, block_size * ItemsPerThread - 1) across threads in a block + /// output[i] = op(input[i], input[i+1]); + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_right(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = false; + static constexpr auto with_successor = false; + + base_type::template apply_right( + input, output, op, input[0] /* successor */, storage.get().right); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the right item, with an explicit item after + /// the tile. + /// + /// \code + /// // For each items i in [0, block_size * ItemsPerThread - 1) across threads in a block + /// output[i] = op(input[i], input[i+1]); + /// // For the last item on the last thread use the tile successor + /// output[block_size * ItemsPerThread - 1] = + /// op(input[block_size * ItemsPerThread - 1], tile_successor) + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param [in] tile_successor - the item after the tile, will be used as the input + /// of the last application of `op` + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_right(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const T tile_successor, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = false; + static constexpr auto with_successor = true; + + base_type::template apply_right( + input, output, op, tile_successor, storage.get().right); + } + + /// \brief Apply a function to each consecutive pair of elements partitioned across threads in + /// the block and write the output to the position of the right item, in a partial tile. + /// + /// \code + /// // For each item i in [0, valid_items) across threads in a block + /// output[i] = op(input[i], input[i + 1]); + /// // Just copy "invalid" items in [valid_items, block_size * ItemsPerThread) + /// output[i] = input[i] + /// \endcode + /// + /// \tparam Output - [inferred] the type of output, must be assignable from the result of `op` + /// \tparam ItemsPerThread - [inferred] the number of items processed by each thread + /// \tparam BinaryFunction - [inferred] the type of the function to apply + /// \param [in] input - array that data is loaded from partitioned across the threads in the block + /// \param [out] output - array where the result of function application will be written to + /// \param [in] op - binary function applied to the items. + /// The signature of the function should be equivalent to the following: + /// `bool f(const T &a, const T &b)` The signature does not need to have + /// `const &` but the function object must not modify the objects passed to it. + /// \param [in] valid_items - number of items in the block which are considered "valid" and will + /// be used. Must be less or equal to `BlockSize` * `ItemsPerThread` + /// \param storage - reference to a temporary storage object of type #storage_type + /// \par Storage reuse + /// Synchronization barrier should be placed before `storage` is reused + /// or repurposed: `__syncthreads()` or \link syncthreads() rocprim::syncthreads() \endlink. + template + ROCPRIM_DEVICE ROCPRIM_INLINE void subtract_right_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const unsigned int valid_items, + storage_type& storage) + { + static constexpr auto as_flags = false; + static constexpr auto reversed = false; + + base_type::template apply_right_partial( + input, output, op, valid_items, storage.get().right); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_ADJACENT_DIFFERENCE_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_discontinuity.hpp b/3rdparty/cub/rocprim/block/block_discontinuity.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13a1d6809ee1955eb74ee414a3726a011a342ebb --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_discontinuity.hpp @@ -0,0 +1,803 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ +#define ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ + + +#include "detail/block_adjacent_difference_impl.hpp" + +#include "../config.hpp" +#include "../detail/various.hpp" + + + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The \p block_discontinuity class is a block level parallel primitive which provides +/// methods for flagging items that are discontinued within an ordered set of items across +/// threads in a block. +/// +/// \tparam T - the input type. +/// \tparam BlockSize - the number of threads in a block. +/// +/// \par Overview +/// * There are two types of flags: +/// * Head flags. +/// * Tail flags. +/// * The above flags are used to differentiate items from their predecessors or successors. +/// * E.g. Head flags are convenient for differentiating disjoint data segments as part of a +/// segmented reduction/scan. +/// +/// \par Examples +/// \parblock +/// In the examples discontinuity operation is performed on block of 128 threads, using type +/// \p int. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize discontinuity for int and a block of 128 threads +/// using block_discontinuity_int = rocprim::block_discontinuity; +/// // allocate storage in shared memory +/// __shared__ block_discontinuity_int::storage_type storage; +/// +/// // segment of consecutive items to be used +/// int input[8]; +/// ... +/// int head_flags[8]; +/// block_discontinuity_int b_discontinuity; +/// using flag_op_type = typename rocprim::greater; +/// b_discontinuity.flag_heads(head_flags, input, flag_op_type(), storage); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_discontinuity +#ifndef DOXYGEN_SHOULD_SKIP_THIS // hide implementation detail from documentation + : private detail::block_adjacent_difference_impl +#endif // DOXYGEN_SHOULD_SKIP_THIS +{ + using base_type = detail::block_adjacent_difference_impl; + + static constexpr unsigned BlockSize = base_type::BlockSize; + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + typename base_type::storage_type left; + typename base_type::storage_type right; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; + #endif + + /// \brief Tags \p head_flags that indicate discontinuities between items partitioned + /// across the thread block, where the first item has no reference and is always + /// flagged. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int head_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads(head_flags, input, flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = false; + base_type::template apply_left( + input, head_flags, flag_op, input[0] /* predecessor */, storage.get().left); + } + + /// \overload + /// This overload does not take a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads(head_flags, input, flag_op, storage); + } + + /// \brief Tags \p head_flags that indicate discontinuities between items partitioned + /// across the thread block, where the first item of the first thread is compared against + /// a \p tile_predecessor_item. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads(head_flags, tile_item, input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = true; + base_type::template apply_left( + input, head_flags, flag_op, tile_predecessor_item, storage.get().left); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads(head_flags, tile_predecessor_item, input, flag_op, storage); + } + + /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned + /// across the thread block, where the last item has no reference and is always + /// flagged. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_tails(tail_flags, input, flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_successor = false; + base_type::template apply_right( + input, tail_flags, flag_op, input[0] /* successor */, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_tails(tail_flags, input, flag_op, storage); + } + + /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned + /// across the thread block, where the last item of the last thread is compared against + /// a \p tile_successor_item. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_tails(tail_flags, tile_item, input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_successor = true; + base_type::template apply_right( + input, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_tails(Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_tails(tail_flags, tile_successor_item, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, input, + /// flag_op_type(), storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = false; + static constexpr auto with_successor = false; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tail_flags, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the last item of the + /// last thread is compared against a \p tile_successor_item. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, tile_item, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = false; + static constexpr auto with_successor = true; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tail_flags, tile_successor_item, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the first item of the + /// first thread is compared against a \p tile_predecessor_item. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tile_item, tail_flags, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = true; + static constexpr auto with_successor = false; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, tile_predecessor_item, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, input, flag_op, storage); + } + + /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities + /// between items partitioned across the thread block, where the first and last items of + /// the first and last thread is compared against a \p tile_predecessor_item and + /// a \p tile_successor_item. + /// + /// \tparam ItemsPerThread - [inferred] the number of items to be processed by + /// each thread. + /// \tparam Flag - [inferred] the flag type. + /// \tparam FlagOp - [inferred] type of binary function used for flagging. + /// + /// \param [out] head_flags - array that contains the head flags. + /// \param [in] tile_predecessor_item - first tile item from thread to be compared + /// against. + /// \param [out] tail_flags - array that contains the tail flags. + /// \param [in] tile_successor_item - last tile item from thread to be compared + /// against. + /// \param [in] input - array that data is loaded from. + /// \param [in] flag_op - binary operation function object that will be used for flagging. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. + /// The signature does not need to have const &, but function object + /// must not modify the objects passed to it. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize discontinuity for int and a block of 128 threads + /// using block_discontinuity_int = rocprim::block_discontinuity; + /// // allocate storage in shared memory + /// __shared__ block_discontinuity_int::storage_type storage; + /// + /// // segment of consecutive items to be used + /// int input[8]; + /// int tile_predecessor_item = 0; + /// int tile_successor_item = 0; + /// if (threadIdx.x == 0) + /// { + /// tile_predecessor_item = ... + /// tile_successor_item = ... + /// } + /// ... + /// int head_flags[8]; + /// int tail_flags[8]; + /// block_discontinuity_int b_discontinuity; + /// using flag_op_type = typename rocprim::greater; + /// b_discontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, + /// tail_flags, tile_successor_item, + /// input, flag_op_type(), + /// storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op, + storage_type& storage) + { + static constexpr auto as_flags = true; + static constexpr auto reversed = false; + static constexpr auto with_predecessor = true; + static constexpr auto with_successor = true; + + // Copy items in case head_flags is aliased with input + T items[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { + items[i] = input[i]; + } + + base_type::template apply_left( + items, head_flags, flag_op, tile_predecessor_item, storage.get().left); + + base_type::template apply_right( + items, tail_flags, flag_op, tile_successor_item, storage.get().right); + } + + /// \overload + /// This overload does not accept a reference to temporary storage, instead it is declared as + /// part of the function itself. Note that this does NOT decrease the shared memory requirements + /// of a kernel using this function. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], + T tile_predecessor_item, + Flag (&tail_flags)[ItemsPerThread], + T tile_successor_item, + const T (&input)[ItemsPerThread], + FlagOp flag_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + flag_heads_and_tails( + head_flags, tile_predecessor_item, tail_flags, tile_successor_item, + input, flag_op, storage + ); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_DISCONTINUITY_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_exchange.hpp b/3rdparty/cub/rocprim/block/block_exchange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70f4f00134344a9aa7345ad037b166b997b10b16 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_exchange.hpp @@ -0,0 +1,769 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ +#define ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The \p block_exchange class is a block level parallel primitive which provides +/// methods for rearranging items partitioned across threads in a block. +/// +/// \tparam T - the input type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - the number of items contributed by each thread. +/// +/// \par Overview +/// * The \p block_exchange class supports the following rearrangement methods: +/// * Transposing a blocked arrangement to a striped arrangement. +/// * Transposing a striped arrangement to a blocked arrangement. +/// * Transposing a blocked arrangement to a warp-striped arrangement. +/// * Transposing a warp-striped arrangement to a blocked arrangement. +/// * Scattering items to a blocked arrangement. +/// * Scattering items to a striped arrangement. +/// * Data is automatically be padded to ensure zero bank conflicts. +/// +/// \par Examples +/// \parblock +/// In the examples exchange operation is performed on block of 128 threads, using type +/// \p int with 8 items per thread. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize block_exchange for int, block of 128 threads and 8 items per thread +/// using block_exchange_int = rocprim::block_exchange; +/// // allocate storage in shared memory +/// __shared__ block_exchange_int::storage_type storage; +/// +/// int items[8]; +/// ... +/// block_exchange_int b_exchange; +/// b_exchange.blocked_to_striped(items, items, storage); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_exchange +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Select warp size + static constexpr unsigned int warp_size = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + // Number of warps in block + static constexpr unsigned int warps_no = (BlockSize + warp_size - 1) / warp_size; + + // Minimize LDS bank conflicts for power-of-two strides, i.e. when items accessed + // using `thread_id * ItemsPerThread` pattern where ItemsPerThread is power of two + // (all exchanges from/to blocked). + static constexpr bool has_bank_conflicts = + ItemsPerThread >= 2 && ::rocprim::detail::is_power_of_two(ItemsPerThread); + static constexpr unsigned int banks_no = ::rocprim::detail::get_lds_banks_no(); + static constexpr unsigned int bank_conflicts_padding = + has_bank_conflicts ? (BlockSize * ItemsPerThread / banks_no) : 0; + + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + T buffer[BlockSize * ItemsPerThread + bank_conflicts_padding]; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Transposes a blocked arrangement of items to a striped arrangement + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void blocked_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + blocked_to_striped(input, output, storage); + } + + /// \brief Transposes a blocked arrangement of items to a striped arrangement + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.blocked_to_striped(items, items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[index(flat_id * ItemsPerThread + i)] = input[i]; + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[index(i * BlockSize + flat_id)]; + } + } + + /// \brief Transposes a striped arrangement of items to a blocked arrangement + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void striped_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + striped_to_blocked(input, output, storage); + } + + /// \brief Transposes a striped arrangement of items to a blocked arrangement + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.striped_to_blocked(items, items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void striped_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[index(i * BlockSize + flat_id)] = input[i]; + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[index(flat_id * ItemsPerThread + i)]; + } + } + + /// \brief Transposes a blocked arrangement of items to a warp-striped arrangement + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void blocked_to_warp_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + blocked_to_warp_striped(input, output, storage); + } + + /// \brief Transposes a blocked arrangement of items to a warp-striped arrangement + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.blocked_to_warp_striped(items, items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_warp_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + constexpr unsigned int items_per_warp = warp_size * ItemsPerThread; + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int warp_id = ::rocprim::warp_id(); + const unsigned int current_warp_size = get_current_warp_size(); + const unsigned int offset = warp_id * items_per_warp; + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[index(offset + lane_id * ItemsPerThread + i)] = input[i]; + } + + ::rocprim::wave_barrier(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[index(offset + i * current_warp_size + lane_id)]; + } + } + + /// \brief Transposes a warp-striped arrangement of items to a blocked arrangement + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void warp_striped_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + warp_striped_to_blocked(input, output, storage); + } + + /// \brief Transposes a warp-striped arrangement of items to a blocked arrangement + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.warp_striped_to_blocked(items, items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void warp_striped_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + constexpr unsigned int items_per_warp = warp_size * ItemsPerThread; + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int warp_id = ::rocprim::warp_id(); + const unsigned int current_warp_size = get_current_warp_size(); + const unsigned int offset = warp_id * items_per_warp; + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[index(offset + i * current_warp_size + lane_id)] = input[i]; + } + + ::rocprim::wave_barrier(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[index(offset + lane_id * ItemsPerThread + i)]; + } + } + + /// \brief Scatters items to a blocked arrangement based on their ranks + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [out] ranks - array that has rank of data. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void scatter_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + scatter_to_blocked(input, output, ranks, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void gather_from_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + gather_from_striped(input, output, ranks, storage); + } + + /// \brief Scatters items to a blocked arrangement based on their ranks + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [out] ranks - array that has rank of data. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// int ranks[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.scatter_to_blocked(items, items, ranks, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + storage_.buffer[index(rank)] = input[i]; + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[index(flat_id * ItemsPerThread + i)]; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void gather_from_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[index(i * BlockSize + flat_id)] = input[i]; + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + output[i] = storage_.buffer[index(rank)]; + } + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [out] ranks - array that has rank of data. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void scatter_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + scatter_to_striped(input, output, ranks, storage); + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [out] ranks - array that has rank of data. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// int ranks[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.scatter_to_striped(items, items, ranks, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + storage_.buffer[rank] = input[i]; + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[i * BlockSize + flat_id]; + } + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block, guarded by rank. + /// + /// \par Overview + /// * Items with rank -1 are not scattered. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] ranks - array that has rank of data. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void scatter_to_striped_guarded(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + scatter_to_striped_guarded(input, output, ranks, storage); + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block, guarded by rank, using temporary storage. + /// + /// \par Overview + /// * Items with rank -1 are not scattered. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] ranks - array that has rank of data. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// int ranks[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.scatter_to_striped_guarded(items, items, ranks, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_striped_guarded(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + if(rank >= 0) + { + storage_.buffer[rank] = input[i]; + } + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[i * BlockSize + flat_id]; + } + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block, with a flag to denote validity. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// \tparam ValidFlag - [inferred] the validity flag type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] ranks - array that has rank of data. + /// \param [in] is_valid - array that has flags to denote validity. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void scatter_to_striped_flagged(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + const ValidFlag (&is_valid)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + scatter_to_striped_flagged(input, output, ranks, is_valid, storage); + } + + /// \brief Scatters items to a striped arrangement based on their ranks + /// across the thread block, with a flag to denote validity, using temporary + /// storage. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// \tparam ValidFlag - [inferred] the validity flag type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] ranks - array that has rank of data. + /// \param [in] is_valid - array that has flags to denote validity. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// int ranks[8]; + /// int flags[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.scatter_to_striped_flagged(items, items, ranks, flags, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_striped_flagged(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + const ValidFlag (&is_valid)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + if(is_valid[i]) + { + storage_.buffer[rank] = input[i]; + } + } + ::rocprim::syncthreads(); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[i * BlockSize + flat_id]; + } + } + +private: + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int get_current_warp_size() const + { + const unsigned int warp_id = ::rocprim::warp_id(); + return (warp_id == warps_no - 1) + ? (BlockSize % warp_size > 0 ? BlockSize % warp_size : warp_size) + : warp_size; + } + + // Change index to minimize LDS bank conflicts if necessary + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int index(unsigned int n) + { + // Move every 32-bank wide "row" (32 banks * 4 bytes) by one item + return has_bank_conflicts ? (n + n / banks_no) : n; + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_EXCHANGE_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_histogram.hpp b/3rdparty/cub/rocprim/block/block_histogram.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3135806b80a021a6b3ecc9af61ee736d5b762787 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_histogram.hpp @@ -0,0 +1,328 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ +#define ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/block_histogram_atomic.hpp" +#include "detail/block_histogram_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup blockmodule +/// @{ + +/// \brief Available algorithms for block_histogram primitive. +enum class block_histogram_algorithm +{ + /// Atomic addition is used to update bin count directly. + /// \par Performance Notes: + /// * Performance is dependent on hardware implementation of atomic addition. + /// * Performance may decrease for non-uniform random input distributions + /// where many concurrent updates may be made to the same bin counter. + using_atomic, + + /// A two-phase operation is used:- + /// * Data is sorted using radix-sort. + /// * "Runs" of same-valued keys are detected using discontinuity; run-lengths + /// are bin counts. + /// \par Performance Notes: + /// * Performance is consistent regardless of sample bin distribution. + using_sort, + + /// \brief Default block_histogram algorithm. + default_algorithm = using_atomic, +}; + +namespace detail +{ + +// Selector for block_histogram algorithm which gives block histogram implementation +// type based on passed block_histogram_algorithm enum +template +struct select_block_histogram_impl; + +template<> +struct select_block_histogram_impl +{ + template + using type = block_histogram_atomic; +}; + +template<> +struct select_block_histogram_impl +{ + template + using type = block_histogram_sort; +}; + +} // end namespace detail + +/// \brief The block_histogram class is a block level parallel primitive which provides methods +/// for constructing block-wide histograms from items partitioned across threads in a block. +/// +/// \tparam T - the input/output type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - the number of items to be processed by each thread. +/// \tparam Bins - the number of bins within the histogram. +/// \tparam Algorithm - selected histogram algorithm, block_histogram_algorithm::default_algorithm by default. +/// +/// \par Overview +/// * block_histogram has two alternative implementations: \p block_histogram_algorithm::using_atomic +/// and block_histogram_algorithm::using_sort. +/// +/// \par Examples +/// \parblock +/// In the examples histogram operation is performed on block of 192 threads, each provides +/// one \p int value, result is returned using the same variable as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize block_histogram for int, logical block of 192 threads, +/// // 2 items per thread and a bin size of 192. +/// using block_histogram_int = rocprim::block_histogram; +/// // allocate storage in shared memory +/// __shared__ block_histogram_int::storage_type storage; +/// __shared__ int hist[192]; +/// +/// int value[2]; +/// ... +/// // execute histogram +/// block_histogram_int().histogram( +/// value, // input +/// hist, // output +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int Bins, + block_histogram_algorithm Algorithm = block_histogram_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_histogram +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_block_histogram_impl::template type +#endif +{ + using base_type = typename detail::select_block_histogram_impl::template type; + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Initialize histogram counters to zero. + /// + /// \tparam Counter - [inferred] counter type of histogram. + /// + /// \param [out] hist - histogram bin count. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void init_histogram(Counter hist[Bins]) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + + ROCPRIM_UNROLL + for(unsigned int offset = 0; offset < Bins; offset += BlockSize) + { + const unsigned int offset_tid = offset + flat_tid; + if(offset_tid < Bins) + { + hist[offset_tid] = Counter(); + } + } + } + + /// \brief Update an existing block-wide histogram. Each thread composites an array of + /// input elements. + /// + /// \tparam Counter - [inferred] counter type of histogram. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] hist - histogram bin count. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples histogram operation is performed on block of 192 threads, each provides + /// one \p int value, result is returned using the same variable as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_histogram for int, logical block of 192 threads, + /// // 2 items per thread and a bin size of 192. + /// using block_histogram_int = rocprim::block_histogram; + /// // allocate storage in shared memory + /// __shared__ block_histogram_int::storage_type storage; + /// __shared__ int hist[192]; + /// + /// int value[2]; + /// ... + /// // initialize histogram + /// block_histogram_int().init_histogram( + /// hist // output + /// ); + /// + /// rocprim::syncthreads(); + /// + /// // update histogram + /// block_histogram_int().composite( + /// value, // input + /// hist, // output + /// storage + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins], + storage_type& storage) + { + base_type::composite(input, hist, storage); + } + + /// \overload + /// \brief Update an existing block-wide histogram. Each thread composites an array of + /// input elements. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam Counter - [inferred] counter type of histogram. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] hist - histogram bin count. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins]) + { + base_type::composite(input, hist); + } + + /// \brief Construct a new block-wide histogram. Each thread contributes an array of + /// input elements. + /// + /// \tparam Counter - [inferred] counter type of histogram. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] hist - histogram bin count. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples histogram operation is performed on block of 192 threads, each provides + /// one \p int value, result is returned using the same variable as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_histogram for int, logical block of 192 threads, + /// // 2 items per thread and a bin size of 192. + /// using block_histogram_int = rocprim::block_histogram; + /// // allocate storage in shared memory + /// __shared__ block_histogram_int::storage_type storage; + /// __shared__ int hist[192]; + /// + /// int value[2]; + /// ... + /// // execute histogram + /// block_histogram_int().histogram( + /// value, // input + /// hist, // output + /// storage + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void histogram(T (&input)[ItemsPerThread], + Counter hist[Bins], + storage_type& storage) + { + init_histogram(hist); + ::rocprim::syncthreads(); + composite(input, hist, storage); + } + + /// \overload + /// \brief Construct a new block-wide histogram. Each thread contributes an array of + /// input elements. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam Counter - [inferred] counter type of histogram. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] hist - histogram bin count. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void histogram(T (&input)[ItemsPerThread], + Counter hist[Bins]) + { + init_histogram(hist); + ::rocprim::syncthreads(); + composite(input, hist); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_HISTOGRAM_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_load.hpp b/3rdparty/cub/rocprim/block/block_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b3df9929bf28e60e5021b6e26d3d48fc876403e --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_load.hpp @@ -0,0 +1,891 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ +#define ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "block_load_func.hpp" +#include "block_exchange.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief \p block_load_method enumerates the methods available to load data +/// from continuous memory into a blocked arrangement of items across the thread block +enum class block_load_method +{ + /// Data from continuous memory is loaded into a blocked arrangement of items. + /// \par Performance Notes: + /// * Performance decreases with increasing number of items per thread (stride + /// between reads), because of reduced memory coalescing. + block_load_direct, + + /// A striped arrangement of data is read directly from memory. + block_load_striped, + + /// Data from continuous memory is loaded into a blocked arrangement of items + /// using vectorization as an optimization. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, provided that + /// vectorization requirements are fulfilled. Otherwise, performance will default + /// to \p block_load_direct. + /// \par Requirements: + /// * The input offset (\p block_input) must be quad-item aligned. + /// * The following conditions will prevent vectorization and switch to default + /// \p block_load_direct: + /// * \p ItemsPerThread is odd. + /// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, + /// int4, etc. + block_load_vectorize, + + /// A striped arrangement of data from continuous memory is locally transposed + /// into a blocked arrangement of items. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p block_load_direct and + /// \p block_load_vectorize due to reordering on local memory. + block_load_transpose, + + /// A warp-striped arrangement of data from continuous memory is locally transposed + /// into a blocked arrangement of items. + /// \par Requirements: + /// * The number of threads in the block must be a multiple of the size of hardware warp. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p block_load_direct and + /// \p block_load_vectorize due to reordering on local memory. + block_load_warp_transpose, + + /// Defaults to \p block_load_direct + default_method = block_load_direct +}; + +/// \brief The \p block_load class is a block level parallel primitive which provides methods +/// for loading data from continuous memory into a blocked arrangement of items across the thread +/// block. +/// +/// \tparam T - the input/output type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - the number of items to be processed by +/// each thread. +/// \tparam Method - the method to load data. +/// +/// \par Overview +/// * The \p block_load class has a number of different methods to load data: +/// * [block_load_direct](\ref ::block_load_method::block_load_direct) +/// * [block_load_striped](\ref ::block_load_method::block_load_striped) +/// * [block_load_vectorize](\ref ::block_load_method::block_load_vectorize) +/// * [block_load_transpose](\ref ::block_load_method::block_load_transpose) +/// * [block_load_warp_transpose](\ref ::block_load_method::block_load_warp_transpose) +/// +/// \par Example: +/// \parblock +/// In the examples load operation is performed on block of 128 threads, using type +/// \p int and 8 items per thread. +/// +/// \code{.cpp} +/// __global__ void example_kernel(int * input, ...) +/// { +/// const int offset = blockIdx.x * 128 * 8; +/// int items[8]; +/// rocprim::block_load blockload; +/// blockload.load(input + offset, items); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + block_load_method Method = block_load_method::block_load_direct, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_load +{ +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords \p __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread]) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block, which is guarded by range \p valid. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items, valid); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block, which is guarded by range with a fall-back value for out-of-bound + /// elements. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// \tparam Default - [inferred] The data type of the default value. + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// \param [in] out_of_bounds - default value assigned to out-of-bound items. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items, valid, + out_of_bounds); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block, using temporary storage. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int items[8]; + /// using block_load_int = rocprim::block_load; + /// block_load_int bload; + /// __shared__ typename block_load_int::storage_type storage; + /// bload.load(..., items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block, which is guarded by range \p valid, using temporary storage. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int items[8]; + /// using block_load_int = rocprim::block_load; + /// block_load_int bload; + /// tile_static typename block_load_int::storage_type storage; + /// bload.load(..., items, valid, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items, valid); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// thread block, which is guarded by range with a fall-back value for out-of-bound + /// elements, using temporary storage. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// \tparam Default - [inferred] The data type of the default value. + /// + /// \param [in] block_input - the input iterator from the thread block to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// \param [in] out_of_bounds - default value assigned to out-of-bound items. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int items[8]; + /// using block_load_int = rocprim::block_load; + /// block_load_int bload; + /// __shared__ typename block_load_int::storage_type storage; + /// bload.load(..., items, valid, out_of_bounds, storage); + /// ... + /// } + /// \endcode + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items, valid, out_of_bounds); + } +}; + +/// @} +// end of group blockmodule + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ + > +class block_load +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + template + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread]) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items); + } + + template + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid, + out_of_bounds); + } + + template + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items); + } + + template + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE inline + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid, + out_of_bounds); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_load +{ +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(T* block_input, + T (&_items)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked_vectorized(flat_id, block_input, _items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + U (&items)[ItemsPerThread]) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_blocked(flat_id, block_input, items, valid, + out_of_bounds); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(T* block_input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + (void) storage; + load(block_input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + U (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + (void) storage; + load(block_input, items, valid, out_of_bounds); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_load +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + +private: + using block_exchange_type = block_exchange; + +public: + using storage_type = typename block_exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread]) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items); + block_exchange_type().striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid); + block_exchange_type().striped_to_blocked(items, items, storage); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid, + out_of_bounds); + block_exchange_type().striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items); + block_exchange_type().striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid); + block_exchange_type().striped_to_blocked(items, items, storage); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_striped(flat_id, block_input, items, valid, + out_of_bounds); + block_exchange_type().striped_to_blocked(items, items, storage); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_load +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; +private: + using block_exchange_type = block_exchange; + +public: + static_assert(BlockSize % ::rocprim::device_warp_size() == 0, + "BlockSize must be a multiple of hardware warpsize"); + + using storage_type = typename block_exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread]) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items, valid); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items, valid, + out_of_bounds); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items, valid); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_load_direct_warp_striped(flat_id, block_input, items, valid, + out_of_bounds); + block_exchange_type().warp_striped_to_blocked(items, items, storage); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_BLOCK_LOAD_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_load_func.hpp b/3rdparty/cub/rocprim/block/block_load_func.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83ebd5d844e6691db114858e4de4f6119c46295a --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_load_func.hpp @@ -0,0 +1,511 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ +#define ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup blockmodule +/// @{ + +/// \brief Loads data from continuous memory into a blocked arrangement of items +/// across the thread block. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +template< + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_blocked(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread]) +{ + unsigned int offset = flat_id * ItemsPerThread; + InputIterator thread_iter = block_input + offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = thread_iter[item]; + } +} + +/// \brief Loads data from continuous memory into a blocked arrangement of items +/// across the thread block, which is guarded by range \p valid. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +template< + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_blocked(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + unsigned int offset = flat_id * ItemsPerThread; + InputIterator thread_iter = block_input + offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + if (item + offset < valid) + { + items[item] = thread_iter[item]; + } + } +} + +/// \brief Loads data from continuous memory into a blocked arrangement of items +/// across the thread block, which is guarded by range with a fall-back value +/// for out-of-bound elements. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// \tparam Default - [inferred] The data type of the default value +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +/// \param out_of_bounds - default value assigned to out-of-bound items +template< + class InputIterator, + class T, + unsigned int ItemsPerThread, + class Default +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_blocked(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) +{ + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = static_cast(out_of_bounds); + } + // TODO: Consider using std::fill for HIP-CPU, as uses memset() where appropriate + + block_load_direct_blocked(flat_id, block_input, items, valid); +} + +/// \brief Loads data from continuous memory into a blocked arrangement of items +/// across the thread block. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// The input offset (\p block_input + offset) must be quad-item aligned. +/// +/// The following conditions will prevent vectorization and switch to default +/// block_load_direct_blocked: +/// * \p ItemsPerThread is odd. +/// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, +/// int4, etc. +/// +/// \tparam T - [inferred] the input data type +/// \tparam U - [inferred] the output data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// The type \p T must be such that it can be implicitly converted to \p U. +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +template< + class T, + class U, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto +block_load_direct_blocked_vectorized(unsigned int flat_id, + T* block_input, + U (&items)[ItemsPerThread]) -> typename std::enable_if::value>::type +{ + typedef typename detail::match_vector_type::type vector_type; + constexpr unsigned int vectors_per_thread = (sizeof(T) * ItemsPerThread) / sizeof(vector_type); + vector_type vector_items[vectors_per_thread]; + + const vector_type* vector_ptr = reinterpret_cast(block_input) + + (flat_id * vectors_per_thread); + + ROCPRIM_UNROLL + for (unsigned int item = 0; item < vectors_per_thread; item++) + { + vector_items[item] = *(vector_ptr + item); + } + + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = *(reinterpret_cast(vector_items) + item); + } +} + +template< + class T, + class U, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto +block_load_direct_blocked_vectorized(unsigned int flat_id, + T* block_input, + U (&items)[ItemsPerThread]) -> typename std::enable_if::value>::type +{ + block_load_direct_blocked(flat_id, block_input, items); +} + +/// \brief Loads data from continuous memory into a striped arrangement of items +/// across the thread block. +/// +/// The striped arrangement is assumed to be (\p BlockSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam BlockSize - the number of threads in a block +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +template< + unsigned int BlockSize, + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread]) +{ + InputIterator thread_iter = block_input + flat_id; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = thread_iter[item * BlockSize]; + } +} + +/// \brief Loads data from continuous memory into a striped arrangement of items +/// across the thread block, which is guarded by range \p valid. +/// +/// The striped arrangement is assumed to be (\p BlockSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam BlockSize - the number of threads in a block +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +template< + unsigned int BlockSize, + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + InputIterator thread_iter = block_input + flat_id; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + unsigned int offset = item * BlockSize; + if (flat_id + offset < valid) + { + items[item] = thread_iter[offset]; + } + } +} + +/// \brief Loads data from continuous memory into a striped arrangement of items +/// across the thread block, which is guarded by range with a fall-back value +/// for out-of-bound elements. +/// +/// The striped arrangement is assumed to be (\p BlockSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// \tparam BlockSize - the number of threads in a block +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// \tparam Default - [inferred] The data type of the default value +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +/// \param out_of_bounds - default value assigned to out-of-bound items +template< + unsigned int BlockSize, + class InputIterator, + class T, + unsigned int ItemsPerThread, + class Default +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) +{ + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = out_of_bounds; + } + + block_load_direct_striped(flat_id, block_input, items, valid); +} + +/// \brief Loads data from continuous memory into a warp-striped arrangement of items +/// across the thread block. +/// +/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// * The number of threads in the block must be a multiple of \p WarpSize. +/// * The default \p WarpSize is a hardware warpsize and is an optimal value. +/// * \p WarpSize must be a power of two and equal or less than the size of +/// hardware warp. +/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// performance. +/// +/// \tparam WarpSize - [optional] the number of threads in a warp +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +template< + unsigned int WarpSize = device_warp_size(), + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_warp_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread]) +{ + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / WarpSize; + unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + + InputIterator thread_iter = block_input + thread_id + warp_offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = thread_iter[item * WarpSize]; + } +} + +/// \brief Loads data from continuous memory into a warp-striped arrangement of items +/// across the thread block, which is guarded by range \p valid. +/// +/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// * The number of threads in the block must be a multiple of \p WarpSize. +/// * The default \p WarpSize is a hardware warpsize and is an optimal value. +/// * \p WarpSize must be a power of two and equal or less than the size of +/// hardware warp. +/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// performance. +/// +/// \tparam WarpSize - [optional] the number of threads in a warp +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +template< + unsigned int WarpSize = device_warp_size(), + class InputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_warp_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / WarpSize; + unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + + InputIterator thread_iter = block_input + thread_id + warp_offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + unsigned int offset = item * WarpSize; + if (warp_offset + thread_id + offset < valid) + { + items[item] = thread_iter[offset]; + } + } +} + +/// \brief Loads data from continuous memory into a warp-striped arrangement of items +/// across the thread block, which is guarded by range with a fall-back value +/// for out-of-bound elements. +/// +/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to load a range of +/// \p ItemsPerThread into \p items. +/// +/// * The number of threads in the block must be a multiple of \p WarpSize. +/// * The default \p WarpSize is a hardware warpsize and is an optimal value. +/// * \p WarpSize must be a power of two and equal or less than the size of +/// hardware warp. +/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// performance. +/// +/// \tparam WarpSize - [optional] the number of threads in a warp +/// \tparam InputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// \tparam Default - [inferred] The data type of the default value +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_input - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +/// \param valid - maximum range of valid numbers to load +/// \param out_of_bounds - default value assigned to out-of-bound items +template< + unsigned int WarpSize = device_warp_size(), + class InputIterator, + class T, + unsigned int ItemsPerThread, + class Default +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_load_direct_warp_striped(unsigned int flat_id, + InputIterator block_input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds) +{ + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + items[item] = out_of_bounds; + } + + block_load_direct_warp_striped(flat_id, block_input, items, valid); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_LOAD_FUNC_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_radix_sort.hpp b/3rdparty/cub/rocprim/block/block_radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9dcd0f06085f27069fd82975aad05b4ed341a37e --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_radix_sort.hpp @@ -0,0 +1,1016 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ +#define ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/radix_sort.hpp" +#include "../warp/detail/warp_scan_crosslane.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "block_exchange.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +/// Specialized block scan of bool (1 bit values) +/// It uses warp scan and reduce functions of bool (1 bit values) based on ballot and bit count. +/// They have much better performance (several times faster) than generic scan and reduce classes +/// because of using hardware ability to calculate which lanes have true predicate values. +template< + unsigned int BlockSizeX, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_bit_plus_scan +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Select warp size + static constexpr unsigned int warp_size = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + // Number of warps in block + static constexpr unsigned int warps_no = (BlockSize + warp_size - 1) / warp_size; + + // typedef of warp_scan primitive that will be used to get prefix values for + // each warp (scanned carry-outs from warps before it) + // warp_scan_crosslane is an implementation of warp_scan that does not need storage, + // but requires logical warp size to be a power of two. + using warp_scan_prefix_type = + ::rocprim::detail::warp_scan_crosslane; + +public: + + struct storage_type_ + { + unsigned int warp_prefixes[warps_no]; + // ---------- Shared memory optimisation ---------- + // Since we use warp_scan_crosslane for warp scan, we don't need to allocate + // any temporary memory for it. + }; + + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(const unsigned int (&input)[ItemsPerThread], + unsigned int (&output)[ItemsPerThread], + unsigned int& reduction, + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int warp_id = ::rocprim::warp_id(flat_id); + storage_type_& storage_ = storage.get(); + + unsigned int warp_reduction = ::rocprim::bit_count(::rocprim::ballot(input[0])); + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + warp_reduction += ::rocprim::bit_count(::rocprim::ballot(input[i])); + } + if(lane_id == 0) + { + storage_.warp_prefixes[warp_id] = warp_reduction; + } + ::rocprim::syncthreads(); + + // Scan the warp reduction results to calculate warp prefixes + if(flat_id < warps_no) + { + unsigned int prefix = storage_.warp_prefixes[flat_id]; + warp_scan_prefix_type().inclusive_scan(prefix, prefix, ::rocprim::plus()); + storage_.warp_prefixes[flat_id] = prefix; + } +#ifdef __HIP_CPU_RT__ + else + { + // HIP-CPU doesn't implement lockstep behavior. Need to invoke the same number sync ops in divergent branch. + empty_type empty; + ::rocprim::detail::warp_scan_crosslane().inclusive_scan(empty, empty, empty_binary_op{}); + } +#endif + ::rocprim::syncthreads(); + + // Perform exclusive warp scan of bit values + unsigned int lane_prefix = 0; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + lane_prefix = ::rocprim::masked_bit_count(::rocprim::ballot(input[i]), lane_prefix); + } + + // Scan the lane's items and calculate final scan results + output[0] = warp_id == 0 + ? lane_prefix + : lane_prefix + storage_.warp_prefixes[warp_id - 1]; + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + output[i] = output[i - 1] + input[i - 1]; + } + + // Get the final inclusive reduction result + reduction = storage_.warp_prefixes[warps_no - 1]; + } +}; + +} // end namespace detail + +/// \brief The block_radix_sort class is a block level parallel primitive which provides +/// methods sorting items (keys or key-value pairs) partitioned across threads in a block +/// using radix sort algorithm. +/// +/// \tparam Key - the key type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - the number of items contributed by each thread. +/// \tparam Value - the value type. Default type empty_type indicates +/// a keys-only sort. +/// +/// \par Overview +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Performance depends on \p BlockSize and \p ItemsPerThread. +/// * It is usually better of \p BlockSize is a multiple of the size of the hardware warp. +/// * It is usually increased when \p ItemsPerThread is greater than one. However, when there +/// are too many items per thread, each thread may need so much registers and/or shared memory +/// that occupancy will fall too low, decreasing the performance. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \par Examples +/// \parblock +/// In the examples radix sort is performed on a block of 256 threads, each thread provides +/// eight \p int value, results are returned using the same array as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize block_radix_sort for int, block of 256 threads, +/// // and eight items per thread; key-only sort +/// using block_rsort_int = rocprim::block_radix_sort; +/// // allocate storage in shared memory +/// __shared__ block_rsort_int::storage_type storage; +/// +/// int input[8] = ...; +/// // execute block radix sort (ascending) +/// block_rsort_int().sort( +/// input, +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class Key, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + class Value = empty_type, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_radix_sort +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + static constexpr bool with_values = !std::is_same::value; + + using bit_key_type = typename ::rocprim::detail::radix_key_codec::bit_key_type; + using bit_block_scan = detail::block_bit_plus_scan; + + using bit_keys_exchange_type = ::rocprim::block_exchange; + using values_exchange_type = ::rocprim::block_exchange; + + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + union + { + typename bit_keys_exchange_type::storage_type bit_keys_exchange; + typename values_exchange_type::storage_type values_exchange; + }; + typename block_radix_sort::bit_block_scan::storage_type bit_block_scan; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Performs ascending radix sort over keys partitioned across threads in a block. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two \p float value, results are returned using the same array as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for float, block of 128 threads, + /// // and two items per thread; key-only sort + /// using block_rsort_float = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_float::storage_type storage; + /// + /// float input[2] = ...; + /// // execute block radix sort (ascending) + /// block_rsort_float().sort( + /// input, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {[256, 255], ..., [4, 3], [2, 1]}}, then + /// then after sort they will be equal {[1, 2], [3, 4] ..., [255, 256]}. + /// \endparblock + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs ascending radix sort over keys partitioned across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort(keys, storage, begin_bit, end_bit); + } + + /// \brief Performs descending radix sort over keys partitioned across threads in a block. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two \p float value, results are returned using the same array as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for float, block of 128 threads, + /// // and two items per thread; key-only sort + /// using block_rsort_float = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_float::storage_type storage; + /// + /// float input[2] = ...; + /// // execute block radix sort (descending) + /// block_rsort_float().sort_desc( + /// input, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {[1, 2], [3, 4] ..., [255, 256]}, + /// then after sort they will be equal {[256, 255], ..., [4, 3], [2, 1]}. + /// \endparblock + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs descending radix sort over keys partitioned across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc(keys, storage, begin_bit, end_bit); + } + + /// \brief Performs ascending radix sort over key-value pairs partitioned across + /// threads in a block. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two key-value int-float pairs, results are returned using the same + /// arrays as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for int-float pairs, block of 128 + /// // threads, and two items per thread + /// using block_rsort_ii = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_ii::storage_type storage; + /// + /// int keys[2] = ...; + /// float values[2] = ...; + /// // execute block radix sort-by-key (ascending) + /// block_rsort_ii().sort( + /// keys, values, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p keys across threads in a block are {[256, 255], ..., [4, 3], [2, 1]} and + /// the \p values are {[1, 1], [2, 2] ..., [128, 128]}, then after sort the \p keys + /// will be equal {[1, 2], [3, 4] ..., [255, 256]} and the \p values will be + /// equal {[128, 128], [127, 127] ..., [2, 2], [1, 1]}. + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs ascending radix sort over key-value pairs partitioned across + /// threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort(keys, values, storage, begin_bit, end_bit); + } + + /// \brief Performs descending radix sort over key-value pairs partitioned across + /// threads in a block. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two key-value int-float pairs, results are returned using the same + /// arrays as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for int-float pairs, block of 128 + /// // threads, and two items per thread + /// using block_rsort_ii = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_ii::storage_type storage; + /// + /// int keys[2] = ...; + /// float values[2] = ...; + /// // execute block radix sort-by-key (descending) + /// block_rsort_ii().sort_desc( + /// keys, values, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p keys across threads in a block are {[1, 2], [3, 4] ..., [255, 256]} and + /// the \p values are {[128, 128], [127, 127] ..., [2, 2], [1, 1]}, then after sort + /// the \p keys will be equal {[256, 255], ..., [4, 3], [2, 1]} and the \p values + /// will be equal {[1, 1], [2, 2] ..., [128, 128]}. + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs descending radix sort over key-value pairs partitioned across + /// threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc(keys, values, storage, begin_bit, end_bit); + } + + /// \brief Performs ascending radix sort over keys partitioned across threads in a block, + /// results are saved in a striped arrangement. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two \p float value, results are returned using the same array as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for float, block of 128 threads, + /// // and two items per thread; key-only sort + /// using block_rsort_float = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_float::storage_type storage; + /// + /// float keys[2] = ...; + /// // execute block radix sort (ascending) + /// block_rsort_float().sort_to_striped( + /// keys, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {[256, 255], ..., [4, 3], [2, 1]}}, then + /// then after sort they will be equal {[1, 129], [2, 130] ..., [128, 256]}. + /// \endparblock + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs ascending radix sort over keys partitioned across threads in a block, + /// results are saved in a striped arrangement. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_to_striped(keys, storage, begin_bit, end_bit); + } + + /// \brief Performs descending radix sort over keys partitioned across threads in a block, + /// results are saved in a striped arrangement. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 128 threads, each thread provides + /// two \p float value, results are returned using the same array as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for float, block of 128 threads, + /// // and two items per thread; key-only sort + /// using block_rsort_float = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_float::storage_type storage; + /// + /// float input[2] = ...; + /// // execute block radix sort (descending) + /// block_rsort_float().sort_desc_to_striped( + /// input, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {[1, 2], [3, 4] ..., [255, 256]}, + /// then after sort they will be equal {[256, 128], ..., [130, 2], [129, 1]}. + /// \endparblock + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs descending radix sort over keys partitioned across threads in a block, + /// results are saved in a striped arrangement. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc_to_striped(keys, storage, begin_bit, end_bit); + } + + /// \brief Performs ascending radix sort over key-value pairs partitioned across + /// threads in a block, results are saved in a striped arrangement. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 4 threads, each thread provides + /// two key-value int-float pairs, results are returned using the same + /// arrays as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for int-float pairs, block of 4 + /// // threads, and two items per thread + /// using block_rsort_ii = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_ii::storage_type storage; + /// + /// int keys[2] = ...; + /// float values[2] = ...; + /// // execute block radix sort-by-key (ascending) + /// block_rsort_ii().sort_to_striped( + /// keys, values, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p keys across threads in a block are {[8, 7], [6, 5], [4, 3], [2, 1]} and + /// the \p values are {[-1, -2], [-3, -4], [-5, -6], [-7, -8]}, then after sort the + /// \p keys will be equal {[1, 5], [2, 6], [3, 7], [4, 8]} and the \p values will be + /// equal {[-8, -4], [-7, -3], [-6, -2], [-5, -1]}. + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs ascending radix sort over key-value pairs partitioned across + /// threads in a block, results are saved in a striped arrangement. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_to_striped(keys, values, storage, begin_bit, end_bit); + } + + /// \brief Performs descending radix sort over key-value pairs partitioned across + /// threads in a block, results are saved in a striped arrangement. + /// + /// \pre Method is enabled only if \p Value type is different than empty_type. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// Non-default value not supported for floating-point key-types. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples radix sort is performed on a block of 4 threads, each thread provides + /// two key-value int-float pairs, results are returned using the same + /// arrays as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_radix_sort for int-float pairs, block of 4 + /// // threads, and two items per thread + /// using block_rsort_ii = rocprim::block_radix_sort; + /// // allocate storage in shared memory + /// __shared__ block_rsort_ii::storage_type storage; + /// + /// int keys[2] = ...; + /// float values[2] = ...; + /// // execute block radix sort-by-key (descending) + /// block_rsort_ii().sort_desc_to_striped( + /// keys, values, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p keys across threads in a block are {[1, 2], [3, 4], [5, 6], [7, 8]} and + /// the \p values are {[80, 70], [60, 50], [40, 30], [20, 10]}, then after sort the + /// \p keys will be equal {[8, 4], [7, 3], [6, 2], [5, 1]} and the \p values will be + /// equal {[10, 50], [20, 60], [30, 70], [40, 80]}. + /// \endparblock + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + sort_impl(keys, values, storage, begin_bit, end_bit); + } + + /// \overload + /// \brief Performs descending radix sort over key-value pairs partitioned across + /// threads in a block, results are saved in a striped arrangement. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \param [in, out] keys - reference to an array of keys provided by a thread. + /// \param [in, out] values - reference to an array of values provided by a thread. + /// \param [in] begin_bit - [optional] index of the first (least significant) bit used in + /// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. + /// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in + /// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default + /// value: \p 8 * sizeof(Key). + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key)) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc_to_striped(keys, values, storage, begin_bit, end_bit); + } + +private: + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_impl(Key (&keys)[ItemsPerThread], + SortedValue (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) + { + using key_codec = ::rocprim::detail::radix_key_codec; + storage_type_& storage_ = storage.get(); + + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + + bit_key_type bit_keys[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + bit_keys[i] = key_codec::encode(keys[i]); + } + + // Use binary digits (i.e. digits can be 0 or 1) + for(unsigned int bit = begin_bit; bit < end_bit; bit++) + { + unsigned int bits[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + bits[i] = key_codec::extract_digit(bit_keys[i], bit, 1); + } + + unsigned int ranks[ItemsPerThread]; +#ifdef __HIP_CPU_RT__ + // TODO: Check if really necessary + // Initialize contents, as non-hipcc compilers don't unconditionally zero out allocated memory + std::memset(ranks, 0, ItemsPerThread * sizeof(decltype(ranks[0]))); +#endif + unsigned int count; + bit_block_scan().exclusive_scan(bits, ranks, count, storage_.bit_block_scan); + + // Scatter keys to computed positions considering starting positions of their digit values + const unsigned int start = BlockSize * ItemsPerThread - count; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + // Calculate position for the first digit (0) value based on positions of the second (1) + ranks[i] = bits[i] != 0 + ? (start + ranks[i]) + : (flat_id * ItemsPerThread + i - ranks[i]); + } + exchange_keys(storage, bit_keys, ranks); + exchange_values(storage, values, ranks); + } + + if(ToStriped) + { + to_striped_keys(storage, bit_keys); + to_striped_values(storage, values); + } + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + keys[i] = key_codec::decode(bit_keys[i]); + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_keys(storage_type& storage, + bit_key_type (&bit_keys)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + // Synchronization is omitted here because bit_block_scan already calls it + bit_keys_exchange_type().scatter_to_blocked(bit_keys, bit_keys, ranks, storage_.bit_keys_exchange); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_values(storage_type& storage, + SortedValue (&values)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed + values_exchange_type().scatter_to_blocked(values, values, ranks, storage_.values_exchange); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_values(storage_type& storage, + empty_type (&values)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + (void) storage; + (void) values; + (void) ranks; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_striped_keys(storage_type& storage, + bit_key_type (&bit_keys)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + ::rocprim::syncthreads(); + bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage_.bit_keys_exchange); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_striped_values(storage_type& storage, + SortedValue (&values)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed + values_exchange_type().blocked_to_striped(values, values, storage_.values_exchange); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_striped_values(storage_type& storage, + empty_type * values) + { + (void) storage; + (void) values; + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_reduce.hpp b/3rdparty/cub/rocprim/block/block_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7eb202d02296564f6e8c604fdc04f039ad26c0eb --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_reduce.hpp @@ -0,0 +1,414 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ +#define ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/block_reduce_warp_reduce.hpp" +#include "detail/block_reduce_raking_reduce.hpp" + + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Available algorithms for block_reduce primitive. +enum class block_reduce_algorithm +{ + /// \brief A warp_reduce based algorithm. + using_warp_reduce, + /// \brief An algorithm which limits calculations to a single hardware warp. + raking_reduce, + /// \brief raking reduce that supports only commutative operators + raking_reduce_commutative_only, + /// \brief Default block_reduce algorithm. + default_algorithm = using_warp_reduce, +}; + +namespace detail +{ + +// Selector for block_reduce algorithm which gives block reduce implementation +// type based on passed block_reduce_algorithm enum +template +struct select_block_reduce_impl; + +template<> +struct select_block_reduce_impl +{ + template + using type = block_reduce_warp_reduce; +}; + +template<> +struct select_block_reduce_impl +{ + template + using type = block_reduce_raking_reduce; +}; + +template<> +struct select_block_reduce_impl +{ + template + using type = block_reduce_raking_reduce; +}; + + +} // end namespace detail + +/// \brief The block_reduce class is a block level parallel primitive which provides methods +/// for performing reductions operations on items partitioned across threads in a block. +/// +/// \tparam T - the input/output type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam Algorithm - selected reduce algorithm, block_reduce_algorithm::default_algorithm by default. +/// +/// \par Overview +/// * Supports non-commutative reduce operators. However, a reduce operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Computation can more efficient when: +/// * \p ItemsPerThread is greater than one, +/// * \p T is an arithmetic type, +/// * reduce operation is simple addition operator, and +/// * the number of threads in the block is a multiple of the hardware warp size (see rocprim::device_warp_size()). +/// * block_reduce has two alternative implementations: \p block_reduce_algorithm::using_warp_reduce, +/// block_reduce_algorithm::raking_reduce and block_reduce_algorithm::raking_reduce_commutative_only. +/// * If the block sizes less than 64 only one warp reduction is used. The block reduction algorithm +/// stores the result only in the first thread(lane_id = 0 warp_id = 0), when the block size is +/// larger then the warp size. +/// +/// \par Examples +/// \parblock +/// In the examples reduce operation is performed on block of 192 threads, each provides +/// one \p int value, result is returned using the same variable as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize warp_reduce for int and logical warp of 192 threads +/// using block_reduce_int = rocprim::block_reduce; +/// // allocate storage in shared memory +/// __shared__ block_reduce_int::storage_type storage; +/// +/// int value = ...; +/// // execute reduce +/// block_reduce_int().reduce( +/// value, // input +/// value, // output +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + block_reduce_algorithm Algorithm = block_reduce_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_reduce +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_block_reduce_impl::template type +#endif +{ + using base_type = typename detail::select_block_reduce_impl::template type; +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Performs reduction across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present min reduce operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_reduce for float and block of 256 threads + /// using block_reduce_f = rocprim::block_reduce; + /// // allocate storage in shared memory for the block + /// __shared__ block_reduce_float::storage_type storage; + /// + /// float input = ...; + /// float output; + /// // execute min reduce + /// block_reduce_float().reduce( + /// input, + /// output, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, -2, 3, -4, ..., 255, -256}, then + /// \p output value will be {-256}. + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, storage, reduce_op); + } + + /// \overload + /// \brief Performs reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, reduce_op); + } + + /// \brief Performs reduction across threads in a block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present maximum reduce operations performed on a block of 128 threads, + /// each provides two \p long value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_reduce for long and block of 128 threads + /// using block_reduce_f = rocprim::block_reduce; + /// // allocate storage in shared memory for the block + /// __shared__ block_reduce_long::storage_type storage; + /// + /// long input[2] = ...; + /// long output[2]; + /// // execute max reduce + /// block_reduce_long().reduce( + /// input, + /// output, + /// storage, + /// rocprim::maximum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {-1, 2, -3, 4, ..., -255, 256}, then + /// \p output value will be {256}. + /// \endparblock + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, storage, reduce_op); + } + + /// \overload + /// \brief Performs reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, reduce_op); + } + + /// \brief Performs reduction across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] valid_items - number of items that will be reduced in the block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present min reduce operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_reduce for float and block of 256 threads + /// using block_reduce_f = rocprim::block_reduce; + /// // allocate storage in shared memory for the block + /// __shared__ block_reduce_float::storage_type storage; + /// + /// float input = ...; + /// unsigned int valid_items = 250; + /// float output; + /// // execute min reduce + /// block_reduce_float().reduce( + /// input, + /// output, + /// valid_items, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, valid_items, storage, reduce_op); + } + + /// \overload + /// \brief Performs reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] valid_items - number of items that will be reduced in the block. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + BinaryFunction reduce_op = BinaryFunction()) + { + base_type::reduce(input, output, valid_items, reduce_op); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_scan.hpp b/3rdparty/cub/rocprim/block/block_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aed0fcefe63f9fe5b958b5f586d856feb413d3a1 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_scan.hpp @@ -0,0 +1,1322 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ +#define ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/block_scan_warp_scan.hpp" +#include "detail/block_scan_reduce_then_scan.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Available algorithms for block_scan primitive. +enum class block_scan_algorithm +{ + /// \brief A warp_scan based algorithm. + using_warp_scan, + /// \brief An algorithm which limits calculations to a single hardware warp. + reduce_then_scan, + /// \brief Default block_scan algorithm. + default_algorithm = using_warp_scan, +}; + +namespace detail +{ + +// Selector for block_scan algorithm which gives block scan implementation +// type based on passed block_scan_algorithm enum +template +struct select_block_scan_impl; + +template<> +struct select_block_scan_impl +{ + template + using type = block_scan_warp_scan; +}; + +template<> +struct select_block_scan_impl +{ + template + // When BlockSize is less than hardware warp size block_scan_warp_scan performs better than + // block_scan_reduce_then_scan by specializing for warps + using type = typename std::conditional< + (BlockSizeX * BlockSizeY * BlockSizeZ <= ::rocprim::device_warp_size()), + block_scan_warp_scan, + block_scan_reduce_then_scan + >::type; +}; + +} // end namespace detail + +/// \brief The block_scan class is a block level parallel primitive which provides methods +/// for performing inclusive and exclusive scan operations of items partitioned across +/// threads in a block. +/// +/// \tparam T - the input/output type. +/// \tparam BlockSizeX - the number of threads in a block's x dimension. +/// \tparam Algorithm - selected scan algorithm, block_scan_algorithm::default_algorithm by default. +/// \tparam BlockSizeY - the number of threads in a block's y dimension, defaults to 1. +/// \tparam BlockSizeZ - the number of threads in a block's z dimension, defaults to 1. +/// +/// \par Overview +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Computation can more efficient when: +/// * \p ItemsPerThread is greater than one, +/// * \p T is an arithmetic type, +/// * scan operation is simple addition operator, and +/// * the number of threads in the block is a multiple of the hardware warp size (see rocprim::device_warp_size()). +/// * block_scan has two alternative implementations: \p block_scan_algorithm::using_warp_scan +/// and block_scan_algorithm::reduce_then_scan. +/// +/// \par Examples +/// \parblock +/// In the examples scan operation is performed on block of 192 threads, each provides +/// one \p int value, result is returned using the same variable as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize warp_scan for int and logical warp of 192 threads +/// using block_scan_int = rocprim::block_scan; +/// // allocate storage in shared memory +/// __shared__ block_scan_int::storage_type storage; +/// +/// int value = ...; +/// // execute inclusive scan +/// block_scan_int().inclusive_scan( +/// value, // input +/// value, // output +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + block_scan_algorithm Algorithm = block_scan_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_scan +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_block_scan_impl::template type +#endif +{ + using base_type = typename detail::select_block_scan_impl::template type; +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Performs inclusive scan across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive min scan operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for float and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_float::storage_type storage; + /// + /// float input = ...; + /// float output; + /// // execute inclusive min scan + /// block_scan_float().inclusive_scan( + /// input, + /// output, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, -2, 3, -4, ..., 255, -256}, then + /// \p output values in will be {1, -2, -2, -4, ..., -254, -256}. + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::inclusive_scan(input, output, storage, scan_op); + } + + /// \overload + /// \brief Performs inclusive scan across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::inclusive_scan(input, output, scan_op); + } + + /// \brief Performs inclusive scan and reduction across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive min scan operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for float and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_float::storage_type storage; + /// + /// float input = ...; + /// float output; + /// float reduction; + /// // execute inclusive min scan + /// block_scan_float().inclusive_scan( + /// input, + /// output, + /// reduction, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, -2, 3, -4, ..., 255, -256}, then + /// \p output values in will be {1, -2, -2, -4, ..., -254, -256}, and the \p reduction will + /// be -256. + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::inclusive_scan(input, output, reduction, storage, scan_op); + } + + /// \overload + /// \brief Performs inclusive scan and reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::inclusive_scan(input, output, reduction, scan_op); + } + + /// \brief Performs inclusive scan across threads in a block, and uses + /// \p prefix_callback_op to generate prefix value for the whole block. + /// + /// \tparam PrefixCallback - type of the unary function object used for generating + /// block-wide prefix value for the scan operation. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in,out] prefix_callback_op - function object for generating block prefix value. + /// The signature of the \p prefix_callback_op should be equivalent to the following: + /// T f(const T &block_reduction);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the + /// block-wide prefix. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive prefix sum operations performed on a block of 256 threads, + /// each thread provides one \p int value. + /// + /// \code{.cpp} + /// + /// struct my_block_prefix + /// { + /// int prefix; + /// + /// __device__ my_block_prefix(int prefix) : prefix(prefix) {} + /// + /// __device__ int operator()(int block_reduction) + /// { + /// int old_prefix = prefix; + /// prefix = prefix + block_reduction; + /// return old_prefix; + /// } + /// }; + /// + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for int and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_int::storage_type storage; + /// + /// // init prefix functor + /// my_block_prefix prefix_callback(10); + /// + /// int input; + /// int output; + /// // execute inclusive prefix sum + /// block_scan_int().inclusive_scan( + /// input, + /// output, + /// storage, + /// prefix_callback, + /// rocprim::plus() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, 1, 1, ..., 1}, then + /// \p output values in will be {11, 12, 13, ..., 266}, and the \p prefix will + /// be 266. + /// \endparblock + template< + class PrefixCallback, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + base_type::inclusive_scan(input, output, storage, prefix_callback_op, scan_op); + } + + /// \brief Performs inclusive scan across threads in a block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive maximum scan operations performed on a block of 128 threads, + /// each provides two \p long value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for long and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_long::storage_type storage; + /// + /// long input[2] = ...; + /// long output[2]; + /// // execute inclusive min scan + /// block_scan_long().inclusive_scan( + /// input, + /// output, + /// storage, + /// rocprim::maximum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {-1, 2, -3, 4, ..., -255, 256}, then + /// \p output values in will be {-1, 2, 2, 4, ..., 254, 256}. + /// \endparblock + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::inclusive_scan(input[0], output[0], storage, scan_op); + } + else + { + base_type::inclusive_scan(input, output, storage, scan_op); + } + } + + /// \overload + /// \brief Performs inclusive scan across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::inclusive_scan(input[0], output[0], scan_op); + } + else + { + base_type::inclusive_scan(input, output, scan_op); + } + } + + /// \brief Performs inclusive scan and reduction across threads in a block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive maximum scan operations performed on a block of 128 threads, + /// each provides two \p long value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for long and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_long::storage_type storage; + /// + /// long input[2] = ...; + /// long output[2]; + /// long reduction; + /// // execute inclusive min scan + /// block_scan_long().inclusive_scan( + /// input, + /// output, + /// reduction, + /// storage, + /// rocprim::maximum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {-1, 2, -3, 4, ..., -255, 256}, then + /// \p output values in will be {-1, 2, 2, 4, ..., 254, 256} and the \p reduction will be \p 256. + /// \endparblock + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::inclusive_scan(input[0], output[0], reduction, storage, scan_op); + } + else + { + base_type::inclusive_scan(input, output, reduction, storage, scan_op); + } + } + + /// \overload + /// \brief Performs inclusive scan and reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::inclusive_scan(input[0], output[0], reduction, scan_op); + } + else + { + base_type::inclusive_scan(input, output, reduction, scan_op); + } + } + + /// \brief Performs inclusive scan across threads in a block, and uses + /// \p prefix_callback_op to generate prefix value for the whole block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam PrefixCallback - type of the unary function object used for generating + /// block-wide prefix value for the scan operation. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in,out] prefix_callback_op - function object for generating block prefix value. + /// The signature of the \p prefix_callback_op should be equivalent to the following: + /// T f(const T &block_reduction);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the + /// block-wide prefix. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive prefix sum operations performed on a block of 128 threads, + /// each thread provides two \p int value. + /// + /// \code{.cpp} + /// + /// struct my_block_prefix + /// { + /// int prefix; + /// + /// __device__ my_block_prefix(int prefix) : prefix(prefix) {} + /// + /// __device__ int operator()(int block_reduction) + /// { + /// int old_prefix = prefix; + /// prefix = prefix + block_reduction; + /// return old_prefix; + /// } + /// }; + /// + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for int and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_int::storage_type storage; + /// + /// // init prefix functor + /// my_block_prefix prefix_callback(10); + /// + /// int input[2] = ...; + /// int output[2]; + /// // execute inclusive prefix sum + /// block_scan_int().inclusive_scan( + /// input, + /// output, + /// storage, + /// prefix_callback, + /// rocprim::plus() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, 1, 1, ..., 1}, then + /// \p output values in will be {11, 12, 13, ..., 266}, and the \p prefix will + /// be 266. + /// \endparblock + template< + unsigned int ItemsPerThread, + class PrefixCallback, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + if(ItemsPerThread == 1) + { + base_type::inclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); + } + else + { + base_type::inclusive_scan(input, output, storage, prefix_callback_op, scan_op); + } + } + + /// \brief Performs exclusive scan across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive min scan operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for float and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_float::storage_type storage; + /// + /// float init = ...; + /// float input = ...; + /// float output; + /// // execute exclusive min scan + /// block_scan_float().exclusive_scan( + /// input, + /// output, + /// init, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, -2, 3, -4, ..., 255, -256} + /// and \p init is \p 0, then \p output values in will be {0, 0, -2, -2, -4, ..., -254, -254}. + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::exclusive_scan(input, output, init, storage, scan_op); + } + + /// \overload + /// \brief Performs exclusive scan across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::exclusive_scan(input, output, init, scan_op); + } + + /// \brief Performs exclusive scan and reduction across threads in a block. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive min scan operations performed on a block of 256 threads, + /// each provides one \p float value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for float and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_float::storage_type storage; + /// + /// float init = 0; + /// float input = ...; + /// float output; + /// float reduction; + /// // execute exclusive min scan + /// block_scan_float().exclusive_scan( + /// input, + /// output, + /// init, + /// reduction, + /// storage, + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, -2, 3, -4, ..., 255, -256} + /// and \p init is \p 0, then \p output values in will be {0, 0, -2, -2, -4, ..., -254, -254} + /// and the \p reduction will be \p -256. + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); + } + + /// \overload + /// \brief Performs exclusive scan and reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + BinaryFunction scan_op = BinaryFunction()) + { + base_type::exclusive_scan(input, output, init, reduction, scan_op); + } + + /// \brief Performs exclusive scan across threads in a block, and uses + /// \p prefix_callback_op to generate prefix value for the whole block. + /// + /// \tparam PrefixCallback - type of the unary function object used for generating + /// block-wide prefix value for the scan operation. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in,out] prefix_callback_op - function object for generating block prefix value. + /// The signature of the \p prefix_callback_op should be equivalent to the following: + /// T f(const T &block_reduction);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the + /// block-wide prefix. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive prefix sum operations performed on a block of 256 threads, + /// each thread provides one \p int value. + /// + /// \code{.cpp} + /// + /// struct my_block_prefix + /// { + /// int prefix; + /// + /// __device__ my_block_prefix(int prefix) : prefix(prefix) {} + /// + /// __device__ int operator()(int block_reduction) + /// { + /// int old_prefix = prefix; + /// prefix = prefix + block_reduction; + /// return old_prefix; + /// } + /// }; + /// + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize block_scan for int and block of 256 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_int::storage_type storage; + /// + /// // init prefix functor + /// my_block_prefix prefix_callback(10); + /// + /// int input; + /// int output; + /// // execute exclusive prefix sum + /// block_scan_int().exclusive_scan( + /// input, + /// output, + /// storage, + /// prefix_callback, + /// rocprim::plus() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, 1, 1, ..., 1}, then + /// \p output values in will be {10, 11, 12, 13, ..., 265}, and the \p prefix will + /// be 266. + /// \endparblock + template< + class PrefixCallback, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + base_type::exclusive_scan(input, output, storage, prefix_callback_op, scan_op); + } + + /// \brief Performs exclusive scan across threads in a block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive maximum scan operations performed on a block of 128 threads, + /// each provides two \p long value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for long and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_long::storage_type storage; + /// + /// long init = ...; + /// long input[2] = ...; + /// long output[2]; + /// // execute exclusive min scan + /// block_scan_long().exclusive_scan( + /// input, + /// output, + /// init, + /// storage, + /// rocprim::maximum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {-1, 2, -3, 4, ..., -255, 256} + /// and \p init is 0, then \p output values in will be {0, 0, 2, 2, 4, ..., 254, 254}. + /// \endparblock + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::exclusive_scan(input[0], output[0], init, storage, scan_op); + } + else + { + base_type::exclusive_scan(input, output, init, storage, scan_op); + } + } + + /// \overload + /// \brief Performs exclusive scan across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::exclusive_scan(input[0], output[0], init, scan_op); + } + else + { + base_type::exclusive_scan(input, output, init, scan_op); + } + } + + /// \brief Performs exclusive scan and reduction across threads in a block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive maximum scan operations performed on a block of 128 threads, + /// each provides two \p long value. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for long and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_long::storage_type storage; + /// + /// long init = ...; + /// long input[2] = ...; + /// long output[2]; + /// long reduction; + /// // execute exclusive min scan + /// block_scan_long().exclusive_scan( + /// input, + /// output, + /// init, + /// reduction, + /// storage, + /// rocprim::maximum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {-1, 2, -3, 4, ..., -255, 256} + /// and \p init is 0, then \p output values in will be {0, 0, 2, 2, 4, ..., 254, 254} + /// and the \p reduction will be \p 256. + /// \endparblock + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::exclusive_scan(input[0], output[0], init, reduction, storage, scan_op); + } + else + { + base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); + } + } + + /// \overload + /// \brief Performs exclusive scan and reduction across threads in a block. + /// + /// * This overload does not accept storage argument. Required shared memory is + /// allocated by the method itself. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a block. + /// \param [out] reduction - result of reducing of all \p input values in a block. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::plus + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + BinaryFunction scan_op = BinaryFunction()) + { + if(ItemsPerThread == 1) + { + base_type::exclusive_scan(input[0], output[0], init, reduction, scan_op); + } + else + { + base_type::exclusive_scan(input, output, init, reduction, scan_op); + } + } + + /// \brief Performs exclusive scan across threads in a block, and uses + /// \p prefix_callback_op to generate prefix value for the whole block. + /// + /// \tparam ItemsPerThread - number of items in the \p input array. + /// \tparam PrefixCallback - type of the unary function object used for generating + /// block-wide prefix value for the scan operation. + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - reference to an array containing thread input values. + /// \param [out] output - reference to a thread output array. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in,out] prefix_callback_op - function object for generating block prefix value. + /// The signature of the \p prefix_callback_op should be equivalent to the following: + /// T f(const T &block_reduction);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the + /// block-wide prefix. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive prefix sum operations performed on a block of 128 threads, + /// each thread provides two \p int value. + /// + /// \code{.cpp} + /// + /// struct my_block_prefix + /// { + /// int prefix; + /// + /// __device__ my_block_prefix(int prefix) : prefix(prefix) {} + /// + /// __device__ int operator()(int block_reduction) + /// { + /// int old_prefix = prefix; + /// prefix = prefix + block_reduction; + /// return old_prefix; + /// } + /// }; + /// + /// __global__ void example_kernel(...) // blockDim.x = 128 + /// { + /// // specialize block_scan for int and block of 128 threads + /// using block_scan_f = rocprim::block_scan; + /// // allocate storage in shared memory for the block + /// __shared__ block_scan_int::storage_type storage; + /// + /// // init prefix functor + /// my_block_prefix prefix_callback(10); + /// + /// int input[2] = ...; + /// int output[2]; + /// // execute exclusive prefix sum + /// block_scan_int().exclusive_scan( + /// input, + /// output, + /// storage, + /// prefix_callback, + /// rocprim::plus() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block are {1, 1, 1, ..., 1}, then + /// \p output values in will be {10, 11, 12, 13, ..., 265}, and the \p prefix will + /// be 266. + /// \endparblock + template< + unsigned int ItemsPerThread, + class PrefixCallback, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + if(ItemsPerThread == 1) + { + base_type::exclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); + } + else + { + base_type::exclusive_scan(input, output, storage, prefix_callback_op, scan_op); + } + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_shuffle.hpp b/3rdparty/cub/rocprim/block/block_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..188f96fa28f62af020c04363d644ea4f43fc93b2 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_shuffle.hpp @@ -0,0 +1,490 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ +#define ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/block_reduce_warp_reduce.hpp" +#include "detail/block_reduce_raking_reduce.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The block_shuffle class is a block level parallel primitive which provides methods +/// for shuffling data partitioned across a block +/// +/// \tparam T - the input/output type. +/// \tparam BlockSizeX - the number of threads in a block's x dimension, it has no defaults value. +/// \tparam BlockSizeY - the number of threads in a block's y dimension, defaults to 1. +/// \tparam BlockSizeZ - the number of threads in a block's z dimension, defaults to 1. +/// +/// \par Overview +/// It is commonplace for blocks of threads to rearrange data items between +/// threads. The BlockShuffle abstraction allows threads to efficiently shift items +/// either (a) up to their successor or (b) down to their predecessor. +/// * Computation can more efficient when: +/// * \p ItemsPerThread is greater than one, +/// * \p T is an arithmetic type, +/// * the number of threads in the block is a multiple of the hardware warp size (see rocprim::warp_size()). +/// +/// \par Examples +/// \parblock +/// In the examples shuffle operation is performed on block of 192 threads, each provides +/// one \p int value, result is returned using the same variable as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize block__shuffle_int for int and logical warp of 192 threads +/// using block__shuffle_int = rocprim::block_shuffle; +/// // allocate storage in shared memory +/// __shared__ block_shuffle::storage_type storage; +/// +/// int value = ...; +/// // execute block shuffle +/// block__shuffle_int().inclusive_up( +/// value, // input +/// value, // output +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1> +class block_shuffle +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + T prev[BlockSize]; + T next[BlockSize]; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Shuffles data across threads in a block, offseted by the distance value. + /// + /// \par A thread with threadId i receives data from a thread with threadIdx (i-distance), whre distance may be a negative value. + /// allocated by the method itself. + /// \par Any shuffle operation with invalid input or output threadIds are not carried out, i.e. threadId < 0 || threadId >= BlockSize. + /// + /// \param [in] input - input data to be shuffled to another thread. + /// \param [out] output - reference to a output value, that receives data from another thread + /// \param [in] distance - The input threadId + distance = output threadId. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block__shuffle_int for int and logical warp of 192 threads + /// using block__shuffle_int = rocprim::block_shuffle; + /// + /// int value = ...; + /// // execute block shuffle + /// block__shuffle_int().offset( + /// value, // input + /// value // output + /// ); + /// ... + /// } + /// \endcode + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void offset(T input, + T& output, + int distance = 1) + { + offset( + ::rocprim::flat_block_thread_id(), + input, output, distance + ); + } + + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void offset(const size_t& flat_id, + T input, + T& output, + int distance) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + offset(flat_id, input, output, distance, storage); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void offset(const size_t& flat_id, + T input, + T& output, + int distance, + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.prev[flat_id] = input; + + ::rocprim::syncthreads(); + + const int offset_tid = static_cast(flat_id) + distance; + if ((offset_tid >= 0) && (offset_tid < (int)BlockSize)) + { + output = storage_.prev[static_cast(offset_tid)]; + } + } + + /// \brief Shuffles data across threads in a block, offseted by the distance value. + /// + /// \par A thread with threadId i receives data from a thread with threadIdx (i-distance)%BlockSize, whre distance may be a negative value. + /// allocated by the method itself. + /// \par Data is rotated around the block, using (input_threadId + distance) modulous BlockSize to ensure valid threadIds. + /// + /// \param [in] input - input data to be shuffled to another thread. + /// \param [out] output - reference to a output value, that receives data from another thread + /// \param [in] distance - The input threadId + distance = output threadId. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block__shuffle_int for int and logical warp of 192 threads + /// using block__shuffle_int = rocprim::block_shuffle; + /// + /// int value = ...; + /// // execute block shuffle + /// block__shuffle_int().rotate( + /// value, // input + /// value // output + /// ); + /// ... + /// } + /// \endcode + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void rotate(T input, + T& output, + unsigned int distance = 1) + { + rotate( + ::rocprim::flat_block_thread_id(), + input, output, distance + ); + } + + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void rotate(const size_t& flat_id, + T input, + T& output, + unsigned int distance) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + rotate(flat_id, input, output, distance, storage); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void rotate(const size_t& flat_id, + T input, + T& output, + unsigned int distance, + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.prev[flat_id] = input; + + ::rocprim::syncthreads(); + + unsigned int offset = threadIdx.x + distance; + if (offset >= BlockSize) + offset -= BlockSize; + + output = storage_.prev[offset]; + } + + + /// \brief The thread block rotates a blocked arrange of input items, + /// shifting it up by one item + /// + /// \param [in] input - The calling thread's input items + /// \param [out] prev - The corresponding predecessor items (may be aliased to \p input). + /// The item \p prev[0] is not updated for thread0. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block__shuffle_int for int and logical warp of 192 threads + /// using block__shuffle_int = rocprim::block_shuffle; + /// + /// int value = ...; + /// // execute block shuffle + /// block__shuffle_int().up( + /// value, // input + /// value // output + /// ); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void up(T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread]) + { + this->up( + ::rocprim::flat_block_thread_id(), + input, prev + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void up(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->up(flat_id, input, prev, storage); + } + + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void up(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread], + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.prev[flat_id] = input[ItemsPerThread -1]; + + ::rocprim::syncthreads(); + + ROCPRIM_UNROLL + for (unsigned int i = ItemsPerThread - 1; i > 0; --i) + { + prev[i] = input[i - 1]; + } + + if (flat_id > 0) + { + prev[0] = storage_.prev[flat_id - 1]; + } + } + + + + /// \brief The thread block rotates a blocked arrange of input items, + /// shifting it up by one item + /// + /// \param [in] input - The calling thread's input items + /// \param [out] prev - The corresponding predecessor items (may be aliased to \p input). + /// The item \p prev[0] is not updated for thread0. + /// \param [out] block_suffix - The item \p input[ItemsPerThread-1] from + /// threadBlockSize-1, provided to all threads + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void up(T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread], + T &block_suffix) + { + this->up( + ::rocprim::flat_block_thread_id(), + input, prev, block_suffix + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void up(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread], + T &block_suffix) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->up(flat_id, input, prev, block_suffix, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void up(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&prev)[ItemsPerThread], + T &block_suffix, + storage_type& storage) + { + up(flat_id, input, prev, storage); + + // Update block prefix + block_suffix = storage->prev[BlockSize - 1]; + } + + /// \brief The thread block rotates a blocked arrange of input items, + /// shifting it down by one item + /// + /// \param [in] input - The calling thread's input items + /// \param [out] next - The corresponding successor items (may be aliased to \p input). + /// The item \p prev[0] is not updated for threadBlockSize - 1. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block__shuffle_int for int and logical warp of 192 threads + /// using block__shuffle_int = rocprim::block_shuffle; + /// + /// int value = ...; + /// // execute block shuffle + /// block__shuffle_int().down( + /// value, // input + /// value // output + /// ); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void down(T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread]) + { + this->down( + ::rocprim::flat_block_thread_id(), + input, next + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void down(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->down(flat_id, input, next, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void down(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread], + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.next[flat_id] = input[0]; + + ::rocprim::syncthreads(); + + ROCPRIM_UNROLL + for (unsigned int i = 0; i < (ItemsPerThread - 1); ++i) + { + next[i] = input[i + 1]; + } + + if (flat_id <(BlockSize -1)) + { + next[ItemsPerThread -1] = storage_.next[flat_id + 1]; + } + } + + /// \brief The thread block rotates a blocked arrange of input items, + /// shifting it down by one item + /// + /// \param [in] input - The calling thread's input items + /// \param [out] next - The corresponding successor items (may be aliased to \p input). + /// The item \p prev[0] is not updated for threadBlockSize - 1. + /// \param [out] block_prefix - The item \p input[0] from thread0, provided to all threads + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void down(T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread], + T &block_prefix) + { + this->down( + ::rocprim::flat_block_thread_id(), + input, next, block_prefix + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void down(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread], + T &block_prefix) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->down(flat_id, input, next, block_prefix, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void down(const size_t& flat_id, + T (&input)[ItemsPerThread], + T (&next)[ItemsPerThread], + T &block_prefix, + storage_type& storage) + { + this->down(flat_id, input, next, storage); + + // Update block prefixstorage_-> + block_prefix = storage->next[0]; + } +}; + + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_SHUFFLE_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_sort.hpp b/3rdparty/cub/rocprim/block/block_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66de33ef056aa44422b9085d3c7b8b72548a9d90 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_sort.hpp @@ -0,0 +1,373 @@ +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_SORT_HPP_ +#define ROCPRIM_BLOCK_BLOCK_SORT_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/block_sort_bitonic.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Available algorithms for block_sort primitive. +enum class block_sort_algorithm +{ + /// \brief A bitonic sort based algorithm. + bitonic_sort, + /// \brief Default block_sort algorithm. + default_algorithm = bitonic_sort, +}; + +namespace detail +{ + +// Selector for block_sort algorithm which gives block sort implementation +// type based on passed block_sort_algorithm enum +template +struct select_block_sort_impl; + +template<> +struct select_block_sort_impl +{ + template + using type = block_sort_bitonic; +}; + +} // end namespace detail + +/// \brief The block_sort class is a block level parallel primitive which provides +/// methods sorting items (keys or key-value pairs) partitioned across threads in a block +/// using comparison-based sort algorithm. +/// +/// \tparam Key - the key type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +/// The total range will be BlockSize * ItemsPerThread long +/// \tparam Value - the value type. Default type empty_type indicates +/// a keys-only sort. +/// \tparam Algorithm - selected sort algorithm, block_sort_algorithm::default_algorithm by default. +/// +/// \par Overview +/// * Accepts custom compare_functions for sorting across a block. +/// * Performance depends on \p BlockSize. +/// * It is better if \p BlockSize is a power of two. +/// * If \p BlockSize is not a power of two, or when function with \p size overload is used +/// odd-even sort is used instead of bitonic sort, leading to decreased performance. +/// +/// \par Examples +/// \parblock +/// In the examples sort is performed on a block of 256 threads, each thread provides +/// one \p int value, results are returned using the same variable as for input. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize block_sort for int, block of 256 threads, +/// // key-only sort +/// using block_sort_int = rocprim::block_sort; +/// // allocate storage in shared memory +/// __shared__ block_sort_int::storage_type storage; +/// +/// int input = ...; +/// // execute block sort (ascending) +/// block_sort_int().sort( +/// input, +/// storage +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class Key, + unsigned int BlockSizeX, + unsigned int ItemsPerThread = 1, + class Value = empty_type, + block_sort_algorithm Algorithm = block_sort_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_sort +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_block_sort_impl::template type +#endif +{ + using base_type = typename detail::select_block_sort_impl::template type; +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Block sort for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key& thread_key, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, compare_function); + } + + template > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_keys, compare_function); + } + + /// \brief Block sort for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples sort is performed on a block of 256 threads, each thread provides + /// one \p int value, results are returned using the same variable as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_sort for int, block of 256 threads, + /// // key-only sort + /// using block_sort_int = rocprim::block_sort; + /// // allocate storage in shared memory + /// __shared__ block_sort_int::storage_type storage; + /// + /// int input = ...; + /// // execute block sort (ascending) + /// block_sort_int().sort( + /// input, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, storage, compare_function); + } + + template > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_keys, storage, compare_function); + } + + /// \brief Block sort by key for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in, out] thread_value - reference to a value provided by a thread. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key& thread_key, + Value& thread_value, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, thread_value, compare_function); + } + + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_keys, thread_values, compare_function); + } + + /// \brief Block sort by key for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in, out] thread_value - reference to a value provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \parblock + /// In the examples sort is performed on a block of 256 threads, each thread provides + /// one \p int key and one \p int value, results are returned using the same variable as for input. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_sort for int, block of 256 threads, + /// using block_sort_int = rocprim::block_sort; + /// // allocate storage in shared memory + /// __shared__ block_sort_int::storage_type storage; + /// + /// int key = ...; + /// int value = ...; + /// // execute block sort (ascending) + /// block_sort_int().sort( + /// key, + /// value, + /// storage + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, thread_value, storage, compare_function); + } + + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_keys, thread_values, storage, compare_function); + } + + /// \brief Block sort by key for any data type. If \p size is + /// greater than \p BlockSize, this function does nothing. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] size - custom size of block to be sorted. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key& thread_key, + storage_type& storage, + const unsigned int size, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, storage, size, compare_function); + } + + /// \brief Block sort by key for any data type. If \p size is + /// greater than \p BlockSize, this function does nothing. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param [in, out] thread_key - reference to a key provided by a thread. + /// \param [in, out] thread_value - reference to a value provided by a thread. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] size - custom size of block to be sorted. + /// \param [in] compare_function - comparison function object which returns true if the + /// first argument is is ordered before the second. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template> + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + const unsigned int size, + BinaryFunction compare_function = BinaryFunction()) + { + base_type::sort(thread_key, thread_value, storage, size, compare_function); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_store.hpp b/3rdparty/cub/rocprim/block/block_store.hpp new file mode 100644 index 0000000000000000000000000000000000000000..95daef94283f9712a1ae7e94e5d9fa67950e9a34 --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_store.hpp @@ -0,0 +1,560 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_STORE_HPP_ +#define ROCPRIM_BLOCK_BLOCK_STORE_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "block_store_func.hpp" +#include "block_exchange.hpp" + +/// \addtogroup blockmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief \p block_store_method enumerates the methods available to store a striped arrangement +/// of items into a blocked/striped arrangement on continuous memory +enum class block_store_method +{ + /// A blocked arrangement of items is stored into a blocked arrangement on continuous + /// memory. + /// \par Performance Notes: + /// * Performance decreases with increasing number of items per thread (stride + /// between reads), because of reduced memory coalescing. + block_store_direct, + + /// A striped arrangement of items is stored into a blocked arrangement on continuous + /// memory. + block_store_striped, + + /// A blocked arrangement of items is stored into a blocked arrangement on continuous + /// memory using vectorization as an optimization. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, provided that + /// vectorization requirements are fulfilled. Otherwise, performance will default + /// to \p block_store_direct. + /// \par Requirements: + /// * The output offset (\p block_output) must be quad-item aligned. + /// * The following conditions will prevent vectorization and switch to default + /// \p block_store_direct: + /// * \p ItemsPerThread is odd. + /// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, + /// int4, etc. + block_store_vectorize, + + /// A blocked arrangement of items is locally transposed and stored as a striped + /// arrangement of data on continuous memory. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p block_store_direct and + /// \p block_store_vectorize due to reordering on local memory. + block_store_transpose, + + /// A blocked arrangement of items is locally transposed and stored as a warp-striped + /// arrangement of data on continuous memory. + /// \par Requirements: + /// * The number of threads in the block must be a multiple of the size of hardware warp. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p block_store_direct and + /// \p block_store_vectorize due to reordering on local memory. + block_store_warp_transpose, + + /// Defaults to \p block_store_direct + default_method = block_store_direct +}; + +/// \brief The \p block_store class is a block level parallel primitive which provides methods +/// for storing an arrangement of items into a blocked/striped arrangement on continous memory. +/// +/// \tparam T - the output/output type. +/// \tparam BlockSize - the number of threads in a block. +/// \tparam ItemsPerThread - the number of items to be processed by +/// each thread. +/// \tparam Method - the method to store data. +/// +/// \par Overview +/// * The \p block_store class has a number of different methods to store data: +/// * [block_store_direct](\ref ::block_store_method::block_store_direct) +/// * [block_store_striped](\ref ::block_store_method::block_store_striped) +/// * [block_store_vectorize](\ref ::block_store_method::block_store_vectorize) +/// * [block_store_transpose](\ref ::block_store_method::block_store_transpose) +/// * [block_store_warp_transpose](\ref ::block_store_method::block_store_warp_transpose) +/// +/// \par Example: +/// \parblock +/// In the examples store operation is performed on block of 128 threads, using type +/// \p int and 8 items per thread. +/// +/// \code{.cpp} +/// __global__ void kernel(int * output) +/// { +/// const int offset = blockIdx.x * 128 * 8; +/// int items[8]; +/// rocprim::block_store blockstore; +/// blockstore.store(output + offset, items); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + block_store_method Method = block_store_method::block_store_direct, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1 +> +class block_store +{ +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords \p __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Stores an arrangement of items from across the thread block into an + /// arrangement on continuous memory. + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator from the thread block to store to. + /// \param [in] items - array that data is read from. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_blocked(flat_id, block_output, items); + } + + /// \brief Stores an arrangement of items from across the thread block into an + /// arrangement on continuous memory, which is guarded by range \p valid. + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator from the thread block to store to. + /// \param [in] items - array that data is read from. + /// \param [in] valid - maximum range of valid numbers to read. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_blocked(flat_id, block_output, items, valid); + } + + /// \brief Stores an arrangement of items from across the thread block into an + /// arrangement on continuous memory, using temporary storage. + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator from the thread block to store to. + /// \param [in] items - array that data is read from. + /// \param [in] storage - temporary storage for outputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void kernel(...) + /// { + /// int items[8]; + /// using block_store_int = rocprim::block_store; + /// block_store_int bstore; + /// __shared__ typename block_store_int::storage_type storage; + /// bstore.store(..., items, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + (void) storage; + store(block_output, items); + } + + /// \brief Stores an arrangement of items from across the thread block into an + /// arrangement on continuous memory, which is guarded by range \p valid, + /// using temporary storage + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator from the thread block to store to. + /// \param [in] items - array that data is read from. + /// \param [in] valid - maximum range of valid numbers to read. + /// \param [in] storage - temporary storage for outputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void kernel(...) + /// { + /// int items[8]; + /// using block_store_int = rocprim::block_store; + /// block_store_int bstore; + /// __shared__ typename block_store_int::storage_type storage; + /// bstore.store(..., items, valid, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + (void) storage; + store(block_output, items, valid); + } +}; + +/// @} +// end of group blockmodule + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ + > +class block_store +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + template + ROCPRIM_DEVICE inline + void store(OutputIterator block_output, + T (&items)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE inline + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_striped(flat_id, block_output, items, valid); + } + + template + ROCPRIM_DEVICE inline + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + (void) storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE inline + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + (void) storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_striped(flat_id, block_output, items, valid); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_store +{ +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(T* block_output, + T (&_items)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_blocked_vectorized(flat_id, block_output, _items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + U (&items)[ItemsPerThread]) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_blocked(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_store_direct_blocked(flat_id, block_output, items, valid); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(T* block_output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + (void) storage; + store(block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + U (&items)[ItemsPerThread], + storage_type& storage) + { + (void) storage; + store(block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + (void) storage; + store(block_output, items, valid); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_store +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; +private: + using block_exchange_type = block_exchange; + +public: + using storage_type = typename block_exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_striped(items, items, storage); + block_store_direct_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_striped(items, items, storage); + block_store_direct_striped(flat_id, block_output, items, valid); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_striped(items, items, storage); + block_store_direct_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_striped(items, items, storage); + block_store_direct_striped(flat_id, block_output, items, valid); + } +}; + +template< + class T, + unsigned int BlockSizeX, + unsigned int ItemsPerThread, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_store +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; +private: + using block_exchange_type = block_exchange; + +public: + static_assert(BlockSize % ::rocprim::device_warp_size() == 0, + "BlockSize must be a multiple of hardware warpsize"); + + using storage_type = typename block_exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_warp_striped(items, items, storage); + block_store_direct_warp_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_warp_striped(items, items, storage); + block_store_direct_warp_striped(flat_id, block_output, items, valid); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_warp_striped(items, items, storage); + block_store_direct_warp_striped(flat_id, block_output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + block_exchange_type().blocked_to_warp_striped(items, items, storage); + block_store_direct_warp_striped(flat_id, block_output, items, valid); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_BLOCK_STORE_HPP_ diff --git a/3rdparty/cub/rocprim/block/block_store_func.hpp b/3rdparty/cub/rocprim/block/block_store_func.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31c1e37cfa8ba4a54276403edfd753393e4b23ea --- /dev/null +++ b/3rdparty/cub/rocprim/block/block_store_func.hpp @@ -0,0 +1,393 @@ +// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ +#define ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup blockmodule +/// @{ + +/// \brief Stores a blocked arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +template< + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_blocked(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread]) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + unsigned int offset = flat_id * ItemsPerThread; + OutputIterator thread_iter = block_output + offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + thread_iter[item] = items[item]; + } +} + +/// \brief Stores a blocked arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory, which is guarded by range \p valid. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +/// \param valid - maximum range of valid numbers to store +template< + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_blocked(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + unsigned int offset = flat_id * ItemsPerThread; + OutputIterator thread_iter = block_output + offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + if (item + offset < valid) + { + thread_iter[item] = items[item]; + } + } +} + +/// \brief Stores a blocked arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory. +/// +/// The block arrangement is assumed to be (block-threads * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// The input offset (\p block_output + offset) must be quad-item aligned. +/// +/// The following conditions will prevent vectorization and switch to default +/// block_load_direct_blocked: +/// * \p ItemsPerThread is odd. +/// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, +/// int4, etc. +/// +/// \tparam T - [inferred] the output data type +/// \tparam U - [inferred] the input data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// The type \p U must be such that it can be implicitly converted to \p T. +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to load from +/// \param items - array that data is loaded to +template< + class T, + class U, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto +block_store_direct_blocked_vectorized(unsigned int flat_id, + T* block_output, + U (&items)[ItemsPerThread]) -> typename std::enable_if::value>::type +{ + static_assert(std::is_convertible::value, + "The type U must be such that it can be implicitly converted to T."); + + typedef typename detail::match_vector_type::type vector_type; + constexpr unsigned int vectors_per_thread = (sizeof(T) * ItemsPerThread) / sizeof(vector_type); + vector_type *vectors_ptr = reinterpret_cast(const_cast(block_output)); + + vector_type raw_vector_items[vectors_per_thread]; + T *raw_items = reinterpret_cast(raw_vector_items); + + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + raw_items[item] = items[item]; + } + + block_store_direct_blocked(flat_id, vectors_ptr, raw_vector_items); +} + +template< + class T, + class U, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto +block_store_direct_blocked_vectorized(unsigned int flat_id, + T* block_output, + U (&items)[ItemsPerThread]) -> typename std::enable_if::value>::type +{ + block_store_direct_blocked(flat_id, block_output, items); +} + +/// \brief Stores a striped arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory. +/// +/// The striped arrangement is assumed to be (\p BlockSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// \tparam BlockSize - the number of threads in a block +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +template< + unsigned int BlockSize, + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_striped(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread]) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + OutputIterator thread_iter = block_output + flat_id; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + thread_iter[item * BlockSize] = items[item]; + } +} + +/// \brief Stores a striped arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory, which is guarded by range \p valid. +/// +/// The striped arrangement is assumed to be (\p BlockSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// \tparam BlockSize - the number of threads in a block +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +/// \param valid - maximum range of valid numbers to store +template< + unsigned int BlockSize, + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_striped(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + OutputIterator thread_iter = block_output + flat_id; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + unsigned int offset = item * BlockSize; + if (flat_id + offset < valid) + { + thread_iter[offset] = items[item]; + } + } +} + +/// \brief Stores a warp-striped arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory. +/// +/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// * The number of threads in the block must be a multiple of \p WarpSize. +/// * The default \p WarpSize is a hardware warpsize and is an optimal value. +/// * \p WarpSize must be a power of two and equal or less than the size of +/// hardware warp. +/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// performance. +/// +/// \tparam WarpSize - [optional] the number of threads in a warp +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +template< + unsigned int WarpSize = device_warp_size(), + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_warp_striped(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread]) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / WarpSize; + unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + + OutputIterator thread_iter = block_output + thread_id + warp_offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + thread_iter[item * WarpSize] = items[item]; + } +} + +/// \brief Stores a warp-striped arrangement of items from across the thread block +/// into a blocked arrangement on continuous memory, which is guarded by range \p valid. +/// +/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// across a thread block. Each thread uses a \p flat_id to store a range of +/// \p ItemsPerThread \p items to the thread block. +/// +/// * The number of threads in the block must be a multiple of \p WarpSize. +/// * The default \p WarpSize is a hardware warpsize and is an optimal value. +/// * \p WarpSize must be a power of two and equal or less than the size of +/// hardware warp. +/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// performance. +/// +/// \tparam WarpSize - [optional] the number of threads in a warp +/// \tparam OutputIterator - [inferred] an iterator type for input (can be a simple +/// pointer +/// \tparam T - [inferred] the data type +/// \tparam ItemsPerThread - [inferred] the number of items to be processed by +/// each thread +/// +/// \param flat_id - a local flat 1D thread id in a block (tile) for the calling thread +/// \param block_output - the input iterator from the thread block to store to +/// \param items - array that data is stored to thread block +/// \param valid - maximum range of valid numbers to store +template< + unsigned int WarpSize = device_warp_size(), + class OutputIterator, + class T, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_store_direct_warp_striped(unsigned int flat_id, + OutputIterator block_output, + T (&items)[ItemsPerThread], + unsigned int valid) +{ + static_assert(std::is_assignable::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and assigned a value of type T."); + + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / WarpSize; + unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + + OutputIterator thread_iter = block_output + thread_id + warp_offset; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + unsigned int offset = item * WarpSize; + if (warp_offset + thread_id + offset < valid) + { + thread_iter[offset] = items[item]; + } + } +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group blockmodule + +#endif // ROCPRIM_BLOCK_BLOCK_STORE_FUNC_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_adjacent_difference_impl.hpp b/3rdparty/cub/rocprim/block/detail/block_adjacent_difference_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d41d450ce2cddd0d321474faeab0da51a468478 --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_adjacent_difference_impl.hpp @@ -0,0 +1,347 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_ + +#include "../../config.hpp" +#include "../../detail/various.hpp" +#include "../../intrinsics/thread.hpp" + +#include + +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Wrapping function that allows to call BinaryFunction of any of these signatures: +// with b_index (a, b, b_index) or without it (a, b). +// Only in the case of discontinuity (when flags_style is true) is the operator allowed to take an +// index +// block_discontinuity and block_adjacent difference only differ in their implementations by the +// order the operators parameters are passed, so this method deals with this as well +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int index, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a, index)) +{ + return op(a, b, index); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int index, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) + -> decltype(op(b, a, index)) +{ + return op(b, a, index); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a)) +{ + return op(a, b); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a)) +{ + return op(b, a); +} + +template +class block_adjacent_difference_impl +{ +public: + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + struct storage_type + { + T items[BlockSize]; + }; + + template + ROCPRIM_DEVICE void apply_left(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_predecessor_item, + storage_type& storage) + { + static constexpr auto as_flags = bool_constant {}; + static constexpr auto reversed = bool_constant {}; + + const unsigned int flat_id + = ::rocprim::flat_block_thread_id(); + + // Save the last item of each thread + storage.items[flat_id] = input[ItemsPerThread - 1]; + + ROCPRIM_UNROLL + for(unsigned int i = ItemsPerThread - 1; i > 0; --i) + { + output[i] = detail::apply( + op, input[i - 1], input[i], flat_id * ItemsPerThread + i, as_flags, reversed); + } + ::rocprim::syncthreads(); + + if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor) + { + T predecessor_item = tile_predecessor_item; + if(flat_id != 0) { + predecessor_item = storage.items[flat_id - 1]; + } + + output[0] = detail::apply( + op, predecessor_item, input[0], flat_id * ItemsPerThread, as_flags, reversed); + } + else + { + output[0] = get_default_item(input, 0, as_flags); + if(flat_id != 0) { + output[0] = detail::apply(op, + storage.items[flat_id - 1], + input[0], + flat_id * ItemsPerThread, + as_flags, + reversed); + } + } + } + + template + ROCPRIM_DEVICE void apply_left_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_predecessor_item, + const unsigned int valid_items, + storage_type& storage) + { + static constexpr auto as_flags = bool_constant {}; + static constexpr auto reversed = bool_constant {}; + + assert(valid_items <= BlockSize * ItemsPerThread); + + const unsigned int flat_id + = ::rocprim::flat_block_thread_id(); + + // Save the last item of each thread + storage.items[flat_id] = input[ItemsPerThread - 1]; + + ROCPRIM_UNROLL + for(unsigned int i = ItemsPerThread - 1; i > 0; --i) + { + const unsigned int index = flat_id * ItemsPerThread + i; + output[i] = get_default_item(input, i, as_flags); + if(index < valid_items) { + output[i] = detail::apply(op, input[i - 1], input[i], index, as_flags, reversed); + } + } + ::rocprim::syncthreads(); + + const unsigned int index = flat_id * ItemsPerThread; + + if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor) + { + T predecessor_item = tile_predecessor_item; + if(flat_id != 0) { + predecessor_item = storage.items[flat_id - 1]; + } + + output[0] = get_default_item(input, 0, as_flags); + if(index < valid_items) + { + output[0] + = detail::apply(op, predecessor_item, input[0], index, as_flags, reversed); + } + } + else + { + output[0] = get_default_item(input, 0, as_flags); + if(flat_id != 0 && index < valid_items) + { + output[0] = detail::apply(op, + storage.items[flat_id - 1], + input[0], + flat_id * ItemsPerThread, + as_flags, + reversed); + } + } + } + + template + ROCPRIM_DEVICE void apply_right(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_successor_item, + storage_type& storage) + { + static constexpr auto as_flags = bool_constant {}; + static constexpr auto reversed = bool_constant {}; + + const unsigned int flat_id + = ::rocprim::flat_block_thread_id(); + + // Save the first item of each thread + storage.items[flat_id] = input[0]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread - 1; ++i) + { + output[i] = detail::apply( + op, input[i], input[i + 1], flat_id * ItemsPerThread + i + 1, as_flags, reversed); + } + ::rocprim::syncthreads(); + + if ROCPRIM_IF_CONSTEXPR (WithTileSuccessor) + { + T successor_item = tile_successor_item; + if(flat_id != BlockSize - 1) { + successor_item = storage.items[flat_id + 1]; + } + + output[ItemsPerThread - 1] = detail::apply(op, + input[ItemsPerThread - 1], + successor_item, + flat_id * ItemsPerThread + ItemsPerThread, + as_flags, + reversed); + } + else + { + output[ItemsPerThread - 1] = get_default_item(input, ItemsPerThread - 1, as_flags); + if(flat_id != BlockSize - 1) { + output[ItemsPerThread - 1] + = detail::apply(op, + input[ItemsPerThread - 1], + storage.items[flat_id + 1], + flat_id * ItemsPerThread + ItemsPerThread, + as_flags, + reversed); + } + } + } + template + ROCPRIM_DEVICE void apply_right_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const unsigned int valid_items, + storage_type& storage) + { + static constexpr auto as_flags = bool_constant {}; + static constexpr auto reversed = bool_constant {}; + + assert(valid_items <= BlockSize * ItemsPerThread); + + const unsigned int flat_id + = ::rocprim::flat_block_thread_id(); + + // Save the first item of each thread + storage.items[flat_id] = input[0]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread - 1; ++i) + { + const unsigned int index = flat_id * ItemsPerThread + i + 1; + output[i] = get_default_item(input, i, as_flags); + if(index < valid_items) + { + output[i] = detail::apply(op, input[i], input[i + 1], index, as_flags, reversed); + } + } + ::rocprim::syncthreads(); + + output[ItemsPerThread - 1] = get_default_item(input, ItemsPerThread - 1, as_flags); + + const unsigned int next_thread_index = flat_id * ItemsPerThread + ItemsPerThread; + if(next_thread_index < valid_items) + { + output[ItemsPerThread - 1] = detail::apply(op, + input[ItemsPerThread - 1], + storage.items[flat_id + 1], + next_thread_index, + as_flags, + reversed); + } + } + +private: + template + ROCPRIM_DEVICE int get_default_item(const T (&)[ItemsPerThread], + unsigned int /*index*/, + bool_constant /*as_flags*/) + { + return 1; + } + + template + ROCPRIM_DEVICE T get_default_item(const T (&input)[ItemsPerThread], + const unsigned int index, + bool_constant /*as_flags*/) + { + return input[index]; + } +}; + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_ADJACENT_DIFFERENCE_IMPL_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_histogram_atomic.hpp b/3rdparty/cub/rocprim/block/detail/block_histogram_atomic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b238c74e336e58f0872b4882de6fc75cc65cd0a9 --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_histogram_atomic.hpp @@ -0,0 +1,89 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ, + unsigned int ItemsPerThread, + unsigned int Bins +> +class block_histogram_atomic +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + static_assert( + std::is_convertible::value, + "T must be convertible to unsigned int" + ); + +public: + using storage_type = typename ::rocprim::detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins]) + { + static_assert( + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "Counter must be type that is supported by atomics (float, int, unsigned int, unsigned long long)" + ); + ROCPRIM_UNROLL + for (unsigned int i = 0; i < ItemsPerThread; ++i) + { + ::rocprim::detail::atomic_add(&hist[static_cast(input[i])], Counter(1)); + } + ::rocprim::syncthreads(); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins], + storage_type& storage) + { + (void) storage; + this->composite(input, hist); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_histogram_sort.hpp b/3rdparty/cub/rocprim/block/detail/block_histogram_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6076f7508c178cf4317a2929ac24adc4685f26ae --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_histogram_sort.hpp @@ -0,0 +1,172 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../block_radix_sort.hpp" +#include "../block_discontinuity.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ, + unsigned int ItemsPerThread, + unsigned int Bins +> +class block_histogram_sort +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + static_assert( + std::is_convertible::value, + "T must be convertible to unsigned int" + ); + +private: + using radix_sort = block_radix_sort; + using discontinuity = block_discontinuity; + +public: + union storage_type_ + { + typename radix_sort::storage_type sort; + struct + { + typename discontinuity::storage_type flag; + unsigned int start[Bins]; + unsigned int end[Bins]; + }; + }; + + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins]) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->composite(input, hist, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void composite(T (&input)[ItemsPerThread], + Counter hist[Bins], + storage_type& storage) + { + // TODO: Check, MSVC rejects the code with the static assertion, yet compiles fine for all tested types. Predicate likely too strict + //static_assert( + // std::is_convertible::value, + // "unsigned int must be convertible to Counter" + //); + constexpr auto tile_size = BlockSize * ItemsPerThread; + const auto flat_tid = ::rocprim::flat_block_thread_id(); + unsigned int head_flags[ItemsPerThread]; + discontinuity_op flags_op(storage); + storage_type_& storage_ = storage.get(); + + radix_sort().sort(input, storage_.sort); + ::rocprim::syncthreads(); // Fix race condition that appeared on Vega10 hardware, storage LDS is reused below. + + ROCPRIM_UNROLL + for(unsigned int offset = 0; offset < Bins; offset += BlockSize) + { + const unsigned int offset_tid = offset + flat_tid; + if(offset_tid < Bins) + { + storage_.start[offset_tid] = tile_size; + storage_.end[offset_tid] = tile_size; + } + } + ::rocprim::syncthreads(); + + discontinuity().flag_heads(head_flags, input, flags_op, storage_.flag); + ::rocprim::syncthreads(); + + // The start of the first bin is not overwritten since the input is sorted + // and the starts are based on the second item. + // The very first item is never used as `b` in the operator + // This means that this should not need synchromization, but in practice it does. + if(flat_tid == 0) + { + storage_.start[static_cast(input[0])] = 0; + } + ::rocprim::syncthreads(); + + ROCPRIM_UNROLL + for(unsigned int offset = 0; offset < Bins; offset += BlockSize) + { + const unsigned int offset_tid = offset + flat_tid; + if(offset_tid < Bins) + { + Counter count = static_cast(storage_.end[offset_tid] - storage_.start[offset_tid]); + hist[offset_tid] += count; + } + } + } + +private: + struct discontinuity_op + { + storage_type &storage; + + ROCPRIM_DEVICE ROCPRIM_INLINE + discontinuity_op(storage_type &storage) : storage(storage) + { + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const T& b, unsigned int b_index) const + { + storage_type_& storage_ = storage.get(); + if(static_cast(a) != static_cast(b)) + { + storage_.start[static_cast(b)] = b_index; + storage_.end[static_cast(a)] = b_index; + return true; + } + else + { + return false; + } + } + }; +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_reduce_raking_reduce.hpp b/3rdparty/cub/rocprim/block/detail/block_reduce_raking_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..38f993b6cb210214255242274de9bc36692e26a5 --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_reduce_raking_reduce.hpp @@ -0,0 +1,308 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../warp/warp_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ, + bool CommutativeOnly = false +> +class block_reduce_raking_reduce +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Number of items to reduce per thread + static constexpr unsigned int thread_reduction_size_ = + (BlockSize + ::rocprim::device_warp_size() - 1)/ ::rocprim::device_warp_size(); + + // Warp reduce, warp_reduce_crosslane does not require shared memory (storage), but + // logical warp size must be a power of two. + static constexpr unsigned int warp_size_ = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + + static constexpr bool commutative_only_ = CommutativeOnly && ((BlockSize % warp_size_ == 0) && (BlockSize > warp_size_)); + static constexpr unsigned int sharing_threads_ = ::rocprim::max(1, BlockSize - warp_size_); + static constexpr unsigned int segment_length_ = sharing_threads_ / warp_size_; + + // BlockSize is multiple of hardware warp + static constexpr bool block_size_smaller_than_warp_size_ = (BlockSize < warp_size_); + using warp_reduce_prefix_type = ::rocprim::detail::warp_reduce_crosslane; + + struct storage_type_ + { + T threads[BlockSize]; + }; + +public: + using storage_type = detail::raw_storage; + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input to be reduced + /// \param output [out] Variable containing reduction output + /// \param storage [in] Temporary Storage used for the Reduction + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) + { + this->reduce_impl( + ::rocprim::flat_block_thread_id(), + input, output, storage, reduce_op + ); + } + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input to be reduced + /// \param output [out] Variable containing reduction output + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, storage, reduce_op); + } + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input array to be reduced + /// \param output [out] Variable containing reduction output + /// \param storage [in] Temporary Storage used for the Reduction + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + storage_type& storage, + BinaryFunction reduce_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = reduce_op(thread_input, input[i]); + } + + // Reduction of reduced values to get partials + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->reduce_impl( + flat_tid, + thread_input, output, // input, output + storage, + reduce_op + ); + } + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input array to be reduced + /// \param output [out] Variable containing reduction output + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, storage, reduce_op); + } + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input partial reductions + /// \param output [out] Variable containing reduction output + /// \param valid_items [in] Number of valid elements (may be less than BlockSize) + /// \param storage [in] Temporary Storage used for reduction + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) + { + this->reduce_impl( + ::rocprim::flat_block_thread_id(), + input, output, valid_items, storage, reduce_op + ); + } + + + /// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread0. + /// \param input [in] Calling thread's input partial reductions + /// \param output [out] Variable containing reduction output + /// \param valid_items [in] Number of valid elements (may be less than BlockSize) + /// \param reduce_op [in] Binary reduction operator + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, valid_items, storage, reduce_op); + } + +private: + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) + -> typename std::enable_if<(!FunctionCommutativeOnly), void>::type + { + storage_type_& storage_ = storage.get(); + storage_.threads[flat_tid] = input; + ::rocprim::syncthreads(); + + if (flat_tid < warp_size_) + { + T thread_reduction = storage_.threads[flat_tid]; + for(unsigned int i = warp_size_ + flat_tid; i < BlockSize; i += warp_size_) + { + thread_reduction = reduce_op( + thread_reduction, storage_.threads[i] + ); + } + warp_reduce( + thread_reduction, output, BlockSize, reduce_op + ); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) + -> typename std::enable_if<(FunctionCommutativeOnly), void>::type + { + storage_type_& storage_ = storage.get(); + + if (flat_tid >= warp_size_) + storage_.threads[flat_tid - warp_size_] = input; + + ::rocprim::syncthreads(); + + if (flat_tid < warp_size_) + { + T thread_reduction = input; + T* storage_pointer = &storage_.threads[flat_tid * segment_length_]; + #pragma unroll + for( unsigned int i = 0; i < segment_length_; i++ ) + { + thread_reduction = reduce_op( + thread_reduction, storage_pointer[i] + ); + } + warp_reduce( + thread_reduction, output, BlockSize, reduce_op + ); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto warp_reduce(T input, + T& output, + const unsigned int valid_items, + BinaryFunction reduce_op) + -> typename std::enable_if::type + { + WarpReduce().reduce( + input, output, valid_items, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto warp_reduce(T input, + T& output, + const unsigned int valid_items, + BinaryFunction reduce_op) + -> typename std::enable_if::type + { + (void) valid_items; + WarpReduce().reduce( + input, output, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce_impl(const unsigned int flat_tid, + T input, + T& output, + const unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) + { + storage_type_& storage_ = storage.get(); + storage_.threads[flat_tid] = input; + ::rocprim::syncthreads(); + + if (flat_tid < warp_size_) + { + T thread_reduction = storage_.threads[flat_tid]; + for(unsigned int i = warp_size_ + flat_tid; i < BlockSize; i += warp_size_) + { + if(i < valid_items) + { + thread_reduction = reduce_op(thread_reduction, storage_.threads[i]); + } + } + warp_reduce_prefix_type().reduce(thread_reduction, output, valid_items, reduce_op); + } + } +}; +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_reduce_warp_reduce.hpp b/3rdparty/cub/rocprim/block/detail/block_reduce_warp_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..590b572202807a3accf01bf89662a55c7a086e96 --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_reduce_warp_reduce.hpp @@ -0,0 +1,271 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../warp/warp_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_reduce_warp_reduce +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Select warp size + static constexpr unsigned int warp_size_ = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + // Number of warps in block + static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_; + + // Check if we have to pass number of valid items into warp reduction primitive + static constexpr bool block_size_is_warp_multiple_ = ((BlockSize % warp_size_) == 0); + static constexpr bool warps_no_is_pow_of_two_ = detail::is_power_of_two(warps_no_); + + // typedef of warp_reduce primitive that will be used to perform warp-level + // reduce operation on input values. + // warp_reduce_crosslane is an implementation of warp_reduce that does not need storage, + // but requires logical warp size to be a power of two. + using warp_reduce_input_type = ::rocprim::detail::warp_reduce_crosslane; + // typedef of warp_reduce primitive that will be used to perform reduction + // of results of warp-level reduction. + using warp_reduce_output_type = ::rocprim::detail::warp_reduce_crosslane< + T, detail::next_power_of_two(warps_no_), false + >; + + struct storage_type_ + { + T warp_partials[warps_no_]; + }; + +public: + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) + { + this->reduce_impl( + ::rocprim::flat_block_thread_id(), + input, output, storage, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, storage, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + storage_type& storage, + BinaryFunction reduce_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = reduce_op(thread_input, input[i]); + } + + // Reduction of reduced values to get partials + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->reduce_impl( + flat_tid, + thread_input, output, // input, output + storage, + reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T (&input)[ItemsPerThread], + T& output, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, storage, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) + { + this->reduce_impl( + ::rocprim::flat_block_thread_id(), + input, output, valid_items, storage, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void reduce(T input, + T& output, + unsigned int valid_items, + BinaryFunction reduce_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->reduce(input, output, valid_items, storage, reduce_op); + } + +private: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) + { + const auto warp_id = ::rocprim::warp_id(flat_tid); + const auto lane_id = ::rocprim::lane_id(); + const unsigned int warp_offset = warp_id * warp_size_; + const unsigned int num_valid = + (warp_offset < BlockSize) ? BlockSize - warp_offset : 0; + storage_type_& storage_ = storage.get(); + + // Perform warp reduce + warp_reduce( + input, output, num_valid, reduce_op + ); + + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + warp_reduce( + warp_partial, output, warps_no_, reduce_op + ); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto warp_reduce(T input, + T& output, + const unsigned int valid_items, + BinaryFunction reduce_op) + -> typename std::enable_if::type + { + WarpReduce().reduce( + input, output, valid_items, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto warp_reduce(T input, + T& output, + const unsigned int valid_items, + BinaryFunction reduce_op) + -> typename std::enable_if::type + { + (void) valid_items; + WarpReduce().reduce( + input, output, reduce_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce_impl(const unsigned int flat_tid, + T input, + T& output, + const unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) + { + const auto warp_id = ::rocprim::warp_id(flat_tid); + const auto lane_id = ::rocprim::lane_id(); + const unsigned int warp_offset = warp_id * warp_size_; + const unsigned int num_valid = + (warp_offset < valid_items) ? valid_items - warp_offset : 0; + storage_type_& storage_ = storage.get(); + + // Perform warp reduce + warp_reduce_input_type().reduce( + input, output, num_valid, reduce_op + ); + + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; + warp_reduce_output_type().reduce( + warp_partial, output, valid_warps_no, reduce_op + ); + } + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_scan_reduce_then_scan.hpp b/3rdparty/cub/rocprim/block/detail/block_scan_reduce_then_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3599b805a4e194e9c4133fa89a99c17e06974e4 --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_scan_reduce_then_scan.hpp @@ -0,0 +1,631 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../warp/warp_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_scan_reduce_then_scan +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Number of items to reduce per thread + static constexpr unsigned int thread_reduction_size_ = + (BlockSize + ::rocprim::device_warp_size() - 1)/ ::rocprim::device_warp_size(); + + // Warp scan, warp_scan_crosslane does not require shared memory (storage), but + // logical warp size must be a power of two. + static constexpr unsigned int warp_size_ = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane; + + // Minimize LDS bank conflicts + static constexpr unsigned int banks_no_ = ::rocprim::detail::get_lds_banks_no(); + static constexpr bool has_bank_conflicts_ = + ::rocprim::detail::is_power_of_two(thread_reduction_size_) && thread_reduction_size_ > 1; + static constexpr unsigned int bank_conflicts_padding = + has_bank_conflicts_ ? (warp_size_ * thread_reduction_size_ / banks_no_) : 0; + + struct storage_type_ + { + T threads[warp_size_ * thread_reduction_size_ + bank_conflicts_padding]; + }; + +public: + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->inclusive_scan(input, output, storage, scan_op); + reduction = storage_.threads[index(BlockSize - 1)]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, reduction, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + storage_type_& storage_ = storage.get(); + this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op); + // Include block prefix (this operation overwrites storage_.threads[0]) + T block_prefix = this->get_block_prefix( + flat_tid, warp_id, + storage_.threads[index(BlockSize - 1)], // block reduction + prefix_callback_op, storage + ); + output = scan_op(block_prefix, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + BinaryFunction scan_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // Include prefix (first thread does not have prefix) + output[0] = input[0]; + if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]); + // Final thread-local scan + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + output[i] = scan_op(output[i-1], input[i]); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->inclusive_scan(input, output, storage, scan_op); + // Save reduction result + reduction = storage_.threads[index(BlockSize - 1)]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, reduction, storage, scan_op); + } + + template< + class PrefixCallback, + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // this operation overwrites storage_.threads[0] + T block_prefix = this->get_block_prefix( + flat_tid, ::rocprim::warp_id(flat_tid), + storage_.threads[index(BlockSize - 1)], // block reduction + prefix_callback_op, storage + ); + + // Include prefix (first thread does not have prefix) + output[0] = input[0]; + if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]); + // Include block prefix + output[0] = scan_op(block_prefix, output[0]); + // Final thread-local scan + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + output[i] = scan_op(output[i-1], input[i]); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl(flat_tid, input, output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + storage_type_& storage_ = storage.get(); + this->exclusive_scan_impl( + flat_tid, input, output, init, storage, scan_op + ); + // Save reduction result + reduction = storage_.threads[index(BlockSize - 1)]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, reduction, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + storage_type_& storage_ = storage.get(); + this->exclusive_scan_impl( + flat_tid, input, output, storage, scan_op + ); + // Get reduction result + T reduction = storage_.threads[index(BlockSize - 1)]; + // Include block prefix (this operation overwrites storage_.threads[0]) + T block_prefix = this->get_block_prefix( + flat_tid, warp_id, reduction, + prefix_callback_op, storage + ); + output = scan_op(block_prefix, output); + if(flat_tid == 0) output = block_prefix; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + storage_type& storage, + BinaryFunction scan_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + init, + storage, + scan_op + ); + + // Include init value + T prev = input[0]; + T exclusive = init; + if(flat_tid != 0) + { + exclusive = thread_input; + } + output[0] = exclusive; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + exclusive = scan_op(exclusive, prev); + prev = input[i]; + output[i] = exclusive; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->exclusive_scan(input, output, init, storage, scan_op); + // Save reduction result + reduction = storage_.threads[index(BlockSize - 1)]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, reduction, storage, scan_op); + } + + template< + class PrefixCallback, + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // this operation overwrites storage_.warp_prefixes[0] + T block_prefix = this->get_block_prefix( + flat_tid, ::rocprim::warp_id(flat_tid), + storage_.threads[index(BlockSize - 1)], // block reduction + prefix_callback_op, storage + ); + + // Include init value and block prefix + T prev = input[0]; + T exclusive = block_prefix; + if(flat_tid != 0) + { + exclusive = scan_op(block_prefix, thread_input); + } + output[0] = exclusive; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + exclusive = scan_op(exclusive, prev); + prev = input[i]; + output[i] = exclusive; + } + } + +private: + + // Calculates inclusive scan results and stores them in storage_.threads, + // result for each thread is stored in storage_.threads[flat_tid], and sets + // output to storage_.threads[flat_tid] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Calculate inclusive scan, + // result for each thread is stored in storage_.threads[flat_tid] + this->inclusive_scan_base(flat_tid, input, storage, scan_op); + output = storage_.threads[index(flat_tid)]; + } + + // Calculates inclusive scan results and stores them in storage_.threads, + // result for each thread is stored in storage_.threads[flat_tid] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan_base(const unsigned int flat_tid, + T input, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + storage_.threads[index(flat_tid)] = input; + ::rocprim::syncthreads(); + if(flat_tid < warp_size_) + { + const unsigned int idx_start = index(flat_tid * thread_reduction_size_); + const unsigned int idx_end = idx_start + thread_reduction_size_; + + T thread_reduction = storage_.threads[idx_start]; + ROCPRIM_UNROLL + for(unsigned int i = idx_start + 1; i < idx_end; i++) + { + thread_reduction = scan_op( + thread_reduction, storage_.threads[i] + ); + } + + // Calculate warp prefixes + warp_scan_prefix_type().inclusive_scan(thread_reduction, thread_reduction, scan_op); + thread_reduction = warp_shuffle_up(thread_reduction, 1, warp_size_); + + // Include warp prefix + thread_reduction = scan_op(thread_reduction, storage_.threads[idx_start]); + if(flat_tid == 0) + { + thread_reduction = input; + } + + storage_.threads[idx_start] = thread_reduction; + ROCPRIM_UNROLL + for(unsigned int i = idx_start + 1; i < idx_end; i++) + { + thread_reduction = scan_op( + thread_reduction, storage_.threads[i] + ); + storage_.threads[i] = thread_reduction; + } + } + ::rocprim::syncthreads(); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Calculates inclusive scan, result for each thread is stored in storage_.threads[flat_tid] + this->inclusive_scan_base(flat_tid, input, storage, scan_op); + output = init; + if(flat_tid != 0) output = scan_op(init, storage_.threads[index(flat_tid-1)]); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Calculates inclusive scan, result for each thread is stored in storage_.threads[flat_tid] + this->inclusive_scan_base(flat_tid, input, storage, scan_op); + if(flat_tid > 0) + { + output = storage_.threads[index(flat_tid-1)]; + } + } + + // OVERWRITES storage_.threads[0] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void include_block_prefix(const unsigned int flat_tid, + const unsigned int warp_id, + const T input, + T& output, + const T reduction, + PrefixCallback& prefix_callback_op, + storage_type& storage, + BinaryFunction scan_op) + { + T block_prefix = this->get_block_prefix( + flat_tid, warp_id, reduction, + prefix_callback_op, storage + ); + output = scan_op(block_prefix, input); + } + + // OVERWRITES storage_.threads[0] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_block_prefix(const unsigned int flat_tid, + const unsigned int warp_id, + const T reduction, + PrefixCallback& prefix_callback_op, + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + if(warp_id == 0) + { + T block_prefix = prefix_callback_op(reduction); + if(flat_tid == 0) + { + // Reuse storage_.threads[0] which should not be + // needed at that point. + storage_.threads[0] = block_prefix; + } + } + ::rocprim::syncthreads(); + return storage_.threads[0]; + } + + // Change index to minimize LDS bank conflicts if necessary + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int index(unsigned int n) const + { + // Move every 32-bank wide "row" (32 banks * 4 bytes) by one item + return has_bank_conflicts_ ? (n + (n/banks_no_)) : n; + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_scan_warp_scan.hpp b/3rdparty/cub/rocprim/block/detail/block_scan_warp_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3708b26ca92cce57ff567d2563e7a6b0dd2240e --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_scan_warp_scan.hpp @@ -0,0 +1,750 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../warp/warp_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ +> +class block_scan_warp_scan +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + // Select warp size + static constexpr unsigned int warp_size_ = + detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); + // Number of warps in block + static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_; + + // typedef of warp_scan primitive that will be used to perform warp-level + // inclusive/exclusive scan operations on input values. + // warp_scan_crosslane is an implementation of warp_scan that does not need storage, + // but requires logical warp size to be a power of two. + using warp_scan_input_type = ::rocprim::detail::warp_scan_crosslane; + // typedef of warp_scan primitive that will be used to get prefix values for + // each warp (scanned carry-outs from warps before it). + using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane; + + struct storage_type_ + { + T warp_prefixes[warps_no_]; + // ---------- Shared memory optimisation ---------- + // Since warp_scan_input and warp_scan_prefix are typedef of warp_scan_crosslane, + // we don't need to allocate any temporary memory for them. + // If we just use warp_scan, we would need to add following union to this struct: + // union + // { + // typename warp_scan_input::storage_type wscan[warps_no_]; + // typename warp_scan_prefix::storage_type wprefix_scan; + // }; + // and use storage_.wscan[warp_id] and storage.wprefix_scan when calling + // warp_scan_input().inclusive_scan(..) and warp_scan_prefix().inclusive_scan(..). + }; + +public: + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + { + this->inclusive_scan_impl( + ::rocprim::flat_block_thread_id(), + input, output, storage, scan_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->inclusive_scan(input, output, storage, scan_op); + // Save reduction result + reduction = storage_.warp_prefixes[warps_no_ - 1]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T input, + T& output, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, reduction, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + storage_type_& storage_ = storage.get(); + this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op); + // Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1]) + T block_prefix = this->get_block_prefix( + flat_tid, warp_id, + storage_.warp_prefixes[warps_no_ - 1], // block reduction + prefix_callback_op, storage + ); + output = scan_op(block_prefix, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + BinaryFunction scan_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // Include prefix (first thread does not have prefix) + output[0] = input[0]; + if(flat_tid != 0) + { + output[0] = scan_op(thread_input, input[0]); + } + + // Final thread-local scan + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + output[i] = scan_op(output[i-1], input[i]); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->inclusive_scan(input, output, storage, scan_op); + // Save reduction result + reduction = storage_.warp_prefixes[warps_no_ - 1]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->inclusive_scan(input, output, reduction, storage, scan_op); + } + + template< + class PrefixCallback, + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // this operation overwrites storage_.warp_prefixes[warps_no_ - 1] + T block_prefix = this->get_block_prefix( + flat_tid, ::rocprim::warp_id(flat_tid), + storage_.warp_prefixes[warps_no_ - 1], // block reduction + prefix_callback_op, storage + ); + + // Include prefix (first thread does not have prefix) + output[0] = input[0]; + if(flat_tid != 0) + { + output[0] = scan_op(thread_input, input[0]); + } + // Include block prefix + output[0] = scan_op(block_prefix, output[0]); + // Final thread-local scan + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + output[i] = scan_op(output[i-1], input[i]); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op) + { + this->exclusive_scan_impl( + ::rocprim::flat_block_thread_id(), + input, output, init, storage, scan_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan( + input, output, init, storage, scan_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->exclusive_scan( + input, output, init, storage, scan_op + ); + // Save reduction result + reduction = storage_.warp_prefixes[warps_no_ - 1]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T input, + T& output, + T init, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan( + input, output, init, reduction, storage, scan_op + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, + T& output, + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + const auto flat_tid = ::rocprim::flat_block_thread_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + storage_type_& storage_ = storage.get(); + this->exclusive_scan_impl( + flat_tid, input, output, storage, scan_op + ); + // Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1]) + T block_prefix = this->get_block_prefix( + flat_tid, warp_id, + storage_.warp_prefixes[warps_no_ - 1], // block reduction + prefix_callback_op, storage + ); + output = scan_op(block_prefix, output); + if(flat_tid == 0) output = block_prefix; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + storage_type& storage, + BinaryFunction scan_op) + { + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + init, + storage, + scan_op + ); + + // Include init value + T prev = input[0]; + T exclusive = init; + if(flat_tid != 0) + { + exclusive = thread_input; + } + output[0] = exclusive; + + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + exclusive = scan_op(exclusive, prev); + prev = input[i]; + output[i] = exclusive; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + this->exclusive_scan(input, output, init, storage, scan_op); + // Save reduction result + reduction = storage_.warp_prefixes[warps_no_ - 1]; + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T init, + T& reduction, + BinaryFunction scan_op) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->exclusive_scan(input, output, init, reduction, storage, scan_op); + } + + template< + class PrefixCallback, + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Reduce thread items + T thread_input = input[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + thread_input = scan_op(thread_input, input[i]); + } + + // Scan of reduced values to get prefixes + const auto flat_tid = ::rocprim::flat_block_thread_id(); + this->exclusive_scan_impl( + flat_tid, + thread_input, thread_input, // input, output + storage, + scan_op + ); + + // this operation overwrites storage_.warp_prefixes[warps_no_ - 1] + T block_prefix = this->get_block_prefix( + flat_tid, ::rocprim::warp_id(flat_tid), + storage_.warp_prefixes[warps_no_ - 1], // block reduction + prefix_callback_op, storage + ); + + // Include init value and block prefix + T prev = input[0]; + T exclusive = block_prefix; + if(flat_tid != 0) + { + exclusive = scan_op(block_prefix, thread_input); + } + output[0] = exclusive; + + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ItemsPerThread; i++) + { + exclusive = scan_op(exclusive, prev); + prev = input[i]; + output[i] = exclusive; + } + } + +private: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type + { + storage_type_& storage_ = storage.get(); + // Perform warp scan + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1] + const auto warp_id = ::rocprim::warp_id(flat_tid); + this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op); + + // Use warp prefix to calculate the final scan results for every thread + if(warp_id != 0) + { + auto warp_prefix = storage_.warp_prefixes[warp_id - 1]; + output = scan_op(warp_prefix, output); + } + } + + // When BlockSize is less than warp_size we dont need the extra prefix calculations. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan_impl(unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if ::rocprim::device_warp_size())>::type + { + (void) storage; + (void) flat_tid; + storage_type_& storage_ = storage.get(); + // Perform warp scan + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + if(flat_tid == BlockSize_ - 1) + { + storage_.warp_prefixes[0] = output; + } + ::rocprim::syncthreads(); + } + + // Exclusive scan with initial value when BlockSize is bigger than warp_size + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type + { + storage_type_& storage_ = storage.get(); + // Perform warp scan on input values + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1] + const auto warp_id = ::rocprim::warp_id(flat_tid); + this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op); + + // Include initial value in warp prefixes, and fix warp prefixes + // for exclusive scan (first warp prefix is init) + auto warp_prefix = init; + if(warp_id != 0) + { + warp_prefix = scan_op(init, storage_.warp_prefixes[warp_id-1]); + } + + // Use warp prefix to calculate the final scan results for every thread + output = scan_op(warp_prefix, output); // include warp prefix in scan results + output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results + if(::rocprim::lane_id() == 0) + { + output = warp_prefix; + } + } + + // Exclusive scan with initial value when BlockSize is less than warp_size. + // When BlockSize is less than warp_size we dont need the extra prefix calculations. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if ::rocprim::device_warp_size())>::type + { + (void) flat_tid; + (void) storage; + (void) init; + storage_type_& storage_ = storage.get(); + // Perform warp scan on input values + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + if(flat_tid == BlockSize_ - 1) + { + storage_.warp_prefixes[0] = output; + } + ::rocprim::syncthreads(); + + // Use warp prefix to calculate the final scan results for every thread + output = scan_op(init, output); // include warp prefix in scan results + output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results + if(::rocprim::lane_id() == 0) + { + output = init; + } + } + + // Exclusive scan with unknown initial value + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type + { + storage_type_& storage_ = storage.get(); + // Perform warp scan on input values + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1] + const auto warp_id = ::rocprim::warp_id(flat_tid); + this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op); + + // Use warp prefix to calculate the final scan results for every thread + T warp_prefix; + if(warp_id != 0) + { + warp_prefix = storage_.warp_prefixes[warp_id - 1]; + output = scan_op(warp_prefix, output); + } + output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results + if(::rocprim::lane_id() == 0) + { + output = warp_prefix; + } + } + + // Exclusive scan with unknown initial value, when BlockSize less than warp_size. + // When BlockSize is less than warp_size we dont need the extra prefix calculations. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan_impl(const unsigned int flat_tid, + T input, + T& output, + storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if ::rocprim::device_warp_size())>::type + { + (void) flat_tid; + (void) storage; + storage_type_& storage_ = storage.get(); + // Perform warp scan on input values + warp_scan_input_type().inclusive_scan( + // not using shared mem, see note in storage_type + input, output, scan_op + ); + + if(flat_tid == BlockSize_ - 1) + { + storage_.warp_prefixes[0] = output; + } + ::rocprim::syncthreads(); + output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results + } + + // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void calculate_warp_prefixes(const unsigned int flat_tid, + const unsigned int warp_id, + T inclusive_input, + storage_type& storage, + BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + // Save the warp reduction result, that is the scan result + // for last element in each warp + if(flat_tid == ::rocprim::min((warp_id+1) * warp_size_, BlockSize_) - 1) + { + storage_.warp_prefixes[warp_id] = inclusive_input; + } + ::rocprim::syncthreads(); + + // Scan the warp reduction results and store in storage_.warp_prefixes + if(flat_tid < warps_no_) + { + auto warp_prefix = storage_.warp_prefixes[flat_tid]; + warp_scan_prefix_type().inclusive_scan( + // not using shared mem, see note in storage_type + warp_prefix, warp_prefix, scan_op + ); + storage_.warp_prefixes[flat_tid] = warp_prefix; + } + ::rocprim::syncthreads(); + } + + // THIS OVERWRITES storage_.warp_prefixes[warps_no_ - 1] + template + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_block_prefix(const unsigned int flat_tid, + const unsigned int warp_id, + const T reduction, + PrefixCallback& prefix_callback_op, + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + if(warp_id == 0) + { + T block_prefix = prefix_callback_op(reduction); + if(flat_tid == 0) + { + // Reuse storage_.warp_prefixes[warps_no_ - 1] to store block prefix + storage_.warp_prefixes[warps_no_ - 1] = block_prefix; + } + } + ::rocprim::syncthreads(); + return storage_.warp_prefixes[warps_no_ - 1]; + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/block/detail/block_sort_bitonic.hpp b/3rdparty/cub/rocprim/block/detail/block_sort_bitonic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..467c7296f6a06382edbf3742321a29b463db4e6e --- /dev/null +++ b/3rdparty/cub/rocprim/block/detail/block_sort_bitonic.hpp @@ -0,0 +1,606 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ +#define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../warp/warp_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class Key, + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ, + unsigned int ItemsPerThread, + class Value +> +class block_sort_bitonic +{ + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + + template + struct storage_type_ + { + KeyType key[BlockSize * ItemsPerThread]; + ValueType value[BlockSize * ItemsPerThread]; + }; + + template + struct storage_type_ + { + KeyType key[BlockSize * ItemsPerThread]; + }; + +public: + using storage_type = detail::raw_storage>; + + static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be a power of two!"); + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + storage_type& storage, + BinaryFunction compare_function) + { + this->sort_impl( + ::rocprim::flat_block_thread_id(), + storage, compare_function, + thread_key + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) + { + this->sort_impl( + ::rocprim::flat_block_thread_id(), + storage, compare_function, + thread_keys + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key& thread_key, + BinaryFunction compare_function) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->sort(thread_key, storage, compare_function); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + BinaryFunction compare_function) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->sort(thread_keys, storage, compare_function); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + BinaryFunction compare_function) + { + this->sort_impl( + ::rocprim::flat_block_thread_id(), + storage, compare_function, + thread_key, thread_value + ); + } + + template + ROCPRIM_DEVICE inline + void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) + { + this->sort_impl( + ::rocprim::flat_block_thread_id(), + storage, compare_function, + thread_keys, thread_values + ); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key& thread_key, + Value& thread_value, + BinaryFunction compare_function) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->sort(thread_key, thread_value, storage, compare_function); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) + { + ROCPRIM_SHARED_MEMORY storage_type storage; + this->sort(thread_keys, thread_values, storage, compare_function); + } + + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, + storage_type& storage, + const unsigned int size, + BinaryFunction compare_function) + { + this->sort_impl( + ::rocprim::flat_block_thread_id(), size, + storage, compare_function, + thread_key + ); + } + +private: + ROCPRIM_DEVICE ROCPRIM_INLINE + void copy_to_shared(Key& k, const unsigned int flat_tid, storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.key[flat_tid] = k; + ::rocprim::syncthreads(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, storage_type& storage) { + storage_type_& storage_ = storage.get(); + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + storage_.key[item * BlockSize + flat_tid] = k[item]; + } + ::rocprim::syncthreads(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void copy_to_shared(Key& k, Value& v, const unsigned int flat_tid, storage_type& storage) + { + storage_type_& storage_ = storage.get(); + storage_.key[flat_tid] = k; + storage_.value[flat_tid] = v; + ::rocprim::syncthreads(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void copy_to_shared(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + const unsigned int flat_tid, + storage_type& storage) + { + storage_type_& storage_ = storage.get(); + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + storage_.key[item * BlockSize + flat_tid] = k[item]; + storage_.value[item * BlockSize + flat_tid] = v[item]; + } + ::rocprim::syncthreads(); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void swap(Key& key, + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) + { + storage_type_& storage_ = storage.get(); + Key next_key = storage_.key[next_id]; + bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key); + bool swap = compare ^ dir; + if(swap) + { + key = next_key; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void swap(Key (&key)[ItemsPerThread], + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) + { + storage_type_& storage_ = storage.get(); + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + Key next_key = storage_.key[item * BlockSize + next_id]; + bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]); + bool swap = compare ^ dir; + if(swap) + { + key[item] = next_key; + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void swap(Key& key, + Value& value, + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) + { + storage_type_& storage_ = storage.get(); + Key next_key = storage_.key[next_id]; + bool b = next_id < flat_tid; + bool compare = compare_function(b ? key : next_key, b ? next_key : key); + bool swap = compare ^ dir; + if(swap) + { + key = next_key; + value = storage_.value[next_id]; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void swap(Key (&key)[ItemsPerThread], + Value (&value)[ItemsPerThread], + const unsigned int flat_tid, + const unsigned int next_id, + const bool dir, + storage_type& storage, + BinaryFunction compare_function) + { + storage_type_& storage_ = storage.get(); + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + Key next_key = storage_.key[item * BlockSize + next_id]; + bool b = next_id < flat_tid; + bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]); + bool swap = compare ^ dir; + if(swap) + { + key[item] = next_key; + value[item] = storage_.value[item * BlockSize + next_id]; + } + } + } + + template< + unsigned int Size, + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type + sort_power_two(const unsigned int flat_tid, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) + { + (void) flat_tid; + (void) storage; + + ::rocprim::warp_sort wsort; + wsort.sort(kv..., compare_function); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void warp_swap(Key& k, Value& v, int mask, bool dir, BinaryFunction compare_function) + { + Key k1 = warp_shuffle_xor(k, mask); + bool swap = compare_function(dir ? k : k1, dir ? k1 : k); + if (swap) + { + k = k1; + v = warp_shuffle_xor(v, mask); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void warp_swap(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + int mask, + bool dir, + BinaryFunction compare_function) + { + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + Key k1 = warp_shuffle_xor(k[item], mask); + bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]); + if (swap) + { + k[item] = k1; + v[item] = warp_shuffle_xor(v[item], mask); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void warp_swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + { + Key k1 = warp_shuffle_xor(k, mask); + bool swap = compare_function(dir ? k : k1, dir ? k1 : k); + if (swap) + { + k = k1; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void warp_swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + { + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + Key k1 = warp_shuffle_xor(k[item], mask); + bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]); + if (swap) + { + k[item] = k1; + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Items < 2)>::type + thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) + { + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_swap(Key (&k)[ItemsPerThread], + Value (&v)[ItemsPerThread], + bool dir, + unsigned int i, + unsigned int j, + BinaryFunction compare_function) + { + if(compare_function(k[i], k[j]) == dir) + { + Key k_temp = k[i]; + k[i] = k[j]; + k[j] = k_temp; + Value v_temp = v[i]; + v[i] = v[j]; + v[j] = v_temp; + } + } + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_swap(Key (&k)[ItemsPerThread], + bool dir, + unsigned int i, + unsigned int j, + BinaryFunction compare_function) + { + if(compare_function(k[i], k[j]) == dir) + { + Key k_temp = k[i]; + k[i] = k[j]; + k[j] = k_temp; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_shuffle(unsigned int offset, bool dir, BinaryFunction compare_function, KeyValue&... kv) + { + ROCPRIM_UNROLL + for(unsigned base = 0; base < ItemsPerThread; base += 2 * offset) + { + ROCPRIM_UNROLL + for(unsigned i = 0; i < offset; ++i) + { + thread_swap(kv..., dir, base + i, base + i + offset, compare_function); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if::type + thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) + { + ROCPRIM_UNROLL + for(unsigned int k = ItemsPerThread / 2; k > 0; k /= 2) + { + thread_shuffle(k, dir, compare_function, kv...); + } + } + + template< + unsigned int Size, + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Size > ::rocprim::device_warp_size())>::type + sort_power_two(const unsigned int flat_tid, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) + { + const auto warp_id_is_even = ((flat_tid / ::rocprim::device_warp_size()) % 2) == 0; + ::rocprim::warp_sort wsort; + auto compare_function2 = + [compare_function, warp_id_is_even](const Key& a, const Key& b) mutable -> bool + { + auto r = compare_function(a, b); + if(warp_id_is_even) + return r; + return !r; + }; + wsort.sort(kv..., compare_function2); + + ROCPRIM_UNROLL + for(unsigned int length = ::rocprim::device_warp_size(); length < Size; length *= 2) + { + const bool dir = (flat_tid & (length * 2)) != 0; + ROCPRIM_UNROLL + for(unsigned int k = length; k > ::rocprim::device_warp_size() / 2; k /= 2) + { + copy_to_shared(kv..., flat_tid, storage); + swap(kv..., flat_tid, flat_tid ^ k, dir, storage, compare_function); + ::rocprim::syncthreads(); + } + + ROCPRIM_UNROLL + for(unsigned int k = ::rocprim::device_warp_size() / 2; k > 0; k /= 2) + { + const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0; + const bool local_dir = length_even ? dir : !dir; + warp_swap(kv..., k, local_dir, compare_function); + } + thread_merge(dir, compare_function, kv...); + } + } + + template< + unsigned int Size, + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if::type + sort_impl(const unsigned int flat_tid, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) + { + static constexpr unsigned int PairSize = sizeof...(KeyValue); + static_assert( + PairSize < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" + ); + + sort_power_two(flat_tid, storage, compare_function, kv...); + } + + // In case BlockSize is not a power-of-two, the slower odd-even mergesort function is used + // instead of the bitonic sort function + template< + unsigned int Size, + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if::type + sort_impl(const unsigned int flat_tid, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) + { + static constexpr unsigned int PairSize = sizeof...(KeyValue); + static_assert( + PairSize < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" + ); + + copy_to_shared(kv..., flat_tid, storage); + + bool is_even = (flat_tid % 2) == 0; + unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, Size - 1); + unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, Size - 1) : ::rocprim::max(flat_tid, 1u) - 1; + + ROCPRIM_UNROLL + for(unsigned int length = 0; length < Size; length++) + { + unsigned int next_id = (length % 2) == 0 ? even_id : odd_id; + swap(kv..., flat_tid, next_id, 0, storage, compare_function); + ::rocprim::syncthreads(); + copy_to_shared(kv..., flat_tid, storage); + } + } + + template< + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_impl(const unsigned int flat_tid, + const unsigned int size, + storage_type& storage, + BinaryFunction compare_function, + KeyValue&... kv) + { + static constexpr unsigned int PairSize = sizeof...(KeyValue); + static_assert( + PairSize < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" + ); + + if(size > BlockSize) + { + return; + } + + copy_to_shared(kv..., flat_tid, storage); + + bool is_even = (flat_tid % 2 == 0); + unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, size - 1); + unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, size - 1) : ::rocprim::max(flat_tid, 1u) - 1; + + for(unsigned int length = 0; length < size; length++) + { + unsigned int next_id = (length % 2 == 0) ? even_id : odd_id; + // Use only "valid" keys to ensure that compare_function will not use garbage keys + // for example, as indices of an array (a lookup table) + if(flat_tid < size) + { + swap(kv..., flat_tid, next_id, 0, storage, compare_function); + } + ::rocprim::syncthreads(); + copy_to_shared(kv..., flat_tid, storage); + } + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ diff --git a/3rdparty/cub/rocprim/config.hpp b/3rdparty/cub/rocprim/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f2710adf9811eb9fc1c79ad666215ffb4d79752f --- /dev/null +++ b/3rdparty/cub/rocprim/config.hpp @@ -0,0 +1,123 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_CONFIG_HPP_ +#define ROCPRIM_CONFIG_HPP_ + +#define BEGIN_ROCPRIM_NAMESPACE \ + namespace rocprim { + +#define END_ROCPRIM_NAMESPACE \ + } /* rocprim */ + +#include + +#include +#include +#include + +#ifndef ROCPRIM_DEVICE + #define ROCPRIM_DEVICE __device__ + #define ROCPRIM_HOST __host__ + #define ROCPRIM_HOST_DEVICE __host__ __device__ + #define ROCPRIM_SHARED_MEMORY __shared__ + #ifdef WIN32 + #define ROCPRIM_KERNEL __global__ static + #else + #define ROCPRIM_KERNEL __global__ + #endif + // TODO: These paremeters should be tuned for NAVI in the close future. + #ifndef ROCPRIM_DEFAULT_MAX_BLOCK_SIZE + #define ROCPRIM_DEFAULT_MAX_BLOCK_SIZE 256 + #endif + #ifndef ROCPRIM_DEFAULT_MIN_WARPS_PER_EU + #define ROCPRIM_DEFAULT_MIN_WARPS_PER_EU 1 + #endif + // Currently HIP on Windows has a bug involving inline device functions generating + // local memory/register allocation errors during compilation. Current workaround is to + // use __attribute__((always_inline)) for the affected functions + #ifdef WIN32 + #define ROCPRIM_INLINE inline __attribute__((always_inline)) + #else + #define ROCPRIM_INLINE inline + #endif + #define ROCPRIM_FORCE_INLINE __attribute__((always_inline)) +#endif + +#ifndef ROCPRIM_DISABLE_DPP + #define ROCPRIM_DETAIL_USE_DPP true +#endif + +#ifdef ROCPRIM_DISABLE_LOOKBACK_SCAN + #define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN false +#else + #define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN true +#endif + +#ifndef ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS + #define ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS 1 +#endif + +#ifndef ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS + #define ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS 1 +#endif + + +// Defines targeted AMD architecture. Supported values: +// * 803 (gfx803) +// * 900 (gfx900) +// * 906 (gfx906) +// * 908 (gfx908) +// * 910 (gfx90a) +#ifndef ROCPRIM_TARGET_ARCH + #define ROCPRIM_TARGET_ARCH 0 +#endif + +#if (__gfx1010__ || __gfx1011__ || __gfx1012__ || __gfx1030__ || __gfx1031__) + #define ROCPRIM_NAVI 1 +#else + #define ROCPRIM_NAVI 0 +#endif +#define ROCPRIM_ARCH_90a 910 + +/// Supported warp sizes +#define ROCPRIM_WARP_SIZE_32 32u +#define ROCPRIM_WARP_SIZE_64 64u +#define ROCPRIM_MAX_WARP_SIZE ROCPRIM_WARP_SIZE_64 + +#if (defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__)) +#define ROCPRIM_UNROLL +#define ROCPRIM_NO_UNROLL +#else +#define ROCPRIM_UNROLL _Pragma("unroll") +#define ROCPRIM_NO_UNROLL _Pragma("nounroll") +#endif + +#ifndef ROCPRIM_GRID_SIZE_LIMIT +#define ROCPRIM_GRID_SIZE_LIMIT std::numeric_limits::max() +#endif + +#if __cpp_if_constexpr >= 201606 +#define ROCPRIM_IF_CONSTEXPR constexpr +#else +#define ROCPRIM_IF_CONSTEXPR +#endif + +#endif // ROCPRIM_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/detail/all_true.hpp b/3rdparty/cub/rocprim/detail/all_true.hpp new file mode 100644 index 0000000000000000000000000000000000000000..29176548dc33fad4ca1515f788a506cb9d3c86d5 --- /dev/null +++ b/3rdparty/cub/rocprim/detail/all_true.hpp @@ -0,0 +1,52 @@ +// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_ALL_TRUE_HPP_ +#define ROCPRIM_DETAIL_ALL_TRUE_HPP_ + +#include + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + + +// all_of +template +struct all_true : std::true_type +{ +}; + +template +struct all_true : all_true +{ +}; + +template +struct all_true : std::false_type +{ +}; + +} // end namespace detail +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_ALL_TRUE_HPP_ diff --git a/3rdparty/cub/rocprim/detail/binary_op_wrappers.hpp b/3rdparty/cub/rocprim/detail/binary_op_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aff4dd17127000f8af5bfeeb137c4f02f07d69e1 --- /dev/null +++ b/3rdparty/cub/rocprim/detail/binary_op_wrappers.hpp @@ -0,0 +1,132 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_ +#define ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_ + +#include + +#include "../config.hpp" +#include "../intrinsics.hpp" +#include "../types.hpp" +#include "../functional.hpp" + +#include "../detail/various.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class BinaryFunction, + class ResultType = typename BinaryFunction::result_type, + class InputType = typename BinaryFunction::input_type +> +struct reverse_binary_op_wrapper +{ + using result_type = ResultType; + using input_type = InputType; + + ROCPRIM_HOST_DEVICE inline + reverse_binary_op_wrapper() = default; + + ROCPRIM_HOST_DEVICE inline + reverse_binary_op_wrapper(BinaryFunction binary_op) + : binary_op_(binary_op) + { + } + + ROCPRIM_HOST_DEVICE inline + ~reverse_binary_op_wrapper() = default; + + ROCPRIM_HOST_DEVICE inline + result_type operator()(const input_type& t1, const input_type& t2) + { + return binary_op_(t2, t1); + } + +private: + BinaryFunction binary_op_; +}; + +// Wrapper for performing head-flagged scan +template +struct headflag_scan_op_wrapper +{ + static_assert(std::is_convertible::value, "F must be convertible to bool"); + + using result_type = rocprim::tuple; + using input_type = result_type; + + ROCPRIM_HOST_DEVICE inline + headflag_scan_op_wrapper() = default; + + ROCPRIM_HOST_DEVICE inline + headflag_scan_op_wrapper(BinaryFunction scan_op) + : scan_op_(scan_op) + { + } + + ROCPRIM_HOST_DEVICE inline + ~headflag_scan_op_wrapper() = default; + + ROCPRIM_HOST_DEVICE inline + result_type operator()(const input_type& t1, const input_type& t2) + { + return rocprim::make_tuple(!rocprim::get<1>(t2) + ? scan_op_(rocprim::get<0>(t1), rocprim::get<0>(t2)) + : rocprim::get<0>(t2), + F {rocprim::get<1>(t2) || rocprim::get<1>(t1)}); + } + +private: + BinaryFunction scan_op_; +}; + + +template +struct inequality_wrapper +{ + using equality_op_type = EqualityOp; + + ROCPRIM_HOST_DEVICE inline + inequality_wrapper() = default; + + ROCPRIM_HOST_DEVICE inline + inequality_wrapper(equality_op_type equality_op) + : equality_op(equality_op) + {} + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T &a, const U &b) + { + return !equality_op(a, b); + } + + equality_op_type equality_op; +}; + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_ diff --git a/3rdparty/cub/rocprim/detail/match_result_type.hpp b/3rdparty/cub/rocprim/detail/match_result_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3add1f1a219d500f45dc52e48131ab5fddda88f2 --- /dev/null +++ b/3rdparty/cub/rocprim/detail/match_result_type.hpp @@ -0,0 +1,111 @@ +// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ +#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ + +#include + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +// invoke_result is based on std::invoke_result. +// The main difference is using ROCPRIM_HOST_DEVICE, this allows to +// use invoke_result with device-only lambdas/functors in host-only functions +// on HIP-clang. + +template +struct is_reference_wrapper : std::false_type {}; +template +struct is_reference_wrapper> : std::true_type {}; + +template +struct invoke_impl { + template + ROCPRIM_HOST_DEVICE + static auto call(F&& f, Args&&... args) + -> decltype(std::forward(f)(std::forward(args)...)); +}; + +template +struct invoke_impl +{ + template::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> T&&; + + template::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> decltype(t.get()); + + template::type, + class = typename std::enable_if::value>::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> decltype(*std::forward(t)); + + template::value>::type + > + ROCPRIM_HOST_DEVICE + static auto call(MT1 B::*pmf, T&& t, Args&&... args) + -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); + + template + ROCPRIM_HOST_DEVICE + static auto call(MT B::*pmd, T&& t) + -> decltype(invoke_impl::get(std::forward(t)).*pmd); +}; + +template::type> +ROCPRIM_HOST_DEVICE +auto INVOKE(F&& f, Args&&... args) + -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); + +// Conforming C++14 implementation (is also a valid C++11 implementation): +template +struct invoke_result_impl { }; +template +struct invoke_result_impl(), std::declval()...))), F, Args...> +{ + using type = decltype(INVOKE(std::declval(), std::declval()...)); +}; + +template +struct invoke_result : invoke_result_impl {}; + +template +struct match_result_type +{ + using type = typename invoke_result::type; +}; + +} // end namespace detail +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ diff --git a/3rdparty/cub/rocprim/detail/radix_sort.hpp b/3rdparty/cub/rocprim/detail/radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a679fb02b96415f8d00e7f59f7ee28b2cf61d8a2 --- /dev/null +++ b/3rdparty/cub/rocprim/detail/radix_sort.hpp @@ -0,0 +1,255 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_ +#define ROCPRIM_DETAIL_RADIX_SORT_HPP_ + +#include + +#include "../config.hpp" +#include "../type_traits.hpp" + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +// Encode and decode integral and floating point values for radix sort in such a way that preserves +// correct order of negative and positive keys (i.e. negative keys go before positive ones, +// which is not true for a simple reinterpetation of the key's bits). + +// Digit extractor takes into account that (+0.0 == -0.0) is true for floats, +// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction. +// Maximum digit length is 32. + +template +struct radix_key_codec_integral { }; + +template +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bit_key_type encode(Key key) + { + return __builtin_bit_cast(bit_key_type, key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static Key decode(bit_key_type bit_key) + { + return __builtin_bit_cast(Key, bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct radix_key_codec_integral::value>::type> +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bit_key_type encode(Key key) + { + const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); + return sign_bit ^ bit_key; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static Key decode(bit_key_type bit_key) + { + bit_key ^= sign_bit; + return __builtin_bit_cast(Key, bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template +struct float_bit_mask; + +template<> +struct float_bit_mask +{ + static constexpr uint32_t sign_bit = 0x80000000; + static constexpr uint32_t exponent = 0x7F800000; + static constexpr uint32_t mantissa = 0x007FFFFF; + using bit_type = uint32_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint64_t sign_bit = 0x8000000000000000; + static constexpr uint64_t exponent = 0x7FF0000000000000; + static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF; + using bit_type = uint64_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; + +template<> +struct float_bit_mask +{ + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template +struct radix_key_codec_floating +{ + using bit_key_type = BitKey; + + static constexpr bit_key_type sign_bit = float_bit_mask::sign_bit; + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bit_key_type encode(Key key) + { + bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key); + bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1); + return bit_key; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static Key decode(bit_key_type bit_key) + { + bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit; + return __builtin_bit_cast(Key, bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + + // -0.0 should be treated as +0.0 for stable sort + // -0.0 is encoded as inverted sign_bit, +0.0 as sign_bit + // or vice versa for descending sort + bit_key_type key = bit_key == sign_bit ? bit_key_type(~sign_bit) : bit_key; + + return static_cast(key >> start) & mask; + } +}; + +template +struct radix_key_codec_base +{ + static_assert(sizeof(Key) == 0, + "Only integral and floating point types supported as radix sort keys"); +}; + +template +struct radix_key_codec_base< + Key, + typename std::enable_if<::rocprim::is_integral::value>::type +> : radix_key_codec_integral::type> { }; + +template<> +struct radix_key_codec_base +{ + using bit_key_type = unsigned char; + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bit_key_type encode(bool key) + { + return static_cast(key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bool decode(bit_key_type bit_key) + { + return static_cast(bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } +}; + +template<> +struct radix_key_codec_base<::rocprim::half> : radix_key_codec_floating<::rocprim::half, unsigned short> { }; + +template<> +struct radix_key_codec_base<::rocprim::bfloat16> : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> { }; + +template<> +struct radix_key_codec_base : radix_key_codec_floating { }; + +template<> +struct radix_key_codec_base : radix_key_codec_floating { }; + +template +class radix_key_codec : protected radix_key_codec_base +{ + using base_type = radix_key_codec_base; + +public: + using bit_key_type = typename base_type::bit_key_type; + + ROCPRIM_DEVICE ROCPRIM_INLINE + static bit_key_type encode(Key key) + { + bit_key_type bit_key = base_type::encode(key); + return (Descending ? ~bit_key : bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static Key decode(bit_key_type bit_key) + { + bit_key = (Descending ? ~bit_key : bit_key); + return base_type::decode(bit_key); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) + { + return base_type::extract_digit(bit_key, start, radix_bits); + } +}; + +} // end namespace detail +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/detail/various.hpp b/3rdparty/cub/rocprim/detail/various.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f248f6fba0decb827212b984cf63ec3a31db8f06 --- /dev/null +++ b/3rdparty/cub/rocprim/detail/various.hpp @@ -0,0 +1,318 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DETAIL_VARIOUS_HPP_ +#define ROCPRIM_DETAIL_VARIOUS_HPP_ + +#include + +#include "../config.hpp" +#include "../types.hpp" +#include "../type_traits.hpp" + +// TODO: Refactor when it gets crowded + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +struct empty_storage_type +{ + +}; + +template +ROCPRIM_HOST_DEVICE inline +constexpr bool is_power_of_two(const T x) +{ + static_assert(::rocprim::is_integral::value, "T must be integer type"); + return (x > 0) && ((x & (x - 1)) == 0); +} + +template +ROCPRIM_HOST_DEVICE inline +constexpr T next_power_of_two(const T x, const T acc = 1) +{ + static_assert(::rocprim::is_unsigned::value, "T must be unsigned type"); + return acc >= x ? acc : next_power_of_two(x, 2 * acc); +} + +template < + typename T, + typename U, + std::enable_if_t<::rocprim::is_integral::value && ::rocprim::is_unsigned::value, int> = 0> +ROCPRIM_HOST_DEVICE inline constexpr auto ceiling_div(const T a, const U b) +{ + return a / b + (a % b > 0 ? 1 : 0); +} + +ROCPRIM_HOST_DEVICE inline +size_t align_size(size_t size, size_t alignment = 256) +{ + return ceiling_div(size, alignment) * alignment; +} + +// TOOD: Put the block algorithms with warp size variables at device side with macro. +// Temporary workaround +template +ROCPRIM_HOST_DEVICE inline +constexpr T warp_size_in_class(const T warp_size) +{ + return warp_size; +} + +// Select the minimal warp size for block of size block_size, it's +// useful for blocks smaller than maximal warp size. +template +ROCPRIM_HOST_DEVICE inline +constexpr T get_min_warp_size(const T block_size, const T max_warp_size) +{ + static_assert(::rocprim::is_unsigned::value, "T must be unsigned type"); + return block_size >= max_warp_size ? max_warp_size : next_power_of_two(block_size); +} + +template +struct is_warpsize_shuffleable { + static const bool value = detail::is_power_of_two(WarpSize); +}; + +// Selects an appropriate vector_type based on the input T and size N. +// The byte size is calculated and used to select an appropriate vector_type. +template +struct match_vector_type +{ + static constexpr unsigned int size = sizeof(T) * N; + using vector_base_type = + typename std::conditional< + sizeof(T) >= 4, + int, + typename std::conditional< + sizeof(T) >= 2, + short, + char + >::type + >::type; + + using vector_4 = typename make_vector_type::type; + using vector_2 = typename make_vector_type::type; + using vector_1 = typename make_vector_type::type; + + using type = + typename std::conditional< + size % sizeof(vector_4) == 0, + vector_4, + typename std::conditional< + size % sizeof(vector_2) == 0, + vector_2, + vector_1 + >::type + >::type; +}; + +// Checks if Items is odd and ensures that size of T is smaller than vector_type. +template +struct is_vectorizable : std::integral_constant::type))> {}; + +// Returns the number of LDS (local data share) banks. +ROCPRIM_HOST_DEVICE +constexpr unsigned int get_lds_banks_no() +{ + // Currently all devices supported by ROCm have 32 banks (4 bytes each) + return 32; +} + +// Finds biggest fundamental type for type T that sizeof(T) is +// a multiple of that type's size. +template +struct match_fundamental_type +{ + using type = + typename std::conditional< + sizeof(T)%8 == 0, + unsigned long long, + typename std::conditional< + sizeof(T)%4 == 0, + unsigned int, + typename std::conditional< + sizeof(T)%2 == 0, + unsigned short, + unsigned char + >::type + >::type + >::type; +}; + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto store_volatile(T * output, T value) + -> typename std::enable_if::value>::type +{ + // TODO: check GCC + // error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers +#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__)) + *const_cast(output) = value; +#else + *output = value; +#endif +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto store_volatile(T * output, T value) + -> typename std::enable_if::value>::type +{ + using fundamental_type = typename match_fundamental_type::type; + constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type); + + auto input_ptr = reinterpret_cast(&value); + auto output_ptr = reinterpret_cast(output); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < n; i++) + { + output_ptr[i] = input_ptr[i]; + } +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto load_volatile(T * input) + -> typename std::enable_if::value, T>::type +{ + // TODO: check GCC + // error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers +#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__)) + T retval = *const_cast(input); + return retval; +#else + return *input; +#endif +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto load_volatile(T * input) + -> typename std::enable_if::value, T>::type +{ + using fundamental_type = typename match_fundamental_type::type; + constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type); + + T retval; + auto output_ptr = reinterpret_cast(&retval); + auto input_ptr = reinterpret_cast(input); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < n; i++) + { + output_ptr[i] = input_ptr[i]; + } + return retval; +} + +// A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions +template +struct raw_storage +{ + // Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T + typedef typename detail::match_fundamental_type::type device_word; + + // Backing storage + device_word storage[sizeof(T) / sizeof(device_word)]; + + // Alias + ROCPRIM_HOST_DEVICE T& get() + { + return reinterpret_cast(*this); + } +}; + +// Checks if two iterators have the same type and value +template +inline +bool are_iterators_equal(Iterator1, Iterator2) +{ + return false; +} + +template +inline +bool are_iterators_equal(Iterator iter1, Iterator iter2) +{ + return iter1 == iter2; +} + +template +using void_t = void; + +template +struct type_identity { + using type = T; +}; + +template +struct extract_type_impl : type_identity { }; + +template +struct extract_type_impl > : extract_type_impl { }; + +template +using extract_type = typename extract_type_impl::type; + +template +struct select_type_case +{ + static constexpr bool value = Value; + using type = T; +}; + +template +struct select_type_impl + : std::conditional< + Case::value, + type_identity>, + select_type_impl + >::type { }; + +template +struct select_type_impl> : type_identity> { }; + +template +struct select_type_impl> +{ + static_assert( + sizeof(T) == 0, + "Cannot select any case. " + "The last case must have true condition or be a fallback type." + ); +}; + +template +struct select_type_impl : type_identity> { }; + +template +using select_type = typename select_type_impl::type; + +template +using bool_constant = std::integral_constant; + +} // end namespace detail +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DETAIL_VARIOUS_HPP_ diff --git a/3rdparty/cub/rocprim/device/config_types.hpp b/3rdparty/cub/rocprim/device/config_types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a44c614e63d53bde66fce2ce8dca1cfd65083497 --- /dev/null +++ b/3rdparty/cub/rocprim/device/config_types.hpp @@ -0,0 +1,126 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_CONFIG_TYPES_HPP_ +#define ROCPRIM_DEVICE_CONFIG_TYPES_HPP_ + +#include + +#include "../config.hpp" +#include "../intrinsics/thread.hpp" +#include "../detail/various.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Special type used to show that the given device-level operation +/// will be executed with optimal configuration dependent on types of the function's parameters +/// and the target device architecture specified by ROCPRIM_TARGET_ARCH. +struct default_config { }; + +/// \brief Configuration of particular kernels launched by device-level operation +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +template +struct kernel_config +{ + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Number of items processed by a single kernel launch. + static constexpr unsigned int size_limit = SizeLimit; +}; + +namespace detail +{ + +template< + unsigned int MaxBlockSize, + unsigned int SharedMemoryPerThread, + // Most kernels require block sizes not smaller than warp + unsigned int MinBlockSize, + // Can fit in shared memory? + // Although GPUs have 64KiB, 32KiB is used here as a "soft" limit, + // because some additional memory may be required in kernels + bool = (MaxBlockSize * SharedMemoryPerThread <= (1u << 15)) +> +struct limit_block_size +{ + // No, then try to decrease block size + static constexpr unsigned int value = + limit_block_size< + detail::next_power_of_two(MaxBlockSize) / 2, + SharedMemoryPerThread, + MinBlockSize + >::value; +}; + +template< + unsigned int MaxBlockSize, + unsigned int SharedMemoryPerThread, + unsigned int MinBlockSize +> +struct limit_block_size +{ + static_assert(MaxBlockSize >= MinBlockSize, "Data is too large, it cannot fit in shared memory"); + + static constexpr unsigned int value = MaxBlockSize; +}; + +template +struct select_arch_case +{ + static constexpr unsigned int arch = Arch; + using type = T; +}; + +template +struct select_arch + : std::conditional< + Case::arch == TargetArch, + extract_type, + select_arch + >::type { }; + +template +struct select_arch : extract_type { }; + +template +using default_or_custom_config = + typename std::conditional< + std::is_same::value, + Default, + Config + >::type; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_CONFIG_TYPES_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_adjacent_difference.hpp b/3rdparty/cub/rocprim/device/detail/device_adjacent_difference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..383ee401df21d5f104d24c692398e95239baf8bd --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_adjacent_difference.hpp @@ -0,0 +1,264 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#include "../../block/block_adjacent_difference.hpp" +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" + +#include "../../detail/various.hpp" + +#include "../../config.hpp" + +#include + +#include + +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct adjacent_diff_helper +{ + using adjacent_diff_type = ::rocprim::block_adjacent_difference; + using storage_type = typename adjacent_diff_type::storage_type; + + template + ROCPRIM_DEVICE void dispatch(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const InputIt previous_values, + const unsigned int block_id, + const std::size_t starting_block, + const std::size_t num_blocks, + const std::size_t size, + storage_type& storage, + bool_constant /*in_place*/, + std::false_type /*right*/) + { + static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + // Not the first block, i.e. has a predecessor + if(starting_block + block_id != 0) + { + // `previous_values` needs to be accessed with a stride of `items_per_block` if the + // operation is out-of-place + const unsigned int block_offset = InPlace ? block_id : block_id * items_per_block; + const InputIt block_previous_values = previous_values + block_offset; + + const T tile_predecessor = block_previous_values[-1]; + // Not the last (i.e. full block) + if(starting_block + block_id != num_blocks - 1) + { + adjacent_diff_type {}.subtract_left(input, output, op, tile_predecessor, storage); + } + else + { + const unsigned int valid_items + = static_cast(size - (num_blocks - 1) * items_per_block); + adjacent_diff_type {}.subtract_left_partial( + input, output, op, tile_predecessor, valid_items, storage); + } + } + else + { + // Not the last (i.e. full block) + if(starting_block + block_id != num_blocks - 1) + { + adjacent_diff_type {}.subtract_left(input, output, op, storage); + } + else + { + const unsigned int valid_items + = static_cast(size - (num_blocks - 1) * items_per_block); + adjacent_diff_type {}.subtract_left_partial( + input, output, op, valid_items, storage); + } + } + } + + template + ROCPRIM_DEVICE void dispatch(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + const BinaryFunction op, + const InputIt previous_values, + const unsigned int block_id, + const std::size_t starting_block, + const std::size_t num_blocks, + const std::size_t size, + storage_type& storage, + bool_constant /*in_place*/, + std::true_type /*right*/) + { + static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + // Not the last (i.e. full) block and has a successor + if(starting_block + block_id != num_blocks - 1) + { + // `previous_values` needs to be accessed with a stride of `items_per_block` if the + // operation is out-of-place + // When in-place, the first block does not save its value (since it won't be used) + // so the block values are shifted right one. This means that next block's first value + // is in the position `block_id` + const unsigned int block_offset = InPlace ? block_id : (block_id + 1) * items_per_block; + + const InputIt next_block_values = previous_values + block_offset; + const T tile_successor = *next_block_values; + + adjacent_diff_type {}.subtract_right(input, output, op, tile_successor, storage); + } + else + { + const unsigned int valid_items + = static_cast(size - (num_blocks - 1) * items_per_block); + adjacent_diff_type {}.subtract_right_partial(input, output, op, valid_items, storage); + } + } +}; + +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto select_previous_values_iterator(T* previous_values, + InputIterator /*input*/, + std::true_type /*in_place*/) +{ + return previous_values; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE auto select_previous_values_iterator(T* /*previous_values*/, + InputIterator input, + std::false_type /*in_place*/) +{ + return input; +} + +template +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void adjacent_difference_kernel_impl( + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op, + const typename std::iterator_traits::value_type* previous_values, + const std::size_t starting_block) +{ + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + + static constexpr unsigned int block_size = Config::block_size; + static constexpr unsigned int items_per_thread = Config::items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + + using block_load_type + = ::rocprim::block_load; + using block_store_type + = ::rocprim::block_store; + + using adjacent_helper = adjacent_diff_helper; + + ROCPRIM_SHARED_MEMORY struct + { + typename block_load_type::storage_type load; + typename adjacent_helper::storage_type adjacent_diff; + typename block_store_type::storage_type store; + } storage; + + const unsigned int block_id = blockIdx.x; + const unsigned int block_offset = block_id * items_per_block; + + const std::size_t num_blocks = ceiling_div(size, items_per_block); + + input_type thread_input[items_per_thread]; + if(starting_block + block_id < num_blocks - 1) + { + block_load_type {}.load(input + block_offset, thread_input, storage.load); + } + else + { + const unsigned int valid_items + = static_cast(size - (num_blocks - 1) * items_per_block); + block_load_type {}.load(input + block_offset, thread_input, valid_items, storage.load); + } + ::rocprim::syncthreads(); + + // Type tags for tag dispatch. + static constexpr auto in_place = bool_constant {}; + static constexpr auto right = bool_constant {}; + + // When doing the operation in-place the last/first items of each block have been copied out + // in advance and written to the contiguos locations, since accessing them would be a data race + // with the writing of their new values. In this case `select_previous_values_iterator` returns + // a pointer to the copied values, and it should be addressed by block_id. + // Otherwise (when the transform is out-of-place) it just returns the input iterator, and the + // first/last values of the blocks can be accessed with a stride of `items_per_block` + const auto previous_values_it + = select_previous_values_iterator(previous_values, input, in_place); + + output_type thread_output[items_per_thread]; + // Do tag dispatch on `right` to select either `subtract_right` or `subtract_left`. + // Note that the function is overloaded on its last parameter. + adjacent_helper {}.dispatch(thread_input, + thread_output, + op, + previous_values_it, + block_id, + starting_block, + num_blocks, + size, + storage.adjacent_diff, + in_place, + right); + ::rocprim::syncthreads(); + + if(starting_block + block_id < num_blocks - 1) + { + block_store_type {}.store(output + block_offset, thread_output, storage.store); + } + else + { + const unsigned int valid_items + = static_cast(size - (num_blocks - 1) * items_per_block); + block_store_type {}.store(output + block_offset, thread_output, valid_items, storage.store); + } +} + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_ \ No newline at end of file diff --git a/3rdparty/cub/rocprim/device/detail/device_binary_search.hpp b/3rdparty/cub/rocprim/device/detail/device_binary_search.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d8c040ef3fa091917c2a4fc71f1dedf88cf47da --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_binary_search.hpp @@ -0,0 +1,120 @@ +// Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +Size get_binary_search_middle(Size left, Size right) +{ + const Size d = right - left; + return left + d / 2 + d / 64; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +Size lower_bound_n(RandomAccessIterator first, + Size size, + const T& value, + BinaryPredicate compare_op) +{ + Size left = 0; + Size right = size; + while(left < right) + { + const Size mid = get_binary_search_middle(left, right); + if(compare_op(first[mid], value)) + { + left = mid + 1; + } + else + { + right = mid; + } + } + return left; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +Size upper_bound_n(RandomAccessIterator first, + Size size, + const T& value, + BinaryPredicate compare_op) +{ + Size left = 0; + Size right = size; + while(left < right) + { + const Size mid = get_binary_search_middle(left, right); + if(compare_op(value, first[mid])) + { + right = mid; + } + else + { + left = mid + 1; + } + } + return left; +} + +struct lower_bound_search_op +{ + template + ROCPRIM_DEVICE ROCPRIM_INLINE + Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + return lower_bound_n(haystack, size, value, compare_op); + } +}; + +struct upper_bound_search_op +{ + template + ROCPRIM_DEVICE ROCPRIM_INLINE + Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + return upper_bound_n(haystack, size, value, compare_op); + } +}; + +struct binary_search_op +{ + template + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + const Size n = lower_bound_n(haystack, size, value, compare_op); + return n != size && !compare_op(value, haystack[n]); + } +}; + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_config_helper.hpp b/3rdparty/cub/rocprim/device/detail/device_config_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c25c217625bcbe7c1597fe4ec2358d8cd9e7398b --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_config_helper.hpp @@ -0,0 +1,64 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../block/block_reduce.hpp" + +#include "../config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level reduce primitives. +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +/// \tparam BlockReduceMethod - algorithm for block reduce. +/// \tparam SizeLimit - limit on the number of items reduced by a single launch +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + ::rocprim::block_reduce_algorithm BlockReduceMethod, + unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT +> +struct reduce_config +{ + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Algorithm for block reduce. + static constexpr block_reduce_algorithm block_reduce_method = BlockReduceMethod; + /// \brief Limit on the number of items reduced by a single launch + static constexpr unsigned int size_limit = SizeLimit; +}; + +END_ROCPRIM_NAMESPACE + +#endif //ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_histogram.hpp b/3rdparty/cub/rocprim/device/detail/device_histogram.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e7615c593a18aed08d30c5d351d4ddf855451f5a --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_histogram.hpp @@ -0,0 +1,553 @@ +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_ + +#include +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../block/block_load.hpp" + +#include "uint_fast_div.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Special wrapper for passing fixed-length arrays (i.e. T values[Size]) into kernels +template +class fixed_array +{ +private: + T values[Size]; + +public: + + ROCPRIM_HOST_DEVICE inline + fixed_array(const T values[Size]) + { + for(unsigned int i = 0; i < Size; i++) + { + this->values[i] = values[i]; + } + } + + ROCPRIM_HOST_DEVICE inline + T& operator[](unsigned int index) + { + return values[index]; + } + + ROCPRIM_HOST_DEVICE inline + const T& operator[](unsigned int index) const + { + return values[index]; + } +}; + +template +struct sample_to_bin_even +{ + unsigned int bins; + Level lower_level; + Level upper_level; + Level scale; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even() = default; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level) + : bins(bins), + lower_level(lower_level), + upper_level(upper_level), + scale((upper_level - lower_level) / bins) + {} + + template + ROCPRIM_HOST_DEVICE inline + bool operator()(Sample sample, unsigned int& bin) const + { + const Level s = static_cast(sample); + if(s >= lower_level && s < upper_level) + { + bin = static_cast((s - lower_level) / scale); + return true; + } + return false; + } +}; + +// This specialization uses fast division (uint_fast_div) for integers smaller than 64 bit +template +struct sample_to_bin_even::value && (sizeof(Level) <= 4)>::type> +{ + unsigned int bins; + Level lower_level; + Level upper_level; + uint_fast_div scale; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even() = default; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level) + : bins(bins), + lower_level(lower_level), + upper_level(upper_level), + scale((upper_level - lower_level) / bins) + {} + + template + ROCPRIM_HOST_DEVICE inline + bool operator()(Sample sample, unsigned int& bin) const + { + const Level s = static_cast(sample); + if(s >= lower_level && s < upper_level) + { + bin = static_cast(s - lower_level) / scale; + return true; + } + return false; + } +}; + +// This specialization uses multiplication by inv divisor for floats +template +struct sample_to_bin_even::value>::type> +{ + unsigned int bins; + Level lower_level; + Level upper_level; + Level inv_scale; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even() = default; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level) + : bins(bins), + lower_level(lower_level), + upper_level(upper_level), + inv_scale(bins / (upper_level - lower_level)) + {} + + template + ROCPRIM_HOST_DEVICE inline + bool operator()(Sample sample, unsigned int& bin) const + { + const Level s = static_cast(sample); + if(s >= lower_level && s < upper_level) + { + bin = static_cast((s - lower_level) * inv_scale); + return true; + } + return false; + } +}; + +// Returns index of the first element in values that is greater than value, or count if no such element is found. +template +ROCPRIM_HOST_DEVICE inline +unsigned int upper_bound(const T * values, unsigned int count, T value) +{ + unsigned int current = 0; + while(count > 0) + { + const unsigned int step = count / 2; + const unsigned int next = current + step; + if(value < values[next]) + { + count = step; + } + else + { + current = next + 1; + count -= step + 1; + } + } + return current; +} + +template +struct sample_to_bin_range +{ + unsigned int bins; + const Level * level_values; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_range() = default; + + ROCPRIM_HOST_DEVICE inline + sample_to_bin_range(unsigned int bins, const Level * level_values) + : bins(bins), level_values(level_values) + {} + + template + ROCPRIM_HOST_DEVICE inline + bool operator()(Sample sample, unsigned int& bin) const + { + const Level s = static_cast(sample); + bin = upper_bound(level_values, bins + 1, s) - 1; + return bin < bins; + } +}; + +template +struct sample_vector +{ + T values[Size]; +}; + +// Checks if it is possible to load 2 or 4 sample_vector as one 32-bit value +template< + unsigned int ItemsPerThread, + unsigned int Channels, + class Sample +> +struct is_sample_vectorizable + : std::integral_constant< + bool, + ((sizeof(Sample) * Channels == 1) || (sizeof(Sample) * Channels == 2)) && + (sizeof(Sample) * Channels * ItemsPerThread % sizeof(int) == 0) && + (sizeof(Sample) * Channels * ItemsPerThread / sizeof(int) > 0) + > { }; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + class Sample +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::value>::type +load_samples(unsigned int flat_id, + Sample * samples, + sample_vector (&values)[ItemsPerThread]) +{ + using packed_samples_type = int[sizeof(Sample) * Channels * ItemsPerThread / sizeof(int)]; + + if(reinterpret_cast(samples) % sizeof(int) == 0) + { + // the pointer is aligned by 4 bytes + block_load_direct_striped( + flat_id, + reinterpret_cast(samples), + reinterpret_cast(values) + ); + } + else + { + block_load_direct_striped( + flat_id, + reinterpret_cast *>(samples), + values + ); + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + class Sample +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::value>::type +load_samples(unsigned int flat_id, + Sample * samples, + sample_vector (&values)[ItemsPerThread]) +{ + block_load_direct_striped( + flat_id, + reinterpret_cast *>(samples), + values + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + class Sample, + class SampleIterator +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void load_samples(unsigned int flat_id, + SampleIterator samples, + sample_vector (&values)[ItemsPerThread]) +{ + Sample tmp[Channels * ItemsPerThread]; + block_load_direct_blocked( + flat_id, + samples, + tmp + ); + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + for(unsigned int channel = 0; channel < Channels; channel++) + { + values[i].values[channel] = tmp[i * Channels + channel]; + } + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + class Sample, + class SampleIterator +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void load_samples(unsigned int flat_id, + SampleIterator samples, + sample_vector (&values)[ItemsPerThread], + unsigned int valid_count) +{ + Sample tmp[Channels * ItemsPerThread]; + block_load_direct_blocked( + flat_id, + samples, + tmp, + valid_count * Channels + ); + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + for(unsigned int channel = 0; channel < Channels; channel++) + { + values[i].values[channel] = tmp[i * Channels + channel]; + } + } +} + +template< + unsigned int BlockSize, + unsigned int ActiveChannels, + class Counter +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void init_histogram(fixed_array histogram, + fixed_array bins) +{ + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + + const unsigned int index = block_id * BlockSize + flat_id; + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + if(index < bins[channel]) + { + histogram[channel][index] = 0; + } + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + unsigned int ActiveChannels, + class SampleIterator, + class Counter, + class SampleToBinOp +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void histogram_shared(SampleIterator samples, + unsigned int columns, + unsigned int rows, + unsigned int row_stride, + unsigned int rows_per_block, + fixed_array histogram, + fixed_array sample_to_bin_op, + fixed_array bins, + unsigned int * block_histogram_start) +{ + using sample_type = typename std::iterator_traits::value_type; + using sample_vector_type = sample_vector; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int block_id0 = ::rocprim::detail::block_id<0>(); + const unsigned int block_id1 = ::rocprim::detail::block_id<1>(); + const unsigned int grid_size0 = ::rocprim::detail::grid_size<0>(); + + unsigned int * block_histogram[ActiveChannels]; + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + block_histogram[channel] = block_histogram_start; + block_histogram_start += bins[channel]; + for(unsigned int bin = flat_id; bin < bins[channel]; bin += BlockSize) + { + block_histogram[channel][bin] = 0; + } + } + ::rocprim::syncthreads(); + + const unsigned int start_row = block_id1 * rows_per_block; + const unsigned int end_row = ::rocprim::min(rows, start_row + rows_per_block); + for(unsigned int row = start_row; row < end_row; row++) + { + SampleIterator row_samples = samples + row * row_stride; + + unsigned int block_offset = block_id0 * items_per_block; + while(block_offset < columns) + { + sample_vector_type values[ItemsPerThread]; + + if(block_offset + items_per_block <= columns) + { + load_samples(flat_id, row_samples + Channels * block_offset, values); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + unsigned int bin; + if(sample_to_bin_op[channel](values[i].values[channel], bin)) + { + ::rocprim::detail::atomic_add(&block_histogram[channel][bin], 1); + } + } + } + } + else + { + const unsigned int valid_count = columns - block_offset; + load_samples(flat_id, row_samples + Channels * block_offset, values, valid_count); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(flat_id * ItemsPerThread + i < valid_count) + { + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + unsigned int bin; + if(sample_to_bin_op[channel](values[i].values[channel], bin)) + { + ::rocprim::detail::atomic_add(&block_histogram[channel][bin], 1); + } + } + } + } + } + + block_offset += grid_size0 * items_per_block; + } + } + ::rocprim::syncthreads(); + + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + for(unsigned int bin = flat_id; bin < bins[channel]; bin += BlockSize) + { + if(block_histogram[channel][bin] > 0) + { + ::rocprim::detail::atomic_add(&histogram[channel][bin], block_histogram[channel][bin]); + } + } + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + unsigned int ActiveChannels, + class SampleIterator, + class Counter, + class SampleToBinOp +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void histogram_global(SampleIterator samples, + unsigned int columns, + unsigned int row_stride, + fixed_array histogram, + fixed_array sample_to_bin_op, + fixed_array bins_bits) +{ + using sample_type = typename std::iterator_traits::value_type; + using sample_vector_type = sample_vector; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int block_id0 = ::rocprim::detail::block_id<0>(); + const unsigned int block_id1 = ::rocprim::detail::block_id<1>(); + const unsigned int block_offset = block_id0 * items_per_block; + + samples += block_id1 * row_stride + Channels * block_offset; + + sample_vector_type values[ItemsPerThread]; + unsigned int valid_count; + if(block_offset + items_per_block <= columns) + { + valid_count = items_per_block; + load_samples(flat_id, samples, values); + } + else + { + valid_count = columns - block_offset; + load_samples(flat_id, samples, values, valid_count); + } + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + unsigned int bin; + if(sample_to_bin_op[channel](values[i].values[channel], bin)) + { + const unsigned int pos = flat_id * ItemsPerThread + i; + lane_mask_type same_bin_lanes_mask = ::rocprim::ballot(pos < valid_count); + for(unsigned int b = 0; b < bins_bits[channel]; b++) + { + const unsigned int bit_set = bin & (1u << b); + const lane_mask_type bit_set_mask = ::rocprim::ballot(bit_set); + same_bin_lanes_mask &= (bit_set ? bit_set_mask : ~bit_set_mask); + } + const unsigned int same_bin_count = ::rocprim::bit_count(same_bin_lanes_mask); + const unsigned int prev_same_bin_count = ::rocprim::masked_bit_count(same_bin_lanes_mask); + if(prev_same_bin_count == 0) + { + // Write the number of lanes having this bin, + // if the current lane is the first (and maybe only) lane with this bin. + ::rocprim::detail::atomic_add(&histogram[channel][bin], same_bin_count); + } + } + } + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_merge.hpp b/3rdparty/cub/rocprim/device/detail/device_merge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db7df6af5f508e8b3c35f2debad865570f43a546 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_merge.hpp @@ -0,0 +1,447 @@ +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_store.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +struct range_t +{ + unsigned int begin1; + unsigned int end1; + unsigned int begin2; + unsigned int end2; + + ROCPRIM_DEVICE ROCPRIM_INLINE + constexpr unsigned int count1() const + { + return end1 - begin1; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + constexpr unsigned int count2() const + { + return end2 - begin2; + } +}; + +ROCPRIM_DEVICE ROCPRIM_INLINE +range_t compute_range(const unsigned int id, + const unsigned int size1, + const unsigned int size2, + const unsigned int spacing, + const unsigned int p1, + const unsigned int p2) +{ + unsigned int diag1 = id * spacing; + unsigned int diag2 = min(size1 + size2, diag1 + spacing); + + return range_t{p1, p2, diag1 - p1, diag2 - p2}; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + const OffsetT input1_size, + const OffsetT input2_size, + const OffsetT diag, + BinaryFunction compare_function) +{ + using key_type_1 = typename std::iterator_traits::value_type; + using key_type_2 = typename std::iterator_traits::value_type; + + OffsetT begin = diag < input2_size ? 0u : diag - input2_size; + OffsetT end = min(diag, input1_size); + + while(begin < end) + { + OffsetT a = (begin + end) / 2; + OffsetT b = diag - 1 - a; + key_type_1 input_a = keys_input1[a]; + key_type_2 input_b = keys_input2[b]; + if(!compare_function(input_b, input_a)) + { + begin = a + 1; + } + else + { + end = a; + } + } + + return begin; +} + +template< + class IndexIterator, + class KeysInputIterator1, + class KeysInputIterator2, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void partition_kernel_impl(IndexIterator indices, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + const size_t input1_size, + const size_t input2_size, + const unsigned int spacing, + BinaryFunction compare_function) +{ + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int flat_block_size = ::rocprim::detail::block_size<0>(); + + unsigned int id = flat_block_id * flat_block_size + flat_id; + + unsigned int partition_id = id * spacing; + size_t diag = min(static_cast(partition_id), input1_size + input2_size); + + unsigned int begin = + merge_path( + keys_input1, + keys_input2, + input1_size, + input2_size, + diag, + compare_function + ); + + indices[id] = begin; +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator1, + class KeysInputIterator2, + class KeyType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void load(unsigned int flat_id, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeyType * keys_shared, + const size_t input1_size, + const size_t input2_size) +{ + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + unsigned int index = BlockSize * i + flat_id; + if(index < input1_size) + { + keys_shared[index] = keys_input1[index]; + } + else if(index < input1_size + input2_size) + { + keys_shared[index] = keys_input2[index - input1_size]; + } + } + + ::rocprim::syncthreads(); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +void serial_merge(KeyType * keys_shared, + KeyType (&inputs)[ItemsPerThread], + unsigned int (&index)[ItemsPerThread], + range_t range, + BinaryFunction compare_function) +{ + KeyType a = keys_shared[range.begin1]; + KeyType b = keys_shared[range.begin2]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + bool compare = (range.begin2 >= range.end2) || + ((range.begin1 < range.end1) && !compare_function(b, a)); + unsigned int x = compare ? range.begin1 : range.begin2; + + inputs[i] = compare ? a : b; + index[i] = x; + + KeyType c = keys_shared[++x]; + if(compare) + { + a = c; + range.begin1 = x; + } + else + { + b = c; + range.begin2 = x; + } + } + ::rocprim::syncthreads(); +} + +template< + unsigned int BlockSize, + class KeysInputIterator1, + class KeysInputIterator2, + class KeyType, + unsigned int ItemsPerThread, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void merge_keys(unsigned int flat_id, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeyType (&key_inputs)[ItemsPerThread], + unsigned int (&index)[ItemsPerThread], + KeyType * keys_shared, + range_t range, + BinaryFunction compare_function) +{ + load( + flat_id, keys_input1 + range.begin1, keys_input2 + range.begin2, + keys_shared, range.count1(), range.count2() + ); + + range_t range_local = + range_t { + 0, range.count1(), range.count1(), + (range.count1() + range.count2()) + }; + + unsigned int diag = ItemsPerThread * flat_id; + unsigned int partition = + merge_path( + keys_shared + range_local.begin1, + keys_shared + range_local.begin2, + range_local.count1(), + range_local.count2(), + diag, + compare_function + ); + + range_t range_partition = + range_t { + range_local.begin1 + partition, + range_local.end1, + range_local.begin2 + diag - partition, + range_local.end2 + }; + + serial_merge( + keys_shared, key_inputs, index, range_partition, + compare_function + ); +} + +template< + bool WithValues, + unsigned int BlockSize, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::type +merge_values(unsigned int flat_id, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + unsigned int (&index)[ItemsPerThread], + const size_t input1_size, + const size_t input2_size) +{ + using value_type = typename std::iterator_traits::value_type; + + unsigned int count = input1_size + input2_size; + + value_type values[ItemsPerThread]; + + if(count >= ItemsPerThread * BlockSize) + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + values[i] = (index[i] < input1_size) ? values_input1[index[i]] : + values_input2[index[i] - input1_size]; + } + } + else + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + if(flat_id * ItemsPerThread + i < count) + { + values[i] = (index[i] < input1_size) ? values_input1[index[i]] : + values_input2[index[i] - input1_size]; + } + } + } + + ::rocprim::syncthreads(); + + block_store_direct_blocked( + flat_id, + values_output, + values, + count + ); +} + +template< + bool WithValues, + unsigned int BlockSize, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::type +merge_values(unsigned int flat_id, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + unsigned int (&index)[ItemsPerThread], + const size_t input1_size, + const size_t input2_size) +{ + (void) flat_id; + (void) values_input1; + (void) values_input2; + (void) values_output; + (void) index; + (void) input1_size; + (void) input2_size; +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class IndexIterator, + class KeysInputIterator1, + class KeysInputIterator2, + class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void merge_kernel_impl(IndexIterator indices, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + using keys_store_type = ::rocprim::block_store< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_store_method::block_store_transpose + >; + constexpr bool with_values = !std::is_same::value; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + constexpr unsigned int input_block_size = BlockSize * ItemsPerThread + 1; + + ROCPRIM_SHARED_MEMORY union + { + typename detail::raw_storage keys_shared; + typename keys_store_type::storage_type keys_store; + } storage; + + key_type input[ItemsPerThread]; + unsigned int index[ItemsPerThread]; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int count = input1_size + input2_size; + const unsigned int valid_in_last_block = count - block_offset; + const bool is_incomplete_block = valid_in_last_block < items_per_block; + + const unsigned int p1 = indices[flat_block_id]; + const unsigned int p2 = indices[flat_block_id + 1]; + + range_t range = + compute_range( + flat_block_id, input1_size, input2_size, items_per_block, + p1, p2 + ); + + merge_keys( + flat_id, keys_input1, keys_input2, input, index, + storage.keys_shared.get(), + range, compare_function + ); + + ::rocprim::syncthreads(); + + if(is_incomplete_block) // # elements in last block may not equal items_per_block for the last block + { + keys_store_type().store( + keys_output + block_offset, + input, + valid_in_last_block, + storage.keys_store + ); + } + else + { + keys_store_type().store( + keys_output + block_offset, + input, + storage.keys_store + ); + } + + merge_values( + flat_id, values_input1 + range.begin1, values_input2 + range.begin2, + values_output + block_offset, index, + range.count1(), range.count2() + ); +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_merge_sort.hpp b/3rdparty/cub/rocprim/device/detail/device_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91dc4c0ce94149f2023ae75ade3bd8d4f28007df --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_merge_sort.hpp @@ -0,0 +1,529 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR next +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR nextWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR next DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_sort.hpp" +#include "../../block/block_store.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class Key +> +struct block_load_keys_impl { + using block_load_type = ::rocprim::block_load; + + using storage_type = typename block_load_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysInputIterator keys_input, + Key (&keys)[ItemsPerThread], + storage_type& storage) + { + if(is_incomplete_block) + { + block_load_type().load( + keys_input + block_offset, + keys, + valid_in_last_block, + storage + ); + } + else + { + block_load_type().load( + keys_input + block_offset, + keys, + storage + ); + } + + } +}; + +template +struct block_load_values_impl +{ + using storage_type = empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(const unsigned int flat_id, + const unsigned int (&ranks)[ItemsPerThread], + const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + ValuesInputIterator values_input, + Value (&values)[ItemsPerThread], + storage_type& storage) + { + (void) flat_id; + (void) ranks; + (void) block_offset; + (void) valid_in_last_block; + (void) is_incomplete_block; + (void) values_input; + (void) values; + (void) storage; + } +}; + +template +struct block_load_values_impl +{ + using block_exchange = ::rocprim::block_exchange; + + using storage_type = typename block_exchange::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(const unsigned int flat_id, + const unsigned int (&ranks)[ItemsPerThread], + const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + ValuesInputIterator values_input, + Value (&values)[ItemsPerThread], + storage_type& storage) + { + if(is_incomplete_block) + { + block_load_direct_striped( + flat_id, + values_input + block_offset, + values, + valid_in_last_block + ); + } + else + { + block_load_direct_striped( + flat_id, + values_input + block_offset, + values + ); + } + + // Synchronize before reusing shared memory + ::rocprim::syncthreads(); + block_exchange().gather_from_striped(values, values, ranks, storage); + } +}; + +template< + bool WithValues, + unsigned int BlockSize, + unsigned int ItemsPerThread, + class Key, + class Value +> +struct block_store_impl { + using block_store_type + = block_store; + + using storage_type = typename block_store_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysOutputIterator keys_output, + ValuesOutputIterator values_output, + Key (&keys)[ItemsPerThread], + Value (&values)[ItemsPerThread], + storage_type& storage) + { + (void) values_output; + (void) values; + + // Synchronize before reusing shared memory + ::rocprim::syncthreads(); + + if(is_incomplete_block) + { + block_store_type().store( + keys_output + block_offset, + keys, + valid_in_last_block, + storage + ); + } + else + { + block_store_type().store( + keys_output + block_offset, + keys, + storage + ); + } + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class Key, + class Value +> +struct block_store_impl { + using block_store_key_type = block_store; + using block_store_value_type = block_store; + + union storage_type { + typename block_store_key_type::storage_type keys; + typename block_store_value_type::storage_type values; + }; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysOutputIterator keys_output, + ValuesOutputIterator values_output, + Key (&keys)[ItemsPerThread], + Value (&values)[ItemsPerThread], + storage_type& storage) + { + // Synchronize before reusing shared memory + ::rocprim::syncthreads(); + + if(is_incomplete_block) + { + block_store_key_type().store( + keys_output + block_offset, + keys, + valid_in_last_block, + storage.keys + ); + + ::rocprim::syncthreads(); + + block_store_value_type().store( + values_output + block_offset, + values, + valid_in_last_block, + storage.values + ); + } + else + { + block_store_key_type().store( + keys_output + block_offset, + keys, + storage.keys + ); + + ::rocprim::syncthreads(); + + block_store_value_type().store( + values_output + block_offset, + values, + storage.values + ); + } + } +}; + +template +struct block_sort_impl +{ + using stable_key_type = rocprim::tuple; + using block_sort_type = ::rocprim::block_sort; + + using storage_type = typename block_sort_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(stable_key_type (&keys)[ItemsPerThread], + storage_type& storage, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + BinaryFunction compare_function) + { + if(is_incomplete_block) + { + // Special comparison that sorts out of range values after any "valid" values + auto oor_compare + = [compare_function, valid_in_last_block]( + const stable_key_type& lhs, const stable_key_type& rhs) mutable -> bool { + const bool left_oor = rocprim::get<1>(lhs) >= valid_in_last_block; + const bool right_oor = rocprim::get<1>(rhs) >= valid_in_last_block; + return (left_oor || right_oor) ? !left_oor : compare_function(lhs, rhs); + }; + block_sort_type().sort(keys, // keys_input + storage, + oor_compare); + } + else + { + block_sort_type() + .sort( + keys, // keys_input + storage, + compare_function + ); + } + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetT, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void block_sort_kernel_impl(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + BinaryFunction compare_function) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + + const unsigned int flat_id = block_thread_id<0>(); + const unsigned int flat_block_id = block_id<0>(); + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + const OffsetT block_offset = flat_block_id * items_per_block; + const unsigned int valid_in_last_block = input_size - block_offset; + const bool is_incomplete_block = flat_block_id == (input_size / items_per_block); + + key_type keys[ItemsPerThread]; + value_type values[ItemsPerThread]; + + using block_load_keys_impl = block_load_keys_impl; + using block_sort_impl = block_sort_impl; + using block_load_values_impl = block_load_values_impl; + using block_store_impl = block_store_impl; + + ROCPRIM_SHARED_MEMORY union { + typename block_load_keys_impl::storage_type load_keys; + typename block_sort_impl::storage_type sort; + typename block_load_values_impl::storage_type load_values; + typename block_store_impl::storage_type store; + } storage; + + block_load_keys_impl().load( + block_offset, + valid_in_last_block, + is_incomplete_block, + keys_input, + keys, + storage.load_keys + ); + + using stable_key_type = typename block_sort_impl::stable_key_type; + + // Special comparison that preserves relative order of equal keys + auto stable_compare_function = [compare_function](const stable_key_type& a, const stable_key_type& b) mutable -> bool + { + const bool ab = compare_function(rocprim::get<0>(a), rocprim::get<0>(b)); + const bool ba = compare_function(rocprim::get<0>(b), rocprim::get<0>(a)); + return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b))); + }; + + stable_key_type stable_keys[ItemsPerThread]; + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + stable_keys[item] = rocprim::make_tuple(keys[item], ItemsPerThread * flat_id + item); + } + + // Synchronize before reusing shared memory + ::rocprim::syncthreads(); + + block_sort_impl().sort( + stable_keys, + storage.sort, + valid_in_last_block, + is_incomplete_block, + stable_compare_function + ); + + unsigned int ranks[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + keys[item] = rocprim::get<0>(stable_keys[item]); + ranks[item] = rocprim::get<1>(stable_keys[item]); + } + + // Load the values with the already sorted indices + block_load_values_impl().load( + flat_id, + ranks, + block_offset, + valid_in_last_block, + is_incomplete_block, + values_input, + values, + storage.load_values + ); + + block_store_impl().store( + block_offset, + valid_in_last_block, + is_incomplete_block, + keys_output, + values_output, + keys, + values, + storage.store + ); +} + +template< + unsigned int BlockSize, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetT, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void block_merge_kernel_impl(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const unsigned int block_size, + BinaryFunction compare_function) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + unsigned int id = (flat_block_id * BlockSize) + flat_id; + + if (id >= input_size) + { + return; + } + + key_type key; + value_type value; + + key = keys_input[id]; + if(with_values) + { + value = values_input[id]; + } + + const unsigned int block_id = id / block_size; + const bool block_id_is_odd = block_id & 1; + const unsigned int next_block_id = block_id_is_odd ? block_id - 1 : + block_id + 1; + const unsigned int block_start = min(block_id * block_size, (unsigned int) input_size); + const unsigned int next_block_start = min(next_block_id * block_size, (unsigned int) input_size); + const unsigned int next_block_end = min((next_block_id + 1) * block_size, (unsigned int) input_size); + + if(next_block_start == input_size) + { + keys_output[id] = key; + if(with_values) + { + values_output[id] = value; + } + return; + } + + unsigned int left_id = next_block_start; + unsigned int right_id = next_block_end; + + while(left_id < right_id) + { + unsigned int mid_id = (left_id + right_id) / 2; + key_type mid_key = keys_input[mid_id]; + bool smaller = compare_function(mid_key, key); + left_id = smaller ? mid_id + 1 : left_id; + right_id = smaller ? right_id : mid_id; + } + + right_id = next_block_end; + if(block_id_is_odd && left_id != right_id) + { + key_type upper_key = keys_input[left_id]; + while(!compare_function(upper_key, key) && + !compare_function(key, upper_key) && + left_id < right_id) + { + unsigned int mid_id = (left_id + right_id) / 2; + key_type mid_key = keys_input[mid_id]; + bool equal = !compare_function(mid_key, key) && + !compare_function(key, mid_key); + left_id = equal ? mid_id + 1 : left_id + 1; + right_id = equal ? right_id : mid_id; + upper_key = keys_input[left_id]; + } + } + + unsigned int offset = 0; + offset += id - block_start; + offset += left_id - next_block_start; + offset += min(block_start, next_block_start); + keys_output[offset] = key; + if(with_values) + { + values_output[offset] = value; + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_merge_sort_mergepath.hpp b/3rdparty/cub/rocprim/device/detail/device_merge_sort_mergepath.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdd8278e89af466403f74063b2524482d60e6313 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_merge_sort_mergepath.hpp @@ -0,0 +1,439 @@ +/****************************************************************************** +* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. +* Modifications Copyright (c) 2022, Advanced Micro Devices, Inc. All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in the +* documentation and/or other materials provided with the distribution. +* * Neither the name of the NVIDIA CORPORATION nor the +* names of its contributors may be used to endorse or promote products +* derived from this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +******************************************************************************/ + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_ + +#include + +#include "../../detail/various.hpp" + +#include "device_merge_sort.hpp" +#include "device_merge.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + // Load items from input1 and input2 from global memory + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void gmem_to_reg(KeyT (&output)[ItemsPerThread], + InputIterator input1, + InputIterator input2, + unsigned int count1, + unsigned int count2, + bool IsLastTile) + { + if(IsLastTile) + { + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; + if (idx < count1 + count2) + { + output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + } + } + + } + else + { + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; + output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reg_to_shared(OutputIterator output, + KeyT (&input)[ItemsPerThread]) + { + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + unsigned int idx = BlockSize * item + threadIdx.x; + output[idx] = input[item]; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE auto + block_merge_process_tile(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const OffsetT sorted_block_size, + BinaryFunction compare_function, + const OffsetT* merge_partitions) + -> std::enable_if_t< + (!std::is_trivially_copyable< + typename std::iterator_traits::value_type>::value + || rocprim::is_floating_point< + typename std::iterator_traits::value_type>::value + || std::is_integral< + typename std::iterator_traits::value_type>::value), + void> + { + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; + + using block_store = block_store_impl; + + using keys_storage_ = key_type[items_per_tile + 1]; + using values_storage_ = value_type[items_per_tile + 1]; + + ROCPRIM_SHARED_MEMORY union { + typename block_store::storage_type store; + detail::raw_storage keys; + detail::raw_storage values; + } storage; + + auto& keys_shared = storage.keys.get(); + auto& values_shared = storage.values.get(); + + const unsigned short flat_id = block_thread_id<0>(); + const unsigned int flat_block_id = block_id<0>(); + const bool IsIncompleteTile = flat_block_id == (input_size/items_per_tile); + + const OffsetT partition_beg = merge_partitions[flat_block_id]; + const OffsetT partition_end = merge_partitions[flat_block_id + 1]; + + const unsigned int merged_tiles_number = sorted_block_size / items_per_tile; + const unsigned int target_merged_tiles_number = merged_tiles_number * 2; + const unsigned int mask = target_merged_tiles_number - 1; + const unsigned int tilegroup_start_id = ~mask & flat_block_id; + const OffsetT tilegroup_start = items_per_tile * tilegroup_start_id; // Tile-group starts here + + const OffsetT diag = items_per_tile * flat_block_id - tilegroup_start; + + const OffsetT keys1_beg = partition_beg; + OffsetT keys1_end = partition_end; + const OffsetT keys2_beg = rocprim::min(input_size, 2 * tilegroup_start + sorted_block_size + diag - partition_beg); + OffsetT keys2_end = rocprim::min(input_size, 2 * tilegroup_start + sorted_block_size + diag + items_per_tile - partition_end); + + if (mask == (mask & flat_block_id)) // If last tile in the tile-group + { + keys1_end = rocprim::min(input_size, tilegroup_start + sorted_block_size); + keys2_end = rocprim::min(input_size, tilegroup_start + sorted_block_size * 2); + } + + // Number of keys per tile + const unsigned int num_keys1 = static_cast(keys1_end - keys1_beg); + const unsigned int num_keys2 = static_cast(keys2_end - keys2_beg); + // Load keys1 & keys2 + key_type keys[ItemsPerThread]; + gmem_to_reg(keys, + keys_input + keys1_beg, + keys_input + keys2_beg, + num_keys1, + num_keys2, + IsIncompleteTile); + // Load keys into shared memory + reg_to_shared(keys_shared, keys); + + value_type values[ItemsPerThread]; + if ROCPRIM_IF_CONSTEXPR(with_values){ + gmem_to_reg(values, + values_input + keys1_beg, + values_input + keys2_beg, + num_keys1, + num_keys2, + IsIncompleteTile); + } + rocprim::syncthreads(); + + const unsigned int diag0_local = rocprim::min(num_keys1 + num_keys2, ItemsPerThread * flat_id); + + const unsigned int keys1_beg_local = merge_path(keys_shared, + &keys_shared[num_keys1], + num_keys1, + num_keys2, + diag0_local, + compare_function); + const unsigned int keys1_end_local = num_keys1; + const unsigned int keys2_beg_local = diag0_local - keys1_beg_local; + const unsigned int keys2_end_local = num_keys2; + range_t range_local = {keys1_beg_local, + keys1_end_local, + keys2_beg_local + keys1_end_local, + keys2_end_local + keys1_end_local}; + + unsigned int indices[ItemsPerThread]; + + serial_merge(keys_shared, + keys, + indices, + range_local, + compare_function); + + if ROCPRIM_IF_CONSTEXPR(with_values){ + reg_to_shared(values_shared, values); + + rocprim::syncthreads(); + + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + values[item] = values_shared[indices[item]]; + } + + rocprim::syncthreads(); + } + + const OffsetT offset = flat_block_id * items_per_tile; + block_store().store(offset, + input_size - offset, + IsIncompleteTile, + keys_output, + values_output, + keys, + values, + storage.store); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE auto + block_merge_process_tile(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const OffsetT sorted_block_size, + BinaryFunction compare_function, + const OffsetT* merge_partitions) + -> std::enable_if_t< + (std::is_trivially_copyable< + typename std::iterator_traits::value_type>::value + && !rocprim::is_floating_point< + typename std::iterator_traits::value_type>::value + && !std::is_integral< + typename std::iterator_traits::value_type>::value), + void> + { + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; + + using block_store = block_store_impl; + + using keys_storage_ = key_type[items_per_tile + 1]; + using values_storage_ = value_type[items_per_tile + 1]; + + ROCPRIM_SHARED_MEMORY union { + typename block_store::storage_type store; + detail::raw_storage keys; + detail::raw_storage values; + } storage; + + auto& keys_shared = storage.keys.get(); + auto& values_shared = storage.values.get(); + + const unsigned short flat_id = block_thread_id<0>(); + const unsigned int flat_block_id = block_id<0>(); + const bool IsIncompleteTile = flat_block_id == (input_size / items_per_tile); + + const OffsetT partition_beg = merge_partitions[flat_block_id]; + const OffsetT partition_end = merge_partitions[flat_block_id + 1]; + + const unsigned int merged_tiles_number = sorted_block_size / items_per_tile; + const unsigned int target_merged_tiles_number = merged_tiles_number * 2; + const unsigned int mask = target_merged_tiles_number - 1; + const unsigned int tilegroup_start_id = ~mask & flat_block_id; + const OffsetT tilegroup_start = items_per_tile * tilegroup_start_id; // Tile-group starts here + + const OffsetT diag = items_per_tile * flat_block_id - tilegroup_start; + + const OffsetT keys1_beg = partition_beg; + OffsetT keys1_end = partition_end; + const OffsetT keys2_beg = rocprim::min(input_size, 2 * tilegroup_start + sorted_block_size + diag - partition_beg); + OffsetT keys2_end = rocprim::min(input_size, 2 * tilegroup_start + sorted_block_size + diag + items_per_tile - partition_end); + + if (mask == (mask & flat_block_id)) // If last tile in the tile-group + { + keys1_end = rocprim::min(input_size, tilegroup_start + sorted_block_size); + keys2_end = rocprim::min(input_size, tilegroup_start + sorted_block_size * 2); + } + + // Number of keys per tile + const unsigned int num_keys1 = static_cast(keys1_end - keys1_beg); + const unsigned int num_keys2 = static_cast(keys2_end - keys2_beg); + // Load keys1 & keys2 + key_type keys[ItemsPerThread]; + gmem_to_reg(keys, + keys_input + keys1_beg, + keys_input + keys2_beg, + num_keys1, + num_keys2, + IsIncompleteTile); + // Load keys into shared memory + reg_to_shared(keys_shared, keys); + + rocprim::syncthreads(); + + const unsigned int diag0_local = rocprim::min(num_keys1 + num_keys2, ItemsPerThread * flat_id); + + const unsigned int keys1_beg_local = merge_path(keys_shared, + &keys_shared[num_keys1], + num_keys1, + num_keys2, + diag0_local, + compare_function); + const unsigned int keys1_end_local = num_keys1; + const unsigned int keys2_beg_local = diag0_local - keys1_beg_local; + const unsigned int keys2_end_local = num_keys2; + range_t range_local = {keys1_beg_local, + keys1_end_local, + keys2_beg_local + keys1_end_local, + keys2_end_local + keys1_end_local}; + + unsigned int indices[ItemsPerThread]; + + serial_merge(keys_shared, + keys, + indices, + range_local, + compare_function); + + if ROCPRIM_IF_CONSTEXPR(with_values) + { + const ValuesInputIterator input1 = values_input + keys1_beg; + const ValuesInputIterator input2 = values_input + keys2_beg; + if(IsIncompleteTile) + { + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + unsigned int idx = BlockSize * item + threadIdx.x; + if(idx < num_keys1) + { + values_shared[idx] = input1[idx]; + } + else if(idx - num_keys1 < num_keys2) + { + values_shared[idx] = input2[idx - num_keys1]; + } + } + } + else + { + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + unsigned int idx = BlockSize * item + threadIdx.x; + if(idx < num_keys1) + { + values_shared[idx] = input1[idx]; + } + else + { + values_shared[idx] = input2[idx - num_keys1]; + } + } + } + + rocprim::syncthreads(); + + const OffsetT offset = (flat_block_id * items_per_tile) + (threadIdx.x * ItemsPerThread); + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; ++item) + { + values_output[offset + item] = values_shared[indices[item]]; + } + + rocprim::syncthreads(); + } + + const OffsetT offset = flat_block_id * items_per_tile; + value_type values[ItemsPerThread]; + block_store().store(offset, + input_size - offset, + IsIncompleteTile, + keys_output, + values_output, + keys, + values, + storage.store); + } + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void + block_merge_kernel_impl(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const OffsetT sorted_block_size, + BinaryFunction compare_function, + const OffsetT* merge_partitions) + { + block_merge_process_tile(keys_input, + keys_output, + values_input, + values_output, + input_size, + sorted_block_size, + compare_function, + merge_partitions); + } + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_MERGEPATH_HPP_ \ No newline at end of file diff --git a/3rdparty/cub/rocprim/device/detail/device_partition.hpp b/3rdparty/cub/rocprim/device/detail/device_partition.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20f7c54aa0ae53c3bb7c14ab056b69916fbd93a2 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_partition.hpp @@ -0,0 +1,897 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ + +#include +#include + +#include "../../detail/various.hpp" +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" +#include "../../block/block_discontinuity.hpp" + +#include "device_scan_lookback.hpp" +#include "lookback_scan_state.hpp" +#include "ordered_block_id.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +class offset_lookback_scan_prefix_op : public lookback_scan_prefix_op, LookbackScanState> +{ + using base_type = lookback_scan_prefix_op, LookbackScanState>; + using binary_op_type = ::rocprim::plus; +public: + + struct storage_type + { + T block_reduction; + T exclusive_prefix; + }; + + ROCPRIM_DEVICE ROCPRIM_INLINE + offset_lookback_scan_prefix_op(unsigned int block_id, + LookbackScanState &state, + storage_type& storage) + : base_type(block_id, binary_op_type(), state), storage_(storage) + { + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + ~offset_lookback_scan_prefix_op() = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + T operator()(T reduction) + { + auto prefix = base_type::operator()(reduction); + if(::rocprim::lane_id() == 0) + { + storage_.block_reduction = reduction; + storage_.exclusive_prefix = prefix; + } + return prefix; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_reduction() const + { + return storage_.block_reduction; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_exclusive_prefix() const + { + return storage_.exclusive_prefix; + } + +private: + storage_type& storage_; +}; + +enum class select_method +{ + flag = 0, + predicate = 1, + unique = 2 +}; + +template< + select_method SelectMethod, + unsigned int BlockSize, + class BlockLoadFlagsType, + class BlockDiscontinuityType, + class InputIterator, + class FlagIterator, + class ValueType, + unsigned int ItemsPerThread, + class UnaryPredicate, + class InequalityOp, + class StorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_block_load_flags(InputIterator /* block_predecessor */, + FlagIterator block_flags, + ValueType (&/* values */)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + UnaryPredicate /* predicate */, + InequalityOp /* inequality_op */, + StorageType& storage, + const unsigned int /* block_id */, + const unsigned int /* block_thread_id */, + const bool is_last_block, + const unsigned int valid_in_last_block) + -> typename std::enable_if::type +{ + if(is_last_block) // last block + { + BlockLoadFlagsType() + .load( + block_flags, + is_selected, + valid_in_last_block, + false, + storage.load_flags + ); + } + else + { + BlockLoadFlagsType() + .load( + block_flags, + is_selected, + storage.load_flags + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory +} + +template< + select_method SelectMethod, + unsigned int BlockSize, + class BlockLoadFlagsType, + class BlockDiscontinuityType, + class InputIterator, + class FlagIterator, + class ValueType, + unsigned int ItemsPerThread, + class UnaryPredicate, + class InequalityOp, + class StorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_block_load_flags(InputIterator /* block_predecessor */, + FlagIterator /* block_flags */, + ValueType (&values)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + UnaryPredicate predicate, + InequalityOp /* inequality_op */, + StorageType& /* storage */, + const unsigned int /* block_id */, + const unsigned int block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block) + -> typename std::enable_if::type +{ + if(is_last_block) // last block + { + const auto offset = block_thread_id * ItemsPerThread; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if((offset + i) < valid_in_last_block) + { + is_selected[i] = predicate(values[i]); + } + else + { + is_selected[i] = false; + } + } + } + else + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + is_selected[i] = predicate(values[i]); + } + } +} + +// This wrapper processes only part of items and flags (valid_count - 1)th item (for tails) +// and (valid_count)th item (for heads), all items after valid_count are unflagged. +template +struct guarded_inequality_op +{ + InequalityOp inequality_op; + unsigned int valid_count; + + ROCPRIM_DEVICE ROCPRIM_INLINE + guarded_inequality_op(InequalityOp inequality_op, unsigned int valid_count) + : inequality_op(inequality_op), valid_count(valid_count) + {} + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const U& b, unsigned int b_index) + { + return (b_index < valid_count && inequality_op(a, b)); + } +}; + +template< + select_method SelectMethod, + unsigned int BlockSize, + class BlockLoadFlagsType, + class BlockDiscontinuityType, + class InputIterator, + class FlagIterator, + class ValueType, + unsigned int ItemsPerThread, + class UnaryPredicate, + class InequalityOp, + class StorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_block_load_flags(InputIterator block_predecessor, + FlagIterator /* block_flags */, + ValueType (&values)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + UnaryPredicate /* predicate */, + InequalityOp inequality_op, + StorageType& storage, + const unsigned int block_id, + const unsigned int block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block) + -> typename std::enable_if::type +{ + if(block_id > 0) + { + const ValueType predecessor = *block_predecessor; + if(is_last_block) + { + BlockDiscontinuityType() + .flag_heads( + is_selected, + predecessor, + values, + guarded_inequality_op( + inequality_op, + valid_in_last_block + ), + storage.discontinuity_values + ); + } + else + { + BlockDiscontinuityType() + .flag_heads( + is_selected, + predecessor, + values, + inequality_op, + storage.discontinuity_values + ); + } + } + else + { + if(is_last_block) + { + BlockDiscontinuityType() + .flag_heads( + is_selected, + values, + guarded_inequality_op( + inequality_op, + valid_in_last_block + ), + storage.discontinuity_values + ); + } + else + { + BlockDiscontinuityType() + .flag_heads( + is_selected, + values, + inequality_op, + storage.discontinuity_values + ); + } + } + + + // Set is_selected for invalid items to false + if(is_last_block) + { + const auto offset = block_thread_id * ItemsPerThread; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if((offset + i) >= valid_in_last_block) + { + is_selected[i] = false; + } + } + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory +} + +template< + select_method SelectMethod, + unsigned int BlockSize, + class BlockLoadFlagsType, + class BlockDiscontinuityType, + class InputIterator, + class FlagIterator, + class ValueType, + unsigned int ItemsPerThread, + class FirstUnaryPredicate, + class SecondUnaryPredicate, + class InequalityOp, + class StorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void partition_block_load_flags(InputIterator /*block_predecessor*/, + FlagIterator /* block_flags */, + ValueType (&values)[ItemsPerThread], + bool (&is_selected)[2][ItemsPerThread], + FirstUnaryPredicate select_first_part_op, + SecondUnaryPredicate select_second_part_op, + InequalityOp /*inequality_op*/, + StorageType& /*storage*/, + const unsigned int /*block_id*/, + const unsigned int block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block) +{ + if(is_last_block) + { + const auto offset = block_thread_id * ItemsPerThread; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if((offset + i) < valid_in_last_block) + { + is_selected[0][i] = select_first_part_op(values[i]); + is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]); + } + else + { + is_selected[0][i] = false; + is_selected[1][i] = false; + } + } + } + else + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + is_selected[0][i] = select_first_part_op(values[i]); + is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]); + } + } +} + +template< + bool OnlySelected, + unsigned int BlockSize, + class ValueType, + unsigned int ItemsPerThread, + class OffsetType, + class OutputIterator, + class ScatterStorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_scatter(ValueType (&values)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputIterator output, + const size_t size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block, + size_t* /* prev_selected_count */) + -> typename std::enable_if::type +{ + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + // Scatter selected/rejected values to shared memory + auto scatter_storage = storage.get(); + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i; + unsigned int selected_item_index = output_indices[i] - selected_prefix; + unsigned int rejected_item_index = (item_index - selected_item_index) + selected_in_block; + // index of item in scatter_storage + unsigned int scatter_index = is_selected[i] ? selected_item_index : rejected_item_index; + scatter_storage[scatter_index] = values[i]; + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + unsigned int item_index = (i * BlockSize) + flat_block_thread_id; + unsigned int selected_item_index = item_index; + unsigned int rejected_item_index = item_index - selected_in_block; + // number of values rejected in previous blocks + unsigned int rejected_prefix = (flat_block_id * items_per_block) - selected_prefix; + // destination index of item scatter_storage[item_index] in output + OffsetType scatter_index = item_index < selected_in_block + ? selected_prefix + selected_item_index + : size - (rejected_prefix + rejected_item_index + 1); + + // last block can store only valid_in_last_block items + if(!is_last_block || item_index < valid_in_last_block) + { + output[scatter_index] = scatter_storage[item_index]; + } + } +} + +template< + bool OnlySelected, + unsigned int BlockSize, + class ValueType, + unsigned int ItemsPerThread, + class OffsetType, + class OutputIterator, + class ScatterStorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_scatter(ValueType (&values)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputIterator output, + const size_t size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block, + size_t* prev_selected_count) + -> typename std::enable_if::type +{ + (void) size; + (void) storage; + (void) flat_block_id; + (void) flat_block_thread_id; + (void) valid_in_last_block; + + if(selected_in_block > BlockSize) + { + // Scatter selected values to shared memory + auto scatter_storage = storage.get(); + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + unsigned int scatter_index = output_indices[i] - selected_prefix; + if(is_selected[i]) + { + scatter_storage[scatter_index] = values[i]; + } + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // Coalesced write from shared memory to global memory + for(unsigned int i = flat_block_thread_id; i < selected_in_block; i += BlockSize) + { + output[prev_selected_count[0] + selected_prefix + i] = scatter_storage[i]; + } + } + else + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(!is_last_block || output_indices[i] < (selected_prefix + selected_in_block)) + { + if(is_selected[i]) + { + output[prev_selected_count[0] + output_indices[i]] = values[i]; + } + } + } + } +} + +template< + bool OnlySelected, + unsigned int BlockSize, + class ValueType, + unsigned int ItemsPerThread, + class OffsetType, + class OutputType, + class ScatterStorageType +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void partition_scatter(ValueType (&values)[ItemsPerThread], + bool (&is_selected)[2][ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputType output, + const size_t /*size*/, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_last_block, + const unsigned int valid_in_last_block, + size_t* /* prev_selected_count */) +{ + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + auto scatter_storage = storage.get(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int first_selected_item_index = output_indices[i].x - selected_prefix.x; + const unsigned int second_selected_item_index = output_indices[i].y - selected_prefix.y + + selected_in_block.x; + unsigned int scatter_index{}; + + if(is_selected[0][i]) + { + scatter_index = first_selected_item_index; + } + else if(is_selected[1][i]) + { + scatter_index = second_selected_item_index; + } + else + { + const unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i; + const unsigned int unselected_item_index = (item_index - first_selected_item_index - second_selected_item_index) + + 2*selected_in_block.x + selected_in_block.y; + scatter_index = unselected_item_index; + } + scatter_storage[scatter_index] = values[i]; + } + ::rocprim::syncthreads(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int item_index = (i * BlockSize) + flat_block_thread_id; + if (!is_last_block || item_index < valid_in_last_block) + { + if(item_index < selected_in_block.x) + { + get<0>(output)[item_index + selected_prefix.x] = scatter_storage[item_index]; + } + else if(item_index < selected_in_block.x + selected_in_block.y) + { + get<1>(output)[item_index - selected_in_block.x + selected_prefix.y] + = scatter_storage[item_index]; + } + else + { + const unsigned int all_items_in_previous_blocks = items_per_block * flat_block_id; + const unsigned int unselected_items_in_previous_blocks = all_items_in_previous_blocks + - selected_prefix.x - selected_prefix.y; + const unsigned int output_index = item_index + unselected_items_in_previous_blocks + - selected_in_block.x - selected_in_block.y; + get<2>(output)[output_index] = scatter_storage[item_index]; + } + } + } +} + +template< + unsigned int items_per_thread, + class offset_type +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void convert_selected_to_indices(offset_type (&output_indices)[items_per_thread], + bool (&is_selected)[items_per_thread]) +{ + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; i++) + { + output_indices[i] = is_selected[i] ? 1 : 0; + } +} + +template< + unsigned int items_per_thread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void convert_selected_to_indices(uint2 (&output_indices)[items_per_thread], + bool (&is_selected)[2][items_per_thread]) +{ + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; i++) + { + output_indices[i].x = is_selected[0][i] ? 1 : 0; + output_indices[i].y = is_selected[1][i] ? 1 : 0; + } +} + +template< + class OffsetT +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void store_selected_count(size_t* selected_count, + size_t* prev_selected_count, + const OffsetT selected_prefix, + const OffsetT selected_in_block) +{ + selected_count[0] = prev_selected_count[0] + selected_prefix + selected_in_block; +} + +template< +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void store_selected_count(size_t* selected_count, + size_t* prev_selected_count, + const uint2 selected_prefix, + const uint2 selected_in_block) +{ + selected_count[0] = prev_selected_count[0] + selected_prefix.x + selected_in_block.x; + selected_count[1] = prev_selected_count[1] + selected_prefix.y + selected_in_block.y; +} + +template< + select_method SelectMethod, + bool OnlySelected, + class Config, + class KeyIterator, + class ValueIterator, // Can be rocprim::empty_type* if key only + class FlagIterator, + class OutputKeyIterator, + class OutputValueIterator, + class InequalityOp, + class OffsetLookbackScanState, + class... UnaryPredicates +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void partition_kernel_impl(KeyIterator keys_input, + ValueIterator values_input, + FlagIterator flags, + OutputKeyIterator keys_output, + OutputValueIterator values_output, + size_t* selected_count, + size_t* prev_selected_count, + const size_t size, + InequalityOp inequality_op, + OffsetLookbackScanState offset_scan_state, + const unsigned int number_of_blocks, + ordered_block_id ordered_bid, + UnaryPredicates... predicates) +{ + constexpr auto block_size = Config::block_size; + constexpr auto items_per_thread = Config::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + using offset_type = typename OffsetLookbackScanState::value_type; + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + // Block primitives + using block_load_key_type = ::rocprim::block_load< + key_type, block_size, items_per_thread, + Config::key_block_load_method + >; + using block_load_value_type = ::rocprim::block_load< + value_type, block_size, items_per_thread, + Config::value_block_load_method + >; + using block_load_flag_type = ::rocprim::block_load< + bool, block_size, items_per_thread, + Config::flag_block_load_method + >; + using block_scan_offset_type = ::rocprim::block_scan< + offset_type, block_size, + Config::block_scan_method + >; + using block_discontinuity_key_type = ::rocprim::block_discontinuity< + key_type, block_size + >; + using order_bid_type = ordered_block_id; + + // Offset prefix operation type + using offset_scan_prefix_op_type = offset_lookback_scan_prefix_op< + offset_type, OffsetLookbackScanState + >; + + // Memory required for 2-phase scatter + using exchange_keys_storage_type = key_type[items_per_block]; + using raw_exchange_keys_storage_type = typename detail::raw_storage; + using exchange_values_storage_type = value_type[items_per_block]; + using raw_exchange_values_storage_type = typename detail::raw_storage; + + using is_selected_type = std::conditional_t< + sizeof...(UnaryPredicates) == 1, + bool[items_per_thread], + bool[sizeof...(UnaryPredicates)][items_per_thread]>; + + ROCPRIM_SHARED_MEMORY struct + { + typename order_bid_type::storage_type ordered_bid; + union + { + raw_exchange_keys_storage_type exchange_keys; + raw_exchange_values_storage_type exchange_values; + typename block_load_key_type::storage_type load_keys; + typename block_load_value_type::storage_type load_values; + typename block_load_flag_type::storage_type load_flags; + typename block_discontinuity_key_type::storage_type discontinuity_values; + typename block_scan_offset_type::storage_type scan_offsets; + }; + } storage; + + const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_id = ordered_bid.get(flat_block_thread_id, storage.ordered_bid); + const unsigned int block_offset = flat_block_id * items_per_block; + const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); + + key_type keys[items_per_thread]; + is_selected_type is_selected; + offset_type output_indices[items_per_thread]; + + // Load input values into values + const bool is_last_block = flat_block_id == (number_of_blocks - 1); + if(is_last_block) // last block + { + block_load_key_type() + .load( + keys_input + block_offset, + keys, + valid_in_last_block, + storage.load_keys + ); + } + else + { + block_load_key_type() + .load( + keys_input + block_offset, + keys, + storage.load_keys + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // Load selection flags into is_selected, generate them using + // input value and selection predicate, or generate them using + // block_discontinuity primitive + partition_block_load_flags< + SelectMethod, block_size, + block_load_flag_type, block_discontinuity_key_type + >( + keys_input + block_offset - 1, + flags + block_offset, + keys, + is_selected, + predicates ..., + inequality_op, + storage, + flat_block_id, + flat_block_thread_id, + is_last_block, + valid_in_last_block + ); + + // Convert true/false is_selected flags to 0s and 1s + convert_selected_to_indices(output_indices, is_selected); + + // Number of selected values in previous blocks + offset_type selected_prefix{}; + // Number of selected values in this block + offset_type selected_in_block{}; + + // Calculate number of selected values in block and their indices + if(flat_block_id == 0) + { + block_scan_offset_type() + .exclusive_scan( + output_indices, + output_indices, + offset_type{}, /** initial value */ + selected_in_block, + storage.scan_offsets, + ::rocprim::plus() + ); + if(flat_block_thread_id == 0) + { + offset_scan_state.set_complete(flat_block_id, selected_in_block); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + } + else + { + ROCPRIM_SHARED_MEMORY typename offset_scan_prefix_op_type::storage_type storage_prefix_op; + auto prefix_op = offset_scan_prefix_op_type( + flat_block_id, + offset_scan_state, + storage_prefix_op + ); + block_scan_offset_type() + .exclusive_scan( + output_indices, + output_indices, + storage.scan_offsets, + prefix_op, + ::rocprim::plus() + ); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + selected_in_block = prefix_op.get_reduction(); + selected_prefix = prefix_op.get_exclusive_prefix(); + } + + // Scatter selected and rejected values + partition_scatter( + keys, is_selected, output_indices, keys_output, size, + selected_prefix, selected_in_block, storage.exchange_keys, + flat_block_id, flat_block_thread_id, + is_last_block, valid_in_last_block, + prev_selected_count + ); + + static constexpr bool with_values = !std::is_same::value; + + if ROCPRIM_IF_CONSTEXPR (with_values) { + value_type values[items_per_thread]; + + ::rocprim::syncthreads(); // sync threads to reuse shared memory + if(is_last_block) + { + block_load_value_type() + .load( + values_input + block_offset, + values, + valid_in_last_block, + storage.load_values + ); + } + else + { + block_load_value_type() + .load( + values_input + block_offset, + values, + storage.load_values + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + partition_scatter( + values, is_selected, output_indices, values_output, size, + selected_prefix, selected_in_block, storage.exchange_values, + flat_block_id, flat_block_thread_id, + is_last_block, valid_in_last_block, + prev_selected_count + ); + } + + // Last block in grid stores number of selected values + if(is_last_block && flat_block_thread_id == 0) + { + store_selected_count(selected_count, prev_selected_count, selected_prefix, selected_in_block); + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp b/3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03696cf624542a9895ab2c64e3f293b6d9fb2f6d --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_radix_sort.hpp @@ -0,0 +1,1070 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" +#include "../../detail/radix_sort.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_discontinuity.hpp" +#include "../../block/block_exchange.hpp" +#include "../../block/block_load.hpp" +#include "../../block/block_load_func.hpp" +#include "../../block/block_scan.hpp" +#include "../../block/block_radix_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Wrapping functions that allow to call proper methods (with or without values) +// (a variant with values is enabled only when Value is not empty_type) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +void sort_block(SortType sorter, + SortKey (&keys)[ItemsPerThread], + SortValue (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) +{ + if(Descending) + { + sorter.sort_desc(keys, values, storage, begin_bit, end_bit); + } + else + { + sorter.sort(keys, values, storage, begin_bit, end_bit); + } +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +void sort_block(SortType sorter, + SortKey (&keys)[ItemsPerThread], + ::rocprim::empty_type (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) +{ + (void) values; + if(Descending) + { + sorter.sort_desc(keys, storage, begin_bit, end_bit); + } + else + { + sorter.sort(keys, storage, begin_bit, end_bit); + } +} + +template< + unsigned int WarpSize, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending +> +struct radix_digit_count_helper +{ + static constexpr unsigned int radix_size = 1 << RadixBits; + + static constexpr unsigned int warp_size = WarpSize; + static constexpr unsigned int warps_no = BlockSize / warp_size; + static_assert(BlockSize % ::rocprim::device_warp_size() == 0, "BlockSize must be divisible by warp size"); + static_assert(radix_size <= BlockSize, "Radix size must not exceed BlockSize"); + + struct storage_type + { + unsigned int digit_counts[warps_no][radix_size]; + }; + + template< + bool IsFull = false, + class KeysInputIterator, + class Offset + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void count_digits(KeysInputIterator keys_input, + Offset begin_offset, + Offset end_offset, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage, + unsigned int& digit_count) // i-th thread will get i-th digit's value + { + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + using key_type = typename std::iterator_traits::value_type; + + using key_codec = radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int warp_id = ::rocprim::warp_id<0, 1, 1>(); + + if(flat_id < radix_size) + { + for(unsigned int w = 0; w < warps_no; w++) + { + storage.digit_counts[w][flat_id] = 0; + } + } + ::rocprim::syncthreads(); + + for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) + { + key_type keys[ItemsPerThread]; + unsigned int valid_count; + // Use loading into a striped arrangement because an order of items is irrelevant, + // only totals matter + if(IsFull || (block_offset + items_per_block <= end_offset)) + { + valid_count = items_per_block; + block_load_direct_striped(flat_id, keys_input + block_offset, keys); + } + else + { + valid_count = end_offset - block_offset; + block_load_direct_striped(flat_id, keys_input + block_offset, keys, valid_count); + } + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const bit_key_type bit_key = key_codec::encode(keys[i]); + const unsigned int digit = key_codec::extract_digit(bit_key, bit, current_radix_bits); + const unsigned int pos = i * BlockSize + flat_id; + lane_mask_type same_digit_lanes_mask = ::rocprim::ballot(IsFull || (pos < valid_count)); + for(unsigned int b = 0; b < RadixBits; b++) + { + const unsigned int bit_set = digit & (1u << b); + const lane_mask_type bit_set_mask = ::rocprim::ballot(bit_set); + same_digit_lanes_mask &= (bit_set ? bit_set_mask : ~bit_set_mask); + } + const unsigned int same_digit_count = ::rocprim::bit_count(same_digit_lanes_mask); + const unsigned int prev_same_digit_count = ::rocprim::masked_bit_count(same_digit_lanes_mask); + if(prev_same_digit_count == 0) + { + // Write the number of lanes having this digit, + // if the current lane is the first (and maybe only) lane with this digit. + storage.digit_counts[warp_id][digit] += same_digit_count; + } + } + } + ::rocprim::syncthreads(); + + digit_count = 0; + if(flat_id < radix_size) + { + for(unsigned int w = 0; w < warps_no; w++) + { + digit_count += storage.digit_counts[w][flat_id]; + } + } + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool Descending, + class Key, + class Value +> +struct radix_sort_single_helper +{ + static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + using key_type = Key; + using value_type = Value; + + using key_codec = radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + using keys_load_type = ::rocprim::block_load< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using values_load_type = ::rocprim::block_load< + value_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using sort_type = ::rocprim::block_radix_sort; + + static constexpr bool with_values = !std::is_same::value; + + struct storage_type + { + union + { + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename sort_type::storage_type sort; + }; + }; + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = (size + items_per_block - 1) / items_per_block; + unsigned int valid_in_last_block; + const bool last_block = flat_block_id == (number_of_blocks - 1); + + using key_type = typename std::iterator_traits::value_type; + + using key_codec = radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + + key_type keys[ItemsPerThread]; + value_type values[ItemsPerThread]; + if(!last_block) + { + valid_in_last_block = items_per_block; + keys_load_type().load(keys_input + block_offset, keys, storage.keys_load); + if(with_values) + { + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, storage.values_load); + } + } + else + { + const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + valid_in_last_block = size - items_per_block * (number_of_blocks - 1); + keys_load_type().load(keys_input + block_offset, keys, valid_in_last_block, out_of_bounds, storage.keys_load); + if(with_values) + { + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, valid_in_last_block, storage.values_load); + } + } + + ::rocprim::syncthreads(); + + sort_block(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits); + + // Store keys and values + #pragma unroll + for (unsigned int i = 0; i < ItemsPerThread; ++i) + { + unsigned int item_offset = flat_id * ItemsPerThread + i; + if (item_offset < valid_in_last_block) + { + keys_output[block_offset + item_offset] = keys[i]; + if (with_values) + values_output[block_offset + item_offset] = values[i]; + } + } + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending, + class Key, + class Value, + class Offset +> +struct radix_sort_and_scatter_helper +{ + static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + static constexpr unsigned int radix_size = 1 << RadixBits; + + using key_type = Key; + using value_type = Value; + + using key_codec = radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + using keys_load_type = ::rocprim::block_load< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using values_load_type = ::rocprim::block_load< + value_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using sort_type = ::rocprim::block_radix_sort; + using discontinuity_type = ::rocprim::block_discontinuity; + using bit_keys_exchange_type = ::rocprim::block_exchange; + using values_exchange_type = ::rocprim::block_exchange; + + static constexpr bool with_values = !std::is_same::value; + + struct storage_type + { + union + { + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename sort_type::storage_type sort; + typename discontinuity_type::storage_type discontinuity; + typename bit_keys_exchange_type::storage_type bit_keys_exchange; + typename values_exchange_type::storage_type values_exchange; + }; + + unsigned short starts[radix_size]; + unsigned short ends[radix_size]; + + Offset digit_starts[radix_size]; + }; + + template< + bool IsFull = false, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_and_scatter(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Offset begin_offset, + Offset end_offset, + unsigned int bit, + unsigned int current_radix_bits, + Offset digit_start, // i-th thread must pass i-th digit's value + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + + if(flat_id < radix_size) + { + storage.digit_starts[flat_id] = digit_start; + } + + for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) + { + key_type keys[ItemsPerThread]; + value_type values[ItemsPerThread]; + unsigned int valid_count; + if(IsFull || (block_offset + items_per_block <= end_offset)) + { + valid_count = items_per_block; + keys_load_type().load(keys_input + block_offset, keys, storage.keys_load); + if(with_values) + { + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, storage.values_load); + } + } + else + { + valid_count = end_offset - block_offset; + // Sort will leave "invalid" (out of size) items at the end of the sorted sequence + const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + keys_load_type().load(keys_input + block_offset, keys, valid_count, out_of_bounds, storage.keys_load); + if(with_values) + { + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, valid_count, storage.values_load); + } + } + + if(flat_id < radix_size) + { + storage.starts[flat_id] = valid_count; + storage.ends[flat_id] = valid_count; + } + + ::rocprim::syncthreads(); + sort_block(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits); + + bit_key_type bit_keys[ItemsPerThread]; + unsigned int digits[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + bit_keys[i] = key_codec::encode(keys[i]); + digits[i] = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits); + } + + bool head_flags[ItemsPerThread]; + bool tail_flags[ItemsPerThread]; + ::rocprim::not_equal_to flag_op; + + ::rocprim::syncthreads(); + discontinuity_type().flag_heads_and_tails(head_flags, tail_flags, digits, flag_op, storage.discontinuity); + + // Fill start and end position of subsequence for every digit + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int digit = digits[i]; + const unsigned int pos = flat_id * ItemsPerThread + i; + if(head_flags[i]) + { + storage.starts[digit] = pos; + } + if(tail_flags[i]) + { + storage.ends[digit] = pos; + } + } + + ::rocprim::syncthreads(); + // Rearrange to striped arrangement to have faster coalesced writes instead of + // scattering of blocked-arranged items + bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage.bit_keys_exchange); + if(with_values) + { + ::rocprim::syncthreads(); + values_exchange_type().blocked_to_striped(values, values, storage.values_exchange); + } + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int digit = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits); + const unsigned int pos = i * BlockSize + flat_id; + if(IsFull || (pos < valid_count)) + { + const Offset dst = pos - storage.starts[digit] + storage.digit_starts[digit]; + keys_output[dst] = key_codec::decode(bit_keys[i]); + if(with_values) + { + values_output[dst] = values[i]; + } + } + } + + ::rocprim::syncthreads(); + + // Accumulate counts of the current block + if(flat_id < radix_size) + { + const unsigned int digit = flat_id; + const unsigned int start = storage.starts[digit]; + const unsigned int end = storage.ends[digit]; + if(start < valid_count) + { + storage.digit_starts[digit] += (::rocprim::min(valid_count - 1, end) - start + 1); + } + } + } + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending, + class KeysInputIterator, + class Offset +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void fill_digit_counts(KeysInputIterator keys_input, + Offset size, + Offset * batch_digit_counts, + unsigned int bit, + unsigned int current_radix_bits, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + constexpr unsigned int radix_size = 1 << RadixBits; + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + using count_helper_type = radix_digit_count_helper<::rocprim::device_warp_size(), BlockSize, ItemsPerThread, RadixBits, Descending>; + + ROCPRIM_SHARED_MEMORY typename count_helper_type::storage_type storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int batch_id = ::rocprim::detail::block_id<0>(); + + Offset block_offset; + unsigned int blocks_per_batch; + if(batch_id < full_batches) + { + blocks_per_batch = blocks_per_full_batch; + block_offset = batch_id * blocks_per_batch; + } + else + { + blocks_per_batch = blocks_per_full_batch - 1; + block_offset = batch_id * blocks_per_batch + full_batches; + } + block_offset *= items_per_block; + + unsigned int digit_count; + if(batch_id < ::rocprim::detail::grid_size<0>() - 1) + { + count_helper_type().template count_digits( + keys_input, + block_offset, block_offset + blocks_per_batch * items_per_block, + bit, current_radix_bits, + storage, + digit_count + ); + } + else + { + count_helper_type().template count_digits( + keys_input, + block_offset, size, + bit, current_radix_bits, + storage, + digit_count + ); + } + + if(flat_id < radix_size) + { + batch_digit_counts[batch_id * radix_size + flat_id] = digit_count; + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + class Offset +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void scan_batches(Offset * batch_digit_counts, + Offset * digit_counts, + unsigned int batches) +{ + constexpr unsigned int radix_size = 1 << RadixBits; + + using scan_type = typename ::rocprim::block_scan; + + const unsigned int digit = ::rocprim::detail::block_id<0>(); + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + + Offset values[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int batch_id = flat_id * ItemsPerThread + i; + values[i] = (batch_id < batches ? batch_digit_counts[batch_id * radix_size + digit] : 0); + } + + Offset digit_count; + scan_type().exclusive_scan(values, values, 0, digit_count); + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const unsigned int batch_id = flat_id * ItemsPerThread + i; + if(batch_id < batches) + { + batch_digit_counts[batch_id * radix_size + digit] = values[i]; + } + } + + if(flat_id == 0) + { + digit_counts[digit] = digit_count; + } +} + +template< + unsigned int RadixBits, + class Offset +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void scan_digits(Offset * digit_counts) +{ + constexpr unsigned int radix_size = 1 << RadixBits; + + using scan_type = typename ::rocprim::block_scan; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + + Offset value = digit_counts[flat_id]; + scan_type().exclusive_scan(value, value, 0); + digit_counts[flat_id] = value; +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int current_radix_bits) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using sort_single_helper = radix_sort_single_helper< + BlockSize, ItemsPerThread, Descending, + key_type, value_type + >; + + ROCPRIM_SHARED_MEMORY typename sort_single_helper::storage_type storage; + + sort_single_helper().template sort_single( + keys_input, keys_output, values_input, values_output, + size, bit, current_radix_bits, + storage + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Offset +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void sort_and_scatter(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Offset size, + const Offset * batch_digit_starts, + const Offset * digit_starts, + unsigned int bit, + unsigned int current_radix_bits, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + constexpr unsigned int radix_size = 1 << RadixBits; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using sort_and_scatter_helper = radix_sort_and_scatter_helper< + BlockSize, ItemsPerThread, RadixBits, Descending, + key_type, value_type, Offset + >; + + ROCPRIM_SHARED_MEMORY typename sort_and_scatter_helper::storage_type storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int batch_id = ::rocprim::detail::block_id<0>(); + + Offset block_offset; + unsigned int blocks_per_batch; + if(batch_id < full_batches) + { + blocks_per_batch = blocks_per_full_batch; + block_offset = batch_id * blocks_per_batch; + } + else + { + blocks_per_batch = blocks_per_full_batch - 1; + block_offset = batch_id * blocks_per_batch + full_batches; + } + block_offset *= items_per_block; + + Offset digit_start = 0; + if(flat_id < radix_size) + { + digit_start = digit_starts[flat_id] + batch_digit_starts[batch_id * radix_size + flat_id]; + } + + if(batch_id < ::rocprim::detail::grid_size<0>() - 1) + { + sort_and_scatter_helper().template sort_and_scatter( + keys_input, keys_output, values_input, values_output, + block_offset, block_offset + blocks_per_batch * items_per_block, + bit, current_radix_bits, + digit_start, + storage + ); + } + else + { + sort_and_scatter_helper().template sort_and_scatter( + keys_input, keys_output, values_input, values_output, + block_offset, size, + bit, current_radix_bits, + digit_start, + storage + ); + } +} + +template< + bool WithValues, + class KeysInputIterator, + class ValuesInputIterator, + class Key, + class Value, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::type +block_load_radix_impl(const unsigned int flat_id, + const unsigned int block_offset, + const unsigned int valid_in_last_block, + const bool last_block, + KeysInputIterator keys_input, + ValuesInputIterator values_input, + Key (&keys)[ItemsPerThread], + Value (&values)[ItemsPerThread]) +{ + (void) values_input; + (void) values; + + if(last_block) + { + block_load_direct_blocked( + flat_id, + keys_input + block_offset, + keys, + valid_in_last_block + ); + } + else + { + block_load_direct_blocked( + flat_id, + keys_input + block_offset, + keys + ); + } +} + +template< + bool WithValues, + class KeysInputIterator, + class ValuesInputIterator, + class Key, + class Value, + unsigned int ItemsPerThread +> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::type +block_load_radix_impl(const unsigned int flat_id, + const unsigned int block_offset, + const unsigned int valid_in_last_block, + const bool last_block, + KeysInputIterator keys_input, + ValuesInputIterator values_input, + Key (&keys)[ItemsPerThread], + Value (&values)[ItemsPerThread]) +{ + if(last_block) + { + block_load_direct_blocked( + flat_id, + keys_input + block_offset, + keys, + valid_in_last_block + ); + + block_load_direct_blocked( + flat_id, + values_input + block_offset, + values, + valid_in_last_block + ); + } + else + { + block_load_direct_blocked( + flat_id, + keys_input + block_offset, + keys + ); + + block_load_direct_blocked( + flat_id, + values_input + block_offset, + values + ); + } +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto compare_nan_sensitive(const T& a, const T& b) + -> typename std::enable_if::value, bool>::type +{ + // Beware: the performance of this function is extremely vulnerable to refactoring. + // Always check benchmark_device_segmented_radix_sort and benchmark_device_radix_sort + // when making changes to this function. + + using bit_key_type = typename float_bit_mask::bit_type; + static constexpr auto sign_bit = float_bit_mask::sign_bit; + + auto a_bits = __builtin_bit_cast(bit_key_type, a); + auto b_bits = __builtin_bit_cast(bit_key_type, b); + + // convert -0.0 to +0.0 + a_bits = a_bits == sign_bit ? 0 : a_bits; + b_bits = b_bits == sign_bit ? 0 : b_bits; + // invert negatives, put 1 into sign bit for positives + a_bits ^= (sign_bit & a_bits) == 0 ? sign_bit : bit_key_type(-1); + b_bits ^= (sign_bit & b_bits) == 0 ? sign_bit : bit_key_type(-1); + + // sort numbers and NaNs according to their bit representation + return a_bits > b_bits; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto compare_nan_sensitive(const T& a, const T& b) + -> typename std::enable_if::value, bool>::type +{ + return a > b; +} + +template< + bool Descending, + bool UseRadixMask, + class T, + class Enable = void +> +struct radix_merge_compare; + +template +struct radix_merge_compare +{ + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const T& b) const + { + return compare_nan_sensitive(b, a); + } +}; + +template +struct radix_merge_compare +{ + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const T& b) const + { + return compare_nan_sensitive(a, b); + } +}; + +template +struct radix_merge_compare::value>::type> +{ + T radix_mask; + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE + radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits) + { + T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1; + T radix_mask_bottom = (T(1) << start_bit) - 1; + radix_mask = radix_mask_upper ^ radix_mask_bottom; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const T& b) const + { + const T masked_key_a = a & radix_mask; + const T masked_key_b = b & radix_mask; + return masked_key_b > masked_key_a; + } +}; + +template +struct radix_merge_compare::value>::type> +{ + T radix_mask; + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE + radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits) + { + T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1; + T radix_mask_bottom = (T(1) << start_bit) - 1; + radix_mask = (radix_mask_upper ^ radix_mask_bottom); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T& a, const T& b) const + { + const T masked_key_a = a & radix_mask; + const T masked_key_b = b & radix_mask; + return masked_key_a > masked_key_b; + } +}; + +template +struct radix_merge_compare::value>::type> +{ + // radix_merge_compare supports masks only for integrals. + // even though masks are never used for floating point-types, + // it needs to be able to compile. + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE + radix_merge_compare(const unsigned int, const unsigned int){} + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const T&, const T&) const { return false; } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void radix_block_merge_impl(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const size_t input_size, + const unsigned int merge_items_per_block_size, + BinaryFunction compare_function) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = (input_size + items_per_block - 1) / items_per_block; + const bool last_block = flat_block_id == (number_of_blocks - 1); + auto valid_in_last_block = last_block ? input_size - items_per_block * (number_of_blocks - 1) : items_per_block; + + unsigned int start_id = (flat_block_id * items_per_block) + flat_id * ItemsPerThread; + if (start_id >= input_size) + { + return; + } + + + key_type keys[ItemsPerThread]; + value_type values[ItemsPerThread]; + + block_load_radix_impl( + flat_id, + block_offset, + valid_in_last_block, + last_block, + keys_input, + values_input, + keys, + values + ); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if( flat_id * ItemsPerThread + i < valid_in_last_block ) + { + const unsigned int id = start_id + i; + const unsigned int block_id = id / merge_items_per_block_size; + const bool block_id_is_odd = block_id & 1; + const unsigned int next_block_id = block_id_is_odd ? block_id - 1 : + block_id + 1; + const unsigned int block_start = min(block_id * merge_items_per_block_size, (unsigned int) input_size); + const unsigned int next_block_start = min(next_block_id * merge_items_per_block_size, (unsigned int) input_size); + const unsigned int next_block_end = min((next_block_id + 1) * merge_items_per_block_size, (unsigned int) input_size); + + if(next_block_start == input_size) + { + keys_output[id] = keys[i]; + if(with_values) + { + values_output[id] = values[i]; + } + } + + unsigned int left_id = next_block_start; + unsigned int right_id = next_block_end; + + while(left_id < right_id) + { + unsigned int mid_id = (left_id + right_id) / 2; + key_type mid_key = keys_input[mid_id]; + bool smaller = compare_function(mid_key, keys[i]); + left_id = smaller ? mid_id + 1 : left_id; + right_id = smaller ? right_id : mid_id; + } + + + right_id = next_block_end; + if(block_id_is_odd && left_id != right_id) + { + key_type upper_key = keys_input[left_id]; + while(!compare_function(upper_key, keys[i]) && + !compare_function(keys[i], upper_key) && + left_id < right_id) + { + unsigned int mid_id = (left_id + right_id) / 2; + key_type mid_key = keys_input[mid_id]; + bool equal = !compare_function(mid_key, keys[i]) && + !compare_function(keys[i], mid_key); + left_id = equal ? mid_id + 1 : left_id + 1; + right_id = equal ? right_id : mid_id; + upper_key = keys_input[left_id]; + } + } + + unsigned int offset = 0; + offset += id - block_start; + offset += left_id - next_block_start; + offset += min(block_start, next_block_start); + + keys_output[offset] = keys[i]; + if(with_values) + { + values_output[offset] = values[i]; + } + } + } +} + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_reduce.hpp b/3rdparty/cub/rocprim/device/detail/device_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..854b73af4f853cc6dd809570435f215ebda6932e --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_reduce.hpp @@ -0,0 +1,184 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Helper functions for reducing final value with +// initial value. +template< + bool WithInitialValue, + class T, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto reduce_with_initial(T output, + T initial_value, + BinaryFunction reduce_op) + -> typename std::enable_if::type +{ + return reduce_op(initial_value, output); +} + +template< + bool WithInitialValue, + class T, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto reduce_with_initial(T output, + T initial_value, + BinaryFunction reduce_op) + -> typename std::enable_if::type +{ + (void) initial_value; + (void) reduce_op; + return output; +} + +template< + bool WithInitialValue, + class Config, + class ResultType, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void block_reduce_kernel_impl(InputIterator input, + const size_t input_size, + OutputIterator output, + InitValueType initial_value, + BinaryFunction reduce_op) +{ + constexpr unsigned int block_size = Config::block_size; + constexpr unsigned int items_per_thread = Config::items_per_thread; + + using result_type = ResultType; + + using block_reduce_type = ::rocprim::block_reduce< + result_type, block_size, + Config::block_reduce_method + >; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); + auto valid_in_last_block = input_size - items_per_block * (number_of_blocks - 1); + + result_type values[items_per_thread]; + result_type output_value; + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_load_direct_striped( + flat_id, + input + block_offset, + values, + valid_in_last_block + ); + + output_value = values[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < items_per_thread; i++) + { + unsigned int offset = i * block_size; + if(flat_id + offset < valid_in_last_block) + { + output_value = reduce_op(output_value, values[i]); + } + } + + block_reduce_type() + .reduce( + output_value, // input + output_value, // output + valid_in_last_block, + reduce_op + ); + } + else + { + block_load_direct_striped( + flat_id, + input + block_offset, + values + ); + + // load input values into values + block_reduce_type() + .reduce( + values, // input + output_value, // output + reduce_op + ); + } + + // Save value into output + if(flat_id == 0) + { + output[flat_block_id] = input_size == 0 + ? static_cast(initial_value) + : reduce_with_initial( + output_value, + static_cast(initial_value), + reduce_op + ); + } +} + +// Returns size of temporary storage in bytes. +template +size_t reduce_get_temporary_storage_bytes(size_t input_size, + size_t items_per_block) +{ + if(input_size <= items_per_block) + { + return 0; + } + auto size = (input_size + items_per_block - 1)/(items_per_block); + return size * sizeof(T) + reduce_get_temporary_storage_bytes(size, items_per_block); +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp b/3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e277652bc50f42b0429d2b0dcaf41b5e8687f92c --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_reduce_by_key.hpp @@ -0,0 +1,644 @@ +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +#include "../../block/block_discontinuity.hpp" +#include "../../block/block_load_func.hpp" +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct carry_out +{ + ROCPRIM_DEVICE ROCPRIM_INLINE + carry_out() = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + carry_out(const carry_out& rhs) = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + carry_out& operator=(const carry_out& rhs) + { + value = rhs.value; + destination = rhs.destination; + next_has_carry_in = rhs.next_has_carry_in; + return *this; + } + + Value value; // carry-out of the current batch + unsigned int destination; + bool next_has_carry_in; // the next batch has carry-in (i.e. it continues the last segment from the current batch) +}; + +template +struct scan_by_key_pair +{ + ROCPRIM_DEVICE ROCPRIM_INLINE + scan_by_key_pair() = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + scan_by_key_pair(const scan_by_key_pair& rhs) = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + scan_by_key_pair& operator=(const scan_by_key_pair& rhs) + { + key = rhs.key; + value = rhs.value; + return *this; + } + + unsigned int key; + Value value; +}; + +// Special operator which allows to calculate scan-by-key using block_scan. +// block_scan supports non-commutative scan operators. +// Initial values of pairs' keys must be 1 for the first item (head) of segment and 0 otherwise. +// As a result key contains the current segment's index and value contains segmented scan result. +template +struct scan_by_key_op +{ + BinaryFunction reduce_op; + + ROCPRIM_DEVICE ROCPRIM_INLINE + scan_by_key_op(BinaryFunction reduce_op) + : reduce_op(reduce_op) + {} + + ROCPRIM_DEVICE ROCPRIM_INLINE + Pair operator()(const Pair& a, const Pair& b) + { + Pair c; + c.key = a.key + b.key; + c.value = b.key != 0 + ? b.value + : reduce_op(a.value, b.value); + return c; + } +}; + +// Wrappers that reverse results of key comparizon functions to use them as flag_op of block_discontinuity +// (for example, equal_to will work as not_equal_to and divide items into segments by keys) +template +struct key_flag_op +{ + KeyCompareFunction key_compare_op; + + ROCPRIM_DEVICE ROCPRIM_INLINE + key_flag_op(KeyCompareFunction key_compare_op) + : key_compare_op(key_compare_op) + {} + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const Key& a, const Key& b) + { + return !key_compare_op(a, b); + } +}; + +// This wrapper processes only part of items and flags (valid_count - 1)th item (for tails) +// and (valid_count)th item (for heads), all items after valid_count are unflagged. +template +struct guarded_key_flag_op +{ + KeyCompareFunction key_compare_op; + unsigned int valid_count; + + ROCPRIM_DEVICE ROCPRIM_INLINE + guarded_key_flag_op(KeyCompareFunction key_compare_op, unsigned int valid_count) + : key_compare_op(key_compare_op), valid_count(valid_count) + {} + + ROCPRIM_DEVICE ROCPRIM_INLINE + bool operator()(const Key& a, const Key& b, unsigned int b_index) + { + return (b_index < valid_count && !key_compare_op(a, b)) || b_index == valid_count; + } +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeyCompareFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void fill_unique_counts(KeysInputIterator keys_input, + unsigned int size, + unsigned int * unique_counts, + KeyCompareFunction key_compare_op, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + constexpr unsigned int warp_size = ::rocprim::device_warp_size(); + constexpr unsigned int warps_no = BlockSize / warp_size; + + using key_type = typename std::iterator_traits::value_type; + + using keys_load_type = ::rocprim::block_load< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using discontinuity_type = ::rocprim::block_discontinuity; + + ROCPRIM_SHARED_MEMORY struct + { + union + { + typename keys_load_type::storage_type keys_load; + typename discontinuity_type::storage_type discontinuity; + }; + unsigned int unique_counts[warps_no]; + } storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int batch_id = ::rocprim::detail::block_id<0>(); + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int warp_id = ::rocprim::warp_id<0, 1, 1>(); + + unsigned int block_offset; + unsigned int blocks_per_batch; + if(batch_id < full_batches) + { + blocks_per_batch = blocks_per_full_batch; + block_offset = batch_id * blocks_per_batch; + } + else + { + blocks_per_batch = blocks_per_full_batch - 1; + block_offset = batch_id * blocks_per_batch + full_batches; + } + block_offset *= items_per_block; + + unsigned int warp_unique_count = 0; + + for(unsigned int bi = 0; bi < blocks_per_batch; bi++) + { + const bool is_last_block = (block_offset + items_per_block >= size); + + key_type keys[ItemsPerThread]; + unsigned int valid_count; + ::rocprim::syncthreads(); + if(is_last_block) + { + valid_count = size - block_offset; + keys_load_type().load(keys_input + block_offset, keys, valid_count, storage.keys_load); + } + else + { + valid_count = items_per_block; + keys_load_type().load(keys_input + block_offset, keys, storage.keys_load); + } + + bool tail_flags[ItemsPerThread]; + key_type successor_key = keys[ItemsPerThread - 1]; + ::rocprim::syncthreads(); + if(is_last_block) + { + discontinuity_type().flag_tails( + tail_flags, successor_key, keys, + guarded_key_flag_op(key_compare_op, valid_count), + storage.discontinuity + ); + } + else + { + if(flat_id == BlockSize - 1) + { + successor_key = keys_input[block_offset + items_per_block]; + } + discontinuity_type().flag_tails( + tail_flags, successor_key, keys, + key_flag_op(key_compare_op), + storage.discontinuity + ); + } + + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + warp_unique_count += ::rocprim::bit_count(::rocprim::ballot(tail_flags[i])); + } + + block_offset += items_per_block; + } + + if(lane_id == 0) + { + storage.unique_counts[warp_id] = warp_unique_count; + } + ::rocprim::syncthreads(); + + if(flat_id == 0) + { + unsigned int batch_unique_count = 0; + for(unsigned int w = 0; w < warps_no; w++) + { + batch_unique_count += storage.unique_counts[w]; + } + unique_counts[batch_id] = batch_unique_count; + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class UniqueCountOutputIterator +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void scan_unique_counts(unsigned int * unique_counts, + UniqueCountOutputIterator unique_count_output, + unsigned int batches) +{ + using load_type = ::rocprim::block_load< + unsigned int, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using store_type = ::rocprim::block_store< + unsigned int, BlockSize, ItemsPerThread, + ::rocprim::block_store_method::block_store_transpose>; + using scan_type = typename ::rocprim::block_scan; + + ROCPRIM_SHARED_MEMORY union + { + typename load_type::storage_type load; + typename store_type::storage_type store; + typename scan_type::storage_type scan; + } storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + + unsigned int values[ItemsPerThread]; + load_type().load(unique_counts, values, batches, 0, storage.load); + + unsigned int unique_count; + ::rocprim::syncthreads(); + scan_type().exclusive_scan(values, values, 0, unique_count); + + ::rocprim::syncthreads(); + store_type().store(unique_counts, values, batches, storage.store); + + if(flat_id == 0) + { + *unique_count_output = unique_count; + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class ValuesInputIterator, + class Result, + class UniqueOutputIterator, + class AggregatesOutputIterator, + class KeyCompareFunction, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void reduce_by_key(KeysInputIterator keys_input, + ValuesInputIterator values_input, + unsigned int size, + const unsigned int * unique_starts, + carry_out * carry_outs, + Result * leading_aggregates, + UniqueOutputIterator unique_output, + AggregatesOutputIterator aggregates_output, + KeyCompareFunction key_compare_op, + BinaryFunction reduce_op, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + using key_type = typename std::iterator_traits::value_type; + using result_type = Result; + + using keys_load_type = ::rocprim::block_load< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using values_load_type = ::rocprim::block_load< + result_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using discontinuity_type = ::rocprim::block_discontinuity; + using scan_type = ::rocprim::block_scan, BlockSize>; + + ROCPRIM_SHARED_MEMORY struct + { + union + { + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename discontinuity_type::storage_type discontinuity; + typename scan_type::storage_type scan; + }; + unsigned int unique_count; + bool has_carry_in; + detail::raw_storage carry_in; + } storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int batch_id = ::rocprim::detail::block_id<0>(); + + unsigned int block_offset; + unsigned int blocks_per_batch; + if(batch_id < full_batches) + { + blocks_per_batch = blocks_per_full_batch; + block_offset = batch_id * blocks_per_batch; + } + else + { + blocks_per_batch = blocks_per_full_batch - 1; + block_offset = batch_id * blocks_per_batch + full_batches; + } + block_offset *= items_per_block; + + const unsigned int batch_start = unique_starts[batch_id]; + unsigned int block_start = batch_start; + + if(flat_id == 0) + { + storage.has_carry_in = + (block_offset > 0) && + key_compare_op(keys_input[block_offset - 1], keys_input[block_offset]); + } + + for(unsigned int bi = 0; bi < blocks_per_batch; bi++) + { + const bool is_last_block = (block_offset + items_per_block >= size); + + key_type keys[ItemsPerThread]; + result_type values[ItemsPerThread]; + unsigned int valid_count; + if(is_last_block) + { + valid_count = size - block_offset; + keys_load_type().load(keys_input + block_offset, keys, valid_count, storage.keys_load); + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, valid_count, storage.values_load); + } + else + { + valid_count = items_per_block; + keys_load_type().load(keys_input + block_offset, keys, storage.keys_load); + ::rocprim::syncthreads(); + values_load_type().load(values_input + block_offset, values, storage.values_load); + } + + if(bi > 0 && flat_id == 0 && storage.has_carry_in) + { + // Apply carry-out of the previous block as carry-in for the first segment + values[0] = reduce_op(storage.carry_in.get(), values[0]); + } + + bool head_flags[ItemsPerThread]; + bool tail_flags[ItemsPerThread]; + key_type successor_key = keys[ItemsPerThread - 1]; + ::rocprim::syncthreads(); + if(is_last_block) + { + discontinuity_type().flag_heads_and_tails( + head_flags, tail_flags, successor_key, keys, + guarded_key_flag_op(key_compare_op, valid_count), + storage.discontinuity + ); + } + else + { + if(flat_id == BlockSize - 1) + { + successor_key = keys_input[block_offset + items_per_block]; + } + discontinuity_type().flag_heads_and_tails( + head_flags, tail_flags, successor_key, keys, + key_flag_op(key_compare_op), + storage.discontinuity + ); + } + + // Build pairs and run non-commutative inclusive scan to calculate scan-by-key + // and indices (ranks) of each segment: + // input: + // keys | 1 1 1 2 3 3 4 4 | + // head_flags | + + + + | + // values | 2 0 1 4 2 3 1 5 | + // result: + // scan values | 2 2 3 4 2 5 1 6 | + // scan keys | 1 1 1 2 3 3 4 4 | + // ranks (key-1) | 0 0 0 1 2 2 3 3 | + scan_by_key_pair pairs[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + pairs[i].key = head_flags[i]; + pairs[i].value = values[i]; + } + scan_by_key_op, BinaryFunction> scan_op(reduce_op); + ::rocprim::syncthreads(); + scan_type().inclusive_scan(pairs, pairs, storage.scan, scan_op); + + unsigned int ranks[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + ranks[i] = pairs[i].key - 1; // The first item is always flagged as head, so indices start from 1 + values[i] = pairs[i].value; + } + + if(flat_id == BlockSize - 1) + { + storage.unique_count = ranks[ItemsPerThread - 1] + (tail_flags[ItemsPerThread - 1] ? 1 : 0); + } + + ::rocprim::syncthreads(); + const unsigned int unique_count = storage.unique_count; + if(flat_id == 0) + { + // The first item must be written only if it is the first item of the current segment + // (otherwise it is written by one of previous blocks) + head_flags[0] = !storage.has_carry_in; + } + if(is_last_block) + { + // Unflag the head after the last segment as it will be written out of bounds + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(ranks[i] >= unique_count) + { + head_flags[i] = false; + } + } + } + + ::rocprim::syncthreads(); + if(flat_id == BlockSize - 1) + { + if(bi == blocks_per_batch - 1) + { + // Save carry-out of the last block of the current batch + carry_outs[batch_id].value = values[ItemsPerThread - 1]; + carry_outs[batch_id].destination = block_start + ranks[ItemsPerThread - 1]; + carry_outs[batch_id].next_has_carry_in = !tail_flags[ItemsPerThread - 1]; + } + else + { + // Save carry-out to use it as carry-in for the next block of the current batch + storage.has_carry_in = !tail_flags[ItemsPerThread - 1]; + storage.carry_in.get() = values[ItemsPerThread - 1]; + } + } + if(batch_id > 0 && block_start == batch_start) + { + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + // Write the scanned value of the last item of the first segment of the current batch + // (the leading possible incomplete aggregate) to calculate the final aggregate in the next kernel. + // The intermediate array is used instead of aggregates_output because + // aggregates_output may be write-only. + if(tail_flags[i] && ranks[i] == 0) + { + leading_aggregates[batch_id - 1] = values[i]; + } + } + } + + // Save unique keys and aggregates (some aggregates contains partial values + // and will be updated later by calculating scan-by-key of carry-outs) + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(head_flags[i]) + { + // Write the key of the first item of the segment as a unique key + unique_output[block_start + ranks[i]] = keys[i]; + } + if(tail_flags[i]) + { + // Write the scanned value of the last item of the segment as an aggregate (reduction of the segment) + aggregates_output[block_start + ranks[i]] = values[i]; + } + } + + block_offset += items_per_block; + block_start += unique_count; + } +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class Result, + class AggregatesOutputIterator, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void scan_and_scatter_carry_outs(const carry_out * carry_outs, + const Result * leading_aggregates, + AggregatesOutputIterator aggregates_output, + BinaryFunction reduce_op, + unsigned int batches) +{ + using result_type = Result; + + using discontinuity_type = ::rocprim::block_discontinuity; + using scan_type = ::rocprim::block_scan, BlockSize>; + + ROCPRIM_SHARED_MEMORY struct + { + typename discontinuity_type::storage_type discontinuity; + typename scan_type::storage_type scan; + } storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + + carry_out cs[ItemsPerThread]; + block_load_direct_blocked(flat_id, carry_outs, cs, batches - 1); + + unsigned int destinations[ItemsPerThread]; + result_type values[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + destinations[i] = cs[i].destination; + values[i] = cs[i].value; + } + + bool head_flags[ItemsPerThread]; + bool tail_flags[ItemsPerThread]; + ::rocprim::equal_to compare_op; + // If a carry-out of the current batch has the same destination as previous batches, + // then we need to scan its value with values of those previous batches. + discontinuity_type().flag_heads_and_tails( + head_flags, tail_flags, + destinations[ItemsPerThread - 1], // Do not always flag the last item in the block + destinations, + guarded_key_flag_op(compare_op, batches - 1), + storage.discontinuity + ); + + scan_by_key_pair pairs[ItemsPerThread]; + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + pairs[i].key = head_flags[i]; + pairs[i].value = values[i]; + } + + scan_by_key_op, BinaryFunction> scan_op(reduce_op); + scan_type().inclusive_scan(pairs, pairs, storage.scan, scan_op); + + // Scatter the last carry-out of each segment as carry-ins + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(tail_flags[i]) + { + const unsigned int dst = destinations[i]; + const result_type aggregate = pairs[i].value; + if(cs[i].next_has_carry_in) + { + // The next batch continues the last segment from the current batch, + // combine two partial aggregates + aggregates_output[dst] = reduce_op(aggregate, leading_aggregates[flat_id * ItemsPerThread + i]); + } + else + { + // Overwrite the aggregate because the next batch starts with a different key + aggregates_output[dst] = aggregate; + } + } + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_BY_KEY_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp b/3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70ca9cb99ee2bb6720a99103ff2ef557a1d54ff7 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_scan_by_key.hpp @@ -0,0 +1,388 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ + +#include "device_scan_common.hpp" +#include "lookback_scan_state.hpp" +#include "ordered_block_id.hpp" + +#include "../../block/block_discontinuity.hpp" +#include "../../block/block_load.hpp" +#include "../../block/block_scan.hpp" +#include "../../block/block_store.hpp" +#include "../../config.hpp" +#include "../../detail/binary_op_wrappers.hpp" +#include "../../intrinsics/thread.hpp" +#include "../../types/tuple.hpp" + +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + template + struct load_values_flagged + { + using block_load_keys + = ::rocprim::block_load; + + using block_discontinuity = ::rocprim::block_discontinuity; + + using block_load_values + = ::rocprim::block_load; + + union storage_type { + struct { + typename block_load_keys::storage_type load; + typename block_discontinuity::storage_type flag; + } keys; + typename block_load_values::storage_type load_values; + }; + + // Load flagged values + // - if the scan is exlusive the last item of each segment (range where the keys compare equal) + // is flagged and reset to the initial value. Adding the last item of the range to the + // second to last using `headflag_scan_op_wrapper` will return the initial_value, + // which is exactly what should be saved at the start of the next range. + // - if the scan is inclusive, then the first item of each segment is marked, and it will + // restart the scan from that value + template + ROCPRIM_DEVICE void + load(KeyIterator keys_input, + ValueIterator values_input, + CompareFunction compare, + const result_type initial_value, + const unsigned int flat_block_id, + const size_t starting_block, + const size_t number_of_blocks, + const unsigned int flat_thread_id, + const size_t size, + rocprim::tuple (&wrapped_values)[items_per_thread], + storage_type& storage) + { + constexpr static unsigned int items_per_block = items_per_thread * block_size; + const unsigned int block_offset = flat_block_id * items_per_block; + KeyIterator block_keys = keys_input + block_offset; + ValueIterator block_values = values_input + block_offset; + + key_type keys[items_per_thread]; + result_type values[items_per_thread]; + bool flags[items_per_thread]; + + auto not_equal + = [compare](const auto& a, const auto& b) mutable { return !compare(a, b); }; + + const auto flag_segment_boundaries = [&]() { + if(Exclusive) + { + const key_type tile_successor + = starting_block + flat_block_id < number_of_blocks - 1 + ? block_keys[items_per_block] + : *block_keys; + block_discontinuity {}.flag_tails( + flags, tile_successor, keys, not_equal, storage.keys.flag); + } + else + { + const key_type tile_predecessor = starting_block + flat_block_id > 0 + ? block_keys[-1] + : *block_keys; + block_discontinuity {}.flag_heads( + flags, tile_predecessor, keys, not_equal, storage.keys.flag); + } + }; + + if(starting_block + flat_block_id < number_of_blocks - 1) + { + block_load_keys{}.load( + block_keys, + keys, + storage.keys.load + ); + + flag_segment_boundaries(); + // Reusing shared memory for loading values + ::rocprim::syncthreads(); + + block_load_values{}.load( + block_values, + values, + storage.load_values + ); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) { + rocprim::get<0>(wrapped_values[i]) + = (Exclusive && flags[i]) ? initial_value : values[i]; + rocprim::get<1>(wrapped_values[i]) = flags[i]; + } + } + else + { + const unsigned int valid_in_last_block + = static_cast(size - items_per_block * (number_of_blocks - 1)); + + block_load_keys {}.load( + block_keys, + keys, + valid_in_last_block, + *block_keys, // Any value is okay, so discontinuity doesn't access undefined items + storage.keys.load); + + flag_segment_boundaries(); + // Reusing shared memory for loading values + ::rocprim::syncthreads(); + + block_load_values{}.load( + block_values, + values, + valid_in_last_block, + storage.load_values + ); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) { + if(flat_thread_id * items_per_thread + i >= valid_in_last_block) { + break; + } + + rocprim::get<0>(wrapped_values[i]) + = (Exclusive && flags[i]) ? initial_value : values[i]; + rocprim::get<1>(wrapped_values[i]) = flags[i]; + } + } + } + }; + + template + struct unwrap_store + { + using block_store_values + = ::rocprim::block_store; + + using storage_type = typename block_store_values::storage_type; + + template + ROCPRIM_DEVICE void + store(OutputIterator output, + const unsigned int flat_block_id, + const size_t starting_block, + const size_t number_of_blocks, + const unsigned int flat_thread_id, + const size_t size, + const rocprim::tuple (&wrapped_values)[items_per_thread], + storage_type& storage) + { + constexpr static unsigned int items_per_block = items_per_thread * block_size; + const unsigned int block_offset = flat_block_id * items_per_block; + OutputIterator block_output = output + block_offset; + + result_type thread_values[items_per_thread]; + + if(starting_block + flat_block_id < number_of_blocks - 1) + { + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) { + thread_values[i] = rocprim::get<0>(wrapped_values[i]); + } + + // Reusing shared memory from scan to perform store + rocprim::syncthreads(); + + block_store_values {}.store(block_output, thread_values, storage); + } + else + { + const unsigned int valid_in_last_block + = static_cast(size - items_per_block * (number_of_blocks - 1)); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) { + if(flat_thread_id * items_per_thread + i >= valid_in_last_block) { + break; + } + + thread_values[i] = rocprim::get<0>(wrapped_values[i]); + } + + // Reusing shared memory from scan to perform store + rocprim::syncthreads(); + + block_store_values {}.store( + block_output, thread_values, valid_in_last_block, storage); + } + } + }; + + template + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void device_scan_by_key_kernel_impl( + KeyInputIterator keys, + InputIterator values, + OutputIterator output, + ResultType initial_value, + const CompareFunction compare, + const BinaryFunction scan_op, + LookbackScanState scan_state, + const size_t size, + const size_t starting_block, + const size_t number_of_blocks, + ordered_block_id ordered_bid, + const rocprim::tuple* const previous_last_value) + { + using result_type = ResultType; + static_assert(std::is_same, + typename LookbackScanState::value_type>::value, + "value_type of LookbackScanState must be tuple of result type and flag"); + + constexpr auto block_size = Config::block_size; + constexpr auto items_per_thread = Config::items_per_thread; + constexpr auto load_keys_method = Config::block_load_method; + constexpr auto load_values_method = load_keys_method; + + using key_type = typename std::iterator_traits::value_type; + using load_flagged = load_values_flagged; + + auto wrapped_op = headflag_scan_op_wrapper{scan_op}; + using wrapped_type = rocprim::tuple; + + using block_scan_type + = ::rocprim::block_scan; + + constexpr auto store_method = Config::block_store_method; + using store_unwrap = unwrap_store; + + using order_bid_type = ordered_block_id; + + ROCPRIM_SHARED_MEMORY union + { + struct + { + typename load_flagged::storage_type load; + typename order_bid_type::storage_type ordered_bid; + }; + typename block_scan_type::storage_type scan; + typename store_unwrap::storage_type store; + } storage; + + const auto flat_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_id = ordered_bid.get(flat_thread_id, storage.ordered_bid); + + // Load input + wrapped_type wrapped_values[items_per_thread]; + load_flagged {}.load(keys, + values, + compare, + initial_value, + flat_block_id, + starting_block, + number_of_blocks, + flat_thread_id, + size, + wrapped_values, + storage.load); + + // Reusing the storage from load to perform the scan + ::rocprim::syncthreads(); + + // Perform look back scan scan + if(flat_block_id == 0) + { + auto wrapped_initial_value = rocprim::make_tuple(initial_value, false); + + // previous_last_value is used to pass the value from the previous grid, if this is a + // multi grid launch + if(previous_last_value != nullptr) + { + if(Exclusive) { + rocprim::get<0>(wrapped_initial_value) = rocprim::get<0>(*previous_last_value); + } else if (flat_thread_id == 0) { + wrapped_values[0] = wrapped_op(*previous_last_value, wrapped_values[0]); + } + } + + wrapped_type reduction; + lookback_block_scan(wrapped_values, + wrapped_initial_value, + reduction, + storage.scan, + wrapped_op); + + if(flat_thread_id == 0) + { + scan_state.set_complete(flat_block_id, reduction); + } + } + else + { + auto prefix_op = lookback_scan_prefix_op { + flat_block_id, wrapped_op, scan_state}; + + // Scan of block values + lookback_block_scan( + wrapped_values, + storage.scan, + prefix_op, + wrapped_op); + } + + // Store output + // synchronization is inside the function after unwrapping + store_unwrap {}.store(output, + flat_block_id, + starting_block, + number_of_blocks, + flat_thread_id, + size, + wrapped_values, + storage.store); + } +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ \ No newline at end of file diff --git a/3rdparty/cub/rocprim/device/detail/device_scan_common.hpp b/3rdparty/cub/rocprim/device/detail/device_scan_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2a453002e7b530c9c5d7841a12df3b86c16f950b --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_scan_common.hpp @@ -0,0 +1,153 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_SCAN_COMMON_HPP_ +#define ROCPRIM_DEVICE_SCAN_COMMON_HPP_ + +#include "../../config.hpp" +#include "../../intrinsics/thread.hpp" + +#include "lookback_scan_state.hpp" +#include "ordered_block_id.hpp" + +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + template + ROCPRIM_KERNEL + __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void init_lookback_scan_state_kernel( + LookBackScanState lookback_scan_state, + const unsigned int number_of_blocks, + ordered_block_id ordered_bid, + unsigned int save_index = 0, + typename LookBackScanState::value_type* const save_dest = nullptr) + { + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_size = ::rocprim::detail::block_size<0>(); + const unsigned int block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int id = (block_id * block_size) + block_thread_id; + + // Reset ordered_block_id + if(id == 0) + { + ordered_bid.reset(); + } + // Save the reduction (i.e. the last prefix) from the previous user of lookback_scan_state + // If the thread that should reset it is participating then it saves the value before + // reseting, otherwise the first thread saves it (it won't be reset by any thread). + if(save_dest != nullptr + && ((number_of_blocks <= save_index && id == 0) || id == save_index)) + { + typename LookBackScanState::value_type value; + typename LookBackScanState::flag_type dummy_flag; + lookback_scan_state.get(save_index, dummy_flag, value); + + *save_dest = value; + } + // Initialize lookback scan status + lookback_scan_state.initialize_prefix(id, number_of_blocks); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE auto + lookback_block_scan(T (&values)[ItemsPerThread], + T /* initial_value */, + T& reduction, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type + { + BlockScan().inclusive_scan(values, // input + values, // output + reduction, + storage, + scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE auto + lookback_block_scan(T (&values)[ItemsPerThread], + T initial_value, + T& reduction, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type + { + BlockScan().exclusive_scan(values, // input + values, // output + initial_value, + reduction, + storage, + scan_op); + reduction = scan_op(initial_value, reduction); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE auto + lookback_block_scan(T (&values)[ItemsPerThread], + typename BlockScan::storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) -> typename std::enable_if::type + { + BlockScan().inclusive_scan(values, // input + values, // output + storage, + prefix_callback_op, + scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE auto + lookback_block_scan(T (&values)[ItemsPerThread], + typename BlockScan::storage_type& storage, + PrefixCallback& prefix_callback_op, + BinaryFunction scan_op) -> typename std::enable_if::type + { + BlockScan().exclusive_scan(values, // input + values, // output + storage, + prefix_callback_op, + scan_op); + } + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_SCAN_COMMON_HPP_ \ No newline at end of file diff --git a/3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp b/3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b3b068a325d38a486c246eed6552ad5ca784156b --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_scan_lookback.hpp @@ -0,0 +1,222 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_ + +#include +#include + +#include "../../detail/various.hpp" +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" + +#include "device_scan_common.hpp" +#include "lookback_scan_state.hpp" +#include "ordered_block_id.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +// Single pass prefix scan was implemented based on: +// Merrill, D. and Garland, M. Single-pass Parallel Prefix Scan with Decoupled Look-back. +// Technical Report NVR2016-001, NVIDIA Research. Mar. 2016. + +namespace detail +{ + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class ResultType, + class LookbackScanState +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void lookback_scan_kernel_impl(InputIterator input, + OutputIterator output, + const size_t size, + ResultType initial_value, + BinaryFunction scan_op, + LookbackScanState scan_state, + const unsigned int number_of_blocks, + ordered_block_id ordered_bid, + ResultType * previous_last_element = nullptr, + ResultType * new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) +{ + using result_type = ResultType; + static_assert( + std::is_same::value, + "value_type of LookbackScanState must be result_type" + ); + + constexpr auto block_size = Config::block_size; + constexpr auto items_per_thread = Config::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + using block_load_type = ::rocprim::block_load< + result_type, block_size, items_per_thread, + Config::block_load_method + >; + using block_store_type = ::rocprim::block_store< + result_type, block_size, items_per_thread, + Config::block_store_method + >; + using block_scan_type = ::rocprim::block_scan< + result_type, block_size, + Config::block_scan_method + >; + + using order_bid_type = ordered_block_id; + using lookback_scan_prefix_op_type = lookback_scan_prefix_op< + result_type, BinaryFunction, LookbackScanState + >; + + ROCPRIM_SHARED_MEMORY struct + { + typename order_bid_type::storage_type ordered_bid; + union + { + typename block_load_type::storage_type load; + typename block_store_type::storage_type store; + typename block_scan_type::storage_type scan; + }; + } storage; + + const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_id = ordered_bid.get(flat_block_thread_id, storage.ordered_bid); + const unsigned int block_offset = flat_block_id * items_per_block; + const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); + + // For input values + result_type values[items_per_thread]; + + // load input values into values + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_load_type() + .load( + input + block_offset, + values, + valid_in_last_block, + *(input + block_offset), + storage.load + ); + } + else + { + block_load_type() + .load( + input + block_offset, + values, + storage.load + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + if(flat_block_id == 0) + { + // override_first_value only true when the first chunk already processed + // and input iterator starts from an offset. + if(override_first_value) + { + if(Exclusive) + initial_value = scan_op(previous_last_element[0], static_cast(*(input-1))); + else if(flat_block_thread_id == 0) + values[0] = scan_op(previous_last_element[0], values[0]); + } + + result_type reduction; + lookback_block_scan( + values, // input/output + initial_value, + reduction, + storage.scan, + scan_op + ); + + if(flat_block_thread_id == 0) + { + scan_state.set_complete(flat_block_id, reduction); + } + } + else + { + // Scan of block values + auto prefix_op = lookback_scan_prefix_op_type( + flat_block_id, scan_op, scan_state + ); + lookback_block_scan( + values, // input/output + storage.scan, + prefix_op, + scan_op + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // Save values into output array + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_store_type() + .store( + output + block_offset, + values, + valid_in_last_block, + storage.store + ); + + if(save_last_value && + (::rocprim::detail::block_thread_id<0>() == + (valid_in_last_block - 1) / items_per_thread)) + { + for(unsigned int i = 0; i < items_per_thread; i++) + { + if(i == (valid_in_last_block - 1) % items_per_thread) + { + new_last_element[0] = values[i]; + } + } + } + } + else + { + block_store_type() + .store( + output + block_offset, + values, + storage.store + ); + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_LOOKBACK_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_scan_reduce_then_scan.hpp b/3rdparty/cub/rocprim/device/detail/device_scan_reduce_then_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..349a0b2428431afd3410e6793ff515f8f674a9b8 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_scan_reduce_then_scan.hpp @@ -0,0 +1,469 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" +#include "../../block/block_reduce.hpp" + + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Helper functions for performing exclusive or inclusive +// block scan in single_scan. +template< + bool Exclusive, + class BlockScan, + class T, + unsigned int ItemsPerThread, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto single_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + BlockScan() + .exclusive_scan( + input, // input + output, // output + initial_value, + storage, + scan_op + ); +} + +template< + bool Exclusive, + class BlockScan, + class T, + unsigned int ItemsPerThread, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto single_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + (void) initial_value; + BlockScan() + .inclusive_scan( + input, // input + output, // output + storage, + scan_op + ); +} + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class ResultType +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void single_scan_kernel_impl(InputIterator input, + const size_t input_size, + ResultType initial_value, + OutputIterator output, + BinaryFunction scan_op) +{ + constexpr unsigned int block_size = Config::block_size; + constexpr unsigned int items_per_thread = Config::items_per_thread; + + using result_type = ResultType; + + using block_load_type = ::rocprim::block_load< + result_type, block_size, items_per_thread, + Config::block_load_method + >; + using block_store_type = ::rocprim::block_store< + result_type, block_size, items_per_thread, + Config::block_store_method + >; + using block_scan_type = ::rocprim::block_scan< + result_type, block_size, + Config::block_scan_method + >; + + ROCPRIM_SHARED_MEMORY union + { + typename block_load_type::storage_type load; + typename block_store_type::storage_type store; + typename block_scan_type::storage_type scan; + } storage; + + result_type values[items_per_thread]; + // load input values into values + block_load_type() + .load( + input, + values, + input_size, + *(input), + storage.load + ); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + single_scan_block_scan( + values, // input + values, // output + initial_value, + storage.scan, + scan_op + ); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // Save values into output array + block_store_type() + .store( + output, + values, + input_size, + storage.store + ); +} + +// Calculates block prefixes that will be used in final_scan +// when performing block scan operations. +template< + class Config, + class InputIterator, + class BinaryFunction, + class ResultType +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void block_reduce_kernel_impl(InputIterator input, + BinaryFunction scan_op, + ResultType * block_prefixes) +{ + constexpr unsigned int block_size = Config::block_size; + constexpr unsigned int items_per_thread = Config::items_per_thread; + + using result_type = ResultType; + using block_reduce_type = ::rocprim::block_reduce< + result_type, block_size, + ::rocprim::block_reduce_algorithm::using_warp_reduce + >; + using block_load_type = ::rocprim::block_load< + result_type, block_size, items_per_thread, + Config::block_load_method + >; + + ROCPRIM_SHARED_MEMORY union + { + typename block_load_type::storage_type load; + typename block_reduce_type::storage_type reduce; + } storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_thread * block_size; + + // For input values + result_type values[items_per_thread]; + result_type block_prefix; + + block_load_type() + .load( + input + block_offset, + values, + storage.load + ); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + block_reduce_type() + .reduce( + values, // input + block_prefix, // output + storage.reduce, + scan_op + ); + + // Save block prefix + if(flat_id == 0) + { + block_prefixes[flat_block_id] = block_prefix; + } +} + +// Helper functions for performing exclusive or inclusive +// block scan operation in final_scan +template< + bool Exclusive, + class BlockScan, + class T, + unsigned int ItemsPerThread, + class ResultType, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto final_scan_block_scan(const unsigned int flat_block_id, + T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + ResultType * block_prefixes, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + if(flat_block_id != 0) + { + // Include initial value in block prefix + initial_value = scan_op( + initial_value, block_prefixes[flat_block_id - 1] + ); + } + BlockScan() + .exclusive_scan( + input, // input + output, // output + initial_value, + storage, + scan_op + ); +} + +template< + bool Exclusive, + class BlockScan, + class T, + unsigned int ItemsPerThread, + class ResultType, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto final_scan_block_scan(const unsigned int flat_block_id, + T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + ResultType * block_prefixes, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + (void) initial_value; + if(flat_block_id == 0) + { + BlockScan() + .inclusive_scan( + input, // input + output, // output + storage, + scan_op + ); + } + else + { + auto block_prefix_op = + [&block_prefixes, &flat_block_id](const T& /*not used*/) + { + return block_prefixes[flat_block_id - 1]; + }; + BlockScan() + .inclusive_scan( + input, // input + output, // output + storage, + block_prefix_op, + scan_op + ); + } +} + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class ResultType +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void final_scan_kernel_impl(InputIterator input, + const size_t input_size, + OutputIterator output, + ResultType initial_value, + BinaryFunction scan_op, + ResultType * block_prefixes, + ResultType * previous_last_element = nullptr, + ResultType * new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) +{ + constexpr unsigned int block_size = Config::block_size; + constexpr unsigned int items_per_thread = Config::items_per_thread; + + using result_type = ResultType; + + using block_load_type = ::rocprim::block_load< + result_type, block_size, items_per_thread, + Config::block_load_method + >; + using block_store_type = ::rocprim::block_store< + result_type, block_size, items_per_thread, + Config::block_store_method + >; + using block_scan_type = ::rocprim::block_scan< + result_type, block_size, + Config::block_scan_method + >; + + ROCPRIM_SHARED_MEMORY union + { + typename block_load_type::storage_type load; + typename block_store_type::storage_type store; + typename block_scan_type::storage_type scan; + } storage; + + // It's assumed kernel is executed in 1D + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + + constexpr unsigned int items_per_block = block_size * items_per_thread; + const unsigned int block_offset = flat_block_id * items_per_block; + // TODO: number_of_blocks can be calculated on host + const unsigned int number_of_blocks = (input_size + items_per_block - 1)/items_per_block; + + // For input values + result_type values[items_per_thread]; + + // TODO: valid_in_last_block can be calculated on host + auto valid_in_last_block = input_size - items_per_block * (number_of_blocks - 1); + // load input values into values + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_load_type() + .load( + input + block_offset, + values, + valid_in_last_block, + *(input + block_offset), + storage.load + ); + } + else + { + block_load_type() + .load( + input + block_offset, + values, + storage.load + ); + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // override_first_value only true when the first chunk already processed + // and input iterator starts from an offset. + if(override_first_value && flat_block_id == 0) + { + if(Exclusive) + initial_value = scan_op(previous_last_element[0], *(input-1)); + else if(::rocprim::detail::block_thread_id<0>() == 0) + values[0] = scan_op(previous_last_element[0], values[0]); + } + + final_scan_block_scan( + flat_block_id, + values, // input + values, // output + initial_value, + block_prefixes, + storage.scan, + scan_op + ); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + // Save values into output array + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_store_type() + .store( + output + block_offset, + values, + valid_in_last_block, + storage.store + ); + + if(save_last_value && + (::rocprim::detail::block_thread_id<0>() == + (valid_in_last_block - 1) / items_per_thread)) + { + for(unsigned int i = 0; i < items_per_thread; i++) + { + if(i == (valid_in_last_block - 1) % items_per_thread) + { + new_last_element[0] = values[i]; + } + } + } + } + else + { + block_store_type() + .store( + output + block_offset, + values, + storage.store + ); + } +} + +// Returns size of temporary storage in bytes. +template +size_t scan_get_temporary_storage_bytes(size_t input_size, + size_t items_per_block) +{ + if(input_size <= items_per_block) + { + return 0; + } + auto size = (input_size + items_per_block - 1)/(items_per_block); + return size * sizeof(T) + scan_get_temporary_storage_bytes(size, items_per_block); +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_REDUCE_THEN_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_segmented_radix_sort.hpp b/3rdparty/cub/rocprim/device/detail/device_segmented_radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b1a8cf33b503b6943626b72fcf9c385ff7f70e45 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -0,0 +1,990 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" + +#include "../../warp/warp_load.hpp" +#include "../../warp/warp_sort.hpp" +#include "../../warp/warp_store.hpp" + +#include "../device_segmented_radix_sort_config.hpp" +#include "device_radix_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class Key, + class Value, + unsigned int WarpSize, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending +> +class segmented_radix_sort_helper +{ + static constexpr unsigned int radix_size = 1 << RadixBits; + + using key_type = Key; + using value_type = Value; + + using count_helper_type = radix_digit_count_helper; + using scan_type = typename ::rocprim::block_scan; + using sort_and_scatter_helper = radix_sort_and_scatter_helper< + BlockSize, ItemsPerThread, RadixBits, Descending, + key_type, value_type, unsigned int>; + +public: + + union storage_type + { + typename segmented_radix_sort_helper::count_helper_type::storage_type count_helper; + typename segmented_radix_sort_helper::sort_and_scatter_helper::storage_type sort_and_scatter_helper; + }; + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(KeysInputIterator keys_input, + key_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int bit, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + // Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last + // iteration has a shorter mask. + const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit); + + const bool is_first_iteration = (bit == begin_bit); + + if(is_first_iteration) + { + if(to_output) + { + sort( + keys_input, keys_output, values_input, values_output, + begin_offset, end_offset, + bit, current_radix_bits, + storage + ); + } + else + { + sort( + keys_input, keys_tmp, values_input, values_tmp, + begin_offset, end_offset, + bit, current_radix_bits, + storage + ); + } + } + else + { + if(to_output) + { + sort( + keys_tmp, keys_output, values_tmp, values_output, + begin_offset, end_offset, + bit, current_radix_bits, + storage + ); + } + else + { + sort( + keys_output, keys_tmp, values_output, values_tmp, + begin_offset, end_offset, + bit, current_radix_bits, + storage + ); + } + } + } + + // When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(key_type * keys_input, + key_type * keys_tmp, + key_type * keys_output, + value_type * values_input, + value_type * values_tmp, + value_type * values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int bit, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + // Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last + // iteration has a shorter mask. + const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit); + + const bool is_first_iteration = (bit == begin_bit); + + key_type * current_keys_input; + key_type * current_keys_output; + value_type * current_values_input; + value_type * current_values_output; + if(is_first_iteration) + { + if(to_output) + { + current_keys_input = keys_input; + current_keys_output = keys_output; + current_values_input = values_input; + current_values_output = values_output; + } + else + { + current_keys_input = keys_input; + current_keys_output = keys_tmp; + current_values_input = values_input; + current_values_output = values_tmp; + } + } + else + { + if(to_output) + { + current_keys_input = keys_tmp; + current_keys_output = keys_output; + current_values_input = values_tmp; + current_values_output = values_output; + } + else + { + current_keys_input = keys_output; + current_keys_output = keys_tmp; + current_values_input = values_output; + current_values_output = values_tmp; + } + } + sort( + current_keys_input, current_keys_output, current_values_input, current_values_output, + begin_offset, end_offset, + bit, current_radix_bits, + storage + ); + } + +private: + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE + void sort(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage) + { + unsigned int digit_count; + count_helper_type().count_digits( + keys_input, + begin_offset, end_offset, + bit, current_radix_bits, + storage.count_helper, + digit_count + ); + + unsigned int digit_start; + scan_type().exclusive_scan(digit_count, digit_start, 0); + digit_start += begin_offset; + + ::rocprim::syncthreads(); + + sort_and_scatter_helper().sort_and_scatter( + keys_input, keys_output, values_input, values_output, + begin_offset, end_offset, + bit, current_radix_bits, + digit_start, + storage.sort_and_scatter_helper + ); + + ::rocprim::syncthreads(); + } +}; + +template< + class Key, + class Value, + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool Descending +> +class segmented_radix_sort_single_block_helper +{ + using key_type = Key; + using value_type = Value; + + using key_codec = radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + using keys_load_type = ::rocprim::block_load< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using values_load_type = ::rocprim::block_load< + value_type, BlockSize, ItemsPerThread, + ::rocprim::block_load_method::block_load_transpose>; + using sort_type = ::rocprim::block_radix_sort; + using keys_store_type = ::rocprim::block_store< + key_type, BlockSize, ItemsPerThread, + ::rocprim::block_store_method::block_store_transpose>; + using values_store_type = ::rocprim::block_store< + value_type, BlockSize, ItemsPerThread, + ::rocprim::block_store_method::block_store_transpose>; + + static constexpr bool with_values = !std::is_same::value; + +public: + + union storage_type + { + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename sort_type::storage_type sort; + typename keys_store_type::storage_type keys_store; + typename values_store_type::storage_type values_store; + }; + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(KeysInputIterator keys_input, + key_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + if(to_output) + { + sort( + keys_input, keys_output, values_input, values_output, + begin_offset, end_offset, + begin_bit, end_bit, + storage + ); + } + else + { + sort( + keys_input, keys_tmp, values_input, values_tmp, + begin_offset, end_offset, + begin_bit, end_bit, + storage + ); + } + } + + // When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(key_type * keys_input, + key_type * keys_tmp, + key_type * keys_output, + value_type * values_input, + value_type * values_tmp, + value_type * values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + sort( + keys_input, (to_output ? keys_output : keys_tmp), values_input, (to_output ? values_output : values_tmp), + begin_offset, end_offset, + begin_bit, end_bit, + storage + ); + } + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + bool sort(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + using shorter_single_block_helper = segmented_radix_sort_single_block_helper< + key_type, value_type, + BlockSize, ItemsPerThread / 2, Descending + >; + + // Segment is longer than supported by this function + if(end_offset - begin_offset > items_per_block) + { + return false; + } + + // Recursively chech if it is possible to sort the segment using fewer items per thread + const bool processed_by_shorter = + shorter_single_block_helper().sort( + keys_input, keys_output, values_input, values_output, + begin_offset, end_offset, + begin_bit, end_bit, + reinterpret_cast(storage) + ); + if(processed_by_shorter) + { + return true; + } + + key_type keys[ItemsPerThread]; + value_type values[ItemsPerThread]; + const unsigned int valid_count = end_offset - begin_offset; + // Sort will leave "invalid" (out of size) items at the end of the sorted sequence + const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + keys_load_type().load(keys_input + begin_offset, keys, valid_count, out_of_bounds, storage.keys_load); + if(with_values) + { + ::rocprim::syncthreads(); + values_load_type().load(values_input + begin_offset, values, valid_count, storage.values_load); + } + + ::rocprim::syncthreads(); + sort_block(sort_type(), keys, values, storage.sort, begin_bit, end_bit); + + ::rocprim::syncthreads(); + keys_store_type().store(keys_output + begin_offset, keys, valid_count, storage.keys_store); + if(with_values) + { + ::rocprim::syncthreads(); + values_store_type().store(values_output + begin_offset, values, valid_count, storage.values_store); + } + + return true; + } +}; + +template< + class Key, + class Value, + unsigned int BlockSize, + bool Descending +> +class segmented_radix_sort_single_block_helper +{ +public: + + struct storage_type { }; + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + bool sort(KeysInputIterator, + KeysOutputIterator, + ValuesInputIterator, + ValuesOutputIterator, + unsigned int, + unsigned int, + unsigned int, + unsigned int, + storage_type&) + { + // It can't sort anything because ItemsPerThread is 0. + // The segment will be sorted by the calles (i.e. using ItemsPerThread = 1) + return false; + } +}; + +template +struct WarpSortHelperConfig +{ + static constexpr unsigned int logical_warp_size = LogicalWarpSize; + static constexpr unsigned int items_per_thread = ItemsPerThread; + static constexpr unsigned int block_size = BlockSize; +}; + +struct DisabledWarpSortHelperConfig +{ + static constexpr unsigned int logical_warp_size = 1; + static constexpr unsigned int items_per_thread = 1; + static constexpr unsigned int block_size = 1; +}; + +template +using select_warp_sort_helper_config_small_t + = std::conditional_t::value, + DisabledWarpSortHelperConfig, + WarpSortHelperConfig>; + +template +using select_warp_sort_helper_config_medium_t + = std::conditional_t::value, + DisabledWarpSortHelperConfig, + WarpSortHelperConfig>; + +template< + class Config, + class Key, + class Value, + bool Descending, + class Enable = void +> +struct segmented_warp_sort_helper +{ + static constexpr unsigned int items_per_warp = 0; + using storage_type = ::rocprim::empty_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Args&&...) + { + } +}; + +template +class segmented_warp_sort_helper< + Config, + Key, + Value, + Descending, + std::enable_if_t::value>> +{ + static constexpr unsigned int logical_warp_size = Config::logical_warp_size; + static constexpr unsigned int items_per_thread = Config::items_per_thread; + + using key_type = Key; + using value_type = Value; + using key_codec = ::rocprim::detail::radix_key_codec; + using bit_key_type = typename key_codec::bit_key_type; + + using keys_load_type = ::rocprim::warp_load; + using values_load_type = ::rocprim::warp_load; + using keys_store_type = ::rocprim::warp_store; + using values_store_type = ::rocprim::warp_store; + template + using radix_comparator_type = ::rocprim::detail::radix_merge_compare; + using stable_key_type = ::rocprim::tuple; + using sort_type = ::rocprim::warp_sort; + + static constexpr bool with_values = !std::is_same::value; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + decltype(auto) make_stable_comparator(ComparatorT comparator) + { + return [comparator](const stable_key_type& a, const stable_key_type& b) -> bool + { + const bool ab = comparator(rocprim::get<0>(a), rocprim::get<0>(b)); + const bool ba = comparator(rocprim::get<0>(b), rocprim::get<0>(a)); + return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b))); + }; + } + +public: + static constexpr unsigned int items_per_warp = items_per_thread * logical_warp_size; + + union storage_type + { + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename keys_store_type::storage_type keys_store; + typename values_store_type::storage_type values_store; + typename sort_type::storage_type sort; + }; + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + const unsigned int num_items = end_offset - begin_offset; + const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + + key_type keys[items_per_thread]; + stable_key_type stable_keys[items_per_thread]; + value_type values[items_per_thread]; + keys_load_type().load(keys_input + begin_offset, keys, num_items, out_of_bounds, storage.keys_load); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; i++) + { + ::rocprim::get<0>(stable_keys[i]) = keys[i]; + ::rocprim::get<1>(stable_keys[i]) = + ::rocprim::detail::logical_lane_id() + logical_warp_size * i; + } + + if(with_values) + { + ::rocprim::wave_barrier(); + values_load_type().load(values_input + begin_offset, values, num_items, storage.values_load); + } + + ::rocprim::wave_barrier(); + if(begin_bit == 0 && end_bit == 8 * sizeof(key_type)) + { + sort_type().sort(stable_keys, + values, + storage.sort, + make_stable_comparator(radix_comparator_type{})); + } + else + { + radix_comparator_type comparator(begin_bit, end_bit - begin_bit); + sort_type().sort(stable_keys, values, storage.sort, make_stable_comparator(comparator)); + } + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; i++) + { + keys[i] = ::rocprim::get<0>(stable_keys[i]); + } + ::rocprim::wave_barrier(); + keys_store_type().store(keys_output + begin_offset, keys, num_items, storage.keys_store); + + if(with_values) + { + ::rocprim::wave_barrier(); + values_store_type().store(values_output + begin_offset, values, num_items, storage.values_store); + } + } + + template< + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(KeysInputIterator keys_input, + key_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) + { + if(to_output) + { + sort( + keys_input, keys_output, values_input, values_output, + begin_offset, end_offset, + begin_bit, end_bit, + storage + ); + } + else + { + sort( + keys_input, keys_tmp, values_input, values_tmp, + begin_offset, end_offset, + begin_bit, end_bit, + storage + ); + } + } +}; + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetIterator +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void segmented_sort(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) +{ + constexpr unsigned int long_radix_bits = Config::long_radix_bits; + constexpr unsigned int short_radix_bits = Config::short_radix_bits; + constexpr unsigned int block_size = Config::sort::block_size; + constexpr unsigned int items_per_thread = Config::sort::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + constexpr bool warp_sort_enabled = Config::warp_sort_config::enable_unpartitioned_warp_sort; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using single_block_helper_type = segmented_radix_sort_single_block_helper< + key_type, value_type, + block_size, items_per_thread, + Descending + >; + using long_radix_helper_type = segmented_radix_sort_helper< + key_type, value_type, + ::rocprim::device_warp_size(), block_size, items_per_thread, + long_radix_bits, Descending + >; + using short_radix_helper_type = segmented_radix_sort_helper< + key_type, value_type, + ::rocprim::device_warp_size(), block_size, items_per_thread, + short_radix_bits, Descending + >; + using warp_sort_helper_type = segmented_warp_sort_helper< + select_warp_sort_helper_config_small_t, + key_type, + value_type, + Descending>; + static constexpr unsigned int items_per_warp = warp_sort_helper_type::items_per_warp; + + ROCPRIM_SHARED_MEMORY union + { + typename single_block_helper_type::storage_type single_block_helper; + typename long_radix_helper_type::storage_type long_radix_helper; + typename short_radix_helper_type::storage_type short_radix_helper; + typename warp_sort_helper_type::storage_type warp_sort_helper; + } storage; + + const unsigned int segment_id = ::rocprim::detail::block_id<0>(); + + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + + // Empty segment + if(end_offset <= begin_offset) + { + return; + } + + if(end_offset - begin_offset > items_per_block) + { + // Large segment + unsigned int bit = begin_bit; + for(unsigned int i = 0; i < long_iterations; i++) + { + long_radix_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offset, end_offset, + bit, begin_bit, end_bit, + storage.long_radix_helper + ); + + to_output = !to_output; + bit += long_radix_bits; + } + for(unsigned int i = 0; i < short_iterations; i++) + { + short_radix_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offset, end_offset, + bit, begin_bit, end_bit, + storage.short_radix_helper + ); + + to_output = !to_output; + bit += short_radix_bits; + } + } + else if(!warp_sort_enabled || end_offset - begin_offset > items_per_warp) + { + // Small segment + single_block_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + ((long_iterations + short_iterations) % 2 == 0) != to_output, + begin_offset, end_offset, + begin_bit, end_bit, + storage.single_block_helper + ); + } + else if(::rocprim::flat_block_thread_id() < Config::warp_sort_config::logical_warp_size_small) + { + // Single warp segment + warp_sort_helper_type().sort( + keys_input, keys_tmp, keys_output, + values_input, values_tmp, values_output, + ((long_iterations + short_iterations) % 2 == 0) != to_output, + begin_offset, end_offset, + begin_bit, end_bit, storage.warp_sort_helper + ); + } +} + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class SegmentIndexIterator, + class OffsetIterator +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void segmented_sort_large(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) +{ + constexpr unsigned int long_radix_bits = Config::long_radix_bits; + constexpr unsigned int short_radix_bits = Config::short_radix_bits; + constexpr unsigned int block_size = Config::sort::block_size; + constexpr unsigned int items_per_thread = Config::sort::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using single_block_helper_type = segmented_radix_sort_single_block_helper< + key_type, value_type, + block_size, items_per_thread, + Descending + >; + using long_radix_helper_type = segmented_radix_sort_helper< + key_type, value_type, + ::rocprim::device_warp_size(), block_size, items_per_thread, + long_radix_bits, Descending + >; + using short_radix_helper_type = segmented_radix_sort_helper< + key_type, value_type, + ::rocprim::device_warp_size(), block_size, items_per_thread, + short_radix_bits, Descending + >; + + ROCPRIM_SHARED_MEMORY union + { + typename single_block_helper_type::storage_type single_block_helper; + typename long_radix_helper_type::storage_type long_radix_helper; + typename short_radix_helper_type::storage_type short_radix_helper; + } storage; + + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int segment_id = segment_indices[block_id]; + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + + if(end_offset <= begin_offset) + { + return; + } + + if(end_offset - begin_offset > items_per_block) + { + unsigned int bit = begin_bit; + for(unsigned int i = 0; i < long_iterations; i++) + { + long_radix_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offset, end_offset, + bit, begin_bit, end_bit, + storage.long_radix_helper + ); + + to_output = !to_output; + bit += long_radix_bits; + } + for(unsigned int i = 0; i < short_iterations; i++) + { + short_radix_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offset, end_offset, + bit, begin_bit, end_bit, + storage.short_radix_helper + ); + + to_output = !to_output; + bit += short_radix_bits; + } + } + else + { + single_block_helper_type().sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + ((long_iterations + short_iterations) % 2 == 0) != to_output, + begin_offset, end_offset, + begin_bit, end_bit, + storage.single_block_helper + ); + } +} + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class SegmentIndexIterator, + class OffsetIterator +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void segmented_sort_small(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + static constexpr unsigned int block_size = Config::block_size; + static constexpr unsigned int logical_warp_size = Config::logical_warp_size; + static_assert(block_size % logical_warp_size == 0, "logical_warp_size must be a divisor of block_size"); + static constexpr unsigned int warps_per_block = block_size / logical_warp_size; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using warp_sort_helper_type = segmented_warp_sort_helper< + Config, key_type, value_type, Descending + >; + + ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; + + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int logical_warp_id = ::rocprim::detail::logical_warp_id(); + const unsigned int segment_index = block_id * warps_per_block + logical_warp_id; + if(segment_index >= num_segments) + { + return; + } + + const unsigned int segment_id = segment_indices[segment_index]; + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + if(end_offset <= begin_offset) + { + return; + } + warp_sort_helper_type().sort( + keys_input, keys_tmp, keys_output, + values_input, values_tmp, values_output, + to_output, begin_offset, end_offset, + begin_bit, end_bit, storage + ); +} + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_segmented_reduce.hpp b/3rdparty/cub/rocprim/device/detail/device_segmented_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2656a66cc836bc4ecbdb359a435648cf7f6846c9 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_segmented_reduce.hpp @@ -0,0 +1,166 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +#include "../../block/block_load_func.hpp" +#include "../../block/block_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class Config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class ResultType, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void segmented_reduce(InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op, + ResultType initial_value) +{ + constexpr unsigned int block_size = Config::block_size; + constexpr unsigned int items_per_thread = Config::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + using reduce_type = ::rocprim::block_reduce< + ResultType, block_size, + Config::block_reduce_method + >; + + ROCPRIM_SHARED_MEMORY typename reduce_type::storage_type reduce_storage; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int segment_id = ::rocprim::detail::block_id<0>(); + + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + + // Empty segment + if(end_offset <= begin_offset) + { + if(flat_id == 0) + { + output[segment_id] = initial_value; + } + return; + } + + ResultType result; + unsigned int block_offset = begin_offset; + if(block_offset + items_per_block > end_offset) + { + // Segment is shorter than items_per_block + + // Load the partial block and reduce the current thread's values + const unsigned int valid_count = end_offset - block_offset; + if(flat_id < valid_count) + { + unsigned int offset = block_offset + flat_id; + result = input[offset]; + offset += block_size; + while(offset < end_offset) + { + result = reduce_op(result, static_cast(input[offset])); + offset += block_size; + } + } + + // Reduce threads' reductions to compute the final result + if(valid_count >= block_size) + { + // All threads have at least one value, i.e. result has valid value + reduce_type().reduce(result, result, reduce_storage, reduce_op); + } + else + { + reduce_type().reduce(result, result, valid_count, reduce_storage, reduce_op); + } + } + else + { + // Long segments + + ResultType values[items_per_thread]; + + // Load the first block and reduce the current thread's values + block_load_direct_striped(flat_id, input + block_offset, values); + result = values[0]; + for(unsigned int i = 1; i < items_per_thread; i++) + { + result = reduce_op(result, values[i]); + } + block_offset += items_per_block; + + // Load next full blocks and continue reduction + while(block_offset + items_per_block < end_offset) + { + block_load_direct_striped(flat_id, input + block_offset, values); + for(unsigned int i = 0; i < items_per_thread; i++) + { + result = reduce_op(result, values[i]); + } + block_offset += items_per_block; + } + + // Load the last (probably partial) block and continue reduction + const unsigned int valid_count = end_offset - block_offset; + block_load_direct_striped(flat_id, input + block_offset, values, valid_count); + for(unsigned int i = 0; i < items_per_thread; i++) + { + if(i * block_size + flat_id < valid_count) + { + result = reduce_op(result, values[i]); + } + } + + // Reduce threads' reductions to compute the final result + reduce_type().reduce(result, result, reduce_storage, reduce_op); + } + + if(flat_id == 0) + { + output[segment_id] = reduce_op(initial_value, result); + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp b/3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a76ce001c9904a78e93d9780f12dad86987adb86 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_segmented_scan.hpp @@ -0,0 +1,236 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_SCAN_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_SCAN_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +#include "../../detail/various.hpp" +#include "../../detail/binary_op_wrappers.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" +#include "../../block/block_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + bool Exclusive, + bool UsePrefix, + class BlockScanType, + class T, + unsigned int ItemsPerThread, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto segmented_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& prefix, + typename BlockScanType::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + auto prefix_op = + [&prefix, scan_op](const T& reduction) + { + auto saved_prefix = prefix; + prefix = scan_op(prefix, reduction); + return saved_prefix; + }; + BlockScanType() + .exclusive_scan( + input, output, + storage, prefix_op, scan_op + ); +} + +template< + bool Exclusive, + bool UsePrefix, + class BlockScanType, + class T, + unsigned int ItemsPerThread, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +auto segmented_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T& prefix, + typename BlockScanType::storage_type& storage, + BinaryFunction scan_op) + -> typename std::enable_if::type +{ + if(UsePrefix) + { + auto prefix_op = + [&prefix, scan_op](const T& reduction) + { + auto saved_prefix = prefix; + prefix = scan_op(prefix, reduction); + return saved_prefix; + }; + BlockScanType() + .inclusive_scan( + input, output, + storage, prefix_op, scan_op + ); + return; + } + BlockScanType() + .inclusive_scan( + input, output, prefix, + storage, scan_op + ); +} + +template< + bool Exclusive, + class Config, + class ResultType, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class InitValueType, + class BinaryFunction +> +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void segmented_scan(InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + InitValueType initial_value, + BinaryFunction scan_op) +{ + constexpr auto block_size = Config::block_size; + constexpr auto items_per_thread = Config::items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + using result_type = ResultType; + using block_load_type = ::rocprim::block_load< + result_type, block_size, items_per_thread, + Config::block_load_method + >; + using block_store_type = ::rocprim::block_store< + result_type, block_size, items_per_thread, + Config::block_store_method + >; + using block_scan_type = ::rocprim::block_scan< + result_type, block_size, + Config::block_scan_method + >; + + ROCPRIM_SHARED_MEMORY union + { + typename block_load_type::storage_type load; + typename block_store_type::storage_type store; + typename block_scan_type::storage_type scan; + } storage; + + const unsigned int segment_id = ::rocprim::detail::block_id<0>(); + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + + // Empty segment + if(end_offset <= begin_offset) + { + return; + } + + // Input values + result_type values[items_per_thread]; + result_type prefix = initial_value; + + unsigned int block_offset = begin_offset; + if(block_offset + items_per_block > end_offset) + { + // Segment is shorter than items_per_block + + // Load the partial block + const unsigned int valid_count = end_offset - block_offset; + block_load_type().load(input + block_offset, values, valid_count, storage.load); + ::rocprim::syncthreads(); + // Perform scan operation + segmented_scan_block_scan( + values, values, prefix, storage.scan, scan_op + ); + ::rocprim::syncthreads(); + // Store the partial block + block_store_type().store(output + block_offset, values, valid_count, storage.store); + } + else + { + // Long segments + + // Load the first block of input values + block_load_type().load(input + block_offset, values, storage.load); + ::rocprim::syncthreads(); + // Perform scan operation + segmented_scan_block_scan( + values, values, prefix, storage.scan, scan_op + ); + ::rocprim::syncthreads(); + // Store + block_store_type().store(output + block_offset, values, storage.store); + ::rocprim::syncthreads(); + block_offset += items_per_block; + + // Load next full blocks and continue scanning + while(block_offset + items_per_block < end_offset) + { + block_load_type().load(input + block_offset, values, storage.load); + ::rocprim::syncthreads(); + // Perform scan operation + segmented_scan_block_scan( + values, values, prefix, storage.scan, scan_op + ); + ::rocprim::syncthreads(); + block_store_type().store(output + block_offset, values, storage.store); + ::rocprim::syncthreads(); + block_offset += items_per_block; + } + + // Load the last (probably partial) block and continue scanning + const unsigned int valid_count = end_offset - block_offset; + block_load_type().load(input + block_offset, values, valid_count, storage.load); + ::rocprim::syncthreads(); + // Perform scan operation + segmented_scan_block_scan( + values, values, prefix, storage.scan, scan_op + ); + ::rocprim::syncthreads(); + // Store the partial block + block_store_type().store(output + block_offset, values, valid_count, storage.store); + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/device_transform.hpp b/3rdparty/cub/rocprim/device/detail/device_transform.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e84ee0731e21372b86fd5c796c7df0d4790bda84 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/device_transform.hpp @@ -0,0 +1,154 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_ + +#include +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" +#include "../../detail/match_result_type.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" +#include "../../types.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_store.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Wrapper for unpacking tuple to be used with BinaryFunction. +// See transform function which accepts two input iterators. +template +struct unpack_binary_op +{ + using result_type = typename ::rocprim::detail::invoke_result::type; + + ROCPRIM_HOST_DEVICE inline + unpack_binary_op() = default; + + ROCPRIM_HOST_DEVICE inline + unpack_binary_op(BinaryFunction binary_op) : binary_op_(binary_op) + { + } + + ROCPRIM_HOST_DEVICE inline + ~unpack_binary_op() = default; + + ROCPRIM_HOST_DEVICE inline + result_type operator()(const ::rocprim::tuple& t) + { + return binary_op_(::rocprim::get<0>(t), ::rocprim::get<1>(t)); + } + +private: + BinaryFunction binary_op_; +}; + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class ResultType, + class InputIterator, + class OutputIterator, + class UnaryFunction +> +ROCPRIM_DEVICE ROCPRIM_INLINE +void transform_kernel_impl(InputIterator input, + const size_t input_size, + OutputIterator output, + UnaryFunction transform_op) +{ + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + using result_type = + typename std::conditional< + std::is_void::value, ResultType, output_type + >::type; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); + const unsigned int valid_in_last_block = input_size - block_offset; + + input_type input_values[ItemsPerThread]; + result_type output_values[ItemsPerThread]; + + if(flat_block_id == (number_of_blocks - 1)) // last block + { + block_load_direct_striped( + flat_id, + input + block_offset, + input_values, + valid_in_last_block + ); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(BlockSize * i + flat_id < valid_in_last_block) + { + output_values[i] = transform_op(input_values[i]); + } + } + + block_store_direct_striped( + flat_id, + output + block_offset, + output_values, + valid_in_last_block + ); + } + else + { + block_load_direct_striped( + flat_id, + input + block_offset, + input_values + ); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output_values[i] = transform_op(input_values[i]); + } + + block_store_direct_striped( + flat_id, + output + block_offset, + output_values + ); + } +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp b/3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp new file mode 100644 index 0000000000000000000000000000000000000000..126575a30a3f43eb573df828daac5500d67b6b36 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/lookback_scan_state.hpp @@ -0,0 +1,459 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_ + +#include + +#include "../../intrinsics.hpp" +#include "../../types.hpp" +#include "../../type_traits.hpp" + +#include "../../warp/detail/warp_reduce_crosslane.hpp" +#include "../../warp/detail/warp_scan_crosslane.hpp" + +#include "../../detail/various.hpp" +#include "../../detail/binary_op_wrappers.hpp" + +extern "C" +{ + void __builtin_amdgcn_s_sleep(int); +} +BEGIN_ROCPRIM_NAMESPACE + +// Single pass prefix scan was implemented based on: +// Merrill, D. and Garland, M. Single-pass Parallel Prefix Scan with Decoupled Look-back. +// Technical Report NVR2016-001, NVIDIA Research. Mar. 2016. + +namespace detail +{ + +enum prefix_flag +{ + // flag for padding, values should be discarded + PREFIX_INVALID = -1, + // initialized, not result in value + PREFIX_EMPTY = 0, + // partial prefix value (from single block) + PREFIX_PARTIAL = 1, + // final prefix value + PREFIX_COMPLETE = 2 +}; + +// lookback_scan_state object keeps track of prefixes status for +// a look-back prefix scan. Initially every prefix can be either +// invalid (padding values) or empty. One thread in a block should +// later set it to partial, and later to complete. +template +struct lookback_scan_state; + +// Packed flag and prefix value are loaded/stored in one atomic operation. +template +struct lookback_scan_state +{ +private: + using flag_type_ = char; + + // Type which is used in store/load operations of block prefix (flag and value). + // It is 32-bit or 64-bit int and can be loaded/stored using single atomic instruction. + using prefix_underlying_type = + typename std::conditional< + (sizeof(T) > 2), + unsigned long long, + unsigned int + >::type; + + // Helper struct + struct alignas(sizeof(prefix_underlying_type)) prefix_type + { + flag_type_ flag; + T value; + }; + + static_assert(sizeof(prefix_underlying_type) == sizeof(prefix_type), ""); + +public: + // Type used for flag/flag of block prefix + using flag_type = flag_type_; + using value_type = T; + + // temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes + ROCPRIM_HOST static inline + lookback_scan_state create(void* temp_storage, const unsigned int number_of_blocks) + { + (void) number_of_blocks; + lookback_scan_state state; + state.prefixes = reinterpret_cast(temp_storage); + return state; + } + + ROCPRIM_HOST static inline + size_t get_storage_size(const unsigned int number_of_blocks) + { + return sizeof(prefix_underlying_type) * (::rocprim::host_warp_size() + number_of_blocks); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void initialize_prefix(const unsigned int block_id, + const unsigned int number_of_blocks) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + if(block_id < number_of_blocks) + { + prefix_type prefix; + prefix.flag = PREFIX_EMPTY; + prefix_underlying_type p; +#ifndef __HIP_CPU_RT__ + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); +#else + std::memcpy(&p, &prefix, sizeof(prefix_type)); +#endif + prefixes[padding + block_id] = p; + } + if(block_id < padding) + { + prefix_type prefix; + prefix.flag = PREFIX_INVALID; + prefix_underlying_type p; +#ifndef __HIP_CPU_RT__ + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); +#else + std::memcpy(&p, &prefix, sizeof(prefix_type)); +#endif + prefixes[block_id] = p; + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_partial(const unsigned int block_id, const T value) + { + this->set(block_id, PREFIX_PARTIAL, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_complete(const unsigned int block_id, const T value) + { + this->set(block_id, PREFIX_COMPLETE, value); + } + + // block_id must be > 0 + ROCPRIM_DEVICE ROCPRIM_INLINE + void get(const unsigned int block_id, flag_type& flag, T& value) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + prefix_type prefix; + + const unsigned int SLEEP_MAX = 32; + unsigned int times_through = 1; + + prefix_underlying_type p = ::rocprim::detail::atomic_add(&prefixes[padding + block_id], 0); +#ifndef __HIP_CPU_RT__ + __builtin_memcpy(&prefix, &p, sizeof(prefix_type)); +#else + std::memcpy(&prefix, &p, sizeof(prefix_type)); +#endif + while(prefix.flag == PREFIX_EMPTY) + { + if (UseSleep) + { + for (unsigned int j = 0; j < times_through; j++) +#ifndef __HIP_CPU_RT__ + __builtin_amdgcn_s_sleep(1); +#else + std::this_thread::sleep_for(std::chrono::microseconds{1}); +#endif + if (times_through < SLEEP_MAX) + times_through++; + } + // atomic_add(..., 0) is used to load values atomically + prefix_underlying_type p = ::rocprim::detail::atomic_add(&prefixes[padding + block_id], 0); +#ifndef __HIP_CPU_RT__ + __builtin_memcpy(&prefix, &p, sizeof(prefix_type)); +#else + std::memcpy(&prefix, &p, sizeof(prefix_type)); +#endif + } + + // return + flag = prefix.flag; + value = prefix.value; + } + +private: + ROCPRIM_DEVICE ROCPRIM_INLINE + void set(const unsigned int block_id, const flag_type flag, const T value) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + prefix_type prefix = { flag, value }; + prefix_underlying_type p; +#ifndef __HIP_CPU_RT__ + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); +#else + std::memcpy(&p, &prefix, sizeof(prefix_type)); +#endif + ::rocprim::detail::atomic_exch(&prefixes[padding + block_id], p); + } + + prefix_underlying_type * prefixes; +}; + +// Flag, partial and final prefixes are stored in separate arrays. +// Consistency ensured by memory fences between flag and prefixes load/store operations. +template +struct lookback_scan_state +{ + +public: + using flag_type = char; + using value_type = T; + + // temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes + ROCPRIM_HOST static inline + lookback_scan_state create(void* temp_storage, const unsigned int number_of_blocks) + { + const auto n = ::rocprim::host_warp_size() + number_of_blocks; + lookback_scan_state state; + + auto ptr = static_cast(temp_storage); + + state.prefixes_flags = reinterpret_cast(ptr); + ptr += ::rocprim::detail::align_size(n * sizeof(flag_type)); + + state.prefixes_partial_values = reinterpret_cast(ptr); + ptr += ::rocprim::detail::align_size(n * sizeof(T)); + + state.prefixes_complete_values = reinterpret_cast(ptr); + return state; + } + + ROCPRIM_HOST static inline + size_t get_storage_size(const unsigned int number_of_blocks) + { + const auto n = ::rocprim::host_warp_size() + number_of_blocks; + size_t size = ::rocprim::detail::align_size(n * sizeof(flag_type)); + size += 2 * ::rocprim::detail::align_size(n * sizeof(T)); + return size; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void initialize_prefix(const unsigned int block_id, + const unsigned int number_of_blocks) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + if(block_id < number_of_blocks) + { + prefixes_flags[padding + block_id] = PREFIX_EMPTY; + } + if(block_id < padding) + { + prefixes_flags[block_id] = PREFIX_INVALID; + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_partial(const unsigned int block_id, const T value) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + store_volatile(&prefixes_partial_values[padding + block_id], value); + ::rocprim::detail::memory_fence_device(); + store_volatile(&prefixes_flags[padding + block_id], PREFIX_PARTIAL); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_complete(const unsigned int block_id, const T value) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + store_volatile(&prefixes_complete_values[padding + block_id], value); + ::rocprim::detail::memory_fence_device(); + store_volatile(&prefixes_flags[padding + block_id], PREFIX_COMPLETE); + } + + // block_id must be > 0 + ROCPRIM_DEVICE ROCPRIM_INLINE + void get(const unsigned int block_id, flag_type& flag, T& value) + { + constexpr unsigned int padding = ::rocprim::device_warp_size(); + + const unsigned int SLEEP_MAX = 32; + unsigned int times_through = 1; + + flag = load_volatile(&prefixes_flags[padding + block_id]); + ::rocprim::detail::memory_fence_device(); + while(flag == PREFIX_EMPTY) + { + if (UseSleep) + { + for (unsigned int j = 0; j < times_through; j++) +#ifndef __HIP_CPU_RT__ + __builtin_amdgcn_s_sleep(1); +#else + std::this_thread::sleep_for(std::chrono::microseconds{1}); +#endif + if (times_through < SLEEP_MAX) + times_through++; + } + + flag = load_volatile(&prefixes_flags[padding + block_id]); + ::rocprim::detail::memory_fence_device(); + } + + if(flag == PREFIX_PARTIAL) + value = load_volatile(&prefixes_partial_values[padding + block_id]); + else + value = load_volatile(&prefixes_complete_values[padding + block_id]); + } + +private: + flag_type * prefixes_flags; + // We need to separate arrays for partial and final prefixes, because + // value can be overwritten before flag is changed (flag and value are + // not stored in single instruction). + T * prefixes_partial_values; + T * prefixes_complete_values; +}; + +template +class lookback_scan_prefix_op +{ + using flag_type = typename LookbackScanState::flag_type; + static_assert( + std::is_same::value, + "T must be LookbackScanState::value_type" + ); + +public: + ROCPRIM_DEVICE ROCPRIM_INLINE + lookback_scan_prefix_op(unsigned int block_id, + BinaryFunction scan_op, + LookbackScanState &scan_state) + : block_id_(block_id), + scan_op_(scan_op), + scan_state_(scan_state) + { + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + ~lookback_scan_prefix_op() = default; + + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce_partial_prefixes(unsigned int block_id, + flag_type& flag, + T& partial_prefix) + { + // Order of reduction must be reversed, because 0th thread has + // prefix from the (block_id_ - 1) block, 1st thread has prefix + // from (block_id_ - 2) block etc. + using headflag_scan_op_type = reverse_binary_op_wrapper< + BinaryFunction, T, T + >; + using warp_reduce_prefix_type = warp_reduce_crosslane< + T, ::rocprim::device_warp_size(), false + >; + + T block_prefix; + scan_state_.get(block_id, flag, block_prefix); + + auto headflag_scan_op = headflag_scan_op_type(scan_op_); + warp_reduce_prefix_type() + .tail_segmented_reduce( + block_prefix, + partial_prefix, + (flag == PREFIX_COMPLETE), + headflag_scan_op + ); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_prefix() + { + flag_type flag; + T partial_prefix; + unsigned int previous_block_id = block_id_ - ::rocprim::lane_id() - 1; + + // reduce last warp_size() number of prefixes to + // get the complete prefix for this block. + reduce_partial_prefixes(previous_block_id, flag, partial_prefix); + T prefix = partial_prefix; + + // while we don't load a complete prefix, reduce partial prefixes + while(::rocprim::detail::warp_all(flag != PREFIX_COMPLETE)) + { + previous_block_id -= ::rocprim::device_warp_size(); + reduce_partial_prefixes(previous_block_id, flag, partial_prefix); + prefix = scan_op_(partial_prefix, prefix); + } + return prefix; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T operator()(T reduction) + { + // Set partial prefix for next block + if(::rocprim::lane_id() == 0) + { + scan_state_.set_partial(block_id_, reduction); + } + + // Get prefix + auto prefix = get_prefix(); + + // Set complete prefix for next block + if(::rocprim::lane_id() == 0) + { + scan_state_.set_complete(block_id_, scan_op_(prefix, reduction)); + } + return prefix; + } + +protected: + unsigned int block_id_; + BinaryFunction scan_op_; + LookbackScanState& scan_state_; +}; + +inline cudaError_t is_sleep_scan_state_used(bool& use_sleep) +{ + cudaDeviceProp prop; + int deviceId; + if(const cudaError_t error = cudaGetDevice(&deviceId)) + { + return error; + } + else if(const cudaError_t error = cudaGetDeviceProperties(&prop, deviceId)) + { + return error; + } + + const int asicRevision = 0; + use_sleep = 0; + return cudaSuccess; +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_LOOKBACK_SCAN_STATE_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp b/3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60b95b7beca1456cf072f518511831002bba5c95 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/ordered_block_id.hpp @@ -0,0 +1,87 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_ +#define ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_ + +#include +#include + +#include "../../detail/various.hpp" +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Helper struct for generating ordered unique ids for blocks in a grid. +template +struct ordered_block_id +{ + static_assert(std::is_integral::value, "T must be integer"); + using id_type = T; + + // shared memory temporary storage type + struct storage_type + { + id_type id; + }; + + ROCPRIM_HOST static inline + ordered_block_id create(id_type * id) + { + ordered_block_id ordered_id; + ordered_id.id = id; + return ordered_id; + } + + ROCPRIM_HOST static inline + size_t get_storage_size() + { + return sizeof(id_type); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void reset() + { + *id = static_cast(0); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + id_type get(unsigned int tid, storage_type& storage) + { + if(tid == 0) + { + storage.id = ::rocprim::detail::atomic_add(this->id, 1); + } + ::rocprim::syncthreads(); + return storage.id; + } + + id_type* id; +}; + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_ diff --git a/3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp b/3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0c33a120d2592d9c4010c44e69cbf3d0ebf0a4a1 --- /dev/null +++ b/3rdparty/cub/rocprim/device/detail/uint_fast_div.hpp @@ -0,0 +1,106 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_ +#define ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_ + +#include "../../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +struct uint_fast_div +{ + unsigned int magic; // Magic number + unsigned int shift; // shift amount + unsigned int add; // "add" indicator + + ROCPRIM_HOST_DEVICE inline + uint_fast_div() = default; + + ROCPRIM_HOST_DEVICE inline + uint_fast_div(unsigned int d) + { + // Must have 1 <= d <= 2**32-1. + + if(d == 1) + { + magic = 0; + shift = 0; + add = 0; + return; + } + + int p; + unsigned int p32 = 1, q, r, delta; + add = 0; // Initialize "add" indicator. + p = 31; // Initialize p. + q = 0x7FFFFFFF/d; // Initialize q = (2**p - 1)/d. + r = 0x7FFFFFFF - q*d; // Init. r = rem(2**p - 1, d). + do { + p = p + 1; + if(p == 32) p32 = 1; // Set p32 = 2**(p-32). + else p32 = 2*p32; + if(r + 1 >= d - r) + { + if(q >= 0x7FFFFFFF) add = 1; + q = 2*q + 1; + r = 2*r + 1 - d; + } + else + { + if(q >= 0x80000000) add = 1; + q = 2*q; + r = 2*r + 1; + } + delta = d - 1 - r; + } while (p < 64 && p32 < delta); + magic = q + 1; // Magic number and + shift = p - 32; // shift amount + + if(add) shift--; + } +}; + +ROCPRIM_HOST_DEVICE inline +unsigned int operator/(unsigned int n, const uint_fast_div& divisor) +{ + if(divisor.magic == 0) + { + // Special case for 1 + return n; + } + + // Higher 32-bit of 64-bit multiplication + unsigned int q = (static_cast(divisor.magic) * static_cast(n)) >> 32; + if(divisor.add) + { + q = ((n - q) >> 1) + q; + } + return q >> divisor.shift; +} + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_UINT_FAST_DIV_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_adjacent_difference.hpp b/3rdparty/cub/rocprim/device/device_adjacent_difference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0202a2e1988f1f87cdb9c1641dc2744ef1fbbaf7 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_adjacent_difference.hpp @@ -0,0 +1,523 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ +#define ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ + +#include "detail/device_adjacent_difference.hpp" + +#include "device_adjacent_difference_config.hpp" + +#include "config_types.hpp" +#include "device_transform.hpp" + +#include "../config.hpp" +#include "../functional.hpp" + +#include "../detail/various.hpp" +#include "../iterator/counting_iterator.hpp" +#include "../iterator/transform_iterator.hpp" + +#include + +#include +#include +#include + +#include + +/// \file +/// +/// Device level adjacent_difference parallel primitives + +BEGIN_ROCPRIM_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) \ + return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) \ + return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +namespace detail +{ +template +void ROCPRIM_KERNEL __launch_bounds__(Config::block_size) adjacent_difference_kernel( + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op, + const typename std::iterator_traits::value_type* previous_values, + const std::size_t starting_block) +{ + adjacent_difference_kernel_impl( + input, output, size, op, previous_values, starting_block); +} + +template +cudaError_t adjacent_difference_impl(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op, + const cudaStream_t stream, + const bool debug_synchronous) +{ + using value_type = typename std::iterator_traits::value_type; + + using config = detail::default_or_custom_config< + Config, + detail::default_adjacent_difference_config>; + + static constexpr unsigned int block_size = config::block_size; + static constexpr unsigned int items_per_thread = config::items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + + const std::size_t num_blocks = ceiling_div(size, items_per_block); + + if(temporary_storage == nullptr) + { + if(InPlace && num_blocks >= 2) + { + storage_size = align_size((num_blocks - 1) * sizeof(value_type)); + } + else + { + // Make sure user won't try to allocate 0 bytes memory, because + // cudaMalloc will return nullptr when size is zero. + storage_size = 4; + } + + return cudaSuccess; + } + + if(num_blocks == 0) + { + return cudaSuccess; + } + + // Copy values before they are overwritten to use as tile predecessors/successors + // this is not dereferenced when the operation is not in place + auto* const previous_values = static_cast(temporary_storage); + if ROCPRIM_IF_CONSTEXPR(InPlace) + { + // If doing left adjacent diff then the last item of each block is needed for the + // next block, otherwise the first item is needed for the previous block + static constexpr auto offset = items_per_block - (Right ? 0 : 1); + + const auto block_starts_iter = make_transform_iterator( + rocprim::make_counting_iterator(std::size_t {0}), + [base = input + offset](std::size_t i) { return base[i * items_per_block]; }); + + const cudaError_t error = ::rocprim::transform(block_starts_iter, + previous_values, + num_blocks - 1, + rocprim::identity<> {}, + stream, + debug_synchronous); + if(error != cudaSuccess) + { + return error; + } + } + + static constexpr unsigned int size_limit = config::size_limit; + static constexpr auto number_of_blocks_limit = std::max(size_limit / items_per_block, 1u); + static constexpr auto aligned_size_limit = number_of_blocks_limit * items_per_block; + + // Launch number_of_blocks_limit blocks while there is still at least as many blocks + // left as the limit + const auto number_of_launch = ceiling_div(size, aligned_size_limit); + + if(debug_synchronous) + { + std::cout << "----------------------------------\n"; + std::cout << "size: " << size << '\n'; + std::cout << "aligned_size_limit: " << aligned_size_limit << '\n'; + std::cout << "number_of_launch: " << number_of_launch << '\n'; + std::cout << "block_size: " << block_size << '\n'; + std::cout << "items_per_block: " << items_per_block << '\n'; + std::cout << "----------------------------------\n"; + } + + for(std::size_t i = 0, offset = 0; i < number_of_launch; ++i, offset += aligned_size_limit) + { + const auto current_size + = static_cast(std::min(size - offset, aligned_size_limit)); + const auto current_blocks = ceiling_div(current_size, items_per_block); + const auto starting_block = i * number_of_blocks_limit; + + std::chrono::time_point start; + if(debug_synchronous) + { + std::cout << "index: " << i << '\n'; + std::cout << "current_size: " << current_size << '\n'; + std::cout << "number of blocks: " << current_blocks << '\n'; + + start = std::chrono::high_resolution_clock::now(); + } + adjacent_difference_kernel<<>>( + input + offset, + output + offset, + size, + op, + previous_values + starting_block, + starting_block); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( + "adjacent_difference_kernel", current_size, start); + } + return cudaSuccess; +} +} // namespace detail + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/// \addtogroup devicemodule +/// @{ + +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the left item. +/// +/// Copies the first item to the output then performs calls the supplied operator with each pair +/// of neighboring elements and writes its result to the location of the second element. +/// Equivalent to the following code +/// \code{.cpp} +/// output[0] = input[0]; +/// for(std::size_t int i = 1; i < size; ++i) +/// { +/// output[i] = op(input[i], input[i - 1]); +/// } +/// \endcode +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// `adjacent_difference_config` or a class with the same members. +/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size - reference to a size (in bytes) of `temporary_storage` +/// \param input - iterator to the input range +/// \param output - iterator to the output range, must have any overlap with input +/// \param size - number of items in the input +/// \param op - [optional] the binary operation to apply +/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of +/// type `cudaError_t` +/// +/// \par Example +/// \parblock +/// In this example a device-level adjacent_difference operation is performed on integer values. +/// +/// \code{.cpp} +/// #include //or +/// +/// // custom binary function +/// auto binary_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a - b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// std::size_t size; // e.g., 8 +/// int* input1; // e.g., [8, 7, 6, 5, 4, 3, 2, 1] +/// int* output; // empty array of 8 elements +/// +/// std::size_t temporary_storage_size_bytes; +/// void* temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::adjacent_difference( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, size, binary_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform adjacent difference +/// rocprim::adjacent_difference( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, size, binary_op +/// ); +/// // output: [8, 1, 1, 1, 1, 1, 1, 1] +/// \endcode +/// \endparblock +template > +cudaError_t adjacent_difference(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction {}, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = false; + static constexpr bool right = false; + return detail::adjacent_difference_impl( + temporary_storage, storage_size, input, output, size, op, stream, debug_synchronous); +} + +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the left item in place. +/// +/// Copies the first item to the output then performs calls the supplied operator with each pair +/// of neighboring elements and writes its result to the location of the second element. +/// Equivalent to the following code +/// \code{.cpp} +/// for(std::size_t int i = size - 1; i > 0; --i) +/// { +/// input[i] = op(input[i], input[i - 1]); +/// } +/// \endcode +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// `adjacent_difference_config` or a class with the same members. +/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size - reference to a size (in bytes) of `temporary_storage` +/// \param values - iterator to the range values, will be overwritten with the results +/// \param size - number of items in the input +/// \param op - [optional] the binary operation to apply +/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of +/// type `cudaError_t` +template > +cudaError_t adjacent_difference_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt values, + const std::size_t size, + const BinaryFunction op = BinaryFunction {}, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = false; + return detail::adjacent_difference_impl( + temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); +} + +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the right item. +/// +/// Copies the last item to the output then performs calls the supplied operator with each pair +/// of neighboring elements and writes its result to the location of the first element. +/// Equivalent to the following code +/// \code{.cpp} +/// output[size - 1] = input[size - 1]; +/// for(std::size_t int i = 0; i < size - 1; ++i) +/// { +/// output[i] = op(input[i], input[i + 1]); +/// } +/// \endcode +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// `adjacent_difference_config` or a class with the same members. +/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size - reference to a size (in bytes) of `temporary_storage` +/// \param input - iterator to the input range +/// \param output - iterator to the output range, must have any overlap with input +/// \param size - number of items in the input +/// \param op - [optional] the binary operation to apply +/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of +/// type `cudaError_t` +/// +/// \par Example +/// \parblock +/// In this example a device-level adjacent_difference operation is performed on integer values. +/// +/// \code{.cpp} +/// #include //or +/// +/// // custom binary function +/// auto binary_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a - b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// std::size_t size; // e.g., 8 +/// int* input1; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* output; // empty array of 8 elements +/// +/// std::size_t temporary_storage_size_bytes; +/// void* temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::adjacent_difference_right( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, size, binary_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform adjacent difference +/// rocprim::adjacent_difference_right( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, size, binary_op +/// ); +/// // output: [1, 1, 1, 1, 1, 1, 1, 8] +/// \endcode +/// \endparblock +template > +cudaError_t adjacent_difference_right(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction {}, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = false; + static constexpr bool right = true; + return detail::adjacent_difference_impl( + temporary_storage, storage_size, input, output, size, op, stream, debug_synchronous); +} + +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the right item in place. +/// +/// Copies the last item to the output then performs calls the supplied operator with each pair +/// of neighboring elements and writes its result to the location of the first element. +/// Equivalent to the following code +/// \code{.cpp} +/// for(std::size_t int i = 0; i < size - 1; --i) +/// { +/// input[i] = op(input[i], input[i + 1]); +/// } +/// \endcode +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// `adjacent_difference_config` or a class with the same members. +/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size - reference to a size (in bytes) of `temporary_storage` +/// \param values - iterator to the range values, will be overwritten with the results +/// \param size - number of items in the input +/// \param op - [optional] the binary operation to apply +/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `cudaSuccess` (0) after successful scan, otherwise the HIP runtime error of +/// type `cudaError_t` +template > +cudaError_t adjacent_difference_right_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt values, + const std::size_t size, + const BinaryFunction op = BinaryFunction {}, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = true; + return detail::adjacent_difference_impl( + temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_adjacent_difference_config.hpp b/3rdparty/cub/rocprim/device/device_adjacent_difference_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..045e872032cbfab33ce0425e63d8e78fa20b26c9 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_adjacent_difference_config.hpp @@ -0,0 +1,84 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../functional.hpp" + +#include "config_types.hpp" + +#include "../block/block_load.hpp" +#include "../block/block_store.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level adjacent_difference primitives. +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread +/// \tparam LoadMethod - method for loading input values +/// \tparam StoreMethod - method for storing values +/// \tparam SizeLimit - limit on the number of items for a single adjacent_difference kernel launch. +/// Larger input sizes will be broken up to multiple kernel launches. +template +struct adjacent_difference_config : kernel_config +{ + static constexpr block_load_method load_method = LoadMethod; + static constexpr block_store_method store_method = StoreMethod; +}; + +namespace detail +{ + +template +struct adjacent_difference_config_fallback +{ + static constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = adjacent_difference_config<256, ::rocprim::max(1u, 16u / item_scale)>; +}; + +template +struct default_adjacent_difference_config + : select_arch> +{ +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_DIFFERENCE_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_binary_search.hpp b/3rdparty/cub/rocprim/device/device_binary_search.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7f4a5211be534d35cd3d61003e1a4c6354559f23 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_binary_search.hpp @@ -0,0 +1,177 @@ +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_ +#define ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "detail/device_binary_search.hpp" + +#include "device_transform.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + class Config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class SearchFunction, + class CompareFunction +> +inline +cudaError_t binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + SearchFunction search_op, + CompareFunction compare_op, + cudaStream_t stream, + bool debug_synchronous) +{ + using value_type = typename std::iterator_traits::value_type; + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, otherwise + // user may again pass nullptr as temporary_storage + storage_size = 4; + return cudaSuccess; + } + + return transform( + needles, output, + needles_size, + [haystack, haystack_size, search_op, compare_op] + ROCPRIM_DEVICE + (const value_type& value) + { + return search_op(haystack, haystack_size, value, compare_op); + }, + stream, debug_synchronous + ); +} + +} // end of detail namespace + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +cudaError_t lower_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::lower_bound_search_op(), compare_op, + stream, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +cudaError_t upper_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::upper_bound_search_op(), compare_op, + stream, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +cudaError_t binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::binary_search_op(), compare_op, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_histogram.hpp b/3rdparty/cub/rocprim/device/device_histogram.hpp new file mode 100644 index 0000000000000000000000000000000000000000..49a36ab0d9e6fa5e044b520375e18731db5bf8ee --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_histogram.hpp @@ -0,0 +1,1208 @@ +// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ +#define ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ + +#include +#include +#include +#include +#include + +#include "../config.hpp" +#include "../functional.hpp" +#include "../detail/various.hpp" + +#include "device_histogram_config.hpp" +#include "detail/device_histogram.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ActiveChannels, + class Counter +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void init_histogram_kernel(fixed_array histogram, + fixed_array bins) +{ + init_histogram(histogram, bins); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + unsigned int ActiveChannels, + class SampleIterator, + class Counter, + class SampleToBinOp +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void histogram_shared_kernel(SampleIterator samples, + unsigned int columns, + unsigned int rows, + unsigned int row_stride, + unsigned int rows_per_block, + fixed_array histogram, + fixed_array sample_to_bin_op, + fixed_array bins) +{ + HIP_DYNAMIC_SHARED(unsigned int, block_histogram); + + histogram_shared( + samples, columns, rows, row_stride, rows_per_block, + histogram, + sample_to_bin_op, bins, + block_histogram + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Channels, + unsigned int ActiveChannels, + class SampleIterator, + class Counter, + class SampleToBinOp +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void histogram_global_kernel(SampleIterator samples, + unsigned int columns, + unsigned int row_stride, + fixed_array histogram, + fixed_array sample_to_bin_op, + fixed_array bins_bits) +{ + histogram_global( + samples, columns, row_stride, + histogram, + sample_to_bin_op, bins_bits + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config, + class SampleIterator, + class Counter, + class SampleToBinOp +> +inline +cudaError_t histogram_impl(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + SampleToBinOp sample_to_bin_op[ActiveChannels], + cudaStream_t stream, + bool debug_synchronous) +{ + using sample_type = typename std::iterator_traits::value_type; + + using config = default_or_custom_config< + Config, + default_histogram_config + >; + + static constexpr unsigned int block_size = config::histogram::block_size; + static constexpr unsigned int items_per_thread = config::histogram::items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + + if(row_stride_bytes % sizeof(sample_type) != 0) + { + // Row stride must be a whole multiple of the sample data type size + return cudaErrorInvalidValue; + } + + const unsigned int blocks_x = ::rocprim::detail::ceiling_div(columns, items_per_block); + const unsigned int row_stride = row_stride_bytes / sizeof(sample_type); + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, because + // cudaMalloc will return nullptr. + storage_size = 4; + return cudaSuccess; + } + + if(debug_synchronous) + { + std::cout << "columns " << columns << '\n'; + std::cout << "rows " << rows << '\n'; + std::cout << "blocks_x " << blocks_x << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + unsigned int bins[ActiveChannels]; + unsigned int bins_bits[ActiveChannels]; + unsigned int total_bins = 0; + unsigned int max_bins = 0; + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + bins[channel] = levels[channel] - 1; + bins_bits[channel] = static_cast(std::log2(detail::next_power_of_two(bins[channel]))); + total_bins += bins[channel]; + max_bins = std::max(max_bins, bins[channel]); + } + + std::chrono::high_resolution_clock::time_point start; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + init_histogram_kernel<<>>( + fixed_array(histogram), + fixed_array(bins) + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_histogram", max_bins, start); + + if(columns == 0 || rows == 0) + { + return cudaSuccess; + } + + if(total_bins <= config::shared_impl_max_bins) + { + dim3 grid_size; + grid_size.x = std::min(config::max_grid_size, blocks_x); + grid_size.y = std::min(rows, config::max_grid_size / grid_size.x); + const size_t block_histogram_bytes = total_bins * sizeof(unsigned int); + const unsigned int rows_per_block = ::rocprim::detail::ceiling_div(rows, grid_size.y); + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + histogram_shared_kernel< + block_size, items_per_thread, Channels, ActiveChannels + > + <<>>( + samples, columns, rows, row_stride, rows_per_block, + fixed_array(histogram), + fixed_array(sample_to_bin_op), + fixed_array(bins) + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_shared", grid_size.x * grid_size.y * block_size, start); + } + else + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + histogram_global_kernel< + block_size, items_per_thread, Channels, ActiveChannels + > + <<>>( + samples, columns, row_stride, + fixed_array(histogram), + fixed_array(sample_to_bin_op), + fixed_array(bins_bits) + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_global", blocks_x * block_size * rows, start); + } + + return cudaSuccess; +} + +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_even_impl(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level lower_level[ActiveChannels], + Level upper_level[ActiveChannels], + cudaStream_t stream, + bool debug_synchronous) +{ + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + if(levels[channel] < 2) + { + // Histogram must have at least 1 bin + return cudaErrorInvalidValue; + } + } + + sample_to_bin_even sample_to_bin_op[ActiveChannels]; + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + sample_to_bin_op[channel] = sample_to_bin_even( + levels[channel] - 1, + lower_level[channel], upper_level[channel] + ); + } + + return histogram_impl( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram, + levels, sample_to_bin_op, + stream, debug_synchronous + ); +} + +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_range_impl(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level * level_values[ActiveChannels], + cudaStream_t stream, + bool debug_synchronous) +{ + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + if(levels[channel] < 2) + { + // Histogram must have at least 1 bin + return cudaErrorInvalidValue; + } + } + + sample_to_bin_range sample_to_bin_op[ActiveChannels]; + for(unsigned int channel = 0; channel < ActiveChannels; channel++) + { + sample_to_bin_op[channel] = sample_to_bin_range( + levels[channel] - 1, + level_values[channel] + ); + } + + return histogram_impl( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram, + levels, sample_to_bin_op, + stream, debug_synchronous + ); +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end of detail namespace + +/// \brief Computes a histogram from a sequence of samples using equal-width bins. +/// +/// \par +/// * The number of histogram bins is (\p levels - 1). +/// * Bins are evenly-segmented and include the same width of sample values: +/// (\p upper_level - \p lower_level) / (\p levels - 1). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] size - number of elements in the samples range. +/// \param [out] histogram - pointer to the first element in the histogram range. +/// \param [in] levels - number of boundaries (levels) for histogram bins. +/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin. +/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level histogram of 5 bins is computed on an array of float samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int size; // e.g., 8 +/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1] +/// int * histogram; // empty array of at least 5 elements +/// unsigned int levels; // e.g., 6 (for 5 bins) +/// float lower_level; // e.g., 0.0 +/// float upper_level; // e.g., 10.0 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::histogram_even( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, lower_level, upper_level +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histogram +/// rocprim::histogram_even( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, lower_level, upper_level +/// ); +/// // histogram: [3, 0, 1, 0, 2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_even(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int size, + Counter * histogram, + unsigned int levels, + Level lower_level, + Level upper_level, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + Counter * histogram_single[1] = { histogram }; + unsigned int levels_single[1] = { levels }; + Level lower_level_single[1] = { lower_level }; + Level upper_level_single[1] = { upper_level }; + + return detail::histogram_even_impl<1, 1, Config>( + temporary_storage, storage_size, + samples, size, 1, 0, + histogram_single, + levels_single, lower_level_single, upper_level_single, + stream, debug_synchronous + ); +} + +/// \brief Computes a histogram from a two-dimensional region of samples using equal-width bins. +/// +/// \par +/// * The two-dimensional region of interest within \p samples can be specified using the \p columns, +/// \p rows and \p row_stride_bytes parameters. +/// * The row stride must be a whole multiple of the sample data type size, +/// i.e., (row_stride_bytes % sizeof(std::iterator_traits::value_type)) == 0. +/// * The number of histogram bins is (\p levels - 1). +/// * Bins are evenly-segmented and include the same width of sample values: +/// (\p upper_level - \p lower_level) / (\p levels - 1). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] columns - number of elements in each row of the region. +/// \param [in] rows - number of rows of the region. +/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region. +/// \param [out] histogram - pointer to the first element in the histogram range. +/// \param [in] levels - number of boundaries (levels) for histogram bins. +/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin. +/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level histogram of 5 bins is computed on an array of float samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int columns; // e.g., 4 +/// unsigned int rows; // e.g., 2 +/// size_t row_stride_bytes; // e.g., 6 * sizeof(float) +/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, -, -, 1.5, 1.9, 100.0, 5.1, -, -] +/// int * histogram; // empty array of at least 5 elements +/// unsigned int levels; // e.g., 6 (for 5 bins) +/// float lower_level; // e.g., 0.0 +/// float upper_level; // e.g., 10.0 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::histogram_even( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, lower_level, upper_level +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histogram +/// rocprim::histogram_even( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, lower_level, upper_level +/// ); +/// // histogram: [3, 0, 1, 0, 2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_even(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram, + unsigned int levels, + Level lower_level, + Level upper_level, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + Counter * histogram_single[1] = { histogram }; + unsigned int levels_single[1] = { levels }; + Level lower_level_single[1] = { lower_level }; + Level upper_level_single[1] = { upper_level }; + + return detail::histogram_even_impl<1, 1, Config>( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram_single, + levels_single, lower_level_single, upper_level_single, + stream, debug_synchronous + ); +} + +/// \brief Computes histograms from a sequence of multi-channel samples using equal-width bins. +/// +/// \par +/// * The input is a sequence of pixel structures, where each pixel comprises +/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for RGBA samples). +/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms +/// (e.g., \p ActiveChannels = 3 for computing histograms of only RGB from RGBA samples). +/// * For channeli the number of histogram bins is (\p levels[i] - 1). +/// * For channeli bins are evenly-segmented and include the same width of sample values: +/// (\p upper_level[i] - \p lower_level[i]) / (\p levels[i] - 1). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Channels - number of channels interleaved in the input samples. +/// \tparam ActiveChannels - number of channels being used for computing histograms. +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] size - number of pixels in the samples range. +/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel. +/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel. +/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin in each active channel. +/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin in each active channel. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int size; // e.g., 8 +/// unsigned char * samples; // e.g., [(3, 1, 5, 255), (3, 1, 5, 255), (4, 2, 6, 127), (3, 2, 6, 127), +/// // (0, 0, 0, 100), (0, 1, 0, 100), (0, 0, 1, 255), (0, 1, 1, 255)] +/// int * histogram[3]; // 3 empty arrays of at least 256 elements each +/// unsigned int levels[3]; // e.g., [257, 257, 257] (for 256 bins) +/// int lower_level[3]; // e.g., [0, 0, 0] +/// int upper_level[3]; // e.g., [256, 256, 256] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::multi_histogram_even<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, lower_level, upper_level +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histograms +/// rocprim::multi_histogram_even<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, lower_level, upper_level +/// ); +/// // histogram: [[4, 0, 0, 3, 1, 0, 0, ..., 0], +/// // [2, 4, 2, 0, 0, 0, 0, ..., 0], +/// // [2, 2, 0, 0, 0, 2, 2, ..., 0]] +/// \endcode +/// \endparblock +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t multi_histogram_even(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int size, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level lower_level[ActiveChannels], + Level upper_level[ActiveChannels], + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::histogram_even_impl( + temporary_storage, storage_size, + samples, size, 1, 0, + histogram, + levels, lower_level, upper_level, + stream, debug_synchronous + ); +} + +/// \brief Computes histograms from a two-dimensional region of multi-channel samples using equal-width bins. +/// +/// \par +/// * The two-dimensional region of interest within \p samples can be specified using the \p columns, +/// \p rows and \p row_stride_bytes parameters. +/// * The row stride must be a whole multiple of the sample data type size, +/// i.e., (row_stride_bytes % sizeof(std::iterator_traits::value_type)) == 0. +/// * The input is a sequence of pixel structures, where each pixel comprises +/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for RGBA samples). +/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms +/// (e.g., \p ActiveChannels = 3 for computing histograms of only RGB from RGBA samples). +/// * For channeli the number of histogram bins is (\p levels[i] - 1). +/// * For channeli bins are evenly-segmented and include the same width of sample values: +/// (\p upper_level[i] - \p lower_level[i]) / (\p levels[i] - 1). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Channels - number of channels interleaved in the input samples. +/// \tparam ActiveChannels - number of channels being used for computing histograms. +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] columns - number of elements in each row of the region. +/// \param [in] rows - number of rows of the region. +/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region. +/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel. +/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel. +/// \param [in] lower_level - lower sample value bound (inclusive) for the first histogram bin in each active channel. +/// \param [in] upper_level - upper sample value bound (exclusive) for the last histogram bin in each active channel. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int columns; // e.g., 4 +/// unsigned int rows; // e.g., 2 +/// size_t row_stride_bytes; // e.g., 5 * sizeof(unsigned char) +/// unsigned char * samples; // e.g., [(3, 1, 5, 255), (3, 1, 5, 255), (4, 2, 6, 127), (3, 2, 6, 127), (-, -, -, -), +/// // (0, 0, 0, 100), (0, 1, 0, 100), (0, 0, 1, 255), (0, 1, 1, 255), (-, -, -, -)] +/// int * histogram[3]; // 3 empty arrays of at least 256 elements each +/// unsigned int levels[3]; // e.g., [257, 257, 257] (for 256 bins) +/// int lower_level[3]; // e.g., [0, 0, 0] +/// int upper_level[3]; // e.g., [256, 256, 256] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::multi_histogram_even<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, lower_level, upper_level +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histograms +/// rocprim::multi_histogram_even<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, lower_level, upper_level +/// ); +/// // histogram: [[4, 0, 0, 3, 1, 0, 0, ..., 0], +/// // [2, 4, 2, 0, 0, 0, 0, ..., 0], +/// // [2, 2, 0, 0, 0, 2, 2, ..., 0]] +/// \endcode +/// \endparblock +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t multi_histogram_even(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level lower_level[ActiveChannels], + Level upper_level[ActiveChannels], + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::histogram_even_impl( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram, + levels, lower_level, upper_level, + stream, debug_synchronous + ); +} + +/// \brief Computes a histogram from a sequence of samples using the specified bin boundary levels. +/// +/// \par +/// * The number of histogram bins is (\p levels - 1). +/// * The range for binj is [level_values[j], level_values[j+1]). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] size - number of elements in the samples range. +/// \param [out] histogram - pointer to the first element in the histogram range. +/// \param [in] levels - number of boundaries (levels) for histogram bins. +/// \param [in] level_values - pointer to the array of bin boundaries. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level histogram of 5 bins is computed on an array of float samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int size; // e.g., 8 +/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1] +/// int * histogram; // empty array of at least 5 elements +/// unsigned int levels; // e.g., 6 (for 5 bins) +/// float * level_values; // e.g., [0.0, 1.0, 5.0, 10.0, 20.0, 50.0] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::histogram_range( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, level_values +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histogram +/// rocprim::histogram_range( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, level_values +/// ); +/// // histogram: [1, 2, 3, 0, 0] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_range(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int size, + Counter * histogram, + unsigned int levels, + Level * level_values, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + Counter * histogram_single[1] = { histogram }; + unsigned int levels_single[1] = { levels }; + Level * level_values_single[1] = { level_values }; + + return detail::histogram_range_impl<1, 1, Config>( + temporary_storage, storage_size, + samples, size, 1, 0, + histogram_single, + levels_single, level_values_single, + stream, debug_synchronous + ); +} + +/// \brief Computes a histogram from a two-dimensional region of samples using the specified bin boundary levels. +/// +/// \par +/// * The two-dimensional region of interest within \p samples can be specified using the \p columns, +/// \p rows and \p row_stride_bytes parameters. +/// * The row stride must be a whole multiple of the sample data type size, +/// i.e., (row_stride_bytes % sizeof(std::iterator_traits::value_type)) == 0. +/// * The number of histogram bins is (\p levels - 1). +/// * The range for binj is [level_values[j], level_values[j+1]). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] columns - number of elements in each row of the region. +/// \param [in] rows - number of rows of the region. +/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region. +/// \param [out] histogram - pointer to the first element in the histogram range. +/// \param [in] levels - number of boundaries (levels) for histogram bins. +/// \param [in] level_values - pointer to the array of bin boundaries. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level histogram of 5 bins is computed on an array of float samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int columns; // e.g., 4 +/// unsigned int rows; // e.g., 2 +/// size_t row_stride_bytes; // e.g., 6 * sizeof(float) +/// float * samples; // e.g., [-10.0, 0.3, 9.5, 8.1, 1.5, 1.9, 100.0, 5.1] +/// int * histogram; // empty array of at least 5 elements +/// unsigned int levels; // e.g., 6 (for 5 bins) +/// float level_values; // e.g., [0.0, 1.0, 5.0, 10.0, 20.0, 50.0] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::histogram_range( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, level_values +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histogram +/// rocprim::histogram_range( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, level_values +/// ); +/// // histogram: [1, 2, 3, 0, 0] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t histogram_range(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram, + unsigned int levels, + Level * level_values, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + Counter * histogram_single[1] = { histogram }; + unsigned int levels_single[1] = { levels }; + Level * level_values_single[1] = { level_values }; + + return detail::histogram_range_impl<1, 1, Config>( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram_single, + levels_single, level_values_single, + stream, debug_synchronous + ); +} + +/// \brief Computes histograms from a sequence of multi-channel samples using the specified bin boundary levels. +/// +/// \par +/// * The input is a sequence of pixel structures, where each pixel comprises +/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for RGBA samples). +/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms +/// (e.g., \p ActiveChannels = 3 for computing histograms of only RGB from RGBA samples). +/// * For channeli the number of histogram bins is (\p levels[i] - 1). +/// * For channeli the range for binj is +/// [level_values[i][j], level_values[i][j+1]). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Channels - number of channels interleaved in the input samples. +/// \tparam ActiveChannels - number of channels being used for computing histograms. +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] size - number of pixels in the samples range. +/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel. +/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel. +/// \param [in] level_values - pointer to the array of bin boundaries for each active channel. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int size; // e.g., 8 +/// unsigned char * samples; // e.g., [(0, 0, 80, 255), (120, 0, 80, 255), (123, 0, 82, 127), (10, 1, 83, 127), +/// // (51, 1, 8, 100), (52, 1, 8, 100), (53, 0, 81, 255), (54, 50, 81, 255)] +/// int * histogram[3]; // 3 empty arrays of at least 256 elements each +/// unsigned int levels[3]; // e.g., [4, 4, 3] +/// int * level_values[3]; // e.g., [[0, 50, 100, 200], [0, 20, 40, 60], [0, 10, 100]] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::multi_histogram_range<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, level_values +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histograms +/// rocprim::multi_histogram_range<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, size, +/// histogram, levels, level_values +/// ); +/// // histogram: [[2, 4, 2], [7, 0, 1], [2, 6]] +/// \endcode +/// \endparblock +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t multi_histogram_range(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int size, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level * level_values[ActiveChannels], + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::histogram_range_impl( + temporary_storage, storage_size, + samples, size, 1, 0, + histogram, + levels, level_values, + stream, debug_synchronous + ); +} + +/// \brief Computes histograms from a two-dimensional region of multi-channel samples using the specified bin +/// boundary levels. +/// +/// \par +/// * The two-dimensional region of interest within \p samples can be specified using the \p columns, +/// \p rows and \p row_stride_bytes parameters. +/// * The row stride must be a whole multiple of the sample data type size, +/// i.e., (row_stride_bytes % sizeof(std::iterator_traits::value_type)) == 0. +/// * The input is a sequence of pixel structures, where each pixel comprises +/// a record of \p Channels consecutive data samples (e.g., \p Channels = 4 for RGBA samples). +/// * The first \p ActiveChannels channels of total \p Channels channels will be used for computing histograms +/// (e.g., \p ActiveChannels = 3 for computing histograms of only RGB from RGBA samples). +/// * For channeli the number of histogram bins is (\p levels[i] - 1). +/// * For channeli the range for binj is +/// [level_values[i][j], level_values[i][j+1]). +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// +/// \tparam Channels - number of channels interleaved in the input samples. +/// \tparam ActiveChannels - number of channels being used for computing histograms. +/// \tparam Config - [optional] configuration of the primitive. It can be \p histogram_config or +/// a custom class with the same members. +/// \tparam SampleIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam Counter - integer type for histogram bin counters. +/// \tparam Level - type of histogram boundaries (levels) +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] samples - iterator to the first element in the range of input samples. +/// \param [in] columns - number of elements in each row of the region. +/// \param [in] rows - number of rows of the region. +/// \param [in] row_stride_bytes - number of bytes between starts of consecutive rows of the region. +/// \param [out] histogram - pointers to the first element in the histogram range, one for each active channel. +/// \param [in] levels - number of boundaries (levels) for histogram bins in each active channel. +/// \param [in] level_values - pointer to the array of bin boundaries for each active channel. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int columns; // e.g., 4 +/// unsigned int rows; // e.g., 2 +/// size_t row_stride_bytes; // e.g., 5 * sizeof(unsigned char) +/// unsigned char * samples; // e.g., [(0, 0, 80, 0), (120, 0, 80, 0), (123, 0, 82, 0), (10, 1, 83, 0), (-, -, -, -), +/// // (51, 1, 8, 0), (52, 1, 8, 0), (53, 0, 81, 0), (54, 50, 81, 0), (-, -, -, -)] +/// int * histogram[3]; // 3 empty arrays +/// unsigned int levels[3]; // e.g., [4, 4, 3] +/// int * level_values[3]; // e.g., [[0, 50, 100, 200], [0, 20, 40, 60], [0, 10, 100]] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::multi_histogram_range<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, level_values +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // compute histograms +/// rocprim::multi_histogram_range<4, 3>( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// samples, columns, rows, row_stride_bytes, +/// histogram, levels, level_values +/// ); +/// // histogram: [[2, 4, 2], [7, 0, 1], [2, 6]] +/// \endcode +/// \endparblock +template< + unsigned int Channels, + unsigned int ActiveChannels, + class Config = default_config, + class SampleIterator, + class Counter, + class Level +> +inline +cudaError_t multi_histogram_range(void * temporary_storage, + size_t& storage_size, + SampleIterator samples, + unsigned int columns, + unsigned int rows, + size_t row_stride_bytes, + Counter * histogram[ActiveChannels], + unsigned int levels[ActiveChannels], + Level * level_values[ActiveChannels], + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::histogram_range_impl( + temporary_storage, storage_size, + samples, columns, rows, row_stride_bytes, + histogram, + levels, level_values, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_HISTOGRAM_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_histogram_config.hpp b/3rdparty/cub/rocprim/device/device_histogram_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4180c79b32cf72f67453dc9e6ec6c63d89232fa --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_histogram_config.hpp @@ -0,0 +1,128 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_HISTOGRAM_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_HISTOGRAM_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level histogram operation. +/// +/// \tparam HistogramConfig - configuration of histogram kernel. Must be \p kernel_config. +/// \tparam MaxGridSize - maximim number of blocks to launch. +/// \tparam SharedImplMaxBins - maximum total number of bins for all active channels +/// for the shared memory histogram implementation (samples -> shared memory bins -> global memory bins), +/// when exceeded the global memory implementation is used (samples -> global memory bins). +template< + class HistogramConfig, + unsigned int MaxGridSize = 1024, + unsigned int SharedImplMaxBins = 2048 +> +struct histogram_config +{ +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using histogram = HistogramConfig; + + static constexpr unsigned int max_grid_size = MaxGridSize; + static constexpr unsigned int shared_impl_max_bins = SharedImplMaxBins; +#endif +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template< + class HistogramConfig, + unsigned int MaxGridSize, + unsigned int SharedImplMaxBins +> constexpr unsigned int +histogram_config::max_grid_size; +template< + class HistogramConfig, + unsigned int MaxGridSize, + unsigned int SharedImplMaxBins +> constexpr unsigned int +histogram_config::shared_impl_max_bins; +#endif + +namespace detail +{ + +template +struct histogram_config_803 +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + + using type = histogram_config>; +}; + +template +struct histogram_config_900 +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + + using type = histogram_config>; +}; + +// TODO: We need to update these parameters +template +struct histogram_config_90a +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + + using type = histogram_config>; +}; + +// TODO: We need to update these parameters +template +struct histogram_config_1030 +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + + using type = histogram_config>; +}; + +template +struct default_histogram_config + : select_arch< + TargetArch, + select_arch_case<803, histogram_config_803 >, + select_arch_case<900, histogram_config_900 >, + select_arch_case >, + select_arch_case<1030, histogram_config_1030 >, + histogram_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_HISTOGRAM_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_merge.hpp b/3rdparty/cub/rocprim/device/device_merge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4a2a0d05098ba4273d9db734a52955271aaf56b6 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_merge.hpp @@ -0,0 +1,438 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_MERGE_HPP_ +#define ROCPRIM_DEVICE_DEVICE_MERGE_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "device_merge_config.hpp" +#include "detail/device_merge.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + class IndexIterator, + class KeysInputIterator1, + class KeysInputIterator2, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void partition_kernel(IndexIterator index, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + const size_t input1_size, + const size_t input2_size, + const unsigned int spacing, + BinaryFunction compare_function) +{ + partition_kernel_impl( + index, keys_input1, keys_input2, input1_size, input2_size, + spacing, compare_function + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class IndexIterator, + class KeysInputIterator1, + class KeysInputIterator2, + class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void merge_kernel(IndexIterator index, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function) +{ + merge_kernel_impl( + index, keys_input1, keys_input2, keys_output, + values_input1, values_input2, values_output, + input1_size, input2_size, compare_function + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + class Config, + class KeysInputIterator1, + class KeysInputIterator2, + class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + class BinaryFunction +> +inline +cudaError_t merge_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function, + const cudaStream_t stream, + bool debug_synchronous) + +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_merge_config + >; + + static constexpr unsigned int block_size = config::block_size; + static constexpr unsigned int half_block = block_size / 2; + static constexpr unsigned int items_per_thread = config::items_per_thread; + static constexpr auto items_per_block = block_size * items_per_thread; + + const unsigned int partitions = ((input1_size + input2_size) + items_per_block - 1) / items_per_block; + const size_t partition_bytes = (partitions + 1) * sizeof(unsigned int); + + if(temporary_storage == nullptr) + { + // storage_size is never zero + storage_size = partition_bytes; + return cudaSuccess; + } + + if( partitions == 0u ) + return cudaSuccess; + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + auto number_of_blocks = partitions; + if(debug_synchronous) + { + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + unsigned int * index = reinterpret_cast(temporary_storage); + + const unsigned partition_blocks = ((partitions + 1) + half_block - 1) / half_block; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::partition_kernel + <<>>( + index, keys_input1, keys_input2, input1_size, input2_size, + items_per_block, compare_function + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", input1_size, start); + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::merge_kernel + <<>>( + index, keys_input1, keys_input2, keys_output, + values_input1, values_input2, values_output, + input1_size, input2_size, compare_function + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("merge_kernel", input1_size, start); + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end of detail namespace + +/// \brief Parallel merge primitive for device level. +/// +/// \p merge function performs a device-wide merge. +/// Function merges two ordered sets of input values based on comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the merging function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for merging across the device. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or +/// a custom class with the same members. +/// \tparam InputIterator1 - random-access iterator type of the first input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam InputIterator2 - random-access iterator type of the second input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input1 - iterator to the first element in the first range to merge. +/// \param [in] input2 - iterator to the first element in the second range to merge. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] input1_size - number of element in the first input range. +/// \param [in] input2_size - number of element in the second input range. +/// \param [in] compare_function - binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge is performed on an array of +/// \p int values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 +/// int * input1; // e.g., [0, 1, 2, 3] +/// int * input2; // e.g., [0, 1, 2, 3] +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input1, input2, output, input_size1, input_size2 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform merge +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input1, input2, output, input_size1, input_size2 +/// ); +/// // output: [0, 0, 1, 1, 2, 2, 3, 3] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator1, + class InputIterator2, + class OutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +cudaError_t merge(void * temporary_storage, + size_t& storage_size, + InputIterator1 input1, + InputIterator2 input2, + OutputIterator output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + return detail::merge_impl( + temporary_storage, storage_size, + input1, input2, output, + values, values, values, + input1_size, input2_size, compare_function, + stream, debug_synchronous + ); +} + +/// \brief Parallel merge primitive for device level. +/// +/// \p merge function performs a device-wide merge of (key, value) pairs. +/// Function merges two ordered sets of input keys and corresponding values +/// based on key comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the merging function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for merging across the device. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator1 - random-access iterator type of the first keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysInputIterator2 - random-access iterator type of the second keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the keys output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator1 - random-access iterator type of the first values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator2 - random-access iterator type of the second values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the values output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input1 - iterator to the first key in the first range to merge. +/// \param [in] keys_input2 - iterator to the first key in the second range to merge. +/// \param [out] keys_output - iterator to the first key in the output range. +/// \param [in] values_input1 - iterator to the first value in the first range to merge. +/// \param [in] values_input2 - iterator to the first value in the second range to merge. +/// \param [out] values_output - iterator to the first value in the output range. +/// \param [in] input1_size - number of element in the first input range. +/// \param [in] input2_size - number of element in the second input range. +/// \param [in] compare_function - binary operation function object that will be used for key comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge is performed on an array of +/// \p int values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 +/// int * keys_input1; // e.g., [0, 1, 2, 3] +/// int * keys_input2; // e.g., [0, 1, 2, 3] +/// int * keys_output; // empty array of 8 elements +/// int * values_input1; // e.g., [10, 11, 12, 13] +/// int * values_input2; // e.g., [20, 21, 22, 23] +/// int * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1, keys_input2, keys_output, +/// values_input1, values_input2, values_output, +// input_size1, input_size2 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform merge +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1, keys_input2, keys_output, +/// values_input1, values_input2, values_output, +// input_size1, input_size2 +/// ); +/// // keys_output: [0, 0, 1, 1, 2, 2, 3, 3] +/// // values_output: [10, 20, 11, 21, 12, 22, 13, 23] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator1, + class KeysInputIterator2, + class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +cudaError_t merge(void * temporary_storage, + size_t& storage_size, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::merge_impl( + temporary_storage, storage_size, + keys_input1, keys_input2, keys_output, + values_input1, values_input2, values_output, + input1_size, input2_size, compare_function, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_MERGE_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_merge_config.hpp b/3rdparty/cub/rocprim/device/device_merge_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..07c7ffc684091ed201f14435ba87b882c4fdc0d7 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_merge_config.hpp @@ -0,0 +1,159 @@ +// Copyright (c) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_MERGE_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_MERGE_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level merge primitives. +template +using merge_config = kernel_config; + +namespace detail +{ + +template +struct merge_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + // TODO Tune when merge-by-key is ready + using type = merge_config<256, ::rocprim::max(1u, 10u / item_scale)>; +}; + +template +struct merge_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_type< + select_type_case >, + select_type_case >, + select_type_case >, + merge_config<256, ::rocprim::max(1u, 10u / item_scale)> + >; +}; + +template +struct merge_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + // TODO Tune when merge-by-key is ready + using type = merge_config<256, ::rocprim::max(1u, 10u / item_scale)>; +}; + +template +struct merge_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_type< + select_type_case >, + select_type_case >, + select_type_case >, + merge_config<256, ::rocprim::max(1u, 10u / item_scale)> + >; +}; + +// TODO: We need to update these parameters +template +struct merge_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + // TODO Tune when merge-by-key is ready + using type = merge_config<256, ::rocprim::max(1u, 10u / item_scale)>; +}; + +template +struct merge_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_type< + select_type_case >, + select_type_case >, + select_type_case >, + merge_config<256, ::rocprim::max(1u, 10u / item_scale)> + >; +}; + +// TODO: We need to update these parameters +template +struct merge_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + // TODO Tune when merge-by-key is ready + using type = merge_config<256, ::rocprim::max(1u, 10u / item_scale)>; +}; + +template +struct merge_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_type< + select_type_case >, + select_type_case >, + select_type_case >, + merge_config<256, ::rocprim::max(1u, 10u / item_scale)> + >; +}; + +template +struct default_merge_config + : select_arch< + TargetArch, + select_arch_case<803, merge_config_803>, + select_arch_case<900, merge_config_900>, + select_arch_case>, + select_arch_case<1030, merge_config_1030>, + merge_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_MERGE_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_merge_sort.hpp b/3rdparty/cub/rocprim/device/device_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..36e5190ac25b6182b8b10e338f43c8aeaff20eca --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_merge_sort.hpp @@ -0,0 +1,590 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SORT_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SORT_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "detail/device_merge.hpp" +#include "detail/device_merge_sort.hpp" +#include "detail/device_merge_sort_mergepath.hpp" +#include "device_transform.hpp" +#include "device_merge_sort_config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetT, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void block_sort_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT size, + BinaryFunction compare_function) +{ + block_sort_kernel_impl( + keys_input, keys_output, values_input, values_output, + size, compare_function + ); +} + +template< + unsigned int BlockSize, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetT, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void block_merge_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const OffsetT sorted_block_size, + BinaryFunction compare_function) +{ + block_merge_kernel_impl(keys_input, + keys_output, + values_input, + values_output, + input_size, + sorted_block_size, + compare_function); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetT, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void block_merge_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const OffsetT input_size, + const OffsetT sorted_block_size, + BinaryFunction compare_function, + const OffsetT* merge_partitions) +{ + block_merge_kernel_impl(keys_input, + keys_output, + values_input, + values_output, + input_size, + sorted_block_size, + compare_function, + merge_partitions); +} + +#define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto error = cudaStreamSynchronize(stream); \ + if(error != cudaSuccess) return error; \ + auto end = std::chrono::high_resolution_clock::now(); \ + auto d = std::chrono::duration_cast>(end - start); \ + std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ + } + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void device_mergepath_partition_kernel(KeysInputIterator keys, + const OffsetT input_size, + const unsigned int num_partitions, + OffsetT *merge_partitions, + const CompareOpT compare_op, + const OffsetT sorted_block_size) +{ + const OffsetT partition_id = blockIdx.x * BlockSize + threadIdx.x; + + if (partition_id >= num_partitions) + { + return; + } + + const unsigned int merged_tiles = sorted_block_size / ItemsPerTile; + const unsigned int target_merged_tiles = merged_tiles * 2; + const unsigned int mask = target_merged_tiles - 1; + const unsigned int tilegroup_start_id = ~mask & partition_id; // id of the first tile in the current tile-group + const OffsetT tilegroup_start = ItemsPerTile * tilegroup_start_id; // index of the first item in the current tile-group + + const unsigned int local_tile_id = mask & partition_id; // id of the current tile in the current tile-group + + const OffsetT keys1_beg = rocprim::min(input_size, tilegroup_start); + const OffsetT keys1_end = rocprim::min(input_size, tilegroup_start + sorted_block_size); + const OffsetT keys2_beg = keys1_end; + const OffsetT keys2_end = rocprim::min(input_size, keys2_beg + sorted_block_size); + + const OffsetT partition_at = rocprim::min(keys2_end - keys1_beg, ItemsPerTile * local_tile_id); + + const OffsetT partition_diag = ::rocprim::detail::merge_path(keys + keys1_beg, + keys + keys2_beg, + keys1_end - keys1_beg, + keys2_end - keys2_beg, + partition_at, + compare_op); + + merge_partitions[partition_id] = keys1_beg + partition_diag; +} + +template< + class Config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class BinaryFunction +> +inline +cudaError_t merge_sort_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const unsigned int size, + BinaryFunction compare_function, + const cudaStream_t stream, + bool debug_synchronous) +{ + using OffsetT = unsigned int; + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + constexpr bool with_values = !std::is_same::value; + + // Get default config if Config is default_config + using config = default_or_custom_config< + Config, + default_merge_sort_config + >; + + static constexpr unsigned int sort_block_size = config::sort_config::block_size; + static constexpr unsigned int sort_items_per_thread = config::sort_config::items_per_thread; + static constexpr unsigned int sort_items_per_block = sort_block_size * sort_items_per_thread; + + static constexpr unsigned int merge_impl1_block_size = config::merge_impl1_config::block_size; + static constexpr unsigned int merge_impl1_items_per_thread = config::merge_impl1_config::items_per_thread; + static constexpr unsigned int merge_impl1_items_per_block = merge_impl1_block_size * merge_impl1_items_per_thread; + + static constexpr unsigned int merge_partition_block_size = config::merge_mergepath_partition_config::block_size; + static constexpr unsigned int merge_mergepath_block_size = config::merge_mergepath_config::block_size; + static constexpr unsigned int merge_mergepath_items_per_thread = config::merge_mergepath_config::items_per_thread; + static constexpr unsigned int merge_mergepath_items_per_block = merge_mergepath_block_size * merge_mergepath_items_per_thread; + + static_assert(merge_mergepath_items_per_block >= sort_items_per_block, + "merge_mergepath_items_per_block must be greater than or equal to sort_items_per_block"); + static_assert(sort_items_per_block % config::merge_impl1_config::block_size == 0, + "Merge block size must be a divisor of the items per block of the sort step"); + + const size_t keys_bytes = ::rocprim::detail::align_size(size * sizeof(key_type)); + const size_t values_bytes = with_values ? ::rocprim::detail::align_size(size * sizeof(value_type)) : 0; + + const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); + const unsigned int merge_impl1_number_of_blocks = ceiling_div(size, merge_impl1_items_per_block); + const unsigned int merge_mergepath_number_of_blocks = ceiling_div(size, merge_mergepath_items_per_block); + + bool use_mergepath = size > config::min_input_size_mergepath; + // variables below used for mergepath + const unsigned int merge_num_partitions = merge_mergepath_number_of_blocks + 1; + const unsigned int merge_partition_number_of_blocks = ceiling_div(merge_num_partitions, merge_partition_block_size); + const size_t d_merge_partitions_bytes = use_mergepath ? merge_num_partitions * sizeof(OffsetT) : 0; + + if(temporary_storage == nullptr) + { + storage_size = d_merge_partitions_bytes + keys_bytes + values_bytes; + // Make sure user won't try to allocate 0 bytes memory + storage_size = storage_size == 0 ? 4 : storage_size; + return cudaSuccess; + } + + if( size == size_t(0) ) + return cudaSuccess; + + if(debug_synchronous) + { + std::cout << "-----" << '\n'; + std::cout << "size: " << size << '\n'; + std::cout << "sort_block_size: " << sort_block_size << '\n'; + std::cout << "sort_items_per_thread: " << sort_items_per_thread << '\n'; + std::cout << "sort_items_per_block: " << sort_items_per_block << '\n'; + std::cout << "sort_number_of_blocks: " << sort_number_of_blocks << '\n'; + std::cout << "merge_impl1_block_size: " << merge_impl1_block_size << '\n'; + std::cout << "merge_impl1_number_of_blocks: " << merge_impl1_number_of_blocks << '\n'; + std::cout << "merge_impl1_items_per_thread: " << merge_impl1_items_per_thread << '\n'; + std::cout << "merge_impl1_items_per_block: " << merge_impl1_items_per_block << '\n'; + std::cout << "merge_mergepath_block_size: " << merge_mergepath_block_size << '\n'; + std::cout << "merge_mergepath_number_of_blocks: " << merge_mergepath_number_of_blocks << '\n'; + std::cout << "merge_mergepath_items_per_thread: " << merge_mergepath_items_per_thread << '\n'; + std::cout << "merge_mergepath_items_per_block: " << merge_mergepath_items_per_block << '\n'; + std::cout << "num_partitions: " << merge_num_partitions << '\n'; + std::cout << "merge_mergepath_partition_block_size: " << merge_partition_block_size << '\n'; + std::cout << "merge_mergepath_partition_number_of_blocks: " << merge_partition_number_of_blocks << '\n'; + } + + char* ptr = reinterpret_cast(temporary_storage); + OffsetT* d_merge_partitions = reinterpret_cast(ptr); + ptr += d_merge_partitions_bytes; + key_type * keys_buffer = reinterpret_cast(ptr); + ptr += keys_bytes; + value_type * values_buffer = with_values ? reinterpret_cast(ptr) : nullptr; + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + block_sort_kernel + <<>>( + keys_input, keys_buffer, values_input, values_buffer, + size, compare_function + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_sort_kernel", size, start); + + bool temporary_store = true; + for(OffsetT block = sort_items_per_block; block < size; block *= 2) + { + temporary_store = !temporary_store; + + const auto merge_step = [&](auto keys_input_, + auto keys_output_, + auto values_input_, + auto values_output_) -> cudaError_t { + if(use_mergepath) + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + device_mergepath_partition_kernel + <<>>( + keys_input_, size, merge_num_partitions, d_merge_partitions, + compare_function, block); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("device_mergepath_partition_kernel", size, start); + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + block_merge_kernel + <<>>( + keys_input_, keys_output_, values_input_, values_output_, + size, block, compare_function, d_merge_partitions + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_merge_kernel", size, start); + } + else + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + block_merge_kernel + <<>>( + keys_input_, keys_output_, values_input_, values_output_, + size, block, compare_function + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_merge_kernel", size, start) + } + return cudaSuccess; + }; + + cudaError_t error; + if(temporary_store) + { + error = merge_step(keys_output, keys_buffer, values_output, values_buffer); + } + else + { + error = merge_step(keys_buffer, keys_output, values_buffer, values_output); + } + if(error != cudaSuccess) return error; + } + + if(temporary_store) + { + cudaError_t error = ::rocprim::transform( + keys_buffer, keys_output, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + if(with_values) + { + cudaError_t error = ::rocprim::transform( + values_buffer, values_output, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + } + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR +#undef ROCPRIM_DETAIL_HIP_SYNC + +} // end of detail namespace + +/// \brief Parallel merge sort primitive for device level. +/// +/// \p merge_sort function performs a device-wide merge sort +/// of keys. Function sorts input keys based on comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for sorting across the device. +/// +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] compare_function - binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge sort is performed on an array of +/// \p float values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::merge_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// // keys_output: [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +cudaError_t merge_sort(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + const size_t size, + BinaryFunction compare_function = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + return detail::merge_sort_impl( + temporary_storage, storage_size, + keys_input, keys_output, values, values, size, + compare_function, stream, debug_synchronous + ); +} + +/// \brief Parallel ascending merge sort-by-key primitive for device level. +/// +/// \p merge_sort function performs a device-wide merge sort +/// of (key, value) pairs. Function sorts input pairs based on comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for sorting across the device. +/// +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] values_input - pointer to the first element in the range to sort. +/// \param [out] values_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] compare_function - binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::merge_sort( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// // keys_output: [ 1, 2, 3, 4, 5, 6, 7, 8] +/// // values_output: [-1, -2, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +cudaError_t merge_sort(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const size_t size, + BinaryFunction compare_function = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::merge_sort_impl( + temporary_storage, storage_size, + keys_input, keys_output, values_input, values_output, size, + compare_function, stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_merge_sort_config.hpp b/3rdparty/cub/rocprim/device/device_merge_sort_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83b9003794368a3c67353daced4945d17a710502 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_merge_sort_config.hpp @@ -0,0 +1,223 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_MERGE_SORT_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_MERGE_SORT_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../functional.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + template + struct merge_sort_config_impl + { + using sort_config = kernel_config; + using merge_impl1_config = kernel_config; + using merge_mergepath_partition_config = kernel_config; + using merge_mergepath_config + = kernel_config; + static constexpr unsigned int min_input_size_mergepath = MinInputSizeMergepath; + }; +} + +/// \brief Configuration of device-level merge primitives. +/// +/// \tparam SortBlockSize - block size in the block-sort step +/// \tparam SortItemsPerThread - ItemsPerThread in the block-sort step +/// \tparam MergeImpl1BlockSize - block size in the block merge step using impl1 (used when input_size < MinInputSizeMergepath) +/// \tparam MergeImplMPPartitionBlockSize - block size of the partition kernel in the block merge step using mergepath impl +/// \tparam MergeImplMPBlockSize - block size in the block merge step using mergepath impl +/// \tparam MergeImplMPItemsPerThread - ItemsPerThread in the block merge step using mergepath impl +/// \tparam MinInputSizeMergepath - breakpoint of input-size to use mergepath impl for block merge step +template +using merge_sort_config = detail::merge_sort_config_impl; + +namespace detail +{ + +template +struct merge_sort_config_803 +{ + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + merge_sort_config<64U> + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + merge_sort_config<256U> + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + merge_sort_config<512U> + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + merge_sort_config<1024U> + >, + merge_sort_config::value> + >; +}; + +template +struct merge_sort_config_803 +{ + using type = merge_sort_config::value>; +}; + +template +struct merge_sort_config_803 + : select_type< + select_type_case >, + select_type_case >, + select_type_case >, + select_type_case= 8, merge_sort_config::value> > + > { }; + +template<> +struct merge_sort_config_803 +{ + using type = merge_sort_config<256U>; +}; + +template::value> +struct merge_sort_config_900 +{ + using type = select_type< + // clang-format off + select_type_case<(sizeof(Key) == 1 && sizeof(Value) <= 16), merge_sort_config<512U, 512U, 2U>>, + select_type_case<(sizeof(Key) == 2 && sizeof(Value) <= 16), merge_sort_config<512U, 256U, 4U>>, + select_type_case<(sizeof(Key) == 4 && sizeof(Value) <= 16), merge_sort_config<512U, 256U, 4U>>, + select_type_case<(sizeof(Key) == 8 && sizeof(Value) <= 16), merge_sort_config<256U, 256U, 4U>>, + // clang-format on + merge_sort_config< + limit_block_size<1024U, + ::rocprim::max(sizeof(Key) + sizeof(unsigned int), sizeof(Value)), + ROCPRIM_WARP_SIZE_64>::value>>; +}; + +template +struct merge_sort_config_900 +{ + using type = select_type< + // clang-format off + select_type_case<(sizeof(Key) == 8 && sizeof(Value) <= 16), merge_sort_config<512U, 512U, 2U>>, + select_type_case<(sizeof(Key) == 16 && sizeof(Value) <= 16), merge_sort_config<512U, 512U, 2U>>, + // clang-format on + merge_sort_config< + limit_block_size<512U, + ::rocprim::max(sizeof(Key) + sizeof(unsigned int), sizeof(Value)), + ROCPRIM_WARP_SIZE_64>::value>>; +}; + +// TODO: We need to update these parameters +template +struct merge_sort_config_1030 +{ + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + merge_sort_config<64U> + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + merge_sort_config<256U> + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + merge_sort_config<512U> + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + merge_sort_config<1024U> + >, + merge_sort_config::value> + >; +}; + +template +struct merge_sort_config_1030 +{ + using type = merge_sort_config::value>; +}; + +template +struct merge_sort_config_1030 + : select_type< + select_type_case >, + select_type_case >, + select_type_case >, + select_type_case= 8, merge_sort_config::value> > + > { }; + +template<> +struct merge_sort_config_1030 +{ + using type = merge_sort_config<256U>; +}; + +template +struct default_merge_sort_config + : select_arch< + TargetArch, + select_arch_case<803, merge_sort_config_803>, + select_arch_case<900, merge_sort_config_900>, + select_arch_case<1030, merge_sort_config_1030>, + merge_sort_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_MERGE_SORT_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_partition.hpp b/3rdparty/cub/rocprim/device/device_partition.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8be346fc4ff94e6516b65989d540a242e9ea25fb --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_partition.hpp @@ -0,0 +1,707 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_PARTITION_HPP_ +#define ROCPRIM_DEVICE_DEVICE_PARTITION_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../functional.hpp" +#include "../types.hpp" +#include "../type_traits.hpp" +#include "../detail/various.hpp" + +#include "device_select_config.hpp" +#include "detail/device_scan_common.hpp" +#include "detail/device_partition.hpp" +#include "device_transform.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + select_method SelectMethod, + bool OnlySelected, + class Config, + class KeyIterator, + class ValueIterator, + class FlagIterator, + class OutputKeyIterator, + class OutputValueIterator, + class InequalityOp, + class OffsetLookbackScanState, + class... UnaryPredicates +> +ROCPRIM_KERNEL +__launch_bounds__(Config::block_size) +void partition_kernel(KeyIterator keys_input, + ValueIterator values_input, + FlagIterator flags, + OutputKeyIterator keys_output, + OutputValueIterator values_output, + size_t* selected_count, + size_t* prev_selected_count, + const size_t size, + InequalityOp inequality_op, + OffsetLookbackScanState offset_scan_state, + const unsigned int number_of_blocks, + ordered_block_id ordered_bid, + UnaryPredicates... predicates) +{ + partition_kernel_impl( + keys_input, values_input, flags, keys_output, values_output, selected_count, prev_selected_count, + size, inequality_op, offset_scan_state, number_of_blocks, ordered_bid, predicates... + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto error = cudaStreamSynchronize(stream); \ + if(error != cudaSuccess) return error; \ + auto end = std::chrono::high_resolution_clock::now(); \ + auto d = std::chrono::duration_cast>(end - start); \ + std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ + } + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + // Method of selection: flag, predicate, unique + select_method SelectMethod, + // if true, it doesn't copy rejected values to output + bool OnlySelected, + class Config, + class OffsetT, + class KeyIterator, + class ValueIterator, // can be rocprim::empty_type* for key only + class FlagIterator, + class OutputKeyIterator, + class OutputValueIterator, // can be rocprim::empty_type* for key only + class InequalityOp, + class SelectedCountOutputIterator, + class... UnaryPredicates +> +inline +cudaError_t partition_impl(void * temporary_storage, + size_t& storage_size, + KeyIterator keys_input, + ValueIterator values_input, + FlagIterator flags, + OutputKeyIterator keys_output, + OutputValueIterator values_output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + InequalityOp inequality_op, + const cudaStream_t stream, + bool debug_synchronous, + UnaryPredicates... predicates) +{ + using offset_type = OffsetT; + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + // Get default config if Config is default_config + using config = default_or_custom_config< + Config, + default_select_config + >; + + using offset_scan_state_type = detail::lookback_scan_state; + using offset_scan_state_with_sleep_type = detail::lookback_scan_state; + using ordered_block_id_type = detail::ordered_block_id; + + + static constexpr unsigned int block_size = config::block_size; + static constexpr unsigned int items_per_thread = config::items_per_thread; + static constexpr auto items_per_block = block_size * items_per_thread; + + static constexpr bool is_three_way = sizeof...(UnaryPredicates) == 2; + + static constexpr size_t size_limit = config::size_limit; + static constexpr size_t aligned_size_limit = ::rocprim::max(size_limit - (size_limit % items_per_block), items_per_block); + const size_t limited_size = std::min(size, aligned_size_limit); + const bool use_limited_size = limited_size == aligned_size_limit; + + const unsigned int number_of_blocks = + static_cast(::rocprim::detail::ceiling_div(limited_size, items_per_block)); + + // Calculate required temporary storage + size_t offset_scan_state_bytes = ::rocprim::detail::align_size( + // This is valid even with offset_scan_state_with_sleep_type + offset_scan_state_type::get_storage_size(number_of_blocks) + ); + size_t ordered_block_id_bytes = ::rocprim::detail::align_size( + ordered_block_id_type::get_storage_size(), + alignof(size_t) + ); + + if(temporary_storage == nullptr) + { + // storage_size is never zero + storage_size = offset_scan_state_bytes + ordered_block_id_bytes + (sizeof(size_t) * 2 * (is_three_way ? 2 : 1)); + + return cudaSuccess; + } + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + // Create and initialize lookback_scan_state obj + auto offset_scan_state = offset_scan_state_type::create( + temporary_storage, number_of_blocks + ); + auto offset_scan_state_with_sleep = offset_scan_state_with_sleep_type::create( + temporary_storage, number_of_blocks + ); + // Create ad initialize ordered_block_id obj + auto ptr = reinterpret_cast(temporary_storage); + auto ordered_bid = ordered_block_id_type::create( + reinterpret_cast(ptr + offset_scan_state_bytes) + ); + + size_t* selected_count = reinterpret_cast(ptr + offset_scan_state_bytes + + ordered_block_id_bytes); + size_t* prev_selected_count + = reinterpret_cast(ptr + offset_scan_state_bytes + ordered_block_id_bytes + + (is_three_way ? 2 : 1) * sizeof(size_t)); + + cudaError_t error; + + // Memset selected_count and prev_selected_count at once + error = cudaMemsetAsync(selected_count, + 0, + sizeof(*selected_count) * 2 * (is_three_way ? 2 : 1), + stream); + if (error != cudaSuccess) return error; + + cudaDeviceProp prop; + int deviceId; + static_cast(cudaGetDevice(&deviceId)); + static_cast(cudaGetDeviceProperties(&prop, deviceId)); + + + int asicRevision = 0; + + + const size_t number_of_launches = ::rocprim::detail::ceiling_div(size, aligned_size_limit); + + if(debug_synchronous) + { + std::cout << "use_limited_size " << use_limited_size << '\n'; + std::cout << "aligned_size_limit " << aligned_size_limit << '\n'; + std::cout << "number_of_launches " << number_of_launches << '\n'; + std::cout << "size " << size << '\n'; + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + for (size_t i = 0, offset = 0; i < number_of_launches; i++, offset+=limited_size) + { + const unsigned int current_size = static_cast(std::min(size - offset, limited_size)); + + const unsigned int current_number_of_blocks = ::rocprim::detail::ceiling_div(current_size, items_per_block); + + auto grid_size = ::rocprim::detail::ceiling_div(number_of_blocks, block_size); + + if(debug_synchronous) + { + std::cout << "current size " << current_size << '\n'; + std::cout << "current number of blocks " << current_number_of_blocks << '\n'; + + start = std::chrono::high_resolution_clock::now(); + } + + + + init_lookback_scan_state_kernel + <<>>( + offset_scan_state, current_number_of_blocks, ordered_bid + ); + + + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_offset_scan_state_kernel", current_number_of_blocks, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + grid_size = current_number_of_blocks; + + + + partition_kernel< + SelectMethod, OnlySelected, config + > + <<>>( + keys_input + offset, values_input + offset, flags + offset, keys_output, values_output, selected_count, prev_selected_count, + current_size, inequality_op, offset_scan_state, current_number_of_blocks, ordered_bid, predicates... + ); + + + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", size, start) + + std::swap(selected_count, prev_selected_count); + } + + error = ::rocprim::transform( + prev_selected_count, selected_count_output, (is_three_way ? 2 : 1), + ::rocprim::identity<>{}, + stream, debug_synchronous + ); + if (error != cudaSuccess) return error; + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR +#undef ROCPRIM_DETAIL_HIP_SYNC + +} // end of detail namespace + +/// \brief Parallel select primitive for device level using range of flags. +/// +/// Performs a device-wide partition based on input \p flags. Partition copies +/// the values from \p input to \p output in such a way that all values for which the corresponding +/// items from /p flags are \p true (or can be implicitly converted to \p true) precede +/// the elements for which the corresponding items from /p flags are \p false. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input, \p flags and \p output must have at least \p size elements. +/// * Range specified by \p selected_count_output must have at least 1 element. +/// * Values of \p flag range should be implicitly convertible to `bool` type. +/// * Relative order is preserved for the elements for which the corresponding values from \p flags +/// are \p true. Other elements are copied in reverse order. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p select_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam FlagIterator - random-access iterator type of the flag range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [in] flags - iterator to the selection flag corresponding to the first element from \p input range. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level partition operation is performed on an array of +/// integer values with array of chars used as flags. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// char * flags; // e.g., [0, 1, 1, 0, 0, 1, 0, 1] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::partition( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, +/// output, output_count, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform partition +/// rocprim::partition( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, +/// output, output_count, +/// input_size +/// ); +/// // output: [2, 3, 6, 8, 7, 5, 4, 1] +/// // output_count: 4 +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class FlagIterator, + class OutputIterator, + class SelectedCountOutputIterator +> +inline +cudaError_t partition(void * temporary_storage, + size_t& storage_size, + InputIterator input, + FlagIterator flags, + OutputIterator output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy unary predicate + using unary_predicate_type = ::rocprim::empty_type; + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + using offset_type = unsigned int; + rocprim::empty_type* const no_values = nullptr; // key only + + return detail::partition_impl( + temporary_storage, storage_size, input, no_values, flags, output, no_values, selected_count_output, + size, inequality_op_type(), stream, debug_synchronous, unary_predicate_type() + ); +} + +/// \brief Parallel select primitive for device level using selection predicate. +/// +/// Performs a device-wide partition using selection predicate. Partition copies +/// the values from \p input to \p output in such a way that all values for which +/// the \p predicate returns \p true precede the elements for which it returns \p false. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input, \p flags and \p output must have at least \p size elements. +/// * Range specified by \p selected_count_output must have at least 1 element. +/// * Relative order is preserved for the elements for which the \p predicate returns \p true. Other +/// elements are copied in reverse order. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p select_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// \tparam UnaryPredicate - type of a unary selection predicate. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] predicate - unary function object which returns /p true if the element should be +/// ordered before other elements. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level partition operation is performed on an array of +/// integer values, even values are copied before odd values. +/// +/// \code{.cpp} +/// #include /// +/// +/// auto predicate = +/// [] __device__ (int a) -> bool +/// { +/// return (a%2) == 0; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::partition( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, +/// output, output_count, +/// input_size, +/// predicate +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform partition +/// rocprim::partition( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, +/// output, output_count, +/// input_size, +/// predicate +/// ); +/// // output: [2, 4, 6, 8, 7, 5, 3, 1] +/// // output_count: 4 +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class SelectedCountOutputIterator, + class UnaryPredicate +> +inline +cudaError_t partition(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + UnaryPredicate predicate, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy flag type + using flag_type = ::rocprim::empty_type; + flag_type * flags = nullptr; + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + using offset_type = unsigned int; + rocprim::empty_type* const no_values = nullptr; // key only + + return detail::partition_impl( + temporary_storage, storage_size, input, no_values, flags, output, no_values, selected_count_output, + size, inequality_op_type(), stream, debug_synchronous, predicate + ); +} + +/// \brief Parallel select primitive for device level using two selection predicates. +/// +/// Performs a device-wide three-way partition using two selection predicates. Partition copies +/// the values from \p input to either \p output_first_part or \p output_second_part or +/// \p output_unselected according to the following criteria: +/// The value is copied to \p output_first_part if the predicate \p select_first_part_op invoked +/// with the value returns \p true. It is copied to \p output_second_part if \p select_first_part_op +/// returns \p false and \p select_second_part_op returns \p true, and it is copied to +/// \p output_unselected otherwise. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * Range specified by \p selected_count_output must have at least 2 elements. +/// * Relative order is preserved for the elements. +/// * The number of elements written to \p output_first_part is equal to the number of elements +/// in the input for which \p select_first_part_op returned \p true. +/// * The number of elements written to \p output_second_part is equal to the number of elements +/// in the input for which \p select_first_part_op returned \p false and \p select_second_part_op +/// returned \p true. +/// * The number of elements written to \p output_unselected is equal to the number of input elements +/// minus the number of elements written to \p output_first_part minus the number of elements written +/// to \p output_second_part. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p select_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam FirstOutputIterator - random-access iterator type of the first output range. It can be +/// a simple pointer type. +/// \tparam SecondOutputIterator - random-access iterator type of the second output range. It can be +/// a simple pointer type. +/// \tparam UnselectedOutputIterator - random-access iterator type of the unselected output range. +/// It can be a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// \tparam FirstUnaryPredicate - type of the first unary selection predicate. +/// \tparam SecondUnaryPredicate - type of the second unary selection predicate. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [out] output_first_part - iterator to the first element in the first output range. +/// \param [out] output_second_part - iterator to the first element in the second output range. +/// \param [out] output_unselected - iterator to the first element in the unselected output range. +/// \param [out] selected_count_output - iterator to the total number of selected values in +/// \p output_first_part and \p output_second_part respectively. +/// \param [in] size - number of element in the input range. +/// \param [in] select_first_part_op - unary function object which returns \p true if the element +/// should be in \p output_first_part range +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] select_second_part_op - unary function object which returns \p true if the element +/// should be in \p output_second_part range (given that \p select_first_part_op returned \p false) +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level three-way partition operation is performed on an array of +/// integer values, even values are copied to the first partition, odd and 3-divisible values +/// are copied to the second partition, and the rest of the values are copied to the +/// unselected partition +/// +/// \code{.cpp} +/// #include +/// +/// auto first_predicate = +/// [] __device__ (int a) -> bool +/// { +/// return (a%2) == 0; +/// }; +/// auto second_predicate = +/// [] __device__ (int a) -> bool +/// { +/// return (a%3) == 0; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output_first_part; // array of 8 elements +/// int * output_second_part; // array of 8 elements +/// int * output_unselected; // array of 8 elements +/// size_t * output_count; // array of 2 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::partition_three_way( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, +/// output_first_part, output_second_part, output_unselected, +/// output_count, +/// input_size, +/// first_predicate, +/// second_predicate +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform partition +/// rocprim::partition_three_way( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, +/// output_first_part, output_second_part, output_unselected, +/// output_count, +/// input_size, +/// first_predicate, +/// second_predicate +/// ); +/// // elements denoted by '*' were not modified +/// // output_first_part: [2, 4, 6, 8, *, *, *, *] +/// // output_second_part: [3, *, *, *, *, *, *, *] +/// // output_unselected: [1, 5, 7, *, *, *, *, *] +/// // output_count: [4, 1] +/// \endcode +/// \endparblock +template < + class Config = default_config, + typename InputIterator, + typename FirstOutputIterator, + typename SecondOutputIterator, + typename UnselectedOutputIterator, + typename SelectedCountOutputIterator, + typename FirstUnaryPredicate, + typename SecondUnaryPredicate> +inline +cudaError_t partition_three_way(void * temporary_storage, + size_t& storage_size, + InputIterator input, + FirstOutputIterator output_first_part, + SecondOutputIterator output_second_part, + UnselectedOutputIterator output_unselected, + SelectedCountOutputIterator selected_count_output, + const size_t size, + FirstUnaryPredicate select_first_part_op, + SecondUnaryPredicate select_second_part_op, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy flag type + using flag_type = ::rocprim::empty_type; + flag_type * flags = nullptr; + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + using offset_type = uint2; + using output_key_iterator_tuple = tuple< + FirstOutputIterator, + SecondOutputIterator, + UnselectedOutputIterator>; + using output_value_iterator_tuple + = tuple<::rocprim::empty_type*, ::rocprim::empty_type*, ::rocprim::empty_type*>; + rocprim::empty_type* const no_input_values = nullptr; // key only + const output_value_iterator_tuple no_output_values {nullptr, nullptr, nullptr}; // key only + + output_key_iterator_tuple output{ output_first_part, output_second_part, output_unselected }; + + return detail::partition_impl( + temporary_storage, storage_size, input, no_input_values, flags, output, no_output_values, selected_count_output, + size, inequality_op_type(), stream, debug_synchronous, + select_first_part_op, select_second_part_op + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_PARTITION_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_radix_sort.hpp b/3rdparty/cub/rocprim/device/device_radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..71c5f997e656fb477f3c8a705946d88467c7ddd9 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_radix_sort.hpp @@ -0,0 +1,1677 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/radix_sort.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "device_radix_sort_config.hpp" +#include "device_transform.hpp" +#include "detail/device_radix_sort.hpp" +#include "specialization/device_radix_single_sort.hpp" +#include "specialization/device_radix_merge_sort.hpp" + +/// \addtogroup devicemodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending, + class KeysInputIterator, + class Offset +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void fill_digit_counts_kernel(KeysInputIterator keys_input, + Offset size, + Offset * batch_digit_counts, + unsigned int bit, + unsigned int current_radix_bits, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + fill_digit_counts( + keys_input, size, + batch_digit_counts, + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + class Offset +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void scan_batches_kernel(Offset * batch_digit_counts, + Offset * digit_counts, + unsigned int batches) +{ + scan_batches(batch_digit_counts, digit_counts, batches); +} + +template< + unsigned int RadixBits, + class Offset +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void scan_digits_kernel(Offset * digit_counts) +{ + scan_digits(digit_counts); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int RadixBits, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Offset +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void sort_and_scatter_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Offset size, + const Offset * batch_digit_starts, + const Offset * digit_starts, + unsigned int bit, + unsigned int current_radix_bits, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + sort_and_scatter( + keys_input, keys_output, values_input, values_output, size, + batch_digit_starts, digit_starts, + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); +} + +#ifndef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +#endif + +template< + class Config, + unsigned int RadixBits, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Offset +> +inline +cudaError_t radix_sort_iteration(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + Offset size, + Offset * batch_digit_counts, + Offset * digit_counts, + bool from_input, + bool to_output, + unsigned int bit, + unsigned int end_bit, + unsigned int blocks_per_full_batch, + unsigned int full_batches, + unsigned int batches, + cudaStream_t stream, + bool debug_synchronous) +{ + constexpr unsigned int radix_size = 1 << RadixBits; + + // Handle cases when (end_bit - bit) is not divisible by RadixBits, i.e. the last + // iteration has a shorter mask. + const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit); + + std::chrono::high_resolution_clock::time_point start; + + if(debug_synchronous) + { + std::cout << "RadixBits " << RadixBits << '\n'; + std::cout << "bit " << bit << '\n'; + std::cout << "current_radix_bits " << current_radix_bits << '\n'; + } + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if(from_input) + { + fill_digit_counts_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_input, size, + batch_digit_counts, + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + else + { + if(to_output) + { + fill_digit_counts_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_tmp, size, + batch_digit_counts, + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + else + { + fill_digit_counts_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_output, size, + batch_digit_counts, + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + } + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("fill_digit_counts", size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + scan_batches_kernel + <<>>( + batch_digit_counts, digit_counts, batches + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_batches", radix_size * Config::scan::block_size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + scan_digits_kernel + <<>>( + digit_counts + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_digits", radix_size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if(from_input) + { + if(to_output) + { + sort_and_scatter_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_input, keys_output, values_input, values_output, size, + const_cast(batch_digit_counts), + const_cast(digit_counts), + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + else + { + sort_and_scatter_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_input, keys_tmp, values_input, values_tmp, size, + const_cast(batch_digit_counts), + const_cast(digit_counts), + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + } + else + { + if(to_output) + { + sort_and_scatter_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_tmp, keys_output, values_tmp, values_output, size, + const_cast(batch_digit_counts), + const_cast(digit_counts), + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + else + { + sort_and_scatter_kernel< + Config::sort::block_size, Config::sort::items_per_thread, RadixBits, Descending + > + <<>>( + keys_output, keys_tmp, values_output, values_tmp, size, + const_cast(batch_digit_counts), + const_cast(digit_counts), + bit, current_radix_bits, + blocks_per_full_batch, full_batches + ); + } + } + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("sort_and_scatter", size, start) + + return cudaSuccess; +} + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator +> +inline +cudaError_t radix_sort_single_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + bool& is_result_in_output, + unsigned int begin_bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using config = default_or_custom_config< + Config, + default_radix_sort_config + >; + + const size_t minimum_bytes = ::rocprim::detail::align_size(1); + if(temporary_storage == nullptr) + { + storage_size = minimum_bytes; + return cudaSuccess; + } + + if( size == 0u ) + return cudaSuccess; + + if(debug_synchronous) + { + std::cout << "temporary_storage " << temporary_storage << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + cudaError_t error = radix_sort_single( + keys_input, keys_output, values_input, values_output, size, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + is_result_in_output = true; + return cudaSuccess; +} + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator +> +inline +cudaError_t radix_sort_merge_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + unsigned int size, + bool& is_result_in_output, + unsigned int begin_bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using config = default_or_custom_config< + Config, + default_radix_sort_config + >; + + constexpr bool with_values = !std::is_same::value; + + const bool with_double_buffer = keys_tmp != nullptr; + const size_t keys_bytes = ::rocprim::detail::align_size(size * sizeof(key_type)); + const size_t values_bytes = with_values ? ::rocprim::detail::align_size(size * sizeof(value_type)) : 0; + + const size_t minimum_bytes = ::rocprim::detail::align_size(1); + if(temporary_storage == nullptr) + { + if(!with_double_buffer) + storage_size = keys_bytes + values_bytes; + else + storage_size = minimum_bytes; + return cudaSuccess; + } + + if(debug_synchronous) + { + std::cout << "temporary_storage " << temporary_storage << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + if(!with_double_buffer) + { + char * ptr = reinterpret_cast(temporary_storage); + keys_tmp = reinterpret_cast(ptr); + ptr += keys_bytes; + values_tmp = with_values ? reinterpret_cast(ptr) : nullptr; + } + + cudaError_t error = radix_sort_merge( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, size, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + is_result_in_output = true; + return cudaSuccess; +} + +template +using offset_type_t = std::conditional_t< + sizeof(Size) <= 4, + unsigned int, + size_t +>; + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Size +> +inline +cudaError_t radix_sort_iterations_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + Size size, + bool& is_result_in_output, + unsigned int begin_bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + using offset_type = offset_type_t; + + using config = default_or_custom_config< + Config, + default_radix_sort_config + >; + + constexpr bool with_values = !std::is_same::value; + + constexpr unsigned int max_radix_size = 1 << config::long_radix_bits; + + constexpr unsigned int scan_size = config::scan::block_size * config::scan::items_per_thread; + constexpr unsigned int sort_size = config::sort::block_size * config::sort::items_per_thread; + + const unsigned int blocks = static_cast(::rocprim::detail::ceiling_div(size, sort_size)); + const unsigned int blocks_per_full_batch = ::rocprim::detail::ceiling_div(blocks, scan_size); + const unsigned int full_batches = blocks % scan_size != 0 + ? blocks % scan_size + : scan_size; + const unsigned int batches = (blocks_per_full_batch == 1 ? full_batches : scan_size); + const bool with_double_buffer = keys_tmp != nullptr; + + const unsigned int bits = end_bit - begin_bit; + const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, config::long_radix_bits); + const unsigned int radix_bits_diff = config::long_radix_bits - config::short_radix_bits; + const unsigned int short_iterations = radix_bits_diff != 0 + ? ::rocprim::min(iterations, (config::long_radix_bits * iterations - bits) / std::max(1u, radix_bits_diff)) + : 0; + const unsigned int long_iterations = iterations - short_iterations; + + const size_t batch_digit_counts_bytes = + ::rocprim::detail::align_size(batches * max_radix_size * sizeof(offset_type)); + const size_t digit_counts_bytes = ::rocprim::detail::align_size(max_radix_size * sizeof(offset_type)); + const size_t keys_bytes = ::rocprim::detail::align_size(size * sizeof(key_type)); + const size_t values_bytes = with_values ? ::rocprim::detail::align_size(size * sizeof(value_type)) : 0; + if(temporary_storage == nullptr) + { + storage_size = batch_digit_counts_bytes + digit_counts_bytes; + if(!with_double_buffer) + { + storage_size += keys_bytes + values_bytes; + } + return cudaSuccess; + } + + if( size == 0u ) + return cudaSuccess; + + if(debug_synchronous) + { + std::cout << "scan_size " << scan_size << '\n'; + std::cout << "sort_size " << sort_size << '\n'; + std::cout << "blocks " << blocks << '\n'; + std::cout << "blocks_per_full_batch " << blocks_per_full_batch << '\n'; + std::cout << "full_batches " << full_batches << '\n'; + std::cout << "batches " << batches << '\n'; + std::cout << "iterations " << iterations << '\n'; + std::cout << "long_iterations " << long_iterations << '\n'; + std::cout << "short_iterations " << short_iterations << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + char * ptr = reinterpret_cast(temporary_storage); + offset_type * batch_digit_counts = reinterpret_cast(ptr); + ptr += batch_digit_counts_bytes; + offset_type * digit_counts = reinterpret_cast(ptr); + ptr += digit_counts_bytes; + if(!with_double_buffer) + { + keys_tmp = reinterpret_cast(ptr); + ptr += keys_bytes; + values_tmp = with_values ? reinterpret_cast(ptr) : nullptr; + } + + bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; + bool from_input = true; + if(!with_double_buffer && to_output) + { + // Copy input keys and values if necessary (in-place sorting: input and output iterators are equal) + const bool keys_equal = ::rocprim::detail::are_iterators_equal(keys_input, keys_output); + const bool values_equal = with_values && ::rocprim::detail::are_iterators_equal(values_input, values_output); + if(keys_equal || values_equal) + { + cudaError_t error = ::rocprim::transform( + keys_input, keys_tmp, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + if(with_values) + { + cudaError_t error = ::rocprim::transform( + values_input, values_tmp, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + + from_input = false; + } + } + + unsigned int bit = begin_bit; + for(unsigned int i = 0; i < long_iterations; i++) + { + cudaError_t error = radix_sort_iteration( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + static_cast(size), batch_digit_counts, digit_counts, + from_input, to_output, + bit, end_bit, + blocks_per_full_batch, full_batches, batches, + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + is_result_in_output = to_output; + from_input = false; + to_output = !to_output; + bit += config::long_radix_bits; + } + for(unsigned int i = 0; i < short_iterations; i++) + { + cudaError_t error = radix_sort_iteration( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + static_cast(size), batch_digit_counts, digit_counts, + from_input, to_output, + bit, end_bit, + blocks_per_full_batch, full_batches, batches, + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + is_result_in_output = to_output; + from_input = false; + to_output = !to_output; + bit += config::short_radix_bits; + } + + return cudaSuccess; +} + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Size +> +inline +cudaError_t radix_sort_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + Size size, + bool& is_result_in_output, + unsigned int begin_bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + static_assert( + std::is_same::value_type>::value, + "KeysInputIterator and KeysOutputIterator must have the same value_type" + ); + static_assert( + std::is_same::value_type>::value, + "ValuesInputIterator and ValuesOutputIterator must have the same value_type" + ); + + using config = default_or_custom_config< + Config, + default_radix_sort_config + >; + + constexpr unsigned int single_sort_limit = config::sort_single::block_size * config::sort_single::items_per_thread; + constexpr unsigned int merge_sort_limit = config::sort_merge::block_size * config::sort_merge::items_per_thread * config::merge_size_limit_blocks; + + if( size <= single_sort_limit ) + { + return radix_sort_single_impl( + temporary_storage, + storage_size, + keys_input, + keys_output, + values_input, + values_output, + static_cast(size), + is_result_in_output, + begin_bit, + end_bit, + stream, + debug_synchronous + ); + } + else if( size <= merge_sort_limit ) + { + return radix_sort_merge_impl( + temporary_storage, + storage_size, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + static_cast(size), + is_result_in_output, + begin_bit, + end_bit, + stream, + debug_synchronous + ); + } + else + { + return radix_sort_iterations_impl( + temporary_storage, + storage_size, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + size, + is_result_in_output, + begin_bit, + end_bit, + stream, + debug_synchronous + ); + } +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end namespace detail + +/// \brief Parallel ascending radix sort primitive for device level. +/// +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed on an array of +/// \p float values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// // keys_output: [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class Size, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t radix_sort_keys(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type * values = nullptr; + bool ignored; + return detail::radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values, nullptr, values, + size, ignored, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size +/// ); +/// // keys_output: [8, 7, 6, 5, 4, 3, 2, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class Size, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t radix_sort_keys_desc(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type * values = nullptr; + bool ignored; + return detail::radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values, nullptr, values, + size, ignored, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] values_input - pointer to the first element in the range to sort. +/// \param [out] values_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, 0, 5 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, 0, 5 +/// ); +/// // keys_output: [ 1, 1, 3, 4, 5, 6, 7, 8] +/// // values_output: [-1, -2, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Size, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t radix_sort_pairs(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values_input, nullptr, values_output, + size, ignored, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] values_input - pointer to the first element in the range to sort. +/// \param [out] values_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size +/// ); +/// // keys_output: [ 8, 7, 6, 5, 4, 3, 1, 1] +/// // values_output: [-8, 7, -5, -4, 3, 2, -1, -2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class Size, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t radix_sort_pairs_desc(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool ignored; + return detail::radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values_input, nullptr, values_output, + size, ignored, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel ascending radix sort primitive for device level. +/// +/// \p radix_sort_keys function performs a device-wide radix sort +/// of keys. Function sorts input keys in ascending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed on an array of +/// \p float values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// // keys.current(): [0.08, 0.2, 0.3, 0.4, 0.6, 0.65, 0.7, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Size +> +inline +cudaError_t radix_sort_keys(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type * values = nullptr; + bool is_result_in_output; + cudaError_t error = detail::radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values, values, values, + size, is_result_in_output, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p radix_sort_keys_desc function performs a device-wide radix sort +/// of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * tmp; // empty array of 8 elements +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size +/// ); +/// // keys.current(): [8, 7, 6, 5, 4, 3, 2, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Size +> +inline +cudaError_t radix_sort_keys_desc(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + empty_type * values = nullptr; + bool is_result_in_output; + cudaError_t error = detail::radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values, values, values, + size, is_result_in_output, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Value - value type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values - reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_tmp; // empty array of 8 elements +/// double* values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// 0, 5 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// 0, 5 +/// ); +/// // keys.current(): [ 1, 1, 3, 4, 5, 6, 7, 8] +/// // values.current(): [-1, -2, 2, 3, -4, -5, 7, -8] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Value, + class Size +> +inline +cudaError_t radix_sort_pairs(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + cudaError_t error = detail::radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values.current(), values.current(), values.alternate(), + size, is_result_in_output, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p radix_sort_pairs_desc function performs a device-wide radix sort +/// of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p radix_sort_config or +/// a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Value - value type. +/// \tparam Size - integral type that represents the problem size. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values - reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_tmp; // empty array of 8 elements +/// double * values_tmp; // empty array of 8 elements +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size +/// ); +/// // keys.current(): [ 8, 7, 6, 5, 4, 3, 1, 1] +/// // values.current(): [-8, 7, -5, -4, 3, 2, -1, -2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Value, + class Size +> +inline +cudaError_t radix_sort_pairs_desc(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + Size size, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + static_assert(std::is_integral::value, "Size must be an integral type."); + bool is_result_in_output; + cudaError_t error = detail::radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values.current(), values.current(), values.alternate(), + size, is_result_in_output, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group devicemodule + +#endif // ROCPRIM_DEVICE_DEVICE_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_radix_sort_config.hpp b/3rdparty/cub/rocprim/device/device_radix_sort_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..79b6caf5e326cae22f819157ae52bd6432076460 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_radix_sort_config.hpp @@ -0,0 +1,390 @@ +// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_RADIX_SORT_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_RADIX_SORT_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level radix sort operation. +/// +/// Radix sort is excecuted in a single tile (at size < BlocksPerItem) or +/// few iterations (passes) depending on total number of bits to be sorted +/// (\p begin_bit and \p end_bit), each iteration sorts either \p LongRadixBits or \p ShortRadixBits bits +/// choosen to cover whole bit range in optimal way. +/// +/// For example, if \p LongRadixBits is 7, \p ShortRadixBits is 6, \p begin_bit is 0 and \p end_bit is 32 +/// there will be 5 iterations: 7 + 7 + 6 + 6 + 6 = 32 bits. +/// +/// \tparam LongRadixBits - number of bits in long iterations. +/// \tparam ShortRadixBits - number of bits in short iterations, must be equal to or less than \p LongRadixBits. +/// \tparam ScanConfig - configuration of digits scan kernel. Must be \p kernel_config. +/// \tparam SortConfig - configuration of radix sort kernel. Must be \p kernel_config. +template< + unsigned int LongRadixBits, + unsigned int ShortRadixBits, + class ScanConfig, + class SortConfig, + class SortSingleConfig = kernel_config<256, 10>, + class SortMergeConfig = kernel_config<1024, 1>, + unsigned int MergeSizeLimitBlocks = 1024U, + bool ForceSingleKernelConfig = false +> +struct radix_sort_config +{ + /// \brief Number of bits in long iterations. + static constexpr unsigned int long_radix_bits = LongRadixBits; + /// \brief Number of bits in short iterations. + static constexpr unsigned int short_radix_bits = ShortRadixBits; + /// \brief Limit number of blocks to use merge kernel. + static constexpr unsigned int merge_size_limit_blocks = MergeSizeLimitBlocks; + + /// \brief Configuration of digits scan kernel. + using scan = ScanConfig; + /// \brief Configuration of radix sort kernel. + using sort = SortConfig; + /// \brief Configuration of radix sort single kernel. + using sort_single = SortSingleConfig; + /// \brief Configuration of radix sort merge kernel. + using sort_merge = SortMergeConfig; + /// \brief Force use radix sort single kernel configuration. + static constexpr bool force_single_kernel_config = ForceSingleKernelConfig; +}; + +namespace detail +{ + +template +struct radix_sort_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + radix_sort_config< + 8, 7, scan, + kernel_config<256, 10>, kernel_config<256, 19> + > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + radix_sort_config< + 8, 7, scan, + kernel_config<256, 10>, kernel_config<256, 17> + > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + radix_sort_config< + 7, 6, scan, + kernel_config<256, 15>, kernel_config<256, 13> + > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + radix_sort_config< + 7, 6, scan, + kernel_config<256, 13>, kernel_config<256, 10> + > + >, + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + > + > + >; +}; + +template +struct radix_sort_config_803 + : select_type< + select_type_case, kernel_config<256, 10>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 10>, kernel_config<256, 16> > >, + select_type_case, kernel_config<256, 9>, kernel_config<256, 15> > >, + select_type_case, kernel_config<256, 7>, kernel_config<256, 12> > > + > { }; + +template +struct radix_sort_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + radix_sort_config<4, 4, scan, + kernel_config<256, 10>, kernel_config<256, 19> > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + radix_sort_config<6, 5, scan, + kernel_config<256, 10>, kernel_config<256, 17> > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 15>, kernel_config<256, 15> > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 15>, kernel_config<256, 12> > + >, + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + > + > + >; +}; + +template +struct radix_sort_config_900 + : select_type< + select_type_case, kernel_config<256, 10>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 10>, kernel_config<256, 16> > >, + select_type_case, kernel_config<256, 17>, kernel_config<256, 15> > >, + select_type_case, kernel_config<256, 15>, kernel_config<256, 12> > > + > { }; + + +template +struct radix_sort_config_908 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + radix_sort_config<4, 4, scan, + kernel_config<256, 10>, kernel_config<256, 19> > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + radix_sort_config<6, 5, scan, + kernel_config<256, 10>, kernel_config<256, 17> > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + radix_sort_config<7, 6, kernel_config<256, 4>, + kernel_config<256, 15>, kernel_config<256, 15> > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + radix_sort_config<7, 6, kernel_config<256, 4>, + kernel_config<256, 14>, kernel_config<256, 12> > + >, + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + > + > + >; +}; + +template +struct radix_sort_config_908 + : select_type< + select_type_case, kernel_config<256, 10>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 10>, kernel_config<256, 17> > >, + select_type_case, kernel_config<256, 17>, kernel_config<256, 15> > >, + select_type_case, kernel_config<256, 15>, kernel_config<256, 12> > > + > { }; + +// TODO: We need to update these parameters +template +struct radix_sort_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using scan = kernel_config<256, 1>; + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + radix_sort_config<4, 4, scan, + kernel_config<256, 5>, kernel_config<256, 19> > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + radix_sort_config<6, 5, scan, + kernel_config<256, 5>, kernel_config<256, 17> > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 7>, kernel_config<256, 15> > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 7>, kernel_config<256, 14> > + >, + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + > + > + >; +}; + +template +struct radix_sort_config_90a + : select_type< + select_type_case, kernel_config<256, 5>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 5>, kernel_config<256, 17> > >, + select_type_case, kernel_config<256, 8>, kernel_config<256, 15> > >, + select_type_case, kernel_config<256, 7>, kernel_config<256, 14> > > + > { }; + +// TODO: We need to update these parameters +template +struct radix_sort_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + radix_sort_config<4, 4, scan, + kernel_config<256, 10>, kernel_config<256, 19> > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + radix_sort_config<6, 5, scan, + kernel_config<256, 10>, kernel_config<256, 17> > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 15>, kernel_config<256, 15> > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + radix_sort_config<7, 6, scan, + kernel_config<256, 15>, kernel_config<256, 14> > + >, + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_32>::value, + ::rocprim::max(1u, 15u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + >, + kernel_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale) + > + > + >; +}; + +template +struct radix_sort_config_1030 + : select_type< + select_type_case, kernel_config<256, 10>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 10>, kernel_config<256, 19> > >, + select_type_case, kernel_config<256, 17>, kernel_config<256, 17> > >, + select_type_case, kernel_config<256, 15>, kernel_config<256, 15> > > + > { }; + +template +struct default_radix_sort_config + : select_arch< + TargetArch, + select_arch_case<803, radix_sort_config_803 >, + select_arch_case<900, radix_sort_config_900 >, + select_arch_case<908, radix_sort_config_908 >, + select_arch_case >, + select_arch_case<1030, radix_sort_config_1030 >, + radix_sort_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_RADIX_SORT_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_reduce.hpp b/3rdparty/cub/rocprim/device/device_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08451470ad22aef6b43ecba795f0e2034bb10a20 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_reduce.hpp @@ -0,0 +1,496 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ +#define ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" + +#include "device_reduce_config.hpp" +#include "detail/device_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + bool WithInitialValue, + class Config, + class ResultType, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void block_reduce_kernel(InputIterator input, + const size_t size, + OutputIterator output, + InitValueType initial_value, + BinaryFunction reduce_op) +{ + block_reduce_kernel_impl( + input, size, output, initial_value, reduce_op + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto _error = cudaStreamSynchronize(stream); \ + if(_error != cudaSuccess) return _error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + + +template< + bool WithInitialValue, // true when inital_value should be used in reduction + class Config, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction +> +inline +cudaError_t reduce_impl(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction reduce_op, + const cudaStream_t stream, + bool debug_synchronous) +{ + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + + // Get default config if Config is default_config + using config = default_or_custom_config< + Config, + default_reduce_config + >; + + constexpr unsigned int block_size = config::block_size; + constexpr unsigned int items_per_thread = config::items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + if(temporary_storage == nullptr) + { + storage_size = reduce_get_temporary_storage_bytes(size, items_per_block); + // Make sure user won't try to allocate 0 bytes memory + storage_size = storage_size == 0 ? 4 : storage_size; + return cudaSuccess; + } + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + static constexpr auto size_limit = config::size_limit; + static constexpr auto number_of_blocks_limit = ::rocprim::max(size_limit / items_per_block, 1); + + auto number_of_blocks = (size + items_per_block - 1)/items_per_block; + if(debug_synchronous) + { + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "number of blocks limit " << number_of_blocks_limit << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + if(number_of_blocks > 1) + { + // Pointer to array with block_prefixes + result_type * block_prefixes = static_cast(temporary_storage); + static constexpr auto aligned_size_limit = number_of_blocks_limit * items_per_block; + + // Launch number_of_blocks_limit blocks while there is still at least as many blocks left as the limit + const auto number_of_launch = (size + aligned_size_limit - 1) / aligned_size_limit; + for(size_t i = 0, offset = 0; i < number_of_launch; ++i, offset += aligned_size_limit) { + const auto current_size = std::min(size - offset, aligned_size_limit); + const auto current_blocks = (current_size + items_per_block - 1) / items_per_block; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::block_reduce_kernel + <<>>( + input + offset, + current_size, + block_prefixes + i * number_of_blocks_limit, + initial_value, + reduce_op); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", current_size, start); + } + + void * nested_temp_storage = static_cast(block_prefixes + number_of_blocks); + auto nested_temp_storage_size = storage_size - (number_of_blocks * sizeof(result_type)); + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + auto error = reduce_impl( + nested_temp_storage, + nested_temp_storage_size, + block_prefixes, // input + output, // output + initial_value, + number_of_blocks, // input size + reduce_op, + stream, + debug_synchronous + ); + if(error != cudaSuccess) return error; + ROCPRIM_DETAIL_HIP_SYNC("nested_device_reduce", number_of_blocks, start); + } + else + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::block_reduce_kernel + <<>>( + input, size, output, initial_value, reduce_op + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", size, start); + } + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR +#undef ROCPRIM_DETAIL_HIP_SYNC + +} // end of detail namespace + +/// \brief Parallel reduction primitive for device level. +/// +/// reduce function performs a device-wide reduction operation +/// using binary \p reduce_op operator. +/// +/// \par Overview +/// * Does not support non-commutative reduction operators. Reduction operator should also be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input must have at least \p size elements, while \p output +/// only needs one element. +/// * By default, the input type is used for accumulation. A custom type +/// can be specified using rocprim::transform_iterator, see the example below. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p reduce_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam InitValueType - type of the initial value. +/// \tparam BinaryFunction - type of binary function used for reduction. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to reduce. +/// \param [out] output - iterator to the first element in the output range. It can be +/// same as \p input. +/// \param [in] initial_value - initial value to start the reduction. +/// \param [in] size - number of element in the input range. +/// \param [in] reduce_op - binary operation function object that will be used for reduction. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful reduction; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level min-reduction operation is performed on an array of +/// integer values (shorts are reduced into ints) using custom operator. +/// +/// \code{.cpp} +/// #include +/// +/// // custom reduce function +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [4, 7, 6, 2, 5, 1, 3, 8] +/// int * output; // empty array of 1 element +/// int start_value; // e.g., 9 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, start_value, input_size, min_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform reduce +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, start_value, input_size, min_op +/// ); +/// // output: [1] +/// \endcode +/// +/// The same example as above, but now a custom accumulator type is specified. +/// +/// \code{.cpp} +/// #include +/// +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// size_t input_size; +/// short * input; +/// int * output; +/// int start_value; +/// +/// // Use a transform iterator to specifiy a custom accumulator type +/// auto input_iterator = rocprim::make_transform_iterator( +/// input, [] __device__ (T in) { return static_cast(in); }); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Use the transform iterator +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, start_value, input_size, min_op +/// ); +/// +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, start_value, input_size, min_op +/// ); +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t reduce(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction reduce_op = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::reduce_impl( + temporary_storage, storage_size, + input, output, initial_value, size, + reduce_op, stream, debug_synchronous + ); +} + +/// \brief Parallel reduce primitive for device level. +/// +/// reduce function performs a device-wide reduction operation +/// using binary \p reduce_op operator. +/// +/// \par Overview +/// * Does not support non-commutative reduction operators. Reduction operator should also be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input must have at least \p size elements, while \p output +/// only needs one element. +/// * By default, the input type is used for accumulation. A custom type +/// can be specified using rocprim::transform_iterator, see the example below. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p reduce_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for reduction. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to reduce. +/// \param [out] output - iterator to the first element in the output range. It can be +/// same as \p input. +/// \param [in] size - number of element in the input range. +/// \param [in] reduce_op - binary operation function object that will be used for reduction. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful reduction; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level sum operation is performed on an array of +/// integer values (shorts are reduced into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform reduce +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, rocprim::plus() +/// ); +/// // output: [36] +/// \endcode +/// +/// The same example as above, but now a custom accumulator type is specified. +/// +/// \code{.cpp} +/// #include +/// +/// size_t input_size; +/// short * input; +/// int * output; +/// +/// // Use a transform iterator to specifiy a custom accumulator type +/// auto input_iterator = rocprim::make_transform_iterator( +/// input, [] __device__ (T in) { return static_cast(in); }); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Use the transform iterator +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, start_value, input_size, rocprim::plus() +/// ); +/// +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// rocprim::reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, start_value, input_size, rocprim::plus() +/// ); +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t reduce(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const size_t size, + BinaryFunction reduce_op = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + + return detail::reduce_impl( + temporary_storage, storage_size, + input, output, input_type(), size, + reduce_op, stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_reduce_by_key.hpp b/3rdparty/cub/rocprim/device/device_reduce_by_key.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b03ae3a8f855152776f87de75c7bbfdf69ce07ff --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_reduce_by_key.hpp @@ -0,0 +1,413 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_HPP_ +#define ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" + +#include "../functional.hpp" + +#include "device_reduce_by_key_config.hpp" +#include "detail/device_reduce_by_key.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeyCompareFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void fill_unique_counts_kernel(KeysInputIterator keys_input, + unsigned int size, + unsigned int * unique_counts, + KeyCompareFunction key_compare_op, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + fill_unique_counts( + keys_input, size, + unique_counts, + key_compare_op, + blocks_per_full_batch, full_batches + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class UniqueCountOutputIterator +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void scan_unique_counts_kernel(unsigned int * unique_counts, + UniqueCountOutputIterator unique_count_output, + unsigned int batches) +{ + scan_unique_counts(unique_counts, unique_count_output, batches); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class ValuesInputIterator, + class Result, + class UniqueOutputIterator, + class AggregatesOutputIterator, + class KeyCompareFunction, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void reduce_by_key_kernel(KeysInputIterator keys_input, + ValuesInputIterator values_input, + unsigned int size, + const unsigned int * unique_starts, + carry_out * carry_outs, + Result * leading_aggregates, + UniqueOutputIterator unique_output, + AggregatesOutputIterator aggregates_output, + KeyCompareFunction key_compare_op, + BinaryFunction reduce_op, + unsigned int blocks_per_full_batch, + unsigned int full_batches) +{ + reduce_by_key( + keys_input, values_input, size, + unique_starts, carry_outs, leading_aggregates, + unique_output, aggregates_output, + key_compare_op, reduce_op, + blocks_per_full_batch, full_batches + ); +} + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class Result, + class AggregatesOutputIterator, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void scan_and_scatter_carry_outs_kernel(const carry_out * carry_outs, + const Result * leading_aggregates, + AggregatesOutputIterator aggregates_output, + BinaryFunction reduce_op, + unsigned int batches) +{ + scan_and_scatter_carry_outs( + carry_outs, leading_aggregates, aggregates_output, + reduce_op, + batches + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + class Config, + class KeysInputIterator, + class ValuesInputIterator, + class UniqueOutputIterator, + class AggregatesOutputIterator, + class UniqueCountOutputIterator, + class BinaryFunction, + class KeyCompareFunction +> +inline +cudaError_t reduce_by_key_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + ValuesInputIterator values_input, + const unsigned int size, + UniqueOutputIterator unique_output, + AggregatesOutputIterator aggregates_output, + UniqueCountOutputIterator unique_count_output, + BinaryFunction reduce_op, + KeyCompareFunction key_compare_op, + const cudaStream_t stream, + const bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + typename std::iterator_traits::value_type, + BinaryFunction + >::type; + using carry_out_type = carry_out; + + using config = default_or_custom_config< + Config, + default_reduce_by_key_config + >; + + constexpr unsigned int items_per_block = config::reduce::block_size * config::reduce::items_per_thread; + constexpr unsigned int scan_items_per_block = config::scan::block_size * config::scan::items_per_thread; + + const unsigned int blocks = std::max(1u, ::rocprim::detail::ceiling_div(size, items_per_block)); + const unsigned int blocks_per_full_batch = ::rocprim::detail::ceiling_div(blocks, scan_items_per_block); + const unsigned int full_batches = blocks % scan_items_per_block != 0 + ? blocks % scan_items_per_block + : scan_items_per_block; + const unsigned int batches = (blocks_per_full_batch == 1 ? full_batches : scan_items_per_block); + + const size_t unique_counts_bytes = ::rocprim::detail::align_size(batches * sizeof(unsigned int)); + const size_t carry_outs_bytes = ::rocprim::detail::align_size(batches * sizeof(carry_out_type)); + const size_t leading_aggregates_bytes = ::rocprim::detail::align_size(batches * sizeof(result_type)); + if(temporary_storage == nullptr) + { + storage_size = unique_counts_bytes + carry_outs_bytes + leading_aggregates_bytes; + return cudaSuccess; + } + + if(debug_synchronous) + { + std::cout << "blocks " << blocks << '\n'; + std::cout << "blocks_per_full_batch " << blocks_per_full_batch << '\n'; + std::cout << "full_batches " << full_batches << '\n'; + std::cout << "batches " << batches << '\n'; + std::cout << "storage_size " << storage_size << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + char * ptr = reinterpret_cast(temporary_storage); + unsigned int * unique_counts = reinterpret_cast(ptr); + ptr += unique_counts_bytes; + carry_out_type * carry_outs = reinterpret_cast(ptr); + ptr += carry_outs_bytes; + result_type * leading_aggregates = reinterpret_cast(ptr); + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + fill_unique_counts_kernel + <<>>( + keys_input, size, unique_counts, key_compare_op, + blocks_per_full_batch, full_batches + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("fill_unique_counts", size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + scan_unique_counts_kernel + <<>>( + unique_counts, unique_count_output, + batches + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_unique_counts", config::scan::block_size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + reduce_by_key_kernel + <<>>( + keys_input, values_input, size, + const_cast(unique_counts), carry_outs, leading_aggregates, + unique_output, aggregates_output, + key_compare_op, reduce_op, + blocks_per_full_batch, full_batches + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("reduce_by_key", size, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + scan_and_scatter_carry_outs_kernel + <<>>( + const_cast(carry_outs), const_cast(leading_aggregates), + aggregates_output, + reduce_op, + batches + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_and_scatter_carry_outs", config::scan::block_size, start) + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end of detail namespace + +/// \brief Parallel reduce-by-key primitive for device level. +/// +/// reduce_by_key function performs a device-wide reduction operation of groups +/// of consecutive values having the same key using binary \p reduce_op operator. The first key of each group +/// is copied to \p unique_output and reduction of the group is written to \p aggregates_output. +/// The total number of group is written to \p unique_count_output. +/// +/// \par Overview +/// * Supports non-commutative reduction operators. However, a reduction operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p keys_input and \p values_input must have at least \p size elements. +/// * Range specified by \p unique_count_output must have at least 1 element. +/// * Ranges specified by \p unique_output and \p aggregates_output must have at least +/// *unique_count_output (i.e. the number of unique keys) elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p reduce_by_key_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam UniqueOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam AggregatesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam UniqueCountOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for reduction. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p ValuesInputIterator. +/// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type +/// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - iterator to the first element in the range of keys. +/// \param [in] values_input - iterator to the first element in the range of values to reduce. +/// \param [in] size - number of element in the input range. +/// \param [out] unique_output - iterator to the first element in the output range of unique keys. +/// \param [out] aggregates_output - iterator to the first element in the output range of reductions. +/// \param [out] unique_count_output - iterator to total number of groups. +/// \param [in] reduce_op - binary operation function object that will be used for reduction. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] key_compare_op - binary operation function object that will be used to determine keys equality. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is KeyCompareFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful reduction; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level sum operation is performed on an array of +/// integer values and integer keys. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [1, 1, 1, 2, 10, 10, 10, 88] +/// int * values_input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * unique_output; // empty array of at least 4 elements +/// int * aggregates_output; // empty array of at least 4 elements +/// int * unique_count_output; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::reduce_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, input_size, +/// unique_output, aggregates_output, unique_count_output +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform reduction +/// rocprim::reduce_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, input_size, +/// unique_output, aggregates_output, unique_count_output +/// ); +/// // unique_output: [1, 2, 10, 88] +/// // aggregates_output: [6, 4, 18, 8] +/// // unique_count_output: [4] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class ValuesInputIterator, + class UniqueOutputIterator, + class AggregatesOutputIterator, + class UniqueCountOutputIterator, + class BinaryFunction = ::rocprim::plus::value_type>, + class KeyCompareFunction = ::rocprim::equal_to::value_type> +> +inline +cudaError_t reduce_by_key(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + ValuesInputIterator values_input, + unsigned int size, + UniqueOutputIterator unique_output, + AggregatesOutputIterator aggregates_output, + UniqueCountOutputIterator unique_count_output, + BinaryFunction reduce_op = BinaryFunction(), + KeyCompareFunction key_compare_op = KeyCompareFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::reduce_by_key_impl( + temporary_storage, storage_size, + keys_input, values_input, size, + unique_output, aggregates_output, unique_count_output, + reduce_op, key_compare_op, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_reduce_by_key_config.hpp b/3rdparty/cub/rocprim/device/device_reduce_by_key_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62f6e000e328d8dd4090f38432b6589e8271ded4 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_reduce_by_key_config.hpp @@ -0,0 +1,143 @@ +// Copyright (c) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level reduce-by-key operation. +/// +/// \tparam ScanConfig - configuration of carry-outs scan kernel. Must be \p kernel_config. +/// \tparam ReduceConfig - configuration of the main reduce-by-key kernel. Must be \p kernel_config. +template< + class ScanConfig, + class ReduceConfig +> +struct reduce_by_key_config +{ + /// \brief Configuration of carry-outs scan kernel. + using scan = ScanConfig; + /// \brief Configuration of the main reduce-by-key kernel. + using reduce = ReduceConfig; +}; + +namespace detail +{ + +template +struct reduce_by_key_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using scan = kernel_config<256, 4>; + + using type = select_type< + select_type_case< + (sizeof(Key) <= 8 && sizeof(Value) <= 8), + reduce_by_key_config > + >, + reduce_by_key_config::value, ::rocprim::max(1u, 15u / item_scale)> > + >; +}; + +template +struct reduce_by_key_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) <= 8 && sizeof(Value) <= 8), + reduce_by_key_config > + >, + reduce_by_key_config::value, ::rocprim::max(1u, 15u / item_scale)> > + >; +}; + +// TODO: We need to update these parameters +template +struct reduce_by_key_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) <= 8 && sizeof(Value) <= 8), + reduce_by_key_config > + >, + reduce_by_key_config::value, ::rocprim::max(1u, 15u / item_scale)> > + >; +}; + +// TODO: We need to update these parameters +template +struct reduce_by_key_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using scan = kernel_config<256, 2>; + + using type = select_type< + select_type_case< + (sizeof(Key) <= 8 && sizeof(Value) <= 8), + reduce_by_key_config > + >, + reduce_by_key_config::value, ::rocprim::max(1u, 15u / item_scale)> > + >; +}; + +template +struct default_reduce_by_key_config + : select_arch< + TargetArch, + select_arch_case<803, reduce_by_key_config_803 >, + select_arch_case<900, reduce_by_key_config_900 >, + select_arch_case >, + select_arch_case<1030, reduce_by_key_config_1030 >, + reduce_by_key_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_REDUCE_BY_KEY_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_reduce_config.hpp b/3rdparty/cub/rocprim/device/device_reduce_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..10da280c7d99db44034834dd3dd662e95d7bbb3f --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_reduce_config.hpp @@ -0,0 +1,115 @@ +// Copyright (c) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_REDUCE_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_REDUCE_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../block/block_reduce.hpp" + +#include "config_types.hpp" +#include "detail/device_config_helper.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + + +namespace detail +{ + +template +struct reduce_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = reduce_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ::rocprim::block_reduce_algorithm::using_warp_reduce + >; +}; + +template +struct reduce_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = reduce_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ::rocprim::block_reduce_algorithm::using_warp_reduce + >; +}; + +// TODO: We need to update these parameters +template +struct reduce_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = reduce_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ::rocprim::block_reduce_algorithm::using_warp_reduce + >; +}; + +// TODO: We need to update these parameters +template +struct reduce_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = reduce_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_32>::value, + ::rocprim::max(1u, 16u / item_scale), + ::rocprim::block_reduce_algorithm::using_warp_reduce + >; +}; + +template +struct default_reduce_config + : select_arch< + TargetArch, + select_arch_case<803, reduce_config_803>, + select_arch_case<900, reduce_config_900>, + select_arch_case>, + select_arch_case<1030, reduce_config_1030>, + reduce_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_REDUCE_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_run_length_encode.hpp b/3rdparty/cub/rocprim/device/device_run_length_encode.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c5cf03a50f393cf6ef0f3f711400d9433eb182fa --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_run_length_encode.hpp @@ -0,0 +1,411 @@ +// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ +#define ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../iterator/constant_iterator.hpp" +#include "../iterator/counting_iterator.hpp" +#include "../iterator/discard_iterator.hpp" +#include "../iterator/zip_iterator.hpp" + +#include "device_run_length_encode_config.hpp" +#include "device_reduce_by_key.hpp" +#include "device_select.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + if(error != cudaSuccess) return error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto error = cudaStreamSynchronize(stream); \ + if(error != cudaSuccess) return error; \ + auto end = std::chrono::high_resolution_clock::now(); \ + auto d = std::chrono::duration_cast>(end - start); \ + std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +} // end detail namespace + +/// \brief Parallel run-length encoding for device level. +/// +/// run_length_encode function performs a device-wide run-length encoding of runs (groups) +/// of consecutive values. The first value of each run is copied to \p unique_output and +/// the length of the run is written to \p counts_output. +/// The total number of runs is written to \p runs_count_output. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Range specified by \p input must have at least \p size elements. +/// * Range specified by \p runs_count_output must have at least 1 element. +/// * Ranges specified by \p unique_output and \p counts_output must have at least +/// *runs_count_output (i.e. the number of runs) elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p run_length_encode_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam UniqueOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam CountsOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam RunsCountOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range of values. +/// \param [in] size - number of element in the input range. +/// \param [out] unique_output - iterator to the first element in the output range of unique values. +/// \param [out] counts_output - iterator to the first element in the output range of lenghts. +/// \param [out] runs_count_output - iterator to total number of runs. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level run-length encoding operation is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 1, 1, 2, 10, 10, 10, 88] +/// int * unique_output; // empty array of at least 4 elements +/// int * counts_output; // empty array of at least 4 elements +/// int * runs_count_output; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::run_length_encode( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, input_size, +/// unique_output, counts_output, runs_count_output +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform encoding +/// rocprim::run_length_encode( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, input_size, +/// unique_output, counts_output, runs_count_output +/// ); +/// // unique_output: [1, 2, 10, 88] +/// // counts_output: [3, 1, 3, 1] +/// // runs_count_output: [4] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class UniqueOutputIterator, + class CountsOutputIterator, + class RunsCountOutputIterator +> +inline +cudaError_t run_length_encode(void * temporary_storage, + size_t& storage_size, + InputIterator input, + unsigned int size, + UniqueOutputIterator unique_output, + CountsOutputIterator counts_output, + RunsCountOutputIterator runs_count_output, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + using count_type = unsigned int; + + using config = detail::default_or_custom_config< + Config, + detail::default_run_length_encode_config + >; + + return ::rocprim::reduce_by_key( + temporary_storage, storage_size, + input, make_constant_iterator(1), size, + unique_output, counts_output, runs_count_output, + ::rocprim::plus(), ::rocprim::equal_to(), + stream, debug_synchronous + ); +} + +/// \brief Parallel run-length encoding of non-trivial runs for device level. +/// +/// run_length_encode_non_trivial_runs function performs a device-wide run-length encoding of +/// non-trivial runs (groups) of consecutive values (groups of more than one element). +/// The offset of the first value of each non-trivial run is copied to \p offsets_output and +/// the length of the run (the count of elements) is written to \p counts_output. +/// The total number of non-trivial runs is written to \p runs_count_output. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Range specified by \p input must have at least \p size elements. +/// * Range specified by \p runs_count_output must have at least 1 element. +/// * Ranges specified by \p offsets_output and \p counts_output must have at least +/// *runs_count_output (i.e. the number of non-trivial runs) elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p run_length_encode_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OffsetsOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam CountsOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam RunsCountOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range of values. +/// \param [in] size - number of element in the input range. +/// \param [out] offsets_output - iterator to the first element in the output range of offsets. +/// \param [out] counts_output - iterator to the first element in the output range of lenghts. +/// \param [out] runs_count_output - iterator to total number of runs. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful operation; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level run-length encoding of non-trivial runs is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 1, 1, 2, 10, 10, 10, 88] +/// int * offsets_output; // empty array of at least 2 elements +/// int * counts_output; // empty array of at least 2 elements +/// int * runs_count_output; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::run_length_encode_non_trivial_runs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, input_size, +/// offsets_output, counts_output, runs_count_output +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform encoding +/// rocprim::run_length_encode_non_trivial_runs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, input_size, +/// offsets_output, counts_output, runs_count_output +/// ); +/// // offsets_output: [0, 4] +/// // counts_output: [3, 3] +/// // runs_count_output: [2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OffsetsOutputIterator, + class CountsOutputIterator, + class RunsCountOutputIterator +> +inline +cudaError_t run_length_encode_non_trivial_runs(void * temporary_storage, + size_t& storage_size, + InputIterator input, + unsigned int size, + OffsetsOutputIterator offsets_output, + CountsOutputIterator counts_output, + RunsCountOutputIterator runs_count_output, + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + using offset_type = unsigned int; + using count_type = unsigned int; + using offset_count_pair = typename ::rocprim::tuple; + + using config = detail::default_or_custom_config< + Config, + detail::default_run_length_encode_config + >; + + cudaError_t error; + + auto reduce_op = [] __device__ (const offset_count_pair& a, const offset_count_pair& b) + { + return offset_count_pair( + ::rocprim::get<0>(a), // Always use offset of the first item of the run + ::rocprim::get<1>(a) + ::rocprim::get<1>(b) // Number of items in the run + ); + }; + auto non_trivial_runs_select_op = [] __device__ (const offset_count_pair& a) + { + return ::rocprim::get<1>(a) > 1; + }; + + offset_type * offsets_tmp = nullptr; + count_type * counts_tmp = nullptr; + count_type * all_runs_count_tmp = nullptr; + + // Calculate size of temporary storage for reduce_by_key operation + size_t reduce_by_key_bytes; + error = ::rocprim::reduce_by_key( + nullptr, reduce_by_key_bytes, + input, + ::rocprim::make_zip_iterator( + ::rocprim::make_tuple( + ::rocprim::make_counting_iterator(0), + ::rocprim::make_constant_iterator(1) + ) + ), + size, + ::rocprim::make_discard_iterator(), + ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_tmp, counts_tmp)), + all_runs_count_tmp, + reduce_op, ::rocprim::equal_to(), + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + reduce_by_key_bytes = ::rocprim::detail::align_size(reduce_by_key_bytes); + + // Calculate size of temporary storage for select operation + size_t select_bytes; + error = ::rocprim::select( + nullptr, select_bytes, + ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_tmp, counts_tmp)), + ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_output, counts_output)), + runs_count_output, + size, + non_trivial_runs_select_op, + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + select_bytes = ::rocprim::detail::align_size(select_bytes); + + const size_t offsets_tmp_bytes = ::rocprim::detail::align_size(size * sizeof(offset_type)); + const size_t counts_tmp_bytes = ::rocprim::detail::align_size(size * sizeof(count_type)); + const size_t all_runs_count_tmp_bytes = sizeof(count_type); + if(temporary_storage == nullptr) + { + storage_size = ::rocprim::max(reduce_by_key_bytes, select_bytes) + + offsets_tmp_bytes + counts_tmp_bytes + all_runs_count_tmp_bytes; + return cudaSuccess; + } + + char * ptr = reinterpret_cast(temporary_storage); + ptr += ::rocprim::max(reduce_by_key_bytes, select_bytes); + offsets_tmp = reinterpret_cast(ptr); + ptr += offsets_tmp_bytes; + counts_tmp = reinterpret_cast(ptr); + ptr += counts_tmp_bytes; + all_runs_count_tmp = reinterpret_cast(ptr); + + std::chrono::high_resolution_clock::time_point start; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + error = ::rocprim::reduce_by_key( + temporary_storage, reduce_by_key_bytes, + input, + ::rocprim::make_zip_iterator( + ::rocprim::make_tuple( + ::rocprim::make_counting_iterator(0), + ::rocprim::make_constant_iterator(1) + ) + ), + size, + ::rocprim::make_discard_iterator(), // Ignore unique output + ::rocprim::make_zip_iterator(rocprim::make_tuple(offsets_tmp, counts_tmp)), + all_runs_count_tmp, + reduce_op, ::rocprim::equal_to(), + stream, debug_synchronous + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::reduce_by_key", size, start) + + // Read count of all runs (including trivial runs) + count_type all_runs_count; + // cudaMemcpyWithStream is only supported on rocm 3.1 and above + error = cudaMemcpyAsync(&all_runs_count, all_runs_count_tmp, sizeof(count_type), cudaMemcpyDeviceToHost, stream); + if(error != cudaSuccess) return error; + error = cudaStreamSynchronize(stream); + + + + // Select non-trivial runs + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + error = ::rocprim::select( + temporary_storage, select_bytes, + ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_tmp, counts_tmp)), + ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_output, counts_output)), + runs_count_output, + all_runs_count, + non_trivial_runs_select_op, + stream, debug_synchronous + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::select", all_runs_count, start) + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_run_length_encode_config.hpp b/3rdparty/cub/rocprim/device/device_run_length_encode_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a64987d362d99e0f710376a1e3b72362348d62ab --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_run_length_encode_config.hpp @@ -0,0 +1,66 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level run-length encoding operation. +/// +/// \tparam ReduceByKeyConfig - configuration of device-level reduce-by-key operation. +/// Must be \p reduce_by_key_config or \p default_config. +/// \tparam SelectConfig - configuration of device-level select operation. +/// Must be \p select_config or \p default_config. +template< + class ReduceByKeyConfig, + class SelectConfig = default_config +> +struct run_length_encode_config +{ + /// \brief Configuration of device-level reduce-by-key operation. + using reduce_by_key = ReduceByKeyConfig; + /// \brief Configuration of device-level select operation. + using select = SelectConfig; +}; + +namespace detail +{ + +using default_run_length_encode_config = run_length_encode_config; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_RUN_LENGTH_ENCODE_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_scan.hpp b/3rdparty/cub/rocprim/device/device_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20e27f0a9dd53c9deaa40e885b86e48a1fdc69ad --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_scan.hpp @@ -0,0 +1,826 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../functional.hpp" +#include "../type_traits.hpp" +#include "../types/future_value.hpp" +#include "../detail/various.hpp" + +#include "device_scan_config.hpp" +#include "device_transform.hpp" +#include "detail/device_scan_common.hpp" +#include "detail/device_scan_lookback.hpp" +#include "detail/device_scan_reduce_then_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +// Single kernel scan (performs scan on one thread block only) +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class InitValueType +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void single_scan_kernel(InputIterator input, + const size_t size, + const InitValueType initial_value, + OutputIterator output, + BinaryFunction scan_op) +{ + single_scan_kernel_impl( + input, size, get_input_value(initial_value), output, scan_op + ); +} + +// Reduce-then-scan kernels + +// Calculates block prefixes that will be used in final_scan_kernel +// when performing block scan operations. +template< + class Config, + class InputIterator, + class BinaryFunction, + class ResultType +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void block_reduce_kernel(InputIterator input, + BinaryFunction scan_op, + ResultType * block_prefixes) +{ + block_reduce_kernel_impl( + input, scan_op, block_prefixes + ); +} + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class InitValueType +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void final_scan_kernel(InputIterator input, + const size_t size, + OutputIterator output, + const InitValueType initial_value, + BinaryFunction scan_op, + input_type_t* block_prefixes, + input_type_t* previous_last_element = nullptr, + input_type_t* new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) +{ + final_scan_kernel_impl( + input, size, output, get_input_value(initial_value), + scan_op, block_prefixes, + previous_last_element, new_last_element, + override_first_value, save_last_value + ); +} + +// Single pass (look-back kernels) + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class BinaryFunction, + class InitValueType, + class LookBackScanState +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void lookback_scan_kernel(InputIterator input, + OutputIterator output, + const size_t size, + const InitValueType initial_value, + BinaryFunction scan_op, + LookBackScanState lookback_scan_state, + const unsigned int number_of_blocks, + ordered_block_id ordered_bid, + input_type_t* previous_last_element = nullptr, + input_type_t* new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) +{ + lookback_scan_kernel_impl( + input, output, size, get_input_value(initial_value), scan_op, + lookback_scan_state, number_of_blocks, ordered_bid, + previous_last_element, new_last_element, + override_first_value, save_last_value + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto error = cudaStreamSynchronize(stream); \ + if(error != cudaSuccess) return error; \ + auto end = std::chrono::high_resolution_clock::now(); \ + auto d = std::chrono::duration_cast>(end - start); \ + std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ + } + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction +> +inline +auto scan_impl(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction scan_op, + const cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if::type +{ + using config = Config; + using real_init_value_type = input_type_t; + + constexpr unsigned int block_size = config::block_size; + constexpr unsigned int items_per_thread = config::items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + static constexpr size_t size_limit = config::size_limit; + static constexpr size_t aligned_size_limit = ::rocprim::max(size_limit - size_limit % items_per_block, items_per_block); + size_t limited_size = std::min(size, aligned_size_limit); + const bool use_limited_size = limited_size == aligned_size_limit; + size_t nested_prefixes_size_bytes = scan_get_temporary_storage_bytes(limited_size, items_per_block); + + // Calculate required temporary storage + if(temporary_storage == nullptr) + { + storage_size = nested_prefixes_size_bytes; + + if(use_limited_size) + storage_size += 4 * sizeof(real_init_value_type); + + // Make sure user won't try to allocate 0 bytes memory, because + // cudaMalloc will return nullptr when size is zero. + storage_size = storage_size == 0 ? 4 : storage_size; + return cudaSuccess; + } + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + auto number_of_blocks = (size + items_per_block - 1)/items_per_block; + + if( number_of_blocks == 0u ) + return cudaSuccess; + + if(number_of_blocks > 1) + { + unsigned int number_of_launch = (size + limited_size - 1)/limited_size; + for (size_t i = 0, offset = 0; i < number_of_launch; i++, offset+=limited_size ) + { + size_t current_size = std::min(size - offset, limited_size); + number_of_blocks = (current_size + items_per_block - 1)/items_per_block; + if(debug_synchronous) + { + std::cout << "use_limited_size " << use_limited_size << '\n'; + std::cout << "number_of_launch " << number_of_launch << '\n'; + std::cout << "inex " << i << '\n'; + std::cout << "aligned_size_limit " << aligned_size_limit << '\n'; + std::cout << "size " << current_size << '\n'; + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + std::cout.flush(); + } + + // Pointer to array with block_prefixes + char * ptr = reinterpret_cast(temporary_storage); + real_init_value_type* block_prefixes = reinterpret_cast(ptr); + real_init_value_type* previous_last_element = nullptr; + real_init_value_type* new_last_element = nullptr; + if(use_limited_size) + { + ptr += nested_prefixes_size_bytes; + previous_last_element = reinterpret_cast(ptr); + + ptr += sizeof(real_init_value_type); + new_last_element = reinterpret_cast(ptr); + } + + // Grid size for block_reduce_kernel, we don't need to calculate reduction + // of the last block as it will never be used as prefix for other blocks + auto grid_size = number_of_blocks - 1; + if( grid_size != 0 ) + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::block_reduce_kernel< + config, InputIterator, BinaryFunction, real_init_value_type + > + <<>>( + input + offset, scan_op, block_prefixes + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", current_size, start) + + if( !Exclusive && i > 0 ) + { + cudaError_t error = ::rocprim::transform( + previous_last_element, block_prefixes, block_prefixes, 1, + scan_op, stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + + // TODO: Performance may increase if for (number_of_blocks < 8192) (or some other + // threshold) we would just use CPU to calculate prefixes. + + // Calculate size of temporary storage for nested device scan operation + void * nested_temp_storage = static_cast(block_prefixes + number_of_blocks); + auto nested_temp_storage_size = storage_size - (number_of_blocks * sizeof(real_init_value_type)); + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + auto error = scan_impl( + nested_temp_storage, + nested_temp_storage_size, + block_prefixes, // input + block_prefixes, // output + real_init_value_type(), // dummy initial value + number_of_blocks, // input size + scan_op, + stream, + debug_synchronous + ); + if(error != cudaSuccess) return error; + ROCPRIM_DETAIL_HIP_SYNC("nested_device_scan", number_of_blocks, start); + + } + + // Grid size for final_scan_kernel + grid_size = number_of_blocks; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::final_scan_kernel< + Exclusive, // flag for exclusive scan operation + config, // kernel configuration (block size, ipt) + InputIterator, OutputIterator, + BinaryFunction, InitValueType + > + <<>>( + input + offset, + current_size, + output + offset, + initial_value, + scan_op, + block_prefixes, + previous_last_element, + new_last_element, + i != size_t(0) && ((!Exclusive && number_of_blocks == 1) || Exclusive), + number_of_launch > 1 + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("final_scan_kernel", size, start); + + // Swap the last_elements if it's necessary + if(number_of_launch > 1) + { + cudaError_t error = ::rocprim::transform( + new_last_element, previous_last_element, 1, + ::rocprim::identity(), + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + } + } + else + { + if(debug_synchronous) + { + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::single_scan_kernel< + Exclusive, // flag for exclusive scan operation + config, // kernel configuration (block size, ipt) + InputIterator, OutputIterator, BinaryFunction + > + <<>>( + input, size, initial_value, output, scan_op + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); + } + return cudaSuccess; +} + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction +> +inline +auto scan_impl(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction scan_op, + const cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if::type +{ + using config = Config; + using real_init_value_type = input_type_t; + + using scan_state_type = detail::lookback_scan_state; + using scan_state_with_sleep_type = detail::lookback_scan_state; + using ordered_block_id_type = detail::ordered_block_id; + + constexpr unsigned int block_size = config::block_size; + constexpr unsigned int items_per_thread = config::items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + static constexpr size_t size_limit = config::size_limit; + static constexpr size_t aligned_size_limit = ::rocprim::max(size_limit - size_limit % items_per_block, items_per_block); + size_t limited_size = std::min(size, aligned_size_limit); + const bool use_limited_size = limited_size == aligned_size_limit; + + unsigned int number_of_blocks = (limited_size + items_per_block - 1)/items_per_block; + + // Calculate required temporary storage + size_t scan_state_bytes = ::rocprim::detail::align_size( + // This is valid even with scan_state_with_sleep_type + scan_state_type::get_storage_size(number_of_blocks) + ); + size_t ordered_block_id_bytes = ordered_block_id_type::get_storage_size(); + if(temporary_storage == nullptr) + { + // storage_size is never zero + storage_size = scan_state_bytes + ordered_block_id_bytes; + + if(use_limited_size) + storage_size += 2 * sizeof(real_init_value_type); + + return cudaSuccess; + } + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + if( number_of_blocks == 0u ) + return cudaSuccess; + + if(number_of_blocks > 1 || use_limited_size) + { + // Create and initialize lookback_scan_state obj + auto scan_state = scan_state_type::create(temporary_storage, number_of_blocks); + auto scan_state_with_sleep = scan_state_with_sleep_type::create(temporary_storage, number_of_blocks); + // Create ad initialize ordered_block_id obj + auto ptr = reinterpret_cast(temporary_storage); + auto ordered_bid = ordered_block_id_type::create( + reinterpret_cast(ptr + scan_state_bytes) + ); + + // The last element + real_init_value_type* previous_last_element = nullptr; + real_init_value_type* new_last_element = nullptr; + if(use_limited_size) + { + ptr += storage_size - sizeof(real_init_value_type); + new_last_element = reinterpret_cast(ptr); + ptr -= sizeof(real_init_value_type); + previous_last_element = reinterpret_cast(ptr); + } + + cudaDeviceProp prop; + int deviceId; + static_cast(cudaGetDevice(&deviceId)); + static_cast(cudaGetDeviceProperties(&prop, deviceId)); + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + + int asicRevision = 0; + + + size_t number_of_launch = (size + limited_size - 1)/limited_size; + for (size_t i = 0, offset = 0; i < number_of_launch; i++, offset+=limited_size ) + { + size_t current_size = std::min(size - offset, limited_size); + number_of_blocks = (current_size + items_per_block - 1)/items_per_block; + auto grid_size = (number_of_blocks + block_size - 1)/block_size; + + if(debug_synchronous) + { + std::cout << "use_limited_size " << use_limited_size << '\n'; + std::cout << "aligned_size_limit " << aligned_size_limit << '\n'; + std::cout << "number_of_launch " << number_of_launch << '\n'; + std::cout << "index " << i << '\n'; + std::cout << "size " << current_size << '\n'; + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + + + init_lookback_scan_state_kernel + <<>>( + scan_state, number_of_blocks, ordered_bid + ); + + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_lookback_scan_state_kernel", number_of_blocks, start) + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + grid_size = number_of_blocks; + + + if(debug_synchronous) + { + std::cout << "use_limited_size " << use_limited_size << '\n'; + std::cout << "aligned_size_limit " << aligned_size_limit << '\n'; + std::cout << "size " << current_size << '\n'; + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + lookback_scan_kernel< + Exclusive, // flag for exclusive scan operation + config, // kernel configuration (block size, ipt) + InputIterator, OutputIterator, + BinaryFunction, InitValueType, scan_state_type + > + <<>>( + input + offset, output + offset, current_size, initial_value, + scan_op, scan_state, number_of_blocks, ordered_bid, + previous_last_element, new_last_element, + i != size_t(0), number_of_launch > 1 + ); + + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start) + + // Swap the last_elements + if(number_of_launch > 1) + { + cudaError_t error = ::rocprim::transform( + new_last_element, previous_last_element, 1, + ::rocprim::identity(), + stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + } + } + else + { + if(debug_synchronous) + { + std::cout << "size " << size << '\n'; + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + single_scan_kernel< + Exclusive, // flag for exclusive scan operation + config, // kernel configuration (block size, ipt) + InputIterator, OutputIterator, BinaryFunction + > + <<>>( + input, size, initial_value, output, scan_op + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); + } + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR +#undef ROCPRIM_DETAIL_HIP_SYNC + +} // end of detail namespace + +/// \brief Parallel inclusive scan primitive for device level. +/// +/// inclusive_scan function performs a device-wide inclusive prefix scan operation +/// using binary \p scan_op operator. +/// +/// \par Overview +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input and \p output must have at least \p size elements. +/// * By default, the input type is used for accumulation. A custom type +/// can be specified using rocprim::transform_iterator, see the example below. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for scan. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. It can be +/// same as \p input. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level inclusive sum operation is performed on an array of +/// integer values (shorts are scanned into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, rocprim::plus() +/// ); +/// // output: [1, 3, 6, 10, 15, 21, 28, 36] +/// \endcode +/// +/// The same example as above, but now a custom accumulator type is specified. +/// +/// \code{.cpp} +/// #include +/// +/// size_t input_size; +/// short * input; +/// int * output; +/// +/// // Use a transform iterator to specifiy a custom accumulator type +/// auto input_iterator = rocprim::make_transform_iterator( +/// input, [] __device__ (T in) { return static_cast(in); }); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Use the transform iterator +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, input_size, rocprim::plus() +/// ); +/// +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input_iterator, output, input_size, rocprim::plus() +/// ); +/// \endcode +/// \endparblock + +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t inclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const size_t size, + BinaryFunction scan_op = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + return detail::scan_impl( + temporary_storage, storage_size, + // input_type() is a dummy initial value (not used) + input, output, input_type(), size, + scan_op, stream, debug_synchronous + ); +} + +/// \brief Parallel exclusive scan primitive for device level. +/// +/// exclusive_scan function performs a device-wide exclusive prefix scan operation +/// using binary \p scan_op operator. +/// +/// \par Overview +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input and \p output must have at least \p size elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam InitValueType - type of the initial value. +/// \tparam BinaryFunction - type of binary function used for scan. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. It can be +/// same as \p input. +/// \param [in] initial_value - initial value to start the scan. +/// A rocpim::future_value may be passed to use a value that will be later computed. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level exclusive min-scan operation is performed on an array of +/// integer values (shorts are scanned into ints) using custom operator. +/// +/// \code{.cpp} +/// #include +/// +/// // custom scan function +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [4, 7, 6, 2, 5, 1, 3, 8] +/// int * output; // empty array of 8 elements +/// int start_value; // e.g., 9 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, start_value, input_size, min_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, start_value, input_size, min_op +/// ); +/// // output: [9, 4, 7, 6, 2, 2, 1, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class InitValueType, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t exclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction scan_op = BinaryFunction(), + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using real_init_value_type = detail::input_type_t; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + return detail::scan_impl( + temporary_storage, storage_size, + input, output, initial_value, size, + scan_op, stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_scan_by_key.hpp b/3rdparty/cub/rocprim/device/device_scan_by_key.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee02cee681d47e7f167daa56e6b297eef5f7327a --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_scan_by_key.hpp @@ -0,0 +1,558 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_HPP_ + +#include "detail/device_scan_by_key.hpp" +#include "detail/lookback_scan_state.hpp" +#include "detail/ordered_block_id.hpp" + +#include "config_types.hpp" +#include "device_scan_by_key_config.hpp" + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../functional.hpp" +#include "../types/future_value.hpp" +#include "../types/tuple.hpp" + +#include + +#include +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + + template + void __global__ __launch_bounds__(Config::block_size) device_scan_by_key_kernel( + const KeyInputIterator keys, + const InputIterator values, + const OutputIterator output, + const InitialValueType initial_value, + const CompareFunction compare, + const BinaryFunction scan_op, + const LookbackScanState scan_state, + const size_t size, + const size_t starting_block, + const size_t number_of_blocks, + const ordered_block_id ordered_bid, + const ::rocprim::tuple* const previous_last_value) + { + device_scan_by_key_kernel_impl(keys, + values, + output, + get_input_value(initial_value), + compare, + scan_op, + scan_state, + size, + starting_block, + number_of_blocks, + ordered_bid, + previous_last_value); + } + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + do \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) \ + return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) \ + return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } while(false) + + template + inline cudaError_t scan_by_key_impl(void* const temporary_storage, + size_t& storage_size, + KeysInputIterator keys, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + const BinaryFunction scan_op, + const CompareFunction compare, + const cudaStream_t stream, + const bool debug_synchronous) + { + using config = Config; + using real_init_value_type = input_type_t; + + using wrapped_type = ::rocprim::tuple; + + using scan_state_type = detail::lookback_scan_state; + using scan_state_with_sleep_type = detail::lookback_scan_state; + using ordered_block_id_type = detail::ordered_block_id; + + constexpr unsigned int block_size = config::block_size; + constexpr unsigned int items_per_thread = config::items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + static constexpr size_t size_limit = config::size_limit; + static constexpr size_t aligned_size_limit + = ::rocprim::max(size_limit - size_limit % items_per_block, items_per_block); + + const size_t limited_size = std::min(size, aligned_size_limit); + const bool use_limited_size = limited_size == aligned_size_limit; + + // Number of blocks in a single launch (or the only launch if it fits) + const unsigned int number_of_blocks = ceiling_div(limited_size, items_per_block); + + // Calculate required temporary storage, this is valid even with scan_state_with_sleep_type + const size_t scan_state_bytes + = align_size(scan_state_type::get_storage_size(number_of_blocks)); + if(temporary_storage == nullptr) + { + const size_t ordered_block_id_bytes + = align_size(ordered_block_id_type::get_storage_size(), alignof(wrapped_type)); + + // storage_size is never zero + storage_size = scan_state_bytes + ordered_block_id_bytes + + (use_limited_size ? sizeof(wrapped_type) : 0); + + return cudaSuccess; + } + + if(number_of_blocks == 0u) + { + return cudaSuccess; + } + + bool use_sleep; + if(const cudaError_t error = is_sleep_scan_state_used(use_sleep)) + { + return error; + } + + // Call the provided function with either scan_state or scan_state_with_sleep based on + // the value of use_sleep_scan_state + auto with_scan_state + = [use_sleep, + scan_state = scan_state_type::create(temporary_storage, number_of_blocks), + scan_state_with_sleep = scan_state_with_sleep_type::create( + temporary_storage, number_of_blocks)](auto&& func) mutable -> decltype(auto) { + if(use_sleep) + { + return func(scan_state_with_sleep); + } + else + { + return func(scan_state); + } + }; + + // Create and initialize ordered_block_id obj + auto* const ptr = static_cast(temporary_storage); + const auto ordered_bid = ordered_block_id_type::create( + reinterpret_cast(ptr + scan_state_bytes)); + + // The last element + auto* const previous_last_value + = use_limited_size + ? reinterpret_cast(ptr + storage_size - sizeof(wrapped_type)) + : nullptr; + + // Total number of blocks in all launches + const auto total_number_of_blocks = ceiling_div(size, items_per_block); + const size_t number_of_launch = ceiling_div(size, limited_size); + + if(debug_synchronous) + { + std::cout << "----------------------------------\n"; + std::cout << "size: " << size << '\n'; + std::cout << "aligned_size_limit: " << aligned_size_limit << '\n'; + std::cout << "use_limited_size: " << std::boolalpha << use_limited_size << '\n'; + std::cout << "number_of_launch: " << number_of_launch << '\n'; + std::cout << "block_size: " << block_size << '\n'; + std::cout << "items_per_block: " << items_per_block << '\n'; + std::cout << "----------------------------------\n"; + } + + for(size_t i = 0, offset = 0; i < number_of_launch; i++, offset += limited_size) + { + const size_t current_size = std::min(size - offset, limited_size); + const auto scan_blocks = ceiling_div(current_size, items_per_block); + const auto init_grid_size = ceiling_div(scan_blocks, block_size); + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) + { + std::cout << "index: " << i << '\n'; + std::cout << "current_size: " << current_size << '\n'; + std::cout << "number of blocks: " << scan_blocks << '\n'; + + start = std::chrono::high_resolution_clock::now(); + } + + with_scan_state([&](const auto scan_state) { + init_lookback_scan_state_kernel<<< + dim3(init_grid_size), + dim3(block_size), + 0, + stream>>>( + scan_state, + scan_blocks, + ordered_bid, + number_of_blocks - 1, + i > 0 ? previous_last_value : nullptr); + }); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( + "init_lookback_scan_state_kernel", scan_blocks, start); + + if(debug_synchronous) + { + start = std::chrono::high_resolution_clock::now(); + } + with_scan_state([&](auto& scan_state) { + device_scan_by_key_kernel<<< + dim3(scan_blocks), + dim3(block_size), + 0, + stream>>>( + keys + offset, + input + offset, + output + offset, + initial_value, + compare, + scan_op, + scan_state, + size, + i * number_of_blocks, + total_number_of_blocks, + ordered_bid, + i > 0 ? previous_last_value : nullptr); + }); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( + "device_scan_by_key_kernel", current_size, start); + } + return cudaSuccess; + } + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR +} + +/// \addtogroup devicemodule +/// @{ + +/// \brief Parallel inclusive scan-by-key primitive for device level. +/// +/// inclusive_scan_by_key function performs a device-wide inclusive prefix scan-by-key +/// operation using binary \p scan_op operator. +/// +/// \par Overview +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p keys_input, \p values_input, and \p values_output must have +/// at least \p size elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for scan. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type +/// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - iterator to the first element in the range of keys. +/// \param [in] values_input - iterator to the first element in the range of values to scan. +/// \param [out] values_output - iterator to the first element in the output value range. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scanning +/// input values. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] key_compare_op - binary operation function object that will be used to determine keys equality. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is KeyCompareFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level inclusive sum-by-key operation is performed on an array of +/// integer values (shorts are scanned into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// int * keys_input; // e.g., [1, 1, 2, 2, 3, 3, 3, 5] +/// short * values_input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::inclusive_scan_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, +/// values_output, size, +/// rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan-by-key +/// rocprim::inclusive_scan_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, +/// values_output, size, +/// rocprim::plus() +/// ); +/// // values_output: [1, 2, 3, 7, 5, 11, 18, 8] +/// \endcode +/// \endparblock +template ::value_type>, + typename KeyCompareFunction + = ::rocprim::equal_to::value_type>> +inline cudaError_t inclusive_scan_by_key(void* const temporary_storage, + size_t& storage_size, + const KeysInputIterator keys_input, + const ValuesInputIterator values_input, + const ValuesOutputIterator values_output, + const size_t size, + const BinaryFunction scan_op = BinaryFunction(), + const KeyCompareFunction key_compare_op + = KeyCompareFunction(), + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_by_key_config>; + + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + value_type(), + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); +} + +/// \brief Parallel exclusive scan-by-key primitive for device level. +/// +/// inclusive_scan_by_key function performs a device-wide exclusive prefix scan-by-key +/// operation using binary \p scan_op operator. +/// +/// \par Overview +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p keys_input, \p values_input, and \p values_output must have +/// at least \p size elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam InitValueType - type of the initial value. +/// \tparam BinaryFunction - type of binary function used for scan. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type +/// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - iterator to the first element in the range of keys. +/// \param [in] values_input - iterator to the first element in the range of values to scan. +/// \param [out] values_output - iterator to the first element in the output value range. +/// \param [in] initial_value - initial value to start the scan. +/// A rocpim::future_value may be passed to use a value that will be later computed. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scanning +/// input values. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] key_compare_op - binary operation function object that will be used to determine keys equality. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is KeyCompareFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level inclusive sum-by-key operation is performed on an array of +/// integer values (shorts are scanned into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// int * keys_input; // e.g., [1, 1, 1, 2, 2, 3, 3, 4] +/// short * values_input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int start_value; // e.g., 9 +/// int * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::exclusive_scan_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, +/// values_output, start_value, +/// size,rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan-by-key +/// rocprim::exclusive_scan_by_key( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, values_input, +/// values_output, start_value, +/// size,rocprim::plus() +/// ); +/// // values_output: [9, 10, 12, 9, 13, 9, 15, 9] +/// \endcode +/// \endparblock +template ::value_type>, + typename KeyCompareFunction + = ::rocprim::equal_to::value_type>> +inline cudaError_t exclusive_scan_by_key(void* const temporary_storage, + size_t& storage_size, + const KeysInputIterator keys_input, + const ValuesInputIterator values_input, + const ValuesOutputIterator values_output, + const InitialValueType initial_value, + const size_t size, + const BinaryFunction scan_op = BinaryFunction(), + const KeyCompareFunction key_compare_op + = KeyCompareFunction(), + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + using key_type = typename std::iterator_traits::value_type; + using real_init_value_type = detail::input_type_t; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_by_key_config + >; + + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + initial_value, + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_scan_by_key_config.hpp b/3rdparty/cub/rocprim/device/device_scan_by_key_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5511afa6144357e8b355d14c78392ce5e9c641b5 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_scan_by_key_config.hpp @@ -0,0 +1,158 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level scan-by-key operation. +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +/// \tparam UseLookback - whether to use lookback scan or reduce-then-scan algorithm. +/// \tparam BlockLoadMethod - method for loading input values. +/// \tparam StoreLoadMethod - method for storing values. +/// \tparam BlockScanMethod - algorithm for block scan. +/// \tparam SizeLimit - limit on the number of items for a single scan kernel launch. +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool UseLookback, + ::rocprim::block_load_method BlockLoadMethod, + ::rocprim::block_store_method BlockStoreMethod, + ::rocprim::block_scan_algorithm BlockScanMethod, + unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT +> +struct scan_by_key_config +{ + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Whether to use lookback scan or reduce-then-scan algorithm. + static constexpr bool use_lookback = UseLookback; + /// \brief Method for loading input values. + static constexpr ::rocprim::block_load_method block_load_method = BlockLoadMethod; + /// \brief Method for storing values. + static constexpr ::rocprim::block_store_method block_store_method = BlockStoreMethod; + /// \brief Algorithm for block scan. + static constexpr ::rocprim::block_scan_algorithm block_scan_method = BlockScanMethod; + /// \brief Limit on the number of items for a single scan kernel launch. + static constexpr unsigned int size_limit = SizeLimit; +}; + +namespace detail +{ + +template +struct scan_by_key_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct scan_by_key_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct scan_by_key_config_908 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 20u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +// TODO: We need to update these parameters +template +struct scan_by_key_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_32>::value, + ::rocprim::max(1u, 15u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct default_scan_by_key_config + : select_arch< + TargetArch, + select_arch_case<900, scan_by_key_config_900>, + select_arch_case>, + select_arch_case<908, scan_by_key_config_908>, + select_arch_case<1030, scan_by_key_config_1030>, + scan_by_key_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_SCAN_BY_KEY_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_scan_config.hpp b/3rdparty/cub/rocprim/device/device_scan_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6cd117fc5a81c334ad02e3a090a3bbf1a075bd15 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_scan_config.hpp @@ -0,0 +1,180 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SCAN_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SCAN_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../block/block_load.hpp" +#include "../block/block_store.hpp" +#include "../block/block_scan.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level scan primitives. +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +/// \tparam UseLookback - whether to use lookback scan or reduce-then-scan algorithm. +/// \tparam BlockLoadMethod - method for loading input values. +/// \tparam StoreLoadMethod - method for storing values. +/// \tparam BlockScanMethod - algorithm for block scan. +/// \tparam SizeLimit - limit on the number of items for a single scan kernel launch. +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool UseLookback, + ::rocprim::block_load_method BlockLoadMethod, + ::rocprim::block_store_method BlockStoreMethod, + ::rocprim::block_scan_algorithm BlockScanMethod, + unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT +> +struct scan_config +{ + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Whether to use lookback scan or reduce-then-scan algorithm. + static constexpr bool use_lookback = UseLookback; + /// \brief Method for loading input values. + static constexpr ::rocprim::block_load_method block_load_method = BlockLoadMethod; + /// \brief Method for storing values. + static constexpr ::rocprim::block_store_method block_store_method = BlockStoreMethod; + /// \brief Algorithm for block scan. + static constexpr ::rocprim::block_scan_algorithm block_scan_method = BlockScanMethod; + /// \brief Limit on the number of items for a single scan kernel launch. + static constexpr unsigned int size_limit = SizeLimit; +}; + +namespace detail +{ + +template +struct scan_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct scan_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +// TODO: We need to update these parameters +template +struct scan_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct scan_config_908 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 20u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +// TODO: We need to update these parameters +template +struct scan_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = scan_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_32>::value, + ::rocprim::max(1u, 15u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct default_scan_config + : select_arch< + TargetArch, + select_arch_case<803, scan_config_803>, + select_arch_case<900, scan_config_900>, + select_arch_case>, + select_arch_case<908, scan_config_908>, + select_arch_case<1030, scan_config_1030>, + scan_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_SCAN_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_segmented_radix_sort.hpp b/3rdparty/cub/rocprim/device/device_segmented_radix_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..182588d124fb4f6881577a742716d8f721fe476e --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_segmented_radix_sort.hpp @@ -0,0 +1,1640 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/radix_sort.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "../block/block_load.hpp" +#include "../iterator/counting_iterator.hpp" +#include "../iterator/reverse_iterator.hpp" +#include "detail/device_segmented_radix_sort.hpp" +#include "device_partition.hpp" +#include "device_segmented_radix_sort_config.hpp" + +/// \addtogroup devicemodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class Config, + bool Descending, + unsigned int BlockSize, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetIterator +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void segmented_sort_kernel(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) +{ + segmented_sort( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offsets, end_offsets, + long_iterations, short_iterations, + begin_bit, end_bit + ); +} + +template< + class Config, + bool Descending, + unsigned int BlockSize, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class SegmentIndexIterator, + class OffsetIterator +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void segmented_sort_large_kernel(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + bool to_output, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) +{ + segmented_sort_large( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, segment_indices, + begin_offsets, end_offsets, + long_iterations, short_iterations, + begin_bit, end_bit + ); +} + +template +ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_kernel( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + segmented_sort_small( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, num_segments, segment_indices, + begin_offsets, end_offsets, + begin_bit, end_bit + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +struct TwoWayPartitioner +{ + template + cudaError_t operator()(void* temporary_storage, + size_t& storage_size, + InputIterator input, + FirstOutputIterator output_first_part, + SecondOutputIterator /*output_second_part*/, + UnselectedOutputIterator /*output_unselected*/, + SelectedCountOutputIterator selected_count_output, + const size_t size, + FirstUnaryPredicate select_first_part_op, + SecondUnaryPredicate /*select_second_part_op*/, + const cudaStream_t stream, + const bool debug_synchronous) + { + return partition(temporary_storage, + storage_size, + input, + output_first_part, + selected_count_output, + size, + select_first_part_op, + stream, + debug_synchronous); + } +}; + +struct ThreeWayPartitioner +{ + template + cudaError_t operator()(void* temporary_storage, + size_t& storage_size, + InputIterator input, + FirstOutputIterator output_first_part, + SecondOutputIterator output_second_part, + UnselectedOutputIterator output_unselected, + SelectedCountOutputIterator selected_count_output, + const size_t size, + FirstUnaryPredicate select_first_part_op, + SecondUnaryPredicate select_second_part_op, + const cudaStream_t stream, + const bool debug_synchronous) + { + return partition_three_way(temporary_storage, + storage_size, + input, + output_first_part, + output_second_part, + output_unselected, + selected_count_output, + size, + select_first_part_op, + select_second_part_op, + stream, + debug_synchronous); + } +}; + +template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetIterator +> +inline +cudaError_t segmented_radix_sort_impl(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_tmp, + ValuesOutputIterator values_output, + unsigned int size, + bool& is_result_in_output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + using segment_index_type = unsigned int; + using segment_index_iterator = counting_iterator; + + static_assert( + std::is_same::value_type>::value, + "KeysInputIterator and KeysOutputIterator must have the same value_type" + ); + static_assert( + std::is_same::value_type>::value, + "ValuesInputIterator and ValuesOutputIterator must have the same value_type" + ); + + using config = default_or_custom_config< + Config, + default_segmented_radix_sort_config + >; + + static constexpr bool with_values = !std::is_same::value; + static constexpr bool partitioning_allowed = + !std::is_same::value; + static constexpr unsigned int max_small_segment_length + = config::warp_sort_config::items_per_thread_small + * config::warp_sort_config::logical_warp_size_small; + static constexpr unsigned int small_segments_per_block + = config::warp_sort_config::block_size_small + / config::warp_sort_config::logical_warp_size_small; + static constexpr unsigned int max_medium_segment_length + = config::warp_sort_config::items_per_thread_medium + * config::warp_sort_config::logical_warp_size_medium; + static constexpr unsigned int medium_segments_per_block + = config::warp_sort_config::block_size_medium + / config::warp_sort_config::logical_warp_size_medium; + static_assert( + max_small_segment_length <= max_medium_segment_length, + "The max length of small segments cannot be higher than the max length of medium segments"); + // Don't waste cycles on 3-way partitioning, if the small and medium segments are equal length + static constexpr bool three_way_partitioning + = max_small_segment_length < max_medium_segment_length; + using partitioner_type + = std::conditional_t; + partitioner_type partitioner; + + const auto large_segment_selector = [=](const unsigned int segment_index) mutable -> bool + { + const unsigned int segment_length + = end_offsets[segment_index] - begin_offsets[segment_index]; + return segment_length > max_medium_segment_length; + }; + const auto medium_segment_selector = [=](const unsigned int segment_index) mutable -> bool + { + const unsigned int segment_length = end_offsets[segment_index] - begin_offsets[segment_index]; + return segment_length > max_small_segment_length; + }; + + const bool with_double_buffer = keys_tmp != nullptr; + const unsigned int bits = end_bit - begin_bit; + const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, config::long_radix_bits); + const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; + is_result_in_output = (iterations % 2 == 0) != to_output; + const unsigned int radix_bits_diff = config::long_radix_bits - config::short_radix_bits; + const unsigned int short_iterations = radix_bits_diff != 0 + ? ::rocprim::min(iterations, (config::long_radix_bits * iterations - bits) / radix_bits_diff) + : 0; + const unsigned int long_iterations = iterations - short_iterations; + const bool do_partitioning = partitioning_allowed + && segments >= config::warp_sort_config::partitioning_threshold; + + const size_t keys_bytes = ::rocprim::detail::align_size(size * sizeof(key_type)); + const size_t values_bytes = with_values ? ::rocprim::detail::align_size(size * sizeof(value_type)) : 0; + const size_t large_and_small_segment_indices_bytes + = ::rocprim::detail::align_size(segments * sizeof(segment_index_type)); + const size_t medium_segment_indices_bytes + = three_way_partitioning + ? ::rocprim::detail::align_size(segments * sizeof(segment_index_type)) + : 0; + static constexpr size_t segment_count_output_size = three_way_partitioning ? 2 : 1; + const size_t segment_count_output_bytes + = ::rocprim::detail::align_size(segment_count_output_size * sizeof(segment_index_type)); + + segment_index_type* large_segment_indices_output{}; + // The total number of large and small segments is not above the number of segments + // The same buffer is filled with the large and small indices from both directions + auto small_segment_indices_output + = make_reverse_iterator(large_segment_indices_output + segments); + segment_index_type* medium_segment_indices_output{}; + segment_index_type* segment_count_output{}; + size_t partition_storage_size{}; + void* partition_temporary_storage{}; + if(temporary_storage == nullptr) + { + storage_size = with_double_buffer ? 0 : (keys_bytes + values_bytes); + if(do_partitioning) + { + storage_size += large_and_small_segment_indices_bytes; + storage_size += medium_segment_indices_bytes; + storage_size += segment_count_output_bytes; + const auto partition_result = partitioner(partition_temporary_storage, + partition_storage_size, + segment_index_iterator{}, + large_segment_indices_output, + medium_segment_indices_output, + small_segment_indices_output, + segment_count_output, + segments, + large_segment_selector, + medium_segment_selector, + stream, + debug_synchronous); + if(cudaSuccess != partition_result) + { + return partition_result; + } + storage_size += partition_storage_size; + } + + // Make sure user won't try to allocate 0 bytes memory, otherwise + // user may again pass nullptr as temporary_storage + storage_size = storage_size == 0 ? 4 : storage_size; + return cudaSuccess; + } + if(segments == 0u) + { + return cudaSuccess; + } + if(debug_synchronous) + { + std::cout << "begin_bit " << begin_bit << '\n'; + std::cout << "end_bit " << end_bit << '\n'; + std::cout << "bits " << bits << '\n'; + std::cout << "segments " << segments << '\n'; + std::cout << "radix_bits_diff " << radix_bits_diff << '\n'; + std::cout << "storage_size " << storage_size << '\n'; + std::cout << "iterations " << iterations << '\n'; + std::cout << "long_iterations " << long_iterations << '\n'; + std::cout << "short_iterations " << short_iterations << '\n'; + std::cout << "do_partitioning " << do_partitioning << '\n'; + std::cout << "config::sort::block_size: " << config::sort::block_size << '\n'; + std::cout << "config::sort::items_per_thread: " << config::sort::items_per_thread << '\n'; + cudaError_t error = cudaStreamSynchronize(stream); + if(error != cudaSuccess) return error; + } + + char* ptr = reinterpret_cast(temporary_storage); + if(!with_double_buffer) + { + keys_tmp = reinterpret_cast(ptr); + ptr += keys_bytes; + values_tmp = with_values ? reinterpret_cast(ptr) : nullptr; + ptr += values_bytes; + } + large_segment_indices_output = reinterpret_cast(ptr); + ptr += large_and_small_segment_indices_bytes; + medium_segment_indices_output = reinterpret_cast(ptr); + ptr += medium_segment_indices_bytes; + small_segment_indices_output = make_reverse_iterator(large_segment_indices_output + segments); + segment_count_output = reinterpret_cast(ptr); + ptr += segment_count_output_bytes; + partition_temporary_storage = ptr; + ptr += partition_storage_size; + + if(do_partitioning) + { + cudaError_t result = partitioner(partition_temporary_storage, + partition_storage_size, + segment_index_iterator{}, + large_segment_indices_output, + medium_segment_indices_output, + small_segment_indices_output, + segment_count_output, + segments, + large_segment_selector, + medium_segment_selector, + stream, + debug_synchronous); + if(cudaSuccess != result) + { + return result; + } + segment_index_type segment_counts[segment_count_output_size]{}; + result = cudaMemcpyAsync(&segment_counts, + segment_count_output, + segment_count_output_bytes, + cudaMemcpyDeviceToHost, + stream); + if(cudaSuccess != result) + { + return result; + } + result = cudaStreamSynchronize(stream); + if(cudaSuccess != result) + { + return result; + } + const auto large_segment_count = segment_counts[0]; + const auto medium_segment_count = three_way_partitioning ? segment_counts[1] : 0; + const auto small_segment_count = segments - large_segment_count - medium_segment_count; + if(debug_synchronous) + { + std::cout << "large_segment_count " << large_segment_count << '\n'; + std::cout << "medium_segment_count " << medium_segment_count << '\n'; + std::cout << "small_segment_count " << small_segment_count << '\n'; + } + if(large_segment_count > 0) + { + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + segmented_sort_large_kernel + <<>>( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, large_segment_indices_output, + begin_offsets, end_offsets, + long_iterations, short_iterations, + begin_bit, end_bit + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", + large_segment_count, + start) + } + if(three_way_partitioning && medium_segment_count > 0) + { + const auto medium_segment_grid_size + = ::rocprim::detail::ceiling_div(medium_segment_count, medium_segments_per_block); + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) + start = std::chrono::high_resolution_clock::now(); + + segmented_sort_small_or_medium_kernel< + select_warp_sort_helper_config_medium_t, + Descending, + config::warp_sort_config::block_size_medium> + <<>>( + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + medium_segment_count, + medium_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:medium_segments", + medium_segment_count, + start) + } + if(small_segment_count > 0) + { + const auto small_segment_grid_size = ::rocprim::detail::ceiling_div(small_segment_count, + small_segments_per_block); + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + segmented_sort_small_or_medium_kernel< + select_warp_sort_helper_config_small_t, + Descending, + config::warp_sort_config::block_size_small> + <<>>( + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + small_segment_count, + small_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:small_segments", + small_segment_count, + start) + } + } + else + { + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + segmented_sort_kernel + <<>>( + keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, + to_output, + begin_offsets, end_offsets, + long_iterations, short_iterations, + begin_bit, end_bit + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start) + } + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end namespace detail + +/// \brief Parallel ascending radix sort primitive for device level. +/// +/// \p segmented_radix_sort_keys function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of keys. Function sorts input keys in ascending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed on an array of +/// \p float values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * output; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys_output: [0.3, 0.6, 0.65, 0.08, 0.2, 0.4, 0.7, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class OffsetIterator, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t segmented_radix_sort_keys(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + bool ignored; + return detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values, nullptr, values, + size, ignored, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p segmented_radix_sort_keys_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input and \p keys_output must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * output; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys_output: [6, 3, 5, 8, 7, 4, 2, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class OffsetIterator, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t segmented_radix_sort_keys_desc(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + bool ignored; + return detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values, nullptr, values, + size, ignored, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p segmented_radix_sort_pairs_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] values_input - pointer to the first element in the range to sort. +/// \param [out] values_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, input_size, +/// segments, offsets, offsets + 1, +/// 0, 5 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, input_size, +/// segments, offsets, offsets + 1, +/// 0, 5 +/// ); +/// // keys_output: [3, 6, 5, 1, 1, 4, 7, 8] +/// // values_output: [2, -5, -4, -1, -2, 3, 7, -8] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetIterator, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t segmented_radix_sort_pairs(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + bool ignored; + return detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values_input, nullptr, values_output, + size, ignored, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p segmented_radix_sort_pairs_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the sorting function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * \p Key type (a \p value_type of \p KeysInputIterator and \p KeysOutputIterator) must be +/// an arithmetic type (that is, an integral type or a floating-point type). +/// * Ranges specified by \p keys_input, \p keys_output, \p values_input and \p values_output must +/// have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam KeysInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - pointer to the first element in the range to sort. +/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] values_input - pointer to the first element in the range to sort. +/// \param [out] values_output - pointer to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_output; // empty array of 8 elements +/// double * values_output; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input, keys_output, values_input, values_output, +/// input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys_output: [ 6, 3, 5, 8, 7, 4, 1, 1] +/// // values_output: [-5, 2, -4, -8, 7, 3, -1, -2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class OffsetIterator, + class Key = typename std::iterator_traits::value_type +> +inline +cudaError_t segmented_radix_sort_pairs_desc(void * temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + bool ignored; + return detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys_input, nullptr, keys_output, + values_input, nullptr, values_output, + size, ignored, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); +} + +/// \brief Parallel ascending radix sort primitive for device level. +/// +/// \p segmented_radix_sort_keys function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of keys. Function sorts input keys in ascending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed on an array of +/// \p float values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// float * input; // e.g., [0.6, 0.3, 0.65, 0.4, 0.2, 0.08, 1, 0.7] +/// float * tmp; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_keys( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys.current(): [0.3, 0.6, 0.65, 0.08, 0.2, 0.4, 0.7, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class OffsetIterator +> +inline +cudaError_t segmented_radix_sort_keys(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + bool is_result_in_output; + cudaError_t error = detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values, values, values, + size, is_result_in_output, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort primitive for device level. +/// +/// \p segmented_radix_sort_keys_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of keys. Function sorts input keys in descending order. +/// +/// \par Overview +/// * The contents of both buffers of \p keys may be altered by the sorting function. +/// * \p current() of \p keys is used as the input. +/// * The function will update \p current() of \p keys to point to the buffer +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed on an array of +/// integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [6, 3, 5, 4, 2, 8, 1, 7] +/// int * tmp; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// // Create double-buffer +/// rocprim::double_buffer keys(input, tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_keys_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys.current(): [6, 3, 5, 8, 7, 4, 2, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class OffsetIterator +> +inline +cudaError_t segmented_radix_sort_keys_desc(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + bool is_result_in_output; + cudaError_t error = detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values, values, values, + size, is_result_in_output, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + } + return error; +} + +/// \brief Parallel ascending radix sort-by-key primitive for device level. +/// +/// \p segmented_radix_sort_pairs_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of (key, value) pairs. Function sorts input pairs in ascending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Value - value type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values - reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending radix sort is performed where input keys are +/// represented by an array of unsigned integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// unsigned int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// unsigned int * keys_tmp; // empty array of 8 elements +/// double* values_tmp; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// // Keys are in range [0; 8], so we can limit compared bit to bits on indexes +/// // 0, 1, 2, 3, and 4. In order to do this begin_bit is set to 0 and end_bit +/// // is set to 5. +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// segments, offsets, offsets + 1 +/// 0, 5 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_pairs( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// segments, offsets, offsets + 1 +/// 0, 5 +/// ); +/// // keys.current(): [3, 6, 5, 1, 1, 4, 7, 8] +/// // values.current(): [2, -5, -4, -1, -2, 3, 7, -8] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Value, + class OffsetIterator +> +inline +cudaError_t segmented_radix_sort_pairs(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + bool is_result_in_output; + cudaError_t error = detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values.current(), values.current(), values.alternate(), + size, is_result_in_output, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +/// \brief Parallel descending radix sort-by-key primitive for device level. +/// +/// \p segmented_radix_sort_pairs_desc function performs a device-wide radix sort across multiple, +/// non-overlapping sequences of (key, value) pairs. Function sorts input pairs in descending order of keys. +/// +/// \par Overview +/// * The contents of both buffers of \p keys and \p values may be altered by the sorting function. +/// * \p current() of \p keys and \p values are used as the input. +/// * The function will update \p current() of \p keys and \p values to point to buffers +/// that contains the output range. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * The function requires small \p temporary_storage as it does not need +/// a temporary buffer of \p size elements. +/// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point +/// type). +/// * Buffers of \p keys must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// * If \p Key is an integer type and the range of keys is known in advance, the performance +/// can be improved by setting \p begin_bit and \p end_bit, for example if all keys are in range +/// [100, 10000], begin_bit = 0 and end_bit = 14 will cover the whole range. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be +/// \p segmented_radix_sort_config or a custom class with the same members. +/// \tparam Key - key type. Must be an integral type or a floating-point type. +/// \tparam Value - value type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in,out] keys - reference to the double-buffer of keys, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in,out] values - reference to the double-buffer of values, its \p current() +/// contains the input range and will be updated to point to the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] begin_bit - [optional] index of the first (least significant) bit used in +/// key comparison. Must be in range [0; 8 * sizeof(Key)). Default value: \p 0. +/// Non-default value not supported for floating-point key-types. +/// \param [in] end_bit - [optional] past-the-end index (most significant) bit used in +/// key comparison. Must be in range (begin_bit; 8 * sizeof(Key)]. Default +/// value: \p 8 * sizeof(Key). Non-default value not supported for floating-point key-types. +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level descending radix sort is performed where input keys are +/// represented by an array of integers and input values by an array of doubles. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and tmp (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * keys_input; // e.g., [ 6, 3, 5, 4, 1, 8, 1, 7] +/// double * values_input; // e.g., [-5, 2, -4, 3, -1, -8, -2, 7] +/// int * keys_tmp; // empty array of 8 elements +/// double * values_tmp; // empty array of 8 elements +/// unsigned int segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 3, 8] +/// // Create double-buffers +/// rocprim::double_buffer keys(keys_input, keys_tmp); +/// rocprim::double_buffer values(values_input, values_tmp); +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform sort +/// rocprim::segmented_radix_sort_pairs_desc( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys, values, input_size, +/// segments, offsets, offsets + 1 +/// ); +/// // keys.current(): [ 6, 3, 5, 8, 7, 4, 1, 1] +/// // values.current(): [-5, 2, -4, -8, 7, 3, -1, -2] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class Key, + class Value, + class OffsetIterator +> +inline +cudaError_t segmented_radix_sort_pairs_desc(void * temporary_storage, + size_t& storage_size, + double_buffer& keys, + double_buffer& values, + unsigned int size, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + bool is_result_in_output; + cudaError_t error = detail::segmented_radix_sort_impl( + temporary_storage, storage_size, + keys.current(), keys.current(), keys.alternate(), + values.current(), values.current(), values.alternate(), + size, is_result_in_output, + segments, begin_offsets, end_offsets, + begin_bit, end_bit, + stream, debug_synchronous + ); + if(temporary_storage != nullptr && is_result_in_output) + { + keys.swap(); + values.swap(); + } + return error; +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group devicemodule + +#endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_segmented_radix_sort_config.hpp b/3rdparty/cub/rocprim/device/device_segmented_radix_sort_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b44c7679aed50c0335e6b4b8564807dc45b2f49b --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_segmented_radix_sort_config.hpp @@ -0,0 +1,362 @@ +// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of the warp sort part of the device segmented radix sort operation. +/// Short enough segments are processed on warp level. +/// +/// \tparam LogicalWarpSizeSmall - number of threads in the logical warp of the kernel +/// that processes small segments. +/// \tparam ItemsPerThreadSmall - number of items processed by a thread in the kernel that processes +/// small segments. +/// \tparam BlockSizeSmall - number of threads per block in the kernel which processes the small segments. +/// \tparam PartitioningThreshold - if the number of segments is at least this threshold, the +/// segments are partitioned to a small, a medium and a large segment collection. Both collections +/// are sorted by different kernels. Otherwise, all segments are sorted by a single kernel. +/// \tparam EnableUnpartitionedWarpSort - If set to \p true, warp sort can be used to sort +/// the small segments, even if the total number of segments is below \p PartitioningThreshold. +/// \tparam LogicalWarpSizeMedium - number of threads in the logical warp of the kernel +/// that processes medium segments. +/// \tparam ItemsPerThreadMedium - number of items processed by a thread in the kernel that processes +/// medium segments. +/// \tparam BlockSizeMedium - number of threads per block in the kernel which processes the medium segments. +template +struct WarpSortConfig +{ + static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall + <= LogicalWarpSizeMedium * ItemsPerThreadMedium, + "The number of items processed by a small warp cannot be larger than the number " + "of items processed by a medium warp"); + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = BlockSizeSmall; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = PartitioningThreshold; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of + /// segments is below \p PartitioningThreshold. + static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = BlockSizeMedium; +}; + +/// \brief Indicates if the warp level sorting is disabled in the +/// device segmented radix sort configuration. +struct DisabledWarpSortConfig +{ + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = 1; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = 1; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = 1; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = 0; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of + /// segments is below \p PartitioningThreshold. + static constexpr bool enable_unpartitioned_warp_sort = false; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = 1; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = 1; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = 1; +}; + +/// \brief Selects the appropriate \p WarpSortConfig based on the size of the key type. +/// +/// \tparam Key - the type of the sorted keys. +/// \tparam MediumWarpSize - the logical warp size of the medium segment processing kernel. +template +using select_warp_sort_config_t + = std::conditional_t 2), //< enable unpartitioned warp sort + MediumWarpSize, //< logical warp size - medium kernel + 4, //< items per thread - medium kernel + 256 //< block size - medium kernel + >>; + +/// \brief Configuration of device-level segmented radix sort operation. +/// +/// Radix sort is excecuted in a few iterations (passes) depending on total number of bits to be sorted +/// (\p begin_bit and \p end_bit), each iteration sorts either \p LongRadixBits or \p ShortRadixBits bits +/// choosen to cover whole bit range in optimal way. +/// +/// For example, if \p LongRadixBits is 7, \p ShortRadixBits is 6, \p begin_bit is 0 and \p end_bit is 32 +/// there will be 5 iterations: 7 + 7 + 6 + 6 + 6 = 32 bits. +/// +/// If a segment's element count is low ( <= warp_sort_config::items_per_thread * warp_sort_config::logical_warp_size ), +/// it is sorted by a special warp-level sorting method. +/// +/// \tparam LongRadixBits - number of bits in long iterations. +/// \tparam ShortRadixBits - number of bits in short iterations, must be equal to or less than \p LongRadixBits. +/// \tparam SortConfig - configuration of radix sort kernel. Must be \p kernel_config. +/// \tparam WarpSortConfig - configuration of the warp sort that is used on the short segments. +template< + unsigned int LongRadixBits, + unsigned int ShortRadixBits, + class SortConfig, + class WarpSortConfig = DisabledWarpSortConfig +> +struct segmented_radix_sort_config +{ + /// \brief Number of bits in long iterations. + static constexpr unsigned int long_radix_bits = LongRadixBits; + /// \brief Number of bits in short iterations + static constexpr unsigned int short_radix_bits = ShortRadixBits; + /// \brief Configuration of radix sort kernel. + using sort = SortConfig; + /// \brief Configuration of the warp sort method. + using warp_sort_config = WarpSortConfig; +}; + +namespace detail +{ + +template +struct segmented_radix_sort_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 13>, select_warp_sort_config_t > + >, + segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > + >; +}; + +template +struct segmented_radix_sort_config_803 + : select_type< + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > > + > { }; + +template +struct segmented_radix_sort_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > + >, + segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > + >; +}; + +template +struct segmented_radix_sort_config_900 + : select_type< + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > > + > { }; + +template +struct segmented_radix_sort_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + segmented_radix_sort_config<4, + 4, + kernel_config<256, 10>, + select_warp_sort_config_t>>, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + segmented_radix_sort_config<6, + 5, + kernel_config<256, 10>, + select_warp_sort_config_t>>, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, + 6, + kernel_config<256, 15>, + select_warp_sort_config_t>>, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, + 6, + kernel_config<256, 15>, + select_warp_sort_config_t>>, + segmented_radix_sort_config<7, + 6, + kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, + select_warp_sort_config_t>>; +}; + +template +struct segmented_radix_sort_config_90a + : select_type< + select_type_case< + sizeof(Key) == 1, + segmented_radix_sort_config<4, + 3, + kernel_config<256, 10>, + select_warp_sort_config_t>>, + select_type_case< + sizeof(Key) == 2, + segmented_radix_sort_config<6, + 5, + kernel_config<256, 10>, + select_warp_sort_config_t>>, + select_type_case< + sizeof(Key) == 4, + segmented_radix_sort_config<7, + 6, + kernel_config<256, 17>, + select_warp_sort_config_t>>, + select_type_case< + sizeof(Key) == 8, + segmented_radix_sort_config<7, + 6, + kernel_config<256, 15>, + select_warp_sort_config_t>>> +{}; + +template +struct segmented_radix_sort_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using type = select_type< + select_type_case< + (sizeof(Key) == 1 && sizeof(Value) <= 8), + segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 2 && sizeof(Value) <= 8), + segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 4 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > + >, + select_type_case< + (sizeof(Key) == 8 && sizeof(Value) <= 8), + segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > + >, + segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > + >; +}; + +template +struct segmented_radix_sort_config_1030 + : select_type< + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > >, + select_type_case, select_warp_sort_config_t > > + > { }; + +template +struct default_segmented_radix_sort_config + : select_arch< + TargetArch, + select_arch_case<803, detail::segmented_radix_sort_config_803>, + select_arch_case<900, detail::segmented_radix_sort_config_900>, + select_arch_case<906, detail::segmented_radix_sort_config_90a>, + select_arch_case<908, detail::segmented_radix_sort_config_90a>, + select_arch_case>, + select_arch_case<1030, detail::segmented_radix_sort_config_1030>, + detail::segmented_radix_sort_config_900> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_segmented_reduce.hpp b/3rdparty/cub/rocprim/device/device_segmented_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a151030969cb6a9f4380aab8d44ef799f509918 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_segmented_reduce.hpp @@ -0,0 +1,276 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ + +#include +#include +#include + +#include "device_reduce_config.hpp" + +#include "../config.hpp" +#include "../functional.hpp" +#include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" + +#include "detail/device_segmented_reduce.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + class Config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class ResultType, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void segmented_reduce_kernel(InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op, + ResultType initial_value) +{ + segmented_reduce( + input, output, + begin_offsets, end_offsets, + reduce_op, initial_value + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + class Config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class InitValueType, + class BinaryFunction +> +inline +cudaError_t segmented_reduce_impl(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op, + InitValueType initial_value, + cudaStream_t stream, + bool debug_synchronous) +{ + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + + // Get default config if Config is default_config + using config = default_or_custom_config< + Config, + default_reduce_config + >; + + constexpr unsigned int block_size = config::block_size; + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, because + // cudaMalloc will return nullptr when size is zero. + storage_size = 4; + return cudaSuccess; + } + + if( segments == 0u ) + return cudaSuccess; + + std::chrono::high_resolution_clock::time_point start; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + segmented_reduce_kernel + <<>>( + input, output, + begin_offsets, end_offsets, + reduce_op, static_cast(initial_value) + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_reduce", segments, start); + + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end of detail namespace + +/// \brief Parallel segmented reduction primitive for device level. +/// +/// segmented_reduce function performs a device-wide reduction operation across multiple sequences +/// using binary \p reduce_op operator. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input must have at least \p size elements, \p output must have +/// \p segments elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p reduce_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for reduction. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam InitValueType - type of the initial value. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the reduction operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to reduce. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] initial_value - initial value to start the reduction. +/// \param [in] reduce_op - binary operation function object that will be used for reduction. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful reduction; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level segmented min-reduction operation is performed on an array of +/// integer values (shorts are reduced into ints) using custom operator. +/// +/// \code{.cpp} +/// #include +/// +/// // custom reduce function +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// unsigned int segments; // e.g., 3 +/// short * input; // e.g., [4, 7, 6, 2, 5, 1, 3, 8] +/// int * output; // empty array of 3 elements +/// int * offsets; // e.g. [0, 2, 3, 8] +/// int init_value; // e.g., 9 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, +/// segments, offsets, offsets + 1, +/// min_op, init_value +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform segmented reduction +/// rocprim::segmented_reduce( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, +/// segments, offsets, offsets + 1, +/// min_op, init_value +/// ); +/// // output: [4, 6, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class BinaryFunction = ::rocprim::plus::value_type>, + class InitValueType = typename std::iterator_traits::value_type +> +inline +cudaError_t segmented_reduce(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op = BinaryFunction(), + InitValueType initial_value = InitValueType(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::segmented_reduce_impl( + temporary_storage, storage_size, + input, output, + segments, begin_offsets, end_offsets, + reduce_op, initial_value, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_segmented_scan.hpp b/3rdparty/cub/rocprim/device/device_segmented_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56a4acf9061276fc33a37092be2ee00d9dbede04 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_segmented_scan.hpp @@ -0,0 +1,643 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_SCAN_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEGMENTED_SCAN_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" + +#include "../iterator/zip_iterator.hpp" +#include "../iterator/discard_iterator.hpp" +#include "../iterator/transform_iterator.hpp" +#include "../iterator/counting_iterator.hpp" +#include "../types/tuple.hpp" + +#include "device_scan_config.hpp" +#include "device_scan.hpp" +#include "detail/device_segmented_scan.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + bool Exclusive, + class Config, + class ResultType, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class InitValueType, + class BinaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) +void segmented_scan_kernel(InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + InitValueType initial_value, + BinaryFunction scan_op) +{ + segmented_scan( + input, output, begin_offsets, end_offsets, + static_cast(initial_value), scan_op + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +template< + bool Exclusive, + class Config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class InitValueType, + class BinaryFunction +> +inline +cudaError_t segmented_scan_impl(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + const InitValueType initial_value, + BinaryFunction scan_op, + cudaStream_t stream, + bool debug_synchronous) +{ + using input_type = typename std::iterator_traits::value_type; + using result_type = typename std::conditional::type; + + // Get default config if Config is default_config + using config = default_or_custom_config< + Config, + default_scan_config + >; + + constexpr unsigned int block_size = config::block_size; + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, because + // cudaMalloc will return nullptr when size is zero. + storage_size = 4; + return cudaSuccess; + } + + if( segments == 0u ) + return cudaSuccess; + + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + segmented_scan_kernel + <<>>( + input, output, + begin_offsets, end_offsets, + initial_value, scan_op + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_scan", segments, start); + return cudaSuccess; +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +} // end of detail namespace + +/// \brief Parallel segmented inclusive scan primitive for device level. +/// +/// segmented_inclusive_scan function performs a device-wide inclusive scan operation +/// across multiple sequences from \p input using binary \p scan_op operator. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input and \p output must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for scan operation. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level segmented inclusive min-scan operation is performed on +/// an array of integer values (shorts are scanned into ints) using custom operator. +/// +/// \code{.cpp} +/// #include +/// +/// // custom scan function +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// short * input; // e.g., [4, 7, 6, 2, 5, 1, 3, 8] +/// int * output; // empty array of 8 elements +/// size_t segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 4, 8] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, segments, offsets, offsets + 1, min_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, segments, offsets, offsets + 1, min_op +/// ); +/// // output: [4, 4, 6, 2, 5, 1, 1, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t segmented_inclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction scan_op = BinaryFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + using result_type = input_type; + + return detail::segmented_scan_impl( + temporary_storage, storage_size, + input, output, segments, begin_offsets, end_offsets, result_type(), + scan_op, stream, debug_synchronous + ); +} + +/// \brief Parallel segmented exclusive scan primitive for device level. +/// +/// segmented_exclusive_scan function performs a device-wide exclusive scan operation +/// across multiple sequences from \p input using binary \p scan_op operator. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input and \p output must have at least \p size elements. +/// * Ranges specified by \p begin_offsets and \p end_offsets must have +/// at least \p segments elements. They may use the same sequence offsets of at least +/// segments + 1 elements: offsets for \p begin_offsets and +/// offsets + 1 for \p end_offsets. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OffsetIterator - random-access iterator type of segment offsets. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam InitValueType - type of the initial value. +/// \tparam BinaryFunction - type of binary function used for scan operation. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] segments - number of segments in the input range. +/// \param [in] begin_offsets - iterator to the first element in the range of beginning offsets. +/// \param [in] end_offsets - iterator to the first element in the range of ending offsets. +/// \param [in] initial_value - initial value to start the scan. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level segmented exclusive min-scan operation is performed on +/// an array of integer values (shorts are scanned into ints) using custom operator. +/// +/// \code{.cpp} +/// #include +/// +/// // custom scan function +/// auto min_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a < b ? a : b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// int start_value; // e.g., 9 +/// short * input; // e.g., [4, 7, 6, 2, 5, 1, 3, 8] +/// int * output; // empty array of 8 elements +/// size_t segments; // e.g., 3 +/// int * offsets; // e.g. [0, 2, 4, 8] +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, segments, offsets, offsets + 1 +/// start_value, min_op +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, segments, offsets, offsets + 1 +/// start_value, min_op +/// ); +/// // output: [9, 4, 9, 6, 9, 5, 1, 1] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class OffsetIterator, + class InitValueType, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t segmented_exclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + unsigned int segments, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + const InitValueType initial_value, + BinaryFunction scan_op = BinaryFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::segmented_scan_impl( + temporary_storage, storage_size, + input, output, segments, begin_offsets, end_offsets, initial_value, + scan_op, stream, debug_synchronous + ); +} + +/// \brief Parallel segmented inclusive scan primitive for device level. +/// +/// segmented_inclusive_scan function performs a device-wide inclusive scan operation +/// across multiple sequences from \p input using binary \p scan_op operator. Beginnings +/// of the segments should be marked by value convertible to \p true at corresponding +/// position in \p flags range. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input, \p output, and \p flags must have at least \p size elements. +/// * \p value_type of \p HeadFlagIterator iterator should be convertible to \p bool type. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam HeadFlagIterator - random-access iterator type of flags. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for scan operation. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] head_flags - iterator to the first element in the range of head flags marking +/// beginnings of each segment in the input range. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level segmented inclusive sum operation is performed on +/// an array of integer values (shorts are added into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * flags; // e.g., [1, 0, 0, 1, 0, 1, 0, 0] +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, flags, size, ::rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, flags, size, ::rocprim::plus() +/// ); +/// // output: [1, 3, 6, 4, 9, 6, 13, 21] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class HeadFlagIterator, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t segmented_inclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + HeadFlagIterator head_flags, + size_t size, + BinaryFunction scan_op = BinaryFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using input_type = typename std::iterator_traits::value_type; + using result_type = input_type; + using flag_type = typename std::iterator_traits::value_type; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; + + return inclusive_scan( + temporary_storage, storage_size, + rocprim::make_zip_iterator(rocprim::make_tuple(input, head_flags)), + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), + size, headflag_scan_op_wrapper_type(scan_op), + stream, debug_synchronous + ); +} + +/// \brief Parallel segmented exclusive scan primitive for device level. +/// +/// segmented_exclusive_scan function performs a device-wide exclusive scan operation +/// across multiple sequences from \p input using binary \p scan_op operator. Beginnings +/// of the segments should be marked by value convertible to \p true at corresponding +/// position in \p flags range. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input, \p output, and \p flags must have at least \p size elements. +/// * \p value_type of \p HeadFlagIterator iterator should be convertible to \p bool type. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p scan_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam HeadFlagIterator - random-access iterator type of flags. Must meet the +/// requirements of a C++ RandomAccessIterator concept. It can be a simple pointer type. +/// \tparam InitValueType - type of the initial value. +/// \tparam BinaryFunction - type of binary function used for scan operation. Default type +/// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to scan. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] head_flags - iterator to the first element in the range of head flags marking +/// beginnings of each segment in the input range. +/// \param [in] initial_value - initial value to start the scan. +/// \param [in] size - number of element in the input range. +/// \param [in] scan_op - binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \returns \p cudaSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p cudaError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level segmented exclusive sum operation is performed on +/// an array of integer values (shorts are added into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * flags; // e.g., [1, 0, 0, 1, 0, 1, 0, 0] +/// int init; // e.g., 9 +/// int * output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::segmented_exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, flags, init, size, ::rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::exclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, flags, init, size, ::rocprim::plus() +/// ); +/// // output: [9, 10, 12, 9, 13, 9, 15, 22] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class InitValueType, + class HeadFlagIterator, + class BinaryFunction = ::rocprim::plus::value_type> +> +inline +cudaError_t segmented_exclusive_scan(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + HeadFlagIterator head_flags, + const InitValueType initial_value, + size_t size, + BinaryFunction scan_op = BinaryFunction(), + cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using result_type = InitValueType; + using flag_type = typename std::iterator_traits::value_type; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; + + const result_type initial_value_converted = static_cast(initial_value); + + // Flag the last item of each segment as the next segment's head, use initial_value as its value, + // then run exclusive scan + return exclusive_scan( + temporary_storage, storage_size, + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [input, head_flags, initial_value_converted, size] + ROCPRIM_DEVICE + (const size_t i) + { + flag_type flag(false); + if(i + 1 < size) + { + flag = head_flags[i + 1]; + } + result_type value = initial_value_converted; + if(!flag) + { + value = input[i]; + } + return rocprim::make_tuple(value, flag); + } + ), + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), + rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment + size, + headflag_scan_op_wrapper_type(scan_op), + stream, + debug_synchronous + ); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_select.hpp b/3rdparty/cub/rocprim/device/device_select.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cf83f904fd44c167a7b22f79c79c76092af0cab3 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_select.hpp @@ -0,0 +1,490 @@ +// Copyright (c) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/binary_op_wrappers.hpp" + +#include "../iterator/transform_iterator.hpp" + +#include "device_scan.hpp" +#include "device_partition.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + if(error != cudaSuccess) return error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto error = cudaStreamSynchronize(stream); \ + if(error != cudaSuccess) return error; \ + auto end = std::chrono::high_resolution_clock::now(); \ + auto d = std::chrono::duration_cast>(end - start); \ + std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +} // end detail namespace + +/// \brief Parallel select primitive for device level using range of flags. +/// +/// Performs a device-wide selection based on input \p flags. If a value from \p input +/// should be selected and copied into \p output range the corresponding item from +/// \p flags range should be set to such value that can be implicitly converted to +/// \p true (\p bool type). +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Ranges specified by \p input and \p flags must have at least \p size elements. +/// * Range specified by \p output must have at least so many elements, that all positively +/// flagged values can be copied into it. +/// * Range specified by \p selected_count_output must have at least 1 element. +/// * Values of \p flag range should be implicitly convertible to `bool` type. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p select_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam FlagIterator - random-access iterator type of the flag range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [in] flags - iterator to the selection flag corresponding to the first element from \p input range. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level select operation is performed on an array of +/// integer values with array of chars used as flags. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// char * flags; // e.g., [0, 1, 1, 0, 0, 1, 0, 1] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, +/// output, output_count, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform selection +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, +/// output, output_count, +/// input_size +/// ); +/// // output: [2, 3, 6, 8] +/// // output_count: 4 +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class FlagIterator, + class OutputIterator, + class SelectedCountOutputIterator +> +inline +cudaError_t select(void * temporary_storage, + size_t& storage_size, + InputIterator input, + FlagIterator flags, + OutputIterator output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy unary predicate + using unary_predicate_type = ::rocprim::empty_type; + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + using offset_type = unsigned int; + rocprim::empty_type* const no_values = nullptr; // key only + + return detail::partition_impl( + temporary_storage, storage_size, input, no_values, flags, output, no_values, selected_count_output, + size, inequality_op_type(), stream, debug_synchronous, unary_predicate_type() + ); +} + +/// \brief Parallel select primitive for device level using selection operator. +/// +/// Performs a device-wide selection using selection operator. If a value \p x from \p input +/// should be selected and copied into \p output range, then predicate(x) has to +/// return \p true. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Range specified by \p input must have at least \p size elements. +/// * Range specified by \p output must have at least so many elements, that all selected +/// values can be copied into it. +/// * Range specified by \p selected_count_output must have at least 1 element. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p select_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// \tparam UnaryPredicate - type of a unary selection predicate. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] predicate - unary function object that will be used for selecting values. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level select operation is performed on an array of +/// integer values, only even values are selected. +/// +/// \code{.cpp} +/// #include +/// +/// auto predicate = +/// [] __device__ (int a) -> bool +/// { +/// return (a%2) == 0; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, output_count, +/// predicate, input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform selection +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, output_count, +/// predicate, input_size +/// ); +/// // output: [2, 4, 6, 8] +/// // output_count: 4 +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class SelectedCountOutputIterator, + class UnaryPredicate +> +inline +cudaError_t select(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + UnaryPredicate predicate, + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy flag type + using flag_type = ::rocprim::empty_type; + using offset_type = unsigned int; + flag_type * flags = nullptr; + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + rocprim::empty_type* const no_values = nullptr; // key only + + return detail::partition_impl( + temporary_storage, storage_size, input, no_values, flags, output, no_values, selected_count_output, + size, inequality_op_type(), stream, debug_synchronous, predicate + ); +} + +/// \brief Device-level parallel unique primitive. +/// +/// From given \p input range unique primitive eliminates all but the first element from every +/// consecutive group of equivalent elements and copies them into \p output. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * Range specified by \p input must have at least \p size elements. +/// * Range specified by \p output must have at least so many elements, that all selected +/// values can be copied into it. +/// * Range specified by \p unique_count_output must have at least 1 element. +/// * By default InputIterator::value_type's equality operator is used to check +/// if elements are equivalent. +/// +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam UniqueCountOutputIterator - random-access iterator type of the unique_count_output +/// value used to return number of unique values. It can be a simple pointer type. +/// \tparam EqualityOp - type of an binary operator used to compare values for equality. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the unique operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] unique_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] equality_op - [optional] binary function object used to compare input values for equality. +/// The signature of the function should be equivalent to the following: +/// bool equal_to(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level unique operation is performed on an array of integer values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 4, 2, 4, 4, 7, 7, 7] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::unique( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, output_count, +/// input_size +/// ); +/// +/// // allocate temporary storage +/// cudaMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform unique operation +/// rocprim::unique( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, output_count, +/// input_size +/// ); +/// // output: [1, 4, 2, 4, 7] +/// // output_count: 5 +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class UniqueCountOutputIterator, + class EqualityOp = ::rocprim::equal_to::value_type> +> +inline +cudaError_t unique(void * temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + UniqueCountOutputIterator unique_count_output, + const size_t size, + EqualityOp equality_op = EqualityOp(), + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy unary predicate + using unary_predicate_type = ::rocprim::empty_type; + using offset_type = unsigned int; + // Dummy flag type + using flag_type = ::rocprim::empty_type; + const flag_type * flags = nullptr; + rocprim::empty_type* const no_values = nullptr; // key only + + // Convert equality operator to inequality operator + auto inequality_op = detail::inequality_wrapper(equality_op); + + return detail::partition_impl( + temporary_storage, storage_size, input, no_values, flags, output, no_values, unique_count_output, + size, inequality_op, stream, debug_synchronous, unary_predicate_type() + ); +} + +/// \brief Device-level parallel unique by key primitive. +/// +/// From given \p input range unique primitive eliminates all but the first element from every +/// consecutive group of equivalent elements and copies them and their corresponding keys into +/// \p output. +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage is a null pointer. +/// * Ranges specified by \p keys_input and value_input must have at least \p size elements each. +/// * Ranges specified by \p keys_output and values_output each must have at least so many elements, +/// that all selected values can be copied into them. +/// * Range specified by \p unique_count_output must have at least 1 element. +/// * By default InputIterator::value_type's equality operator is used to check +/// if elements are equivalent. +/// +/// \tparam KeyIterator - random-access iterator type of the input key range. It can be +/// a simple pointer type. +/// \tparam ValueIterator - random-access iterator type of the input value range. It can be +/// a simple pointer type. +/// \tparam OutputKeyIterator - random-access iterator type of the output key range. It can be +/// a simple pointer type. +/// \tparam OutputValueIterator - random-access iterator type of the output value range. It can be +/// a simple pointer type. +/// \tparam UniqueCountOutputIterator - random-access iterator type of the unique_count_output +/// value used to return number of unique keys and values. It can be a simple pointer type. +/// \tparam EqualityOp - type of an binary operator used to compare keys for equality. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the unique operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input - iterator to the first element in the range to select keys from. +/// \param [in] values_input - iterator to the first element in the range of values corresponding to keys +/// \param [out] keys_output - iterator to the first element in the output key range. +/// \param [out] values_output - iterator to the first element in the output value range. +/// \param [out] unique_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] equality_op - [optional] binary function object used to compare input values for equality. +/// The signature of the function should be equivalent to the following: +/// bool equal_to(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +template ::value_type>> +inline cudaError_t unique_by_key(void* temporary_storage, + size_t& storage_size, + const KeyIterator keys_input, + const ValueIterator values_input, + const OutputKeyIterator keys_output, + const OutputValueIterator values_output, + const UniqueCountOutputIterator unique_count_output, + const size_t size, + const EqualityOp equality_op = EqualityOp(), + const cudaStream_t stream = 0, + const bool debug_synchronous = false) +{ + using offset_type = unsigned int; + // Dummy flag + ::rocprim::empty_type* const no_flags = nullptr; + // Dummy predicate + const auto no_predicate = ::rocprim::empty_type{}; + + // Convert equality operator to inequality operator + const auto inequality_op = detail::inequality_wrapper(equality_op); + + return detail::partition_impl( + temporary_storage, + storage_size, + keys_input, + values_input, + no_flags, + keys_output, + values_output, + unique_count_output, + size, + inequality_op, + stream, + debug_synchronous, + no_predicate); +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SELECT_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_select_config.hpp b/3rdparty/cub/rocprim/device/device_select_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af7fbd988218b4b56893830a890586303d98d02b --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_select_config.hpp @@ -0,0 +1,161 @@ +// Copyright (c) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SELECT_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SELECT_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../block/block_load.hpp" +#include "../block/block_scan.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level select operation. +/// +/// \tparam BlockSize - number of threads in a block. +/// \tparam ItemsPerThread - number of items processed by each thread. +/// \tparam KeyBlockLoadMethod - method for loading input keys. +/// \tparam ValueBlockLoadMethod - method for loading input values. +/// \tparam FlagBlockLoadMethod - method for loading flag values. +/// \tparam BlockScanMethod - algorithm for block scan. +/// \tparam SizeLimit - limit on the number of items for a single select kernel launch. +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + ::rocprim::block_load_method KeyBlockLoadMethod, + ::rocprim::block_load_method ValueBlockLoadMethod, + ::rocprim::block_load_method FlagBlockLoadMethod, + ::rocprim::block_scan_algorithm BlockScanMethod, + unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT +> +struct select_config +{ + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Method for loading input keys. + static constexpr block_load_method key_block_load_method = KeyBlockLoadMethod; + /// \brief Method for loading input values. + static constexpr block_load_method value_block_load_method = ValueBlockLoadMethod; + /// \brief Method for loading flag values. + static constexpr block_load_method flag_block_load_method = FlagBlockLoadMethod; + /// \brief Algorithm for block scan. + static constexpr block_scan_algorithm block_scan_method = BlockScanMethod; + /// \brief Limit on the number of items for a single select kernel launch. + static constexpr unsigned int size_limit = SizeLimit; +}; + +namespace detail +{ + +template +struct select_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_config< + limit_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 13u / item_scale), + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct select_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type = select_config< + limit_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale), + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct select_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = select_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 15u / item_scale), + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + +template +struct select_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = select_config< + limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_32>::value, + ::rocprim::max(1u, 15u / item_scale), + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + >; +}; + + +template +struct default_select_config + : select_arch< + TargetArch, + select_arch_case<803, select_config_803>, + select_arch_case<900, select_config_900>, + select_arch_case>, + select_arch_case<1030, select_config_1030>, + select_config_803 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_SELECT_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_transform.hpp b/3rdparty/cub/rocprim/device/device_transform.hpp new file mode 100644 index 0000000000000000000000000000000000000000..153dd87799c6229104daf69cf4c3a9f06feaf9ee --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_transform.hpp @@ -0,0 +1,295 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_TRANSFORM_HPP_ +#define ROCPRIM_DEVICE_DEVICE_TRANSFORM_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" +#include "../types/tuple.hpp" +#include "../iterator/zip_iterator.hpp" + +#include "device_transform_config.hpp" +#include "detail/device_transform.hpp" +#include + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +namespace detail +{ + +template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class ResultType, + class InputIterator, + class OutputIterator, + class UnaryFunction +> +ROCPRIM_KERNEL +__launch_bounds__(BlockSize) +void transform_kernel(InputIterator input, + const size_t size, + OutputIterator output, + UnaryFunction transform_op) +{ + transform_kernel_impl( + input, size, output, transform_op + ); +} + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + _error = cudaStreamSynchronize(stream); \ + if(_error != cudaSuccess) return _error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + +} // end of detail namespace + +/// \brief Parallel transform primitive for device level. +/// +/// transform function performs a device-wide transformation operation +/// using unary \p transform_op operator. +/// +/// \par Overview +/// * Ranges specified by \p input and \p output must have at least \p size elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p transform_config or +/// a custom class with the same members. +/// \tparam InputIterator - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam UnaryFunction - type of unary function used for transform. +/// +/// \param [in] input - iterator to the first element in the range to transform. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] transform_op - unary operation function object that will be used for transform. +/// The signature of the function should be equivalent to the following: +/// U f(const T &a);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level transform operation is performed on an array of +/// integer values (shorts are transformed into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // custom transform function +/// auto transform_op = +/// [] __device__ (int a) -> int +/// { +/// return a + 5; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 8 elements +/// +/// // perform transform +/// rocprim::transform( +/// input, output, input_size, transform_op +/// ); +/// // output: [6, 7, 8, 9, 10, 11, 12, 13] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator, + class OutputIterator, + class UnaryFunction +> +inline +cudaError_t transform(InputIterator input, + OutputIterator output, + const size_t size, + UnaryFunction transform_op, + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + if( size == size_t(0) ) + return cudaSuccess; + + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::invoke_result::type; + + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_transform_config + >; + + static constexpr unsigned int block_size = config::block_size; + static constexpr unsigned int items_per_thread = config::items_per_thread; + static constexpr auto items_per_block = block_size * items_per_thread; + + // Start point for time measurements + std::chrono::high_resolution_clock::time_point start; + + static constexpr auto size_limit = config::size_limit; + static constexpr auto number_of_blocks_limit + = ::rocprim::max(size_limit / items_per_block, 1); + + auto number_of_blocks = (size + items_per_block - 1)/items_per_block; + if(debug_synchronous) + { + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "number of blocks limit " << number_of_blocks_limit << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + static constexpr auto aligned_size_limit = number_of_blocks_limit * items_per_block; + + // Launch number_of_blocks_limit blocks while there is still at least as many blocks left as the limit + const auto number_of_launch = (size + aligned_size_limit - 1) / aligned_size_limit; + for(size_t i = 0, offset = 0; i < number_of_launch; ++i, offset += aligned_size_limit) { + const auto current_size = std::min(size - offset, aligned_size_limit); + const auto current_blocks = (current_size + items_per_block - 1) / items_per_block; + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + detail::transform_kernel< + block_size, items_per_thread, result_type, + InputIterator, OutputIterator, UnaryFunction + > + <<>>( + input + offset, current_size, output + offset, transform_op + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform_kernel", current_size, start); + } + + return cudaSuccess; +} + +/// \brief Parallel device-level transform primitive for two inputs. +/// +/// transform function performs a device-wide transformation operation +/// on two input ranges using binary \p transform_op operator. +/// +/// \par Overview +/// * Ranges specified by \p input1, \p input2, and \p output must have at least \p size elements. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p transform_config or +/// a custom class with the same members. +/// \tparam InputIterator1 - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam InputIterator2 - random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction - type of binary function used for transform. +/// +/// \param [in] input1 - iterator to the first element in the 1st range to transform. +/// \param [in] input2 - iterator to the first element in the 2nd range to transform. +/// \param [out] output - iterator to the first element in the output range. +/// \param [in] size - number of element in the input range. +/// \param [in] transform_op - binary operation function object that will be used for transform. +/// The signature of the function should be equivalent to the following: +/// U f(const T1& a, const T2& b);. The signature does not need to have +/// const &, but function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced. Default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level transform operation is performed on two arrays of +/// integer values (element-wise sum is performed). +/// +/// \code{.cpp} +/// #include +/// +/// // custom transform function +/// auto transform_op = +/// [] __device__ (int a, int b) -> int +/// { +/// return a + b; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// int* input1; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* input2; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* output; // empty array of 8 elements +/// +/// // perform transform +/// rocprim::transform( +/// input1, input2, output, input1.size(), transform_op +/// ); +/// // output: [2, 4, 6, 8, 10, 12, 14, 16] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator1, + class InputIterator2, + class OutputIterator, + class BinaryFunction +> +inline +cudaError_t transform(InputIterator1 input1, + InputIterator2 input2, + OutputIterator output, + const size_t size, + BinaryFunction transform_op, + const cudaStream_t stream = 0, + bool debug_synchronous = false) +{ + using value_type1 = typename std::iterator_traits::value_type; + using value_type2 = typename std::iterator_traits::value_type; + return transform( + ::rocprim::make_zip_iterator(::rocprim::make_tuple(input1, input2)), output, + size, detail::unpack_binary_op(transform_op), + stream, debug_synchronous + ); +} + +#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_TRANSFORM_HPP_ diff --git a/3rdparty/cub/rocprim/device/device_transform_config.hpp b/3rdparty/cub/rocprim/device/device_transform_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7e22ba74596d149377fc9e76d571e848bef7d0c0 --- /dev/null +++ b/3rdparty/cub/rocprim/device/device_transform_config.hpp @@ -0,0 +1,100 @@ +// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_TRANSFORM_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_TRANSFORM_CONFIG_HPP_ + +#include + +#include "../config.hpp" +#include "../functional.hpp" +#include "../detail/various.hpp" + +#include "config_types.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Configuration of device-level transform primitives. +template +using transform_config = kernel_config; + +namespace detail +{ + +template +struct transform_config_803 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; +}; + +template +struct transform_config_900 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; +}; + +template +struct transform_config_90a +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; +}; + +template +struct transform_config_1030 +{ + static constexpr unsigned int item_scale = + ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); + + using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; +}; + +template +struct default_transform_config + : select_arch< + TargetArch, + select_arch_case<803, transform_config_803>, + select_arch_case<900, transform_config_900>, + select_arch_case>, + select_arch_case<1030, transform_config_1030>, + transform_config_900 + > { }; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_TRANSFORM_CONFIG_HPP_ diff --git a/3rdparty/cub/rocprim/device/specialization/device_radix_merge_sort.hpp b/3rdparty/cub/rocprim/device/specialization/device_radix_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..521ebab18ce6d5fec8fd403936d9523a3697d427 --- /dev/null +++ b/3rdparty/cub/rocprim/device/specialization/device_radix_merge_sort.hpp @@ -0,0 +1,184 @@ +// Copyright (c) 2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_MERGE_SORT_HPP_ +#define ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_MERGE_SORT_HPP_ + +#include "../detail/device_radix_sort.hpp" +#include "../specialization/device_radix_single_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator, + class BinaryFunction + > + ROCPRIM_KERNEL + __launch_bounds__(BlockSize) + void radix_block_merge_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const size_t input_size, + const unsigned int merge_items_per_block_size, + BinaryFunction compare_function) + { + radix_block_merge_impl( + keys_input, keys_output, + values_input, values_output, + input_size, merge_items_per_block_size, + compare_function + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_merge(KeysInputIterator keys_input, + typename std::iterator_traits::value_type * keys_buffer, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type * values_buffer, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + constexpr bool with_values = !std::is_same::value; + + constexpr unsigned int items_per_thread = Config::sort_merge::items_per_thread; + constexpr unsigned int block_size = Config::sort_merge::block_size; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + const unsigned int current_radix_bits = end_bit - bit; + auto number_of_blocks = (size + items_per_block - 1) / items_per_block; + + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) + { + std::cout << "block size " << block_size << '\n'; + std::cout << "items per thread " << items_per_thread << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "bit " << bit << '\n'; + std::cout << "current_radix_bits " << current_radix_bits << '\n'; + } + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + sort_single_kernel< + block_size, items_per_thread , Descending + > + <<>>( + keys_input, keys_buffer, values_input, values_buffer, + size, bit, current_radix_bits + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_single", size, start) + + bool temporary_store = true; + for(unsigned int block = items_per_block; block < size; block *= 2) + { + temporary_store = !temporary_store; + if(temporary_store) + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if( current_radix_bits == sizeof(key_type) * 8 ) + { + radix_block_merge_kernel + <<>>( + keys_output, keys_buffer, values_output, values_buffer, + size, block, radix_merge_compare() + ); + } + else + { + radix_block_merge_kernel + <<>>( + keys_output, keys_buffer, values_output, values_buffer, + size, block, radix_merge_compare(bit, current_radix_bits) + ); + } + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_block_merge_kernel", size, start); + } + else + { + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if( current_radix_bits == sizeof(key_type) * 8 ) + { + radix_block_merge_kernel + <<>>( + keys_buffer, keys_output, values_buffer, values_output, + size, block, radix_merge_compare() + ); + } + else + { + radix_block_merge_kernel + <<>>( + keys_buffer, keys_output, values_buffer, values_output, + size, block, radix_merge_compare(bit, current_radix_bits) + ); + } + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_block_merge_kernel", size, start); + } + } + + if(temporary_store) + { + cudaError_t error = ::rocprim::transform( + keys_buffer, keys_output, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + + if(with_values) + { + cudaError_t error = ::rocprim::transform( + values_buffer, values_output, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != cudaSuccess) return error; + } + } + + return cudaSuccess; + } +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_MERGE_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/device/specialization/device_radix_single_sort.hpp b/3rdparty/cub/rocprim/device/specialization/device_radix_single_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5588997289e63360905acdc8f7d87686e0e6585a --- /dev/null +++ b/3rdparty/cub/rocprim/device/specialization/device_radix_single_sort.hpp @@ -0,0 +1,1010 @@ +// Copyright (c) 2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_SINGLE_SORT_HPP_ +#define ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_SINGLE_SORT_HPP_ + +#include "../detail/device_radix_sort.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + { \ + auto _error = cudaGetLastError(); \ + if(_error != cudaSuccess) return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = cudaStreamSynchronize(stream); \ + if(__error != cudaSuccess) return __error; \ + auto _end = std::chrono::high_resolution_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } + + template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + ROCPRIM_KERNEL + __launch_bounds__(BlockSize) + void sort_single_kernel(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int current_radix_bits) + { + sort_single( + keys_input, keys_output, + values_input, values_output, + size, bit, current_radix_bits + ); + } + + template< + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + const unsigned int current_radix_bits = end_bit - bit; + + std::chrono::high_resolution_clock::time_point start; + if(debug_synchronous) + { + std::cout << "BlockSize " << BlockSize << '\n'; + std::cout << "ItemsPerThread " << ItemsPerThread << '\n'; + std::cout << "bit " << bit << '\n'; + std::cout << "current_radix_bits " << current_radix_bits << '\n'; + } + + if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + + sort_single_kernel< + BlockSize, ItemsPerThread, Descending + > + <<>>( + keys_input, keys_output, values_input, values_output, + size, bit, current_radix_bits + ); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_single", size, start) + + return cudaSuccess; + } + + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit64(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + return radix_sort_single<64U, 1U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit128(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 64U ) + return radix_sort_single_limit64( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<64U, 2U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit192(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 128U ) + return radix_sort_single_limit128( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<64U, 3U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit256(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 192U ) + return radix_sort_single_limit192( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<64U, 4U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit320(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 256U ) + return radix_sort_single_limit256( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<64U, 5U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit512(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 320U ) + return radix_sort_single_limit320( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 2U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit768(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 512U ) + return radix_sort_single_limit512( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 3U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit1024(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 768U ) + return radix_sort_single_limit768( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 4U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit1536(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 1024U ) + return radix_sort_single_limit1024( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 6U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit2048(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 1536U ) + return radix_sort_single_limit1536( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 8U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit2560(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 2048U ) + return radix_sort_single_limit2048( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 10U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit3072(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 2560U ) + return radix_sort_single_limit2560( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 12U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit3584(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 3072U ) + return radix_sort_single_limit3072( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 14U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + cudaError_t radix_sort_single_limit4096(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + { + if( !Config::force_single_kernel_config && size <= 3584U ) + return radix_sort_single_limit3584( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single<256U, 16U, Descending>( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 64U, + cudaError_t + >::type + { + return radix_sort_single_limit64( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 64U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 128U, + cudaError_t + >::type + { + return radix_sort_single_limit128( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 128U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 192U, + cudaError_t + >::type + { + return radix_sort_single_limit192( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 192U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 256U, + cudaError_t + >::type + { + return radix_sort_single_limit256( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 256U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 320U, + cudaError_t + >::type + { + return radix_sort_single_limit320( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 320U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 512U, + cudaError_t + >::type + { + return radix_sort_single_limit512( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 512U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 768U, + cudaError_t + >::type + { + return radix_sort_single_limit768( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 768U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 1024U, + cudaError_t + >::type + { + return radix_sort_single_limit1024( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 1024U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 1536U, + cudaError_t + >::type + { + return radix_sort_single_limit1536( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 1536U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 2048U, + cudaError_t + >::type + { + return radix_sort_single_limit2048( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 2048U) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 2560U, + cudaError_t + >::type + { + return radix_sort_single_limit2560( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 2560) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 3072, + cudaError_t + >::type + { + return radix_sort_single_limit3072( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 3072) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 3584, + cudaError_t + >::type + { + return radix_sort_single_limit3584( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 3584) && + Config::sort_single::items_per_thread * Config::sort_single::block_size <= 4096, + cudaError_t + >::type + { + return radix_sort_single_limit4096( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + + template< + class Config, + bool Descending, + class KeysInputIterator, + class KeysOutputIterator, + class ValuesInputIterator, + class ValuesOutputIterator + > + inline + auto radix_sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + unsigned int bit, + unsigned int end_bit, + cudaStream_t stream, + bool debug_synchronous) + -> typename std::enable_if< + (Config::sort_single::items_per_thread * Config::sort_single::block_size > 4096), + cudaError_t + >::type + { + if( size < 4096 ) + return radix_sort_single_limit4096( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + else + return radix_sort_single< + Config::sort_single::block_size, + Config::sort_single::items_per_thread, + Descending + >( + keys_input, keys_output, values_input, values_output, + size, bit, end_bit, stream, debug_synchronous + ); + } + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_SINGLE_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/functional.hpp b/3rdparty/cub/rocprim/functional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f36b8c75e76674fd312d1cb11894ca85eb0bf5d --- /dev/null +++ b/3rdparty/cub/rocprim/functional.hpp @@ -0,0 +1,384 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_FUNCTIONAL_HPP_ +#define ROCPRIM_FUNCTIONAL_HPP_ + +#include + +// Meta configuration for rocPRIM +#include "config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup utilsmodule_functional +/// @{ + +#define ROCPRIM_PRINT_ERROR_ONCE(message) \ +{ \ + unsigned int idx = threadIdx.x + (blockIdx.x * blockDim.x); \ + idx += threadIdx.y + (blockIdx.y * blockDim.y); \ + idx += threadIdx.z + (blockIdx.z * blockDim.z); \ + if (idx == 0) \ + printf("%s\n", #message); \ +} + +template +ROCPRIM_HOST_DEVICE inline +constexpr T max(const T& a, const T& b) +{ + return a < b ? b : a; +} + +template +ROCPRIM_HOST_DEVICE inline +constexpr T min(const T& a, const T& b) +{ + return a < b ? a : b; +} + +template +ROCPRIM_HOST_DEVICE inline +void swap(T& a, T& b) +{ + T c = a; + a = b; + b = c; +} + +template +struct less +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a < b; + } +}; + +template<> +struct less +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const U& b) const + { + return a < b; + } +}; + +template +struct less_equal +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a <= b; + } +}; + +template<> +struct less_equal +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a <= b; + } +}; + +template +struct greater +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a > b; + } +}; + +template<> +struct greater +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a > b; + } +}; + +template +struct greater_equal +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a >= b; + } +}; + +template<> +struct greater_equal +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a >= b; + } +}; + +template +struct equal_to +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a == b; + } +}; + +template<> +struct equal_to +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a == b; + } +}; + +template +struct not_equal_to +{ + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a != b; + } +}; + +template<> +struct not_equal_to +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a != b; + } +}; + +template +struct plus +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a + b; + } +}; + +template<> +struct plus +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a + b; + } +}; + +template +struct minus +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a - b; + } +}; + +template<> +struct minus +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a - b; + } +}; + +template +struct multiplies +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a * b; + } +}; + +template<> +struct multiplies +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a * b; + } +}; + +template +struct maximum +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a < b ? b : a; + } +}; + +template<> +struct maximum +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a < b ? b : a; + } +}; + +template +struct minimum +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a < b ? a : b; + } +}; + +template<> +struct minimum +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a, const T& b) const + { + return a < b ? a : b; + } +}; + +template +struct identity +{ + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a) const + { + return a; + } +}; + +template<> +struct identity +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T& a) const + { + return a; + } +}; + +/** + * \brief Statically determine log2(N), rounded up. + * + * For example: + * Log2<8>::VALUE // 3 + * Log2<3>::VALUE // 2 + */ +template +struct Log2 +{ + /// Static logarithm value + enum { VALUE = Log2> 1), COUNT + 1>::VALUE }; // Inductive case +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct Log2 +{ + enum {VALUE = (1 << (COUNT - 1) < N) ? // Base case + COUNT : + COUNT - 1 }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/****************************************************************************** + * Conditional types + ******************************************************************************/ + +/** + * \brief Type equality test + */ +template +struct Equals +{ + enum { + VALUE = 0, + NEGATE = 1 + }; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +struct Equals +{ + enum { + VALUE = 1, + NEGATE = 0 + }; +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +template +struct Int2Type +{ + enum {VALUE = A}; +}; + +/// @} +// end of group utilsmodule_functional + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_FUNCTIONAL_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics.hpp b/3rdparty/cub/rocprim/intrinsics.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5142d52b53718f07769417185af4f6b8a08b5bf --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics.hpp @@ -0,0 +1,33 @@ +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_HPP_ +#define ROCPRIM_INTRINSICS_HPP_ + +// Meta configuration for rocPRIM +#include "config.hpp" + +#include "intrinsics/atomic.hpp" +#include "intrinsics/bit.hpp" +#include "intrinsics/thread.hpp" +#include "intrinsics/warp.hpp" +#include "intrinsics/warp_shuffle.hpp" + +#endif // ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics/atomic.hpp b/3rdparty/cub/rocprim/intrinsics/atomic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..264afe4b5a3954f168cddb861d480cab14fc10f3 --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics/atomic.hpp @@ -0,0 +1,75 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_ATOMIC_HPP_ +#define ROCPRIM_INTRINSICS_ATOMIC_HPP_ + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_add(unsigned int * address, unsigned int value) + { + return ::atomicAdd(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + int atomic_add(int * address, int value) + { + return ::atomicAdd(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + float atomic_add(float * address, float value) + { + return ::atomicAdd(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned long long atomic_add(unsigned long long * address, unsigned long long value) + { + return ::atomicAdd(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_wrapinc(unsigned int * address, unsigned int value) + { + return ::atomicInc(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_exch(unsigned int * address, unsigned int value) + { + return ::atomicExch(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned long long atomic_exch(unsigned long long * address, unsigned long long value) + { + return ::atomicExch(address, value); + } +} + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_INTRINSICS_ATOMIC_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics/bit.hpp b/3rdparty/cub/rocprim/intrinsics/bit.hpp new file mode 100644 index 0000000000000000000000000000000000000000..816208098080c6717fb719f5de3b99424c9a89e2 --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics/bit.hpp @@ -0,0 +1,61 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_BIT_HPP_ +#define ROCPRIM_INTRINSICS_BIT_HPP_ + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup intrinsicsmodule +/// @{ + +/// \brief Returns a single bit at 'i' from 'x' +ROCPRIM_DEVICE ROCPRIM_INLINE +int get_bit(int x, int i) +{ + return (x >> i) & 1; +} + +/// \brief Bit count +/// +/// Returns the number of bit of \p x set. +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int bit_count(unsigned int x) +{ + return __popc(x); +} + +/// \brief Bit count +/// +/// Returns the number of bit of \p x set. +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int bit_count(unsigned long long x) +{ + return __popcll(x); +} + +/// @} +// end of group intrinsicsmodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_INTRINSICS_BIT_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics/thread.hpp b/3rdparty/cub/rocprim/intrinsics/thread.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5d5a1c0897753520832b1034b605cfcda0d3b170 --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics/thread.hpp @@ -0,0 +1,344 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_THREAD_HPP_ +#define ROCPRIM_INTRINSICS_THREAD_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup intrinsicsmodule +/// @{ + +// Sizes + +/// \brief [DEPRECATED] Returns a number of threads in a hardware warp. +/// +/// It is constant for a device. +/// This function is not supported for the gfx1030 architecture and will be removed in a future release. +/// Please use the new host_warp_size() and device_warp_size() functions. +ROCPRIM_HOST_DEVICE inline +constexpr unsigned int warp_size() +{ + return warpSize; +} + +/// \brief Returns a number of threads in a hardware warp for the actual device. +/// At host side this constant is available at runtime time only. +/// +/// It is constant for a device. +ROCPRIM_HOST inline +unsigned int host_warp_size() +{ + int default_hip_device; + cudaError_t success = cudaGetDevice(&default_hip_device); + cudaDeviceProp device_prop; + success = cudaGetDeviceProperties(&device_prop,default_hip_device); + + if(success != cudaSuccess) + return -1; + else + return device_prop.warpSize; +}; + +/// \brief Returns a number of threads in a hardware warp for the actual target. +/// At device side this constant is available at compile time. +/// +/// It is constant for a device. +ROCPRIM_DEVICE ROCPRIM_INLINE +constexpr unsigned int device_warp_size() +{ + return warpSize; +} + +/// \brief Returns flat size of a multidimensional block (tile). +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int flat_block_size() +{ + return blockDim.z * blockDim.y * blockDim.x; +} + +/// \brief Returns flat size of a multidimensional tile (block). +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int flat_tile_size() +{ + return flat_block_size(); +} + +// IDs + +/// \brief Returns thread identifier in a warp. +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int lane_id() +{ +#ifndef __HIP_CPU_RT__ + return ::__lane_id(); +#else + using namespace hip::detail; + return id(Fiber::this_fiber()) % warpSize; +#endif +} + +/// \brief Returns flat (linear, 1D) thread identifier in a multidimensional block (tile). +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int flat_block_thread_id() +{ + return (threadIdx.z * blockDim.y * blockDim.x) + + (threadIdx.y * blockDim.x) + + threadIdx.x; +} + +/// \brief Returns flat (linear, 1D) thread identifier in a multidimensional block (tile). Use template parameters to optimize 1D or 2D kernels. +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_thread_id() + -> typename std::enable_if<(BlockSizeY == 1 && BlockSizeZ == 1), unsigned int>::type +{ + return threadIdx.x; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_thread_id() + -> typename std::enable_if<(BlockSizeY > 1 && BlockSizeZ == 1), unsigned int>::type +{ + return threadIdx.x + (threadIdx.y * blockDim.x); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_thread_id() + -> typename std::enable_if<(BlockSizeY > 1 && BlockSizeZ > 1), unsigned int>::type +{ + return threadIdx.x + (threadIdx.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x); +} + +/// \brief Returns flat (linear, 1D) thread identifier in a multidimensional tile (block). +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int flat_tile_thread_id() +{ + return flat_block_thread_id(); +} + +/// \brief Returns warp id in a block (tile). +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int warp_id() +{ + return flat_block_thread_id()/device_warp_size(); +} + +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int warp_id(unsigned int flat_id) +{ + return flat_id/device_warp_size(); +} + +/// \brief Returns warp id in a block (tile). Use template parameters to optimize 1D or 2D kernels. +template +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int warp_id() +{ + return flat_block_thread_id()/device_warp_size(); +} + +/// \brief Returns flat (linear, 1D) block identifier in a multidimensional grid. +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int flat_block_id() +{ + return (blockIdx.z * gridDim.y * gridDim.x) + + (blockIdx.y * gridDim.x) + + blockIdx.x; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_id() + -> typename std::enable_if<(BlockSizeY == 1 && BlockSizeZ == 1), unsigned int>::type +{ + return blockIdx.x; +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_id() + -> typename std::enable_if<(BlockSizeY > 1 && BlockSizeZ == 1), unsigned int>::type +{ + return blockIdx.x + (blockIdx.y * gridDim.x); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto flat_block_id() + -> typename std::enable_if<(BlockSizeY > 1 && BlockSizeZ > 1), unsigned int>::type +{ + return blockIdx.x + (blockIdx.y * gridDim.x) + + (blockIdx.z * gridDim.y * gridDim.x); +} + +// Sync + +/// \brief Synchronize all threads in a block (tile) +ROCPRIM_DEVICE ROCPRIM_INLINE +void syncthreads() +{ + __syncthreads(); +} + +/// \brief All lanes in a wave come to convergence point simultaneously +/// with SIMT, thus no special instruction is needed in the ISA +ROCPRIM_DEVICE ROCPRIM_INLINE +void wave_barrier() +{ + __builtin_amdgcn_wave_barrier(); +} + +namespace detail +{ + /// \brief Returns thread identifier in a multidimensional block (tile) by dimension. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int block_thread_id() + { + static_assert(Dim > 2, "Dim must be 0, 1 or 2"); + // dummy return, correct values handled by specializations + return 0; + } + + /// \brief Returns block identifier in a multidimensional grid by dimension. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int block_id() + { + static_assert(Dim > 2, "Dim must be 0, 1 or 2"); + // dummy return, correct values handled by specializations + return 0; + } + + /// \brief Returns block size in a multidimensional grid by dimension. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int block_size() + { + static_assert(Dim > 2, "Dim must be 0, 1 or 2"); + // dummy return, correct values handled by specializations + return 0; + } + + /// \brief Returns grid size by dimension. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int grid_size() + { + static_assert(Dim > 2, "Dim must be 0, 1 or 2"); + // dummy return, correct values handled by specializations + return 0; + } + + #define ROCPRIM_DETAIL_CONCAT(A, B) A B + #define ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNC(name, prefix, dim, suffix) \ + template<> \ + ROCPRIM_DEVICE ROCPRIM_INLINE \ + unsigned int name() \ + { \ + return ROCPRIM_DETAIL_CONCAT(prefix, suffix); \ + } + #define ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS(name, prefix) \ + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNC(name, prefix, 0, x) \ + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNC(name, prefix, 1, y) \ + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNC(name, prefix, 2, z) + + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS(block_thread_id, threadIdx.) + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS(block_id, blockIdx.) + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS(block_size, blockDim.) + ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS(grid_size, gridDim.) + + #undef ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNCS + #undef ROCPRIM_DETAIL_DEFINE_HIP_API_ID_FUNC + #undef ROCPRIM_DETAIL_CONCAT + + // Return thread id in a "logical warp", which can be smaller than a hardware warp size. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto logical_lane_id() + -> typename std::enable_if::type + { + return lane_id() & (LogicalWarpSize-1); // same as land_id()%WarpSize + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto logical_lane_id() + -> typename std::enable_if::type + { + return lane_id()%LogicalWarpSize; + } + + template<> + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int logical_lane_id() + { + return lane_id(); + } + + // Return id of "logical warp" in a block + template + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int logical_warp_id() + { + return flat_block_thread_id()/LogicalWarpSize; + } + + template<> + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int logical_warp_id() + { + return warp_id(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void memory_fence_system() + { + ::__threadfence_system(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void memory_fence_block() + { + ::__threadfence_block(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void memory_fence_device() + { + ::__threadfence(); + } +} + +/// @} +// end of group intrinsicsmodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_INTRINSICS_THREAD_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics/warp.hpp b/3rdparty/cub/rocprim/intrinsics/warp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6629b7438dacc306ebd0505d3f09f57c4063ec19 --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics/warp.hpp @@ -0,0 +1,151 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_WARP_HPP_ +#define ROCPRIM_INTRINSICS_WARP_HPP_ + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup intrinsicsmodule +/// @{ + +/// Evaluate predicate for all active work-items in the warp and return an integer +/// whose i-th bit is set if and only if \p predicate is true +/// for the i-th thread of the warp and the i-th thread is active. +/// +/// \param predicate - input to be evaluated for all active lanes +ROCPRIM_DEVICE ROCPRIM_INLINE +lane_mask_type ballot(int predicate) +{ + return ::__ballot(predicate); +} + +/// \brief Masked bit count +/// +/// For each thread, this function returns the number of active threads which +/// have i-th bit of \p x set and come before the current thread. +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int masked_bit_count(lane_mask_type x, unsigned int add = 0) +{ + int c; + #ifndef __HIP_CPU_RT__ + #if __AMDGCN_WAVEFRONT_SIZE == 32 + #ifdef __CUDACC__ + c = ::__builtin_amdgcn_mbcnt_lo(x, add); + #else + c = ::__mbcnt_lo(x, add); + #endif + #else + #ifdef __CUDACC__ + c = ::__builtin_amdgcn_mbcnt_lo(static_cast(x), add); + c = ::__builtin_amdgcn_mbcnt_hi(static_cast(x >> 32), c); + #else + c = ::__mbcnt_lo(static_cast(x), add); + c = ::__mbcnt_hi(static_cast(x >> 32), c); + #endif + #endif + #else + using namespace hip::detail; + const auto tidx{id(Fiber::this_fiber()) % warpSize}; + std::bitset bits{x >> (warpSize - tidx)}; + c = static_cast(bits.count()) + add; + #endif + return c; +} + +namespace detail +{ + +ROCPRIM_DEVICE ROCPRIM_INLINE +int warp_any(int predicate) +{ +#ifndef __HIP_CPU_RT__ + return ::__any(predicate); +#else + using namespace hip::detail; + const auto tidx{id(Fiber::this_fiber()) % warpSize}; + auto& lds{Tile::scratchpad, 1>()[0]}; + + lds[tidx] = static_cast(predicate); + + barrier(Tile::this_tile()); + + return lds.any(); +#endif +} + +ROCPRIM_DEVICE ROCPRIM_INLINE +int warp_all(int predicate) +{ +#ifndef __HIP_CPU_RT__ + return ::__all(predicate); +#else + using namespace hip::detail; + const auto tidx{id(Fiber::this_fiber()) % warpSize}; + auto& lds{Tile::scratchpad, 1>()[0]}; + + lds[tidx] = static_cast(predicate); + + barrier(Tile::this_tile()); + + return lds.all(); +#endif +} + +} // end detail namespace + +/// @} +// end of group intrinsicsmodule + +/** + * Compute a 32b mask of threads having the same least-significant + * LABEL_BITS of \p label as the calling thread. + */ +template +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int MatchAny(unsigned int label) +{ + unsigned int retval; + + // Extract masks of common threads for each bit + ROCPRIM_UNROLL + for (int BIT = 0; BIT < LABEL_BITS; ++BIT) + { + unsigned long long mask; + unsigned long long current_bit = 1 << BIT; + mask = label & current_bit; + bool bit_match = (mask==current_bit); + mask = ballot(bit_match); + if(!bit_match) + { + mask = ! mask; + } + // Remove peers who differ + retval = (BIT == 0) ? mask : retval & mask; + } + + return retval; + +} +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_INTRINSICS_WARP_HPP_ diff --git a/3rdparty/cub/rocprim/intrinsics/warp_shuffle.hpp b/3rdparty/cub/rocprim/intrinsics/warp_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00c3f1d1035b834cee9fa3c5c95b598f36d423af --- /dev/null +++ b/3rdparty/cub/rocprim/intrinsics/warp_shuffle.hpp @@ -0,0 +1,262 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ +#define ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ + +#include + +#include "../config.hpp" +#include "thread.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +#ifdef __HIP_CPU_RT__ +// TODO: consider adding macro checks relaying to std::bit_cast when compiled +// using C++20. +template +typename std::enable_if_t< + sizeof(To) == sizeof(From) && + std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> +// constexpr support needs compiler magic +bit_cast(const From& src) noexcept +{ + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} +#endif + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::value && (sizeof(T) % sizeof(int) == 0), T>::type +warp_shuffle_op(const T& input, ShuffleOp&& op) +{ + constexpr int words_no = (sizeof(T) + sizeof(int) - 1) / sizeof(int); + + struct V { int words[words_no]; }; +#ifdef __HIP_CPU_RT__ + V a = bit_cast(input); +#else + V a = __builtin_bit_cast(V, input); +#endif + + ROCPRIM_UNROLL + for(int i = 0; i < words_no; i++) + { + a.words[i] = op(a.words[i]); + } + +#ifdef __HIP_CPU_RT__ + return bit_cast(a); +#else + return __builtin_bit_cast(T, a); +#endif +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::enable_if::value && (sizeof(T) % sizeof(int) == 0)), T>::type +warp_shuffle_op(const T& input, ShuffleOp&& op) +{ + constexpr int words_no = (sizeof(T) + sizeof(int) - 1) / sizeof(int); + + T output; + ROCPRIM_UNROLL + for(int i = 0; i < words_no; i++) + { + const size_t s = std::min(sizeof(int), sizeof(T) - i * sizeof(int)); + int word; +#ifdef __HIP_CPU_RT__ + std::memcpy(&word, reinterpret_cast(&input) + i * sizeof(int), s); +#else + __builtin_memcpy(&word, reinterpret_cast(&input) + i * sizeof(int), s); +#endif + word = op(word); +#ifdef __HIP_CPU_RT__ + std::memcpy(reinterpret_cast(&output) + i * sizeof(int), &word, s); +#else + __builtin_memcpy(reinterpret_cast(&output) + i * sizeof(int), &word, s); +#endif + } + + return output; + +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_move_dpp(const T& input) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + // TODO: clean-up, this function activates based ROCPRIM_DETAIL_USE_DPP, however inclusion and + // parsing of the template happens unconditionally. The condition causing compilation to + // fail is ordinary host-compilers looking at the headers. Non-hipcc compilers don't define + // __builtin_amdgcn_update_dpp, hence fail to parse the template altogether. (Except MSVC + // because even using /permissive- they somehow still do delayed parsing of the body of + // function templates, even though they pinky-swear they don't.) +#if !defined(__HIP_CPU_RT__) + return ::__builtin_amdgcn_mov_dpp(v, dpp_ctrl, row_mask, bank_mask, bound_ctrl); +#else + return v; +#endif + } + ); +} + +/// \brief Swizzle for any data type. +/// +/// Each thread in warp obtains \p input from src_lane-th thread +/// in warp, where src_lane is current lane with a mask applied. +/// +/// \param input - input to pass to other threads +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_swizzle(const T& input) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + return ::__builtin_amdgcn_ds_swizzle(v, mask); + } + ); +} + +} // end namespace detail + +/// \brief Shuffle for any data type. +/// +/// Each thread in warp obtains \p input from src_lane-th thread +/// in warp. If \p width is less than device_warp_size() then each subsection of the +/// warp behaves as a separate entity with a starting logical lane id of 0. +/// If \p src_lane is not in [0; \p width) range, the returned value is +/// equal to \p input passed by the src_lane modulo width thread. +/// +/// Note: The optional \p width parameter must be a power of 2; results are +/// undefined if it is not a power of 2, or it is greater than device_warp_size(). +/// +/// \param input - input to pass to other threads +/// \param src_lane - warp if of a thread whose \p input should be returned +/// \param width - logical warp width +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_shuffle(const T& input, const int src_lane, const int width = device_warp_size()) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + return __shfl(v, src_lane, width); + } + ); +} + +/// \brief Shuffle up for any data type. +/// +/// i-th thread in warp obtains \p input from i-delta-th +/// thread in warp. If \p i-delta is not in [0; \p width) range, +/// thread's own \p input is returned. +/// +/// Note: The optional \p width parameter must be a power of 2; results are +/// undefined if it is not a power of 2, or it is greater than device_warp_size(). +/// +/// \param input - input to pass to other threads +/// \param delta - offset for calculating source lane id +/// \param width - logical warp width +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_shuffle_up(const T& input, const unsigned int delta, const int width = device_warp_size()) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + return __shfl_up(v, delta, width); + } + ); +} + +/// \brief Shuffle down for any data type. +/// +/// i-th thread in warp obtains \p input from i+delta-th +/// thread in warp. If \p i+delta is not in [0; \p width) range, +/// thread's own \p input is returned. +/// +/// Note: The optional \p width parameter must be a power of 2; results are +/// undefined if it is not a power of 2, or it is greater than device_warp_size(). +/// +/// \param input - input to pass to other threads +/// \param delta - offset for calculating source lane id +/// \param width - logical warp width +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_shuffle_down(const T& input, const unsigned int delta, const int width = device_warp_size()) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + return __shfl_down(v, delta, width); + } + ); +} + +/// \brief Shuffle XOR for any data type. +/// +/// i-th thread in warp obtains \p input from i^lane_mask-th +/// thread in warp. +/// +/// Note: The optional \p width parameter must be a power of 2; results are +/// undefined if it is not a power of 2, or it is greater than device_warp_size(). +/// +/// \param input - input to pass to other threads +/// \param lane_mask - mask used for calculating source lane id +/// \param width - logical warp width +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T warp_shuffle_xor(const T& input, const int lane_mask, const int width = device_warp_size()) +{ + return detail::warp_shuffle_op( + input, + [=](int v) -> int + { + return __shfl_xor(v, lane_mask, width); + } + ); +} + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ + +/// @} +// end of group warpmodule diff --git a/3rdparty/cub/rocprim/iterator.hpp b/3rdparty/cub/rocprim/iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65d54582817fa6e071d474bd806c3f0661a6f7f8 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator.hpp @@ -0,0 +1,37 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_HPP_ + +// Meta configuration for rocPRIM +#include "config.hpp" + +#include "iterator/arg_index_iterator.hpp" +#include "iterator/constant_iterator.hpp" +#include "iterator/counting_iterator.hpp" +#include "iterator/discard_iterator.hpp" +#ifndef __HIP_CPU_RT__ +#include "iterator/texture_cache_iterator.hpp" +#endif +#include "iterator/transform_iterator.hpp" +#include "iterator/zip_iterator.hpp" + +#endif // ROCPRIM_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/arg_index_iterator.hpp b/3rdparty/cub/rocprim/iterator/arg_index_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b08f532ed726fa5e81ffade1fcd099415c7ae49 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/arg_index_iterator.hpp @@ -0,0 +1,266 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_ARG_INDEX_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_ARG_INDEX_ITERATOR_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../types/key_value_pair.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class arg_index_iterator +/// \brief A random-access input (read-only) iterator adaptor for pairing dereferenced values +/// with their indices. +/// +/// \par Overview +/// * Dereferencing arg_index_iterator return a value of \p key_value_pair +/// type, which includes value from the underlying range and its index in that range. +/// * \p std::iterator_traits::value_type should be convertible to \p InputValueType. +/// +/// \tparam InputIterator - type of the underlying random-access input iterator. Must be +/// a random-access iterator. +/// \tparam Difference - type used for identify distance between iterators and as the index type +/// in the output pair type (see \p value_type). +/// \tparam InputValueType - value type used in the output pair type (see \p value_type). +template< + class InputIterator, + class Difference = std::ptrdiff_t, + class InputValueType = typename std::iterator_traits::value_type +> +class arg_index_iterator +{ +private: + using input_category = typename std::iterator_traits::iterator_category; + +public: + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = ::rocprim::key_value_pair; + /// \brief A reference type of the type iterated over (\p value_type). + /// It's `const` since arg_index_iterator is a read-only iterator. + using reference = const value_type&; + /// \brief A pointer type of the type iterated over (\p value_type). + /// It's `const` since arg_index_iterator is a read-only iterator. + using pointer = const value_type*; + /// A type used for identify distance between iterators. + using difference_type = Difference; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using self_type = arg_index_iterator; +#endif + + static_assert( + std::is_same::value, + "InputIterator must be a random-access iterator" + ); + + ROCPRIM_HOST_DEVICE inline + ~arg_index_iterator() = default; + + /// \brief Creates a new arg_index_iterator. + /// + /// \param iterator input iterator pointing to the input range. + /// \param offset index of the \p iterator in the input range. + ROCPRIM_HOST_DEVICE inline + arg_index_iterator(InputIterator iterator, difference_type offset = 0) + : iterator_(iterator), offset_(offset) + { + } + + ROCPRIM_HOST_DEVICE inline + arg_index_iterator& operator++() + { + iterator_++; + offset_++; + return *this; + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + arg_index_iterator operator++(int) + { + arg_index_iterator old_ai = *this; + iterator_++; + offset_++; + return old_ai; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + value_type ret(offset_, *iterator_); + return ret; + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &(*(*this)); + } + + ROCPRIM_HOST_DEVICE inline + arg_index_iterator operator+(difference_type distance) const + { + return arg_index_iterator(iterator_ + distance, offset_ + distance); + } + + ROCPRIM_HOST_DEVICE inline + arg_index_iterator& operator+=(difference_type distance) + { + iterator_ += distance; + offset_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + arg_index_iterator operator-(difference_type distance) const + { + return arg_index_iterator(iterator_ - distance, offset_ - distance); + } + + ROCPRIM_HOST_DEVICE inline + arg_index_iterator& operator-=(difference_type distance) + { + iterator_ -= distance; + offset_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(arg_index_iterator other) const + { + return iterator_ - other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type distance) const + { + arg_index_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(arg_index_iterator other) const + { + return (iterator_ == other.iterator_) && (offset_ == other.offset_); + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(arg_index_iterator other) const + { + return (iterator_ != other.iterator_) || (offset_ != other.offset_); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(arg_index_iterator other) const + { + return (iterator_ - other.iterator_) > 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(arg_index_iterator other) const + { + return (iterator_ - other.iterator_) >= 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(arg_index_iterator other) const + { + return (iterator_ - other.iterator_) < 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(arg_index_iterator other) const + { + return (iterator_ - other.iterator_) <= 0; + } + + ROCPRIM_HOST_DEVICE inline + void normalize() + { + offset_ = 0; + } + + friend std::ostream& operator<<(std::ostream& os, const arg_index_iterator& /* iter */) + { + return os; + } + //! \skip_doxy_end + +private: + InputIterator iterator_; + difference_type offset_; +}; + +template< + class InputIterator, + class Difference, + class InputValueType +> +ROCPRIM_HOST_DEVICE inline +arg_index_iterator +operator+(typename arg_index_iterator::difference_type distance, + const arg_index_iterator& iterator) +{ + return iterator + distance; +} + + +/// make_arg_index_iterator creates a arg_index_iterator using \p iterator as +/// the underlying iterator and \p offset as the position (index) of \p iterator +/// in the input range. +/// +/// \tparam InputIterator - type of the underlying random-access input iterator. Must be +/// a random-access iterator. +/// \tparam Difference - type used for identify distance between iterators and as the index type +/// in the output pair type (see \p value_type in arg_index_iterator). +/// \tparam InputValueType - value type used in the output pair type (see \p value_type +/// in arg_index_iterator). +/// +/// \param iterator input iterator pointing to the input range. +/// \param offset index of the \p iterator in the input range. +template< + class InputIterator, + class Difference = std::ptrdiff_t, + class InputValueType = typename std::iterator_traits::value_type +> +ROCPRIM_HOST_DEVICE inline +arg_index_iterator +make_arg_index_iterator(InputIterator iterator, Difference offset = 0) +{ + return arg_index_iterator(iterator, offset); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group iteratormodule + +#endif // ROCPRIM_ITERATOR_ARG_INDEX_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/constant_iterator.hpp b/3rdparty/cub/rocprim/iterator/constant_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00ae82ed2b6c64402b2c57169c7e7606cec121ee --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/constant_iterator.hpp @@ -0,0 +1,261 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_CONSTANT_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_CONSTANT_ITERATOR_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class constant_iterator +/// \brief A random-access input (read-only) iterator which generates a sequence +/// of homogeneous values. +/// +/// \par Overview +/// * A constant_iterator represents a pointer into a range of same values. +/// * Using it for simulating a range filled with a sequence of same values saves +/// memory capacity and bandwidth. +/// +/// \tparam ValueType - type of value that can be obtained by dereferencing the iterator. +/// \tparam Difference - a type used for identify distance between iterators +template< + class ValueType, + class Difference = std::ptrdiff_t +> +class constant_iterator +{ +public: + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename std::remove_const::type; + /// \brief A reference type of the type iterated over (\p value_type). + /// It's same as `value_type` since constant_iterator is a read-only + /// iterator and does not have underlying buffer. + using reference = value_type; // constant_iterator is not writable + /// \brief A pointer type of the type iterated over (\p value_type). + /// It's `const` since constant_iterator is a read-only iterator. + using pointer = const value_type*; // constant_iterator is not writable + /// A type used for identify distance between iterators. + using difference_type = Difference; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using self_type = constant_iterator; +#endif + + /// \brief Creates constant_iterator and sets its initial value to \p value. + /// + /// \param value initial value + /// \param index optional index for constant_iterator + ROCPRIM_HOST_DEVICE inline + explicit constant_iterator(const value_type value, const size_t index = 0) + : value_(value), index_(index) + { + } + + ROCPRIM_HOST_DEVICE inline + ~constant_iterator() = default; + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + return value_; + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &value_; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator& operator++() + { + index_++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator operator++(int) + { + constant_iterator old_ci = *this; + index_++; + return old_ci; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator& operator--() + { + index_--; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator operator--(int) + { + constant_iterator old_ci = *this; + index_--; + return old_ci; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator operator+(difference_type distance) const + { + return constant_iterator(value_, index_ + distance); + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator& operator+=(difference_type distance) + { + index_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator operator-(difference_type distance) const + { + return constant_iterator(value_, index_ - distance); + } + + ROCPRIM_HOST_DEVICE inline + constant_iterator& operator-=(difference_type distance) + { + index_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(constant_iterator other) const + { + return static_cast(index_ - other.index_); + } + //! \skip_doxy_end + + /// Constant_iterator is not writable, so we don't return reference, + /// just something convertible to reference. That matches requirement + /// of RandomAccessIterator concept + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type) const + { + return value_; + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + bool operator==(constant_iterator other) const + { + return value_ == other.value_ && index_ == other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(constant_iterator other) const + { + return !(*this == other); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(constant_iterator other) const + { + return distance_to(other) > 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(constant_iterator other) const + { + return distance_to(other) >= 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(constant_iterator other) const + { + return distance_to(other) < 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(constant_iterator other) const + { + return distance_to(other) <= 0; + } + + friend std::ostream& operator<<(std::ostream& os, const constant_iterator& iter) + { + os << "[" << iter.value_ << "]"; + return os; + } + //! \skip_doxy_end + +private: + inline + difference_type distance_to(const constant_iterator& other) const + { + return difference_type(other.index_) - difference_type(index_); + } + + value_type value_; + size_t index_; +}; + +template< + class ValueType, + class Difference +> +ROCPRIM_HOST_DEVICE inline +constant_iterator +operator+(typename constant_iterator::difference_type distance, + const constant_iterator& iter) +{ + return iter + distance; +} + +/// make_constant_iterator creates a constant_iterator with its initial value +/// set to \p value. +/// +/// \tparam ValueType - type of value that can be obtained by dereferencing created iterator. +/// \tparam Difference - a type used for identify distance between constant_iterator iterators. +/// +/// \param value - initial value for constant_iterator. +/// \param index - optional index for constant_iterator. +template< + class ValueType, + class Difference = std::ptrdiff_t +> +ROCPRIM_HOST_DEVICE inline +constant_iterator +make_constant_iterator(ValueType value, size_t index = 0) +{ + return constant_iterator(value, index); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group iteratormodule + +#endif // ROCPRIM_ITERATOR_CONSTANT_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/counting_iterator.hpp b/3rdparty/cub/rocprim/iterator/counting_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..209fa129c5f7157dfb4bf36e8661dc5ee2549667 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/counting_iterator.hpp @@ -0,0 +1,269 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_COUNTING_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_COUNTING_ITERATOR_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../type_traits.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class counting_iterator +/// \brief A random-access input (read-only) iterator over a sequence of consecutive integer values. +/// +/// \par Overview +/// * A counting_iterator represents a pointer into a range of sequentially increasing values. +/// * Using it for simulating a range filled with a sequence of consecutive values saves +/// memory capacity and bandwidth. +/// +/// \tparam Incrementable - type of value that can be obtained by dereferencing the iterator. +/// \tparam Difference - a type used for identify distance between iterators +template< + class Incrementable, + class Difference = std::ptrdiff_t +> +class counting_iterator +{ +public: + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename std::remove_const::type; + /// \brief A reference type of the type iterated over (\p value_type). + /// It's same as `value_type` since constant_iterator is a read-only + /// iterator and does not have underlying buffer. + using reference = value_type; // counting_iterator is not writable + /// \brief A pointer type of the type iterated over (\p value_type). + /// It's `const` since counting_iterator is a read-only iterator. + using pointer = const value_type*; // counting_iterator is not writable + /// A type used for identify distance between iterators. + using difference_type = Difference; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + + static_assert(std::is_integral::value, "Incrementable must be integral type"); + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using self_type = counting_iterator; +#endif + + ROCPRIM_HOST_DEVICE inline + counting_iterator() = default; + + /// \brief Creates counting_iterator with its initial value initialized + /// to its default value (usually 0). + ROCPRIM_HOST_DEVICE inline + ~counting_iterator() = default; + + /// \brief Creates counting_iterator and sets its initial value to \p value_. + /// + /// \param value initial value + ROCPRIM_HOST_DEVICE inline + explicit counting_iterator(const value_type value) : value_(value) + { + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + counting_iterator& operator++() + { + value_++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator operator++(int) + { + counting_iterator old_ci = *this; + value_++; + return old_ci; + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator& operator--() + { + value_--; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator operator--(int) + { + counting_iterator old_ci = *this; + value_--; + return old_ci; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + return value_; + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &value_; + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator operator+(difference_type distance) const + { + return counting_iterator(value_ + static_cast(distance)); + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator& operator+=(difference_type distance) + { + value_ += static_cast(distance); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator operator-(difference_type distance) const + { + return counting_iterator(value_ - static_cast(distance)); + } + + ROCPRIM_HOST_DEVICE inline + counting_iterator& operator-=(difference_type distance) + { + value_ -= static_cast(distance); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(counting_iterator other) const + { + return static_cast(value_ - other.value_); + } + + // counting_iterator is not writable, so we don't return reference, + // just something convertible to reference. That matches requirement + // of RandomAccessIterator concept + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type distance) const + { + return value_ + static_cast(distance); + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(counting_iterator other) const + { + return this->equal_value(value_, other.value_); + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(counting_iterator other) const + { + return !(*this == other); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(counting_iterator other) const + { + return distance_to(other) > 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(counting_iterator other) const + { + return distance_to(other) >= 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(counting_iterator other) const + { + return distance_to(other) < 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(counting_iterator other) const + { + return distance_to(other) <= 0; + } + + friend std::ostream& operator<<(std::ostream& os, const counting_iterator& iter) + { + os << "[" << iter.value_ << "]"; + return os; + } + //! \skip_doxy_end + +private: + template + inline + bool equal_value(const T& x, const T& y) const + { + return (x == y); + } + + inline + difference_type distance_to(const counting_iterator& other) const + { + return difference_type(other.value_) - difference_type(value_); + } + + value_type value_; +}; + +template< + class Incrementable, + class Difference +> +ROCPRIM_HOST_DEVICE inline +counting_iterator +operator+(typename counting_iterator::difference_type distance, + const counting_iterator& iter) +{ + return iter + distance; +} + +/// make_counting_iterator creates a counting_iterator with its initial value +/// set to \p value. +/// +/// \tparam Incrementable - type of value that can be obtained by dereferencing created iterator. +/// \tparam Difference - a type used for identify distance between counting_iterator iterators. +/// +/// \param value - initial value for counting_iterator. +template< + class Incrementable, + class Difference = std::ptrdiff_t +> +ROCPRIM_HOST_DEVICE inline +counting_iterator +make_counting_iterator(Incrementable value) +{ + return counting_iterator(value); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group iteratormodule + +#endif // ROCPRIM_ITERATOR_COUNTING_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/detail/replace_first_iterator.hpp b/3rdparty/cub/rocprim/iterator/detail/replace_first_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..49edfa8ad60a35a6c77e4c5210e876c30ee6acda --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/detail/replace_first_iterator.hpp @@ -0,0 +1,133 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_REPLACE_FIRST_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_REPLACE_FIRST_ITERATOR_HPP_ + +#include +#include +#include + +#include "../../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Replaces first value of given range with given value. Used in exclusive scan-by-key +// and exclusive segmented scan to avoid allocating additional memory and/or running +// additional kernels. +// +// Important: it does not dereference the first item in given range, so it does not matter +// if it's an invalid pointer. +// +// Usage: +// * input - start of your input range +// * value - value that should be used as first element of new range. +// +// replace_first_iterator(input - 1, value); +// +// (input - 1) will never be dereferenced. +template +class replace_first_iterator +{ +private: + using input_category = typename std::iterator_traits::iterator_category; + static_assert( + std::is_same::value, + "InputIterator must be a random-access iterator" + ); + +public: + using value_type = typename std::iterator_traits::value_type; + using reference = value_type; + using pointer = const value_type*; + using difference_type = typename std::iterator_traits::difference_type; + using iterator_category = std::random_access_iterator_tag; + + ROCPRIM_HOST_DEVICE inline + ~replace_first_iterator() = default; + + ROCPRIM_HOST_DEVICE inline + replace_first_iterator(InputIterator iterator, value_type value, size_t index = 0) + : iterator_(iterator), value_(value), index_(index) + { + } + + ROCPRIM_HOST_DEVICE inline + replace_first_iterator& operator++() + { + iterator_++; + index_++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + replace_first_iterator operator++(int) + { + replace_first_iterator old = *this; + iterator_++; + index_++; + return old; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + if(index_ == 0) + { + return value_; + } + return *iterator_; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type distance) const + { + replace_first_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + replace_first_iterator operator+(difference_type distance) const + { + return replace_first_iterator(iterator_ + distance, value_, index_ + distance); + } + + ROCPRIM_HOST_DEVICE inline + replace_first_iterator& operator+=(difference_type distance) + { + iterator_ += distance; + index_ += distance; + return *this; + } + +private: + InputIterator iterator_; + value_type value_; + size_t index_; +}; + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_REPLACE_FIRST_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/discard_iterator.hpp b/3rdparty/cub/rocprim/iterator/discard_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb2236b5b34cba4bf68ec0a1dbc45cb8b13c2d57 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/discard_iterator.hpp @@ -0,0 +1,238 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_DISCARD_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_DISCARD_ITERATOR_HPP_ + +#include +#include +#include + +#include "../config.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class discard_iterator +/// \brief A random-access iterator which discards values assigned to it upon dereference. +/// +/// \par Overview +/// * discard_iterator does not have any underlying array (memory) and does not save values +/// written to it upon dereference. +/// * discard_iterator can be used to safely ignore certain output of algorithms, which +/// saves memory capacity and/or bandwidth. +class discard_iterator +{ +public: + struct discard_value + { + ROCPRIM_HOST_DEVICE inline + discard_value() = default; + + template + ROCPRIM_HOST_DEVICE inline + discard_value(T) {}; + + ROCPRIM_HOST_DEVICE inline + ~discard_value() = default; + + template + ROCPRIM_HOST_DEVICE inline + discard_value& operator=(const T&) + { + return *this; + } + }; + + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = discard_value; + /// \brief A reference type of the type iterated over (\p value_type). + using reference = discard_value; + /// \brief A pointer type of the type iterated over (\p value_type). + using pointer = discard_value*; + /// A type used for identify distance between iterators. + using difference_type = std::ptrdiff_t; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + + /// \brief Creates a new discard_iterator. + /// + /// \param index - optional index of discard iterator (default = 0). + ROCPRIM_HOST_DEVICE inline + discard_iterator(size_t index = 0) + : index_(index) + { + } + + ROCPRIM_HOST_DEVICE inline + ~discard_iterator() = default; + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + discard_iterator& operator++() + { + index_++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator operator++(int) + { + discard_iterator old = *this; + index_++; + return old; + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator& operator--() + { + index_--; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator operator--(int) + { + discard_iterator old = *this; + index_--; + return old; + } + + ROCPRIM_HOST_DEVICE inline + discard_value operator*() const + { + return discard_value(); + } + + ROCPRIM_HOST_DEVICE inline + discard_value operator[](difference_type distance) const + { + discard_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator operator+(difference_type distance) const + { + auto i = static_cast(static_cast(index_) + distance); + return discard_iterator(i); + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator& operator+=(difference_type distance) + { + index_ = static_cast(static_cast(index_) + distance); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator operator-(difference_type distance) const + { + auto i = static_cast(static_cast(index_) - distance); + return discard_iterator(i); + } + + ROCPRIM_HOST_DEVICE inline + discard_iterator& operator-=(difference_type distance) + { + index_ = static_cast(static_cast(index_) - distance); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(discard_iterator other) const + { + return index_ - other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(discard_iterator other) const + { + return index_ == other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(discard_iterator other) const + { + return index_ != other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(discard_iterator other) const + { + return index_ < other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(discard_iterator other) const + { + return index_ <= other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(discard_iterator other) const + { + return index_ > other.index_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(discard_iterator other) const + { + return index_ >= other.index_; + } + + friend std::ostream& operator<<(std::ostream& os, const discard_iterator& /* iter */) + { + return os; + } + //! \skip_doxy_end + +private: + mutable size_t index_; +}; + +ROCPRIM_HOST_DEVICE inline +discard_iterator +operator+(typename discard_iterator::difference_type distance, + const discard_iterator& iterator) +{ + return iterator + distance; +} + +/// make_discard_iterator creates a discard_iterator using optional +/// index parameter \p index. +/// +/// \param index - optional index of discard iterator (default = 0). +/// \return A new discard_iterator object. +ROCPRIM_HOST_DEVICE inline +discard_iterator +make_discard_iterator(size_t index = 0) +{ + return discard_iterator(index); +} + +/// @} +// end of group iteratormodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_DISCARD_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/reverse_iterator.hpp b/3rdparty/cub/rocprim/iterator/reverse_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5b787401862c40d91f8be437e747076c610a5fc6 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/reverse_iterator.hpp @@ -0,0 +1,211 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_REVERSE_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_REVERSE_ITERATOR_HPP_ + +#include +#include +#include + +#include "../config.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class reverse_iterator +/// \brief A reverse iterator is an iterator adaptor that reverses the direction of a wrapped iterator. +/// +/// \par Overview +/// * reverse_iterator can be used with random access iterators to reverse the direction of the iteration. +/// * The increment operators on the reverse iterator are mapped to decrements on the wrapped iterator, +/// * And the decrement operators on the reverse iterators are mapped to increments on the wrapped iterator. +/// * Use it to iterate over the elements of a container in reverse. +/// +/// \tparam SourceIterator - type of the wrapped iterator. +template +class reverse_iterator +{ +public: + static_assert( + std::is_base_of::iterator_category>::value, + "SourceIterator must be a random access iterator"); + + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename std::iterator_traits::value_type; + /// \brief A reference type of the type iterated over (\p value_type). + using reference = typename std::iterator_traits::reference; + /// \brief A pointer type of the type iterated over (\p value_type). + using pointer = typename std::iterator_traits::pointer; + /// A type used for identify distance between iterators. + using difference_type = typename std::iterator_traits::difference_type; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + + ROCPRIM_HOST_DEVICE + reverse_iterator(SourceIterator source_iterator) : source_iterator_(source_iterator) {} + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE + reverse_iterator& operator++() + { + --source_iterator_; + return *this; + } + + ROCPRIM_HOST_DEVICE + reverse_iterator operator++(int) + { + reverse_iterator old = *this; + --source_iterator_; + return old; + } + + ROCPRIM_HOST_DEVICE + reverse_iterator& operator--() + { + ++source_iterator_; + return *this; + } + + ROCPRIM_HOST_DEVICE + reverse_iterator operator--(int) + { + reverse_iterator old = *this; + ++source_iterator_; + return old; + } + + ROCPRIM_HOST_DEVICE + reference operator*() + { + return *(source_iterator_ - static_cast(1)); + } + + ROCPRIM_HOST_DEVICE + reference operator[](difference_type distance) + { + reverse_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE + reverse_iterator operator+(difference_type distance) const + { + return reverse_iterator(source_iterator_ - distance); + } + + ROCPRIM_HOST_DEVICE + reverse_iterator& operator+=(difference_type distance) + { + source_iterator_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE + reverse_iterator operator-(difference_type distance) const + { + return reverse_iterator(source_iterator_ + distance); + } + + ROCPRIM_HOST_DEVICE + reverse_iterator& operator-=(difference_type distance) + { + source_iterator_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE + difference_type operator-(reverse_iterator other) const + { + return other.source_iterator_ - source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator==(reverse_iterator other) const + { + return source_iterator_ == other.source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator!=(reverse_iterator other) const + { + return source_iterator_ != other.source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator<(reverse_iterator other) const + { + return other.source_iterator_ < source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator<=(reverse_iterator other) const + { + return other.source_iterator_ <= source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator>(reverse_iterator other) const + { + return other.source_iterator_ > source_iterator_; + } + + ROCPRIM_HOST_DEVICE + bool operator>=(reverse_iterator other) const + { + return other.source_iterator_ >= source_iterator_; + } + //! \skip_doxy_end + +private: + SourceIterator source_iterator_; +}; + +template +ROCPRIM_HOST_DEVICE reverse_iterator + operator+(typename reverse_iterator::difference_type distance, + const reverse_iterator& iterator) +{ + return iterator + distance; +} + +/// make_reverse_iterator creates a \p reverse_iterator wrapping \p source_iterator. +/// +/// \tparam SourceIterator - type of \p source_iterator. +/// +/// \param source_iterator - the iterator to wrap in the created \p reverse_iterator. +/// \return A \p reverse_iterator that wraps \p source_iterator. +template +ROCPRIM_HOST_DEVICE reverse_iterator + make_reverse_iterator(SourceIterator source_iterator) +{ + return reverse_iterator(source_iterator); +} + +/// @} +// end of group iteratormodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_REVERSE_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/texture_cache_iterator.hpp b/3rdparty/cub/rocprim/iterator/texture_cache_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c849a204e90f1e03f2b60f186330bfa2cb0cc19f --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/texture_cache_iterator.hpp @@ -0,0 +1,349 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_TEXTURE_CACHE_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_TEXTURE_CACHE_ITERATOR_HPP_ + +#include +#include +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ +// Takes a scalar type T and matches to a texture type based on NumElements. +template +struct make_texture_type +{ + using type = void; +}; + +template <> +struct make_texture_type +{ + using type = char; +}; + +template <> +struct make_texture_type +{ + using type = int; +}; + +template <> +struct make_texture_type +{ + using type = short; +}; + +#define DEFINE_MAKE_TEXTURE_TYPE(base, suffix) \ +\ +template<> \ +struct make_texture_type \ +{ \ + using type = ::base##suffix; \ +}; + +DEFINE_MAKE_TEXTURE_TYPE(char, 2); +DEFINE_MAKE_TEXTURE_TYPE(char, 4); +DEFINE_MAKE_TEXTURE_TYPE(int, 2); +DEFINE_MAKE_TEXTURE_TYPE(int, 4); +DEFINE_MAKE_TEXTURE_TYPE(short, 2); +DEFINE_MAKE_TEXTURE_TYPE(short, 4); + +// Selects an appropriate vector_type based on the input T and size N. +// The byte size is calculated and used to select an appropriate vector_type. +template +struct match_texture_type +{ + static constexpr unsigned int size = sizeof(T); + using texture_base_type = + typename std::conditional< + sizeof(T) >= 4, + int, + typename std::conditional< + sizeof(T) >= 2, + short, + char + >::type + >::type; + + using texture_4 = typename make_texture_type::type; + using texture_2 = typename make_texture_type::type; + using texture_1 = typename make_texture_type::type; + + using type = + typename std::conditional< + size % sizeof(texture_4) == 0, + texture_4, + typename std::conditional< + size % sizeof(texture_2) == 0, + texture_2, + texture_1 + >::type + >::type; +}; +} + +/// \class texture_cache_iterator +/// \brief A random-access input (read-only) iterator adaptor for dereferencing array values +/// through texture cache. +/// +/// \par Overview +/// * A texture_cache_iterator wraps a device pointer of type T, where values are obtained +/// by dereferencing through texture cache. +/// * Can be exchanged and manipulated within and between host and device functions. +/// * Can only be constructed within host functions, and can only be dereferenced within +/// device functions. +/// * Accepts any data type from memory, and loads through texture cache. +/// +/// \tparam T - type of value that can be obtained by dereferencing the iterator. +/// \tparam Difference - a type used for identify distance between iterators. +template< + class T, + class Difference = std::ptrdiff_t +> +class texture_cache_iterator +{ +public: + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename std::remove_const::type; + /// \brief A reference type of the type iterated over (\p value_type). + using reference = const value_type&; + /// \brief A pointer type of the type iterated over (\p value_type). + using pointer = const value_type*; + /// A type used for identify distance between iterators. + using difference_type = Difference; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using self_type = texture_cache_iterator; +#endif + + ROCPRIM_HOST_DEVICE inline + ~texture_cache_iterator() = default; + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator() + : ptr(NULL), texture_offset(0), texture_object(0) + { + } + + template + inline + cudaError_t bind_texture(Qualified* ptr, + size_t bytes = size_t(-1), + size_t texture_offset = 0) + { + this->ptr = const_cast::type*>(ptr); + this->texture_offset = texture_offset; + + cudaChannelFormatDesc channel_desc = cudaCreateChannelDesc(); + cudaResourceDesc resourse_desc; + cudaTextureDesc texture_desc; + memset(&resourse_desc, 0, sizeof(cudaResourceDesc)); + memset(&texture_desc, 0, sizeof(cudaTextureDesc)); + resourse_desc.resType = cudaResourceTypeLinear; + resourse_desc.res.linear.devPtr = this->ptr; + resourse_desc.res.linear.desc = channel_desc; + resourse_desc.res.linear.sizeInBytes = bytes; + texture_desc.readMode = cudaReadModeElementType; + + return cudaCreateTextureObject(&texture_object, &resourse_desc, &texture_desc, NULL); + } + + inline + cudaError_t unbind_texture() + { + return cudaDestroyTextureObject(texture_object); + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator& operator++() + { + ptr++; + texture_offset++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator operator++(int) + { + texture_cache_iterator old_tc = *this; + ptr++; + texture_offset++; + return old_tc; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + #ifndef __CUDA_ARCH__ + return ptr[texture_offset]; + #else + texture_type words[multiple]; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < multiple; i++) + { + tex1Dfetch( + &words[i], + texture_object, + (texture_offset * multiple) + i + ); + } + + return *reinterpret_cast(words); + #endif + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &(*(*this)); + } + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator operator+(difference_type distance) const + { + self_type retval; + retval.ptr = ptr + distance; + retval.texture_object = texture_object; + retval.texture_offset = texture_offset + distance; + return retval; + } + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator& operator+=(difference_type distance) + { + ptr += distance; + texture_offset += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator operator-(difference_type distance) const + { + self_type retval; + retval.ptr = ptr - distance; + retval.texture_object = texture_object; + retval.texture_offset = texture_offset - distance; + return retval; + } + + ROCPRIM_HOST_DEVICE inline + texture_cache_iterator& operator-=(difference_type distance) + { + ptr -= distance; + texture_offset -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(texture_cache_iterator other) const + { + return ptr - other.ptr; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type distance) const + { + texture_cache_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(texture_cache_iterator other) const + { + return (ptr == other.ptr) && (texture_offset == other.texture_offset); + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(texture_cache_iterator other) const + { + return (ptr != other.ptr) || (texture_offset != other.texture_offset); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(texture_cache_iterator other) const + { + return (ptr - other.ptr) > 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(texture_cache_iterator other) const + { + return (ptr - other.ptr) >= 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(texture_cache_iterator other) const + { + return (ptr - other.ptr) < 0; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(texture_cache_iterator other) const + { + return (ptr - other.ptr) <= 0; + } + + friend std::ostream& operator<<(std::ostream& os, const texture_cache_iterator& /* iter */) + { + return os; + } + //! \skip_doxy_end + +private: + using texture_type = typename ::rocprim::detail::match_texture_type::type; + static constexpr unsigned int multiple = sizeof(T) / sizeof(texture_type); + value_type* ptr; + difference_type texture_offset; + cudaTextureObject_t texture_object; +}; + +template< + class T, + class Difference +> +ROCPRIM_HOST_DEVICE inline +texture_cache_iterator +operator+(typename texture_cache_iterator::difference_type distance, + const texture_cache_iterator& iterator) +{ + return iterator + distance; +} + +/// @} +// end of group iteratormodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_TEXTURE_CACHE_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/transform_iterator.hpp b/3rdparty/cub/rocprim/iterator/transform_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5b632ef588e60e4323d2631a517d6bbea2e879c --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/transform_iterator.hpp @@ -0,0 +1,262 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_TRANSFORM_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_TRANSFORM_ITERATOR_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../detail/match_result_type.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \class transform_iterator +/// \brief A random-access input (read-only) iterator adaptor for transforming dereferenced values. +/// +/// \par Overview +/// * A transform_iterator uses functor of type UnaryFunction to transform value obtained +/// by dereferencing underlying iterator. +/// * Using it for simulating a range filled with results of applying functor of type +/// \p UnaryFunction to another range saves memory capacity and/or bandwidth. +/// +/// \tparam InputIterator - type of the underlying random-access input iterator. Must be +/// a random-access iterator. +/// \tparam UnaryFunction - type of the transform functor. +/// \tparam ValueType - type of value that can be obtained by dereferencing the iterator. +/// By default it is the return type of \p UnaryFunction. +template< + class InputIterator, + class UnaryFunction, + class ValueType = + typename ::rocprim::detail::invoke_result< + UnaryFunction, typename std::iterator_traits::value_type + >::type +> +class transform_iterator +{ +public: + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = ValueType; + /// \brief A reference type of the type iterated over (\p value_type). + /// It's `const` since transform_iterator is a read-only iterator. + using reference = const value_type&; + /// \brief A pointer type of the type iterated over (\p value_type). + /// It's `const` since transform_iterator is a read-only iterator. + using pointer = const value_type*; + /// A type used for identify distance between iterators. + using difference_type = typename std::iterator_traits::difference_type; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + /// The type of unary function used to transform input range. + using unary_function = UnaryFunction; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + using self_type = transform_iterator; +#endif + + ROCPRIM_HOST_DEVICE inline + ~transform_iterator() = default; + + /// \brief Creates a new transform_iterator. + /// + /// \param iterator input iterator to iterate over and transform. + /// \param transform unary function used to transform values obtained + /// from range pointed by \p iterator. + ROCPRIM_HOST_DEVICE inline + transform_iterator(InputIterator iterator, UnaryFunction transform) + : iterator_(iterator), transform_(transform) + { + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + transform_iterator& operator++() + { + iterator_++; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator operator++(int) + { + transform_iterator old = *this; + iterator_++; + return old; + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator& operator--() + { + iterator_--; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator operator--(int) + { + transform_iterator old = *this; + iterator_--; + return old; + } + + ROCPRIM_HOST_DEVICE inline + value_type operator*() const + { + return transform_(*iterator_); + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &(*(*this)); + } + + ROCPRIM_HOST_DEVICE inline + value_type operator[](difference_type distance) const + { + transform_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator operator+(difference_type distance) const + { + return transform_iterator(iterator_ + distance, transform_); + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator& operator+=(difference_type distance) + { + iterator_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator operator-(difference_type distance) const + { + return transform_iterator(iterator_ - distance, transform_); + } + + ROCPRIM_HOST_DEVICE inline + transform_iterator& operator-=(difference_type distance) + { + iterator_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(transform_iterator other) const + { + return iterator_ - other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(transform_iterator other) const + { + return iterator_ == other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(transform_iterator other) const + { + return iterator_ != other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(transform_iterator other) const + { + return iterator_ < other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(transform_iterator other) const + { + return iterator_ <= other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(transform_iterator other) const + { + return iterator_ > other.iterator_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(transform_iterator other) const + { + return iterator_ >= other.iterator_; + } + + friend std::ostream& operator<<(std::ostream& os, const transform_iterator& /* iter */) + { + return os; + } + //! \skip_doxy_end + +private: + InputIterator iterator_; + UnaryFunction transform_; +}; + +template< + class InputIterator, + class UnaryFunction, + class ValueType +> +ROCPRIM_HOST_DEVICE inline +transform_iterator +operator+(typename transform_iterator::difference_type distance, + const transform_iterator& iterator) +{ + return iterator + distance; +} + +/// make_transform_iterator creates a transform_iterator using \p iterator as +/// the underlying iterator and \p transform as the unary function. +/// +/// \tparam InputIterator - type of the underlying random-access input iterator. +/// \tparam UnaryFunction - type of the transform functor. +/// +/// \param iterator - input iterator. +/// \param transform - transform functor to use in created transform_iterator. +/// \return A new transform_iterator object which transforms the range pointed +/// by \p iterator using \p transform functor. +template< + class InputIterator, + class UnaryFunction +> +ROCPRIM_HOST_DEVICE inline +transform_iterator +make_transform_iterator(InputIterator iterator, UnaryFunction transform) +{ + return transform_iterator(iterator, transform); +} + +/// @} +// end of group iteratormodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_TRANSFORM_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/iterator/zip_iterator.hpp b/3rdparty/cub/rocprim/iterator/zip_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7eb60dbb0d584950b16c0c0ce7ff24a860d8fb88 --- /dev/null +++ b/3rdparty/cub/rocprim/iterator/zip_iterator.hpp @@ -0,0 +1,341 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_ITERATOR_ZIP_ITERATOR_HPP_ +#define ROCPRIM_ITERATOR_ZIP_ITERATOR_HPP_ + +#include +#include +#include + +#include "../config.hpp" +#include "../types/tuple.hpp" + +/// \addtogroup iteratormodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct tuple_of_references; + +template +struct tuple_of_references<::rocprim::tuple> +{ + using type = ::rocprim::tuple::reference...>; +}; + +template +struct tuple_of_values; + +template +struct tuple_of_values<::rocprim::tuple> +{ + using type = ::rocprim::tuple::value_type...>; +}; + +template +ROCPRIM_HOST_DEVICE inline +void for_each_in_tuple_impl(::rocprim::tuple& t, + Function f, + ::rocprim::index_sequence) +{ + auto swallow = { (f(::rocprim::get(t)), 0)... }; + (void) swallow; +} + +template +ROCPRIM_HOST_DEVICE inline +void for_each_in_tuple(::rocprim::tuple& t, Function f) +{ + for_each_in_tuple_impl(t, f, ::rocprim::index_sequence_for()); +} + +struct increment_iterator +{ + template + ROCPRIM_HOST_DEVICE inline + void operator()(Iterator& it) + { + ++it; + } +}; + +struct decrement_iterator +{ + template + ROCPRIM_HOST_DEVICE inline + void operator()(Iterator& it) + { + --it; + } +}; + +template +struct advance_iterator +{ + ROCPRIM_HOST_DEVICE inline + advance_iterator(Difference distance) + : distance_(distance) + { + } + + template + ROCPRIM_HOST_DEVICE inline + void operator()(Iterator& it) + { + using it_distance_type = typename std::iterator_traits::difference_type; + it += static_cast(distance_); + } + +private: + Difference distance_; +}; + +template +ROCPRIM_HOST_DEVICE inline +ReferenceTuple dereference_iterator_tuple_impl(const ::rocprim::tuple& t, + ::rocprim::index_sequence) +{ + ReferenceTuple rt { *::rocprim::get(t)... }; + return rt; +} + +template +ROCPRIM_HOST_DEVICE inline +ReferenceTuple dereference_iterator_tuple(const ::rocprim::tuple& t) +{ + return dereference_iterator_tuple_impl( + t, ::rocprim::index_sequence_for() + ); +} + +} // end detail namespace + +/// \class zip_iterator +/// \brief TBD +/// +/// \par Overview +/// * TBD +/// +/// \tparam IteratorTuple - +template +class zip_iterator +{ +public: + /// \brief A reference type of the type iterated over. + /// + /// The type of the tuple made of the reference types of the iterator + /// types in the IteratorTuple argument. + using reference = typename detail::tuple_of_references::type; + /// The type of the value that can be obtained by dereferencing the iterator. + using value_type = typename detail::tuple_of_values::type; + /// \brief A pointer type of the type iterated over (\p value_type). + using pointer = value_type*; + /// \brief A type used for identify distance between iterators. + /// + /// The difference_type member of zip_iterator is the difference_type of + /// the first of the iterator types in the IteratorTuple argument. + using difference_type = typename std::iterator_traits< + typename ::rocprim::tuple_element<0, IteratorTuple>::type + >::difference_type; + /// The category of the iterator. + using iterator_category = std::random_access_iterator_tag; + + ROCPRIM_HOST_DEVICE inline + ~zip_iterator() = default; + + /// \brief Creates a new zip_iterator. + /// + /// \param iterator_tuple tuple of iterators + ROCPRIM_HOST_DEVICE inline + zip_iterator(IteratorTuple iterator_tuple) + : iterator_tuple_(iterator_tuple) + { + } + + //! \skip_doxy_start + ROCPRIM_HOST_DEVICE inline + zip_iterator& operator++() + { + detail::for_each_in_tuple(iterator_tuple_, detail::increment_iterator()); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator operator++(int) + { + zip_iterator old = *this; + ++(*this); + return old; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator& operator--() + { + detail::for_each_in_tuple(iterator_tuple_, detail::decrement_iterator()); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator operator--(int) + { + zip_iterator old = *this; + --(*this); + return old; + } + + ROCPRIM_HOST_DEVICE inline + reference operator*() const + { + return detail::dereference_iterator_tuple(iterator_tuple_); + } + + ROCPRIM_HOST_DEVICE inline + pointer operator->() const + { + return &(*(*this)); + } + + ROCPRIM_HOST_DEVICE inline + reference operator[](difference_type distance) const + { + zip_iterator i = (*this) + distance; + return *i; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator operator+(difference_type distance) const + { + zip_iterator copy = *this; + copy += distance; + return copy; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator& operator+=(difference_type distance) + { + detail::for_each_in_tuple( + iterator_tuple_, + detail::advance_iterator(distance) + ); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator operator-(difference_type distance) const + { + auto copy = *this; + copy -= distance; + return copy; + } + + ROCPRIM_HOST_DEVICE inline + zip_iterator& operator-=(difference_type distance) + { + *this += -distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline + difference_type operator-(zip_iterator other) const + { + return ::rocprim::get<0>(iterator_tuple_) - ::rocprim::get<0>(other.iterator_tuple_); + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(zip_iterator other) const + { + return iterator_tuple_ == other.iterator_tuple_; + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(zip_iterator other) const + { + return !(*this == other); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(zip_iterator other) const + { + return ::rocprim::get<0>(iterator_tuple_) < ::rocprim::get<0>(other.iterator_tuple_); + } + + ROCPRIM_HOST_DEVICE inline + bool operator<=(zip_iterator other) const + { + return !(other < *this); + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(zip_iterator other) const + { + return other < *this; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>=(zip_iterator other) const + { + return !(*this < other); + } + + friend std::ostream& operator<<(std::ostream& os, const zip_iterator& /* iter */) + { + return os; + } + //! \skip_doxy_end + +private: + IteratorTuple iterator_tuple_; +}; + +template +ROCPRIM_HOST_DEVICE inline +zip_iterator +operator+(typename zip_iterator::difference_type distance, + const zip_iterator& iterator) +{ + return iterator + distance; +} + +/// make_zip_iterator creates a zip_iterator using \p iterator_tuple as +/// the underlying tuple of iterator. +/// +/// \tparam IteratorTuple - iterator tuple type +/// +/// \param iterator_tuple - tuple of iterators to use +/// \return A new zip_iterator object +template +ROCPRIM_HOST_DEVICE inline +zip_iterator +make_zip_iterator(IteratorTuple iterator_tuple) +{ + return zip_iterator(iterator_tuple); +} + +/// @} +// end of group iteratormodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_ITERATOR_ZIP_ITERATOR_HPP_ diff --git a/3rdparty/cub/rocprim/rocprim.hpp b/3rdparty/cub/rocprim/rocprim.hpp new file mode 100644 index 0000000000000000000000000000000000000000..193fc041bdf1c3175ebe5c1c9c34a0b8375c02e1 --- /dev/null +++ b/3rdparty/cub/rocprim/rocprim.hpp @@ -0,0 +1,82 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_HPP_ +#define ROCPRIM_HPP_ + +/// \file +/// +/// Meta-header to include rocPRIM API. + +// Meta configuration for rocPRIM +#include "config.hpp" + +#include "rocprim_version.hpp" + +#include "intrinsics.hpp" +#include "functional.hpp" +#include "types.hpp" +#include "type_traits.hpp" +#include "iterator.hpp" + +#include "warp/warp_reduce.hpp" +#include "warp/warp_scan.hpp" +#include "warp/warp_sort.hpp" + +#include "block/block_discontinuity.hpp" +#include "block/block_exchange.hpp" +#include "block/block_histogram.hpp" +#include "block/block_load.hpp" +#include "block/block_radix_sort.hpp" +#include "block/block_scan.hpp" +#include "block/block_sort.hpp" +#include "block/block_store.hpp" + +#include "device/device_adjacent_difference.hpp" +#include "device/device_binary_search.hpp" +#include "device/device_histogram.hpp" +#include "device/device_merge.hpp" +#include "device/device_merge_sort.hpp" +#include "device/device_partition.hpp" +#include "device/device_radix_sort.hpp" +#include "device/device_reduce_by_key.hpp" +#include "device/device_reduce.hpp" +#include "device/device_run_length_encode.hpp" +#include "device/device_scan_by_key.hpp" +#include "device/device_scan.hpp" +#include "device/device_segmented_radix_sort.hpp" +#include "device/device_segmented_reduce.hpp" +#include "device/device_segmented_scan.hpp" +#include "device/device_select.hpp" +#include "device/device_transform.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Returns version of rocPRIM library. +/// \return version of rocPRIM library +ROCPRIM_HOST_DEVICE inline +unsigned int version() +{ + return ROCPRIM_VERSION; +} + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_HPP_ diff --git a/3rdparty/cub/rocprim/rocprim_version.hpp b/3rdparty/cub/rocprim/rocprim_version.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4cf17e967ff8b13709ec91f7a9327d8be15d6fb9 --- /dev/null +++ b/3rdparty/cub/rocprim/rocprim_version.hpp @@ -0,0 +1,41 @@ +// Copyright (c) 2017-2018 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_VERSION_HPP_ +#define ROCPRIM_VERSION_HPP_ + +/// \def ROCPRIM_VERSION +/// \brief ROCPRIM library version +/// +/// Version number may not be visible in the documentation. +/// +/// ROCPRIM_VERSION % 100 is the patch level, +/// ROCPRIM_VERSION / 100 % 1000 is the minor version, +/// ROCPRIM_VERSION / 100000 is the major version. +/// +/// For example, if ROCPRIM_VERSION is 100500, then the major version is 1, +/// the minor version is 5, and the patch level is 0. +#define ROCPRIM_VERSION 2 * 100000 + 11 * 100 + 1 + +#define ROCPRIM_VERSION_MAJOR 2 +#define ROCPRIM_VERSION_MINOR 11 +#define ROCPRIM_VERSION_PATCH 1 + +#endif // ROCPRIM_VERSION_HPP_ diff --git a/3rdparty/cub/rocprim/thread/thread_load.hpp b/3rdparty/cub/rocprim/thread/thread_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6992a5f2376ddc871fa7520218688979bd83890d --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_load.hpp @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_THREAD_THREAD_LOAD_HPP_ +#define ROCPRIM_THREAD_THREAD_LOAD_HPP_ + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +enum cache_load_modifier : int +{ + load_default, ///< Default (no modifier) + load_ca, ///< Cache at all levels + load_cg, ///< Cache at global level + load_cs, ///< Cache streaming (likely to be accessed once) + load_cv, ///< Cache as volatile (including cached system lines) + load_ldg, ///< Cache as texture + load_volatile, ///< Volatile (any memory space) +}; + +namespace detail +{ + +template +ROCPRIM_DEVICE __forceinline__ T AsmThreadLoad(void * ptr) +{ + T retval = 0; + __builtin_memcpy(&retval, ptr, sizeof(T)); + return retval; +} + +#if ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS == 1 + +// Important for syncing. Check section 9.2.2 or 7.3 in the following document +// http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf +#define ROCPRIM_ASM_THREAD_LOAD(cache_modifier, \ + llvm_cache_modifier, \ + type, \ + interim_type, \ + asm_operator, \ + output_modifier, \ + wait_cmd) \ + template<> \ + ROCPRIM_DEVICE __forceinline__ type AsmThreadLoad(void * ptr) \ + { \ + interim_type retval; \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : "=" #output_modifier(retval) : "v"(ptr)); \ + asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \ + return retval; \ + } + +// TODO Add specialization for custom larger data types +#define ROCPRIM_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_cmd); + +// [HIP-CPU] MSVC: erronous inline assembly specification (Triggers error C2059: syntax error: 'volatile') +#ifndef __HIP_CPU_RT__ +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "glc", ""); +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "glc slc", ""); +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cv, "glc", "vmcnt"); +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "glc", "vmcnt"); + +// TODO find correct modifiers to match these +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ldg, "", ""); +ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cs, "", ""); +#endif // __HIP_CPU_RT__ + +#endif + +} + +/// \brief Store data using the default load instruction. No support for cache modified stores yet +/// \tparam MODIFIER - Value in enum for determine which type of cache store modifier to be used +/// \tparam InputIteratorT - Type of Output Iterator +/// \param itr [in] - Iterator to location where data is to be stored +/// \return Data that is loaded from memory +template < + cache_load_modifier MODIFIER = load_default, + typename InputIteratorT> +ROCPRIM_DEVICE ROCPRIM_INLINE +typename std::iterator_traits::value_type +thread_load(InputIteratorT itr) +{ + using T = typename std::iterator_traits::value_type; + T retval = thread_load(&(*itr)); + return *itr; +} + +/// \brief Load data using the default load instruction. No support for cache modified loads yet +/// \tparam MODIFIER - Value in enum for determine which type of cache store modifier to be used +/// \tparam T - Type of Data to be loaded +/// \param ptr [in] - Pointer to data to be loaded +/// \return Data that is loaded from memory +template < + cache_load_modifier MODIFIER = load_default, + typename T> +ROCPRIM_DEVICE ROCPRIM_INLINE +T thread_load(T* ptr) +{ +#ifndef __HIP_CPU_RT__ + return detail::AsmThreadLoad(ptr); +#else + T retval; + std::memcpy(&retval, ptr, sizeof(T)); + return retval; +#endif +} + +END_ROCPRIM_NAMESPACE + +#endif diff --git a/3rdparty/cub/rocprim/thread/thread_operators.hpp b/3rdparty/cub/rocprim/thread/thread_operators.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60f3cf0b91ddfa9ba2cd87f0e7aceffbcd9f16a6 --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_operators.hpp @@ -0,0 +1,200 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ +#define ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ + +#include "../config.hpp" +#include "../types.hpp" + + +BEGIN_ROCPRIM_NAMESPACE + +struct equality +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a == b; + } +}; + +struct inequality +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a != b; + } +}; + +template +struct inequality_wrapper +{ + EqualityOp op; + + ROCPRIM_HOST_DEVICE inline + inequality_wrapper(EqualityOp op) : op(op) {} + + template + ROCPRIM_HOST_DEVICE inline + bool operator()(const T &a, const T &b) + { + return !op(a, b); + } +}; + +struct sum +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a + b; + } +}; + +struct max +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a < b ? b : a; + } +}; + +struct min +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a < b ? a : b; + } +}; + +struct arg_max +{ + template< + class Key, + class Value + > + ROCPRIM_HOST_DEVICE inline + constexpr key_value_pair + operator()(const key_value_pair& a, + const key_value_pair& b) const + { + return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + } +}; + +struct arg_min +{ + template< + class Key, + class Value + > + ROCPRIM_HOST_DEVICE inline + constexpr key_value_pair + operator()(const key_value_pair& a, + const key_value_pair& b) const + { + return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + } +}; + +namespace detail +{ + +// CUB uses value_type of OutputIteratorT (if not void) as a type of intermediate results in scan and reduce, +// for example: +// +// /// The output value type +// typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? +// typename std::iterator_traits::value_type, // ... then the input iterator's value type, +// typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type +// +// rocPRIM (as well as Thrust) uses result type of BinaryFunction instead (if not void): +// +// using input_type = typename std::iterator_traits::value_type; +// using result_type = typename ::rocprim::detail::match_result_type< +// input_type, BinaryFunction +// >::type; +// +// For short -> float using Sum() +// CUB: float Sum(float, float) +// rocPRIM: short Sum(short, short) +// +// This wrapper allows to have compatibility with CUB in hipCUB. +template< + class InputIteratorT, + class OutputIteratorT, + class BinaryFunction +> +struct convert_result_type_wrapper +{ + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + using result_type = + typename std::conditional< + std::is_void::value, input_type, output_type + >::type; + + convert_result_type_wrapper(BinaryFunction op) : op(op) {} + + template + ROCPRIM_HOST_DEVICE inline + constexpr result_type operator()(const T &a, const T &b) const + { + return static_cast(op(a, b)); + } + + BinaryFunction op; +}; + +template< + class InputIteratorT, + class OutputIteratorT, + class BinaryFunction +> +inline +convert_result_type_wrapper +convert_result_type(BinaryFunction op) +{ + return convert_result_type_wrapper(op); +} + +} // end detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ diff --git a/3rdparty/cub/rocprim/thread/thread_reduce.hpp b/3rdparty/cub/rocprim/thread/thread_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d26f57bd96cc604f88115505b7c2043616af9527 --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_reduce.hpp @@ -0,0 +1,110 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_THREAD_THREAD_REDUCE_HPP_ +#define ROCPRIM_THREAD_THREAD_REDUCE_HPP_ + + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Carry out a reduction on an array of elements in one thread +/// \tparam LENGTH - Length of the array to be reduced +/// \tparam T - the input/output type +/// \tparam ReductionOp - Binary Operation that used to carry out the reduction +/// \tparam NoPrefix - Boolean, determining whether to have a initialization value for the reduction accumulator +/// \param input [in] - Pointer to the first element of the array to be reduced +/// \param reduction_op [in] - Instance of the reduction operator functor +/// \param prefix [in] - Value to be used as prefix, if NoPrefix is false +/// \return - Value obtained from reduction of input array +template < + int LENGTH, + typename T, + typename ReductionOp, + bool NoPrefix = false> +ROCPRIM_DEVICE ROCPRIM_INLINE T thread_reduce( + T* input, + ReductionOp reduction_op, + T prefix = T(0)) +{ + T retval; + if(NoPrefix) + retval = input[0]; + else + retval = prefix; + + ROCPRIM_UNROLL + for (int i = 0 + NoPrefix; i < LENGTH; ++i) + retval = reduction_op(retval, input[i]); + + return retval; +} + +/// \brief Carry out a reduction on an array of elements in one thread +/// \tparam LENGTH - Length of the array to be reduced +/// \tparam T - the input/output type +/// \tparam ReductionOp - Binary Operation that used to carry out the reduction +/// \param input [in] - Pointer to the first element of the array to be reduced +/// \param reduction_op [in] - Instance of the reduction operator functor +/// \param prefix [in] - Value to be used as prefix +/// \return - Value obtained from reduction of input array +template < + int LENGTH, + typename T, + typename ReductionOp> +ROCPRIM_DEVICE ROCPRIM_INLINE T thread_reduce( + T (&input)[LENGTH], + ReductionOp reduction_op, + T prefix) +{ + return thread_reduce((T*)input, reduction_op, prefix); +} + +/// \brief Carry out a reduction on an array of elements in one thread +/// \tparam LENGTH - Length of the array to be reduced +/// \tparam T - the input/output type +/// \tparam ReductionOp - Binary Operation that used to carry out the reduction +/// \param input [in] - Pointer to the first element of the array to be reduced +/// \param reduction_op [in] - Instance of the reduction operator functor +/// \return - Value obtained from reduction of input array +template < + int LENGTH, + typename T, + typename ReductionOp> +ROCPRIM_DEVICE ROCPRIM_INLINE T thread_reduce( + T (&input)[LENGTH], + ReductionOp reduction_op) +{ + return thread_reduce((T*)input, reduction_op); +} + +END_ROCPRIM_NAMESPACE + +#endif diff --git a/3rdparty/cub/rocprim/thread/thread_scan.hpp b/3rdparty/cub/rocprim/thread/thread_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f602e6877a9bac9b3f86f8edf5e70dda34b18ae0 --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_scan.hpp @@ -0,0 +1,289 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_THREAD_THREAD_SCAN_HPP_ +#define ROCPRIM_THREAD_THREAD_SCAN_HPP_ + + +#include "../config.hpp" +#include "../functional.hpp" + +BEGIN_ROCPRIM_NAMESPACE + + /** + * \addtogroup UtilModule + * @{ + */ + + /** + * \name Sequential prefix scan over statically-sized array types + * @{ + */ + + /// \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - Length of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param inclusive [in] - Initial value for inclusive aggregate + /// \param exclusive [in] - Initial value for exclusive aggregate + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_exclusive( + T inclusive, + T exclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) + { + ROCPRIM_UNROLL + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + + return inclusive; + } + + + + /// \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - Length of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \param prefix [in] - Prefix to seed scan with + /// \param apply_prefix [in] - Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_exclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. If not, the first output element is undefined. (Handy for preventing thread-0 from applying a prefix.) + { + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = prefix; + T exclusive = inclusive; + + return thread_scan_exclusive(inclusive, exclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + /// \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - Length of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \param prefix [in] - Prefix to seed scan with + /// \param apply_prefix [in] - Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_exclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + { + return thread_scan_exclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); + } + + /// \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - Length of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param inclusive [in] - Initial value for inclusive aggregate + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_inclusive( + T inclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) + { + ROCPRIM_UNROLL + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } + + return inclusive; + } + + +/// \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. +/// \tparam LENGTH - LengthT of \p input and \p output arrays +/// \tparam T - [inferred] The data type to be scanned. +/// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) +/// \param input [in] - Input array +/// \param output [out] - Output array (may be aliased to \p input) +/// \param scan_op [in] - Binary scan operator +/// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_inclusive( + T *input, + T *output, + ScanOp scan_op) + { + T inclusive = input[0]; + output[0] = inclusive; + + // Continue scan + return thread_scan_inclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + + /// \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - LengthT of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_inclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + return thread_scan_inclusive((T*) input, (T*) output, scan_op); + } + + + /// \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - LengthT of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \param prefix [in] - Prefix to seed scan with + /// \param apply_prefix [in] - Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_inclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + { + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = inclusive; + + // Continue scan + return thread_scan_inclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + + /// \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + /// \tparam LENGTH - LengthT of \p input and \p output arrays + /// \tparam T - [inferred] The data type to be scanned. + /// \tparam ScanOp - [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + /// \param input [in] - Input array + /// \param output [out] - Output array (may be aliased to \p input) + /// \param scan_op [in] - Binary scan operator + /// \param prefix [in] - Prefix to seed scan with + /// \param apply_prefix [in] - Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + /// \return - Aggregate of the scan + template < + int LENGTH, + typename T, + typename ScanOp> + ROCPRIM_DEVICE ROCPRIM_INLINE + T thread_scan_inclusive( + T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + T prefix, + bool apply_prefix = true) + { + return thread_scan_inclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); + } + + + //@} end member group + + /** @} */ // end group UtilModule + + END_ROCPRIM_NAMESPACE + + #endif // ROCPRIM_THREAD_THREAD_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/thread/thread_search.hpp b/3rdparty/cub/rocprim/thread/thread_search.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc18e20e211e63ec40f36c9d4fcda78b26a5f82c --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_search.hpp @@ -0,0 +1,155 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + + #ifndef ROCPRIM_THREAD_THREAD_SCAN_HPP_ + #define ROCPRIM_THREAD_THREAD_SCAN_HPP_ + + #include + #include "../config.hpp" + + BEGIN_ROCPRIM_NAMESPACE + +/** + * Computes the begin offsets into A and B for the specific diagonal + */ +template < + typename AIteratorT, + typename BIteratorT, + typename OffsetT, + typename CoordinateT> +ROCPRIM_HOST_DEVICE inline void merge_path_search( + OffsetT diagonal, + AIteratorT a, + BIteratorT b, + OffsetT a_len, + OffsetT b_len, + CoordinateT& path_coordinate) +{ + /// The value type of the input iterator + typedef typename std::iterator_traits::value_type T; + + OffsetT split_min = ::rocprim::max(diagonal - b_len, 0); + OffsetT split_max = ::rocprim::min(diagonal, a_len); + + while (split_min < split_max) + { + OffsetT split_pivot = (split_min + split_max) >> 1; + if (a[split_pivot] <= b[diagonal - split_pivot - 1]) + { + // Move candidate split range up A, down B + split_min = split_pivot + 1; + } + else + { + // Move candidate split range up B, down A + split_max = split_pivot; + } + } + + path_coordinate.x = ::rocprim::min(split_min, a_len); + path_coordinate.y = diagonal - split_min; +} + + + + +/// \brief Returns the offset of the first value within \p input which does not compare less than \p val +/// \tparam InputIteratorT - [inferred] Type of iterator for the input data to be searched +/// \tparam OffsetT - [inferred] The data type of num_items +/// \tparam T - [inferred] The data type of the input sequence elements +/// \param input [in] - Input sequence +/// \param num_items [out] - Input sequence length +/// \param val [in] - Search Key +/// \return - Offset at which val was found +template < + typename InputIteratorT, + typename OffsetT, + typename T> +ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT lower_bound( + InputIteratorT input, + OffsetT num_items, + T val) +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (input[retval + half] < val) + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + else + { + num_items = half; + } + } + + return retval; +} + + +/// \brief Returns the offset of the first value within \p input which compares greater than \p val +/// \tparam InputIteratorT - [inferred] Type of iterator for the input data to be searched +/// \tparam OffsetT - [inferred] The data type of num_items +/// \tparam T - [inferred] The data type of the input sequence elements +/// \param input [in] - Input sequence +/// \param num_items [out] - Input sequence length +/// \param val [in] - Search Key +/// \return - Offset at which val was found +template < + typename InputIteratorT, + typename OffsetT, + typename T> +ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT upper_bound( + InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (val < input[retval + half]) + { + num_items = half; + } + else + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + } + + return retval; +} + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_THREAD_THREAD_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/thread/thread_store.hpp b/3rdparty/cub/rocprim/thread/thread_store.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b53ab9a13d5cda80f49986d50c4100c9dbb1ee8 --- /dev/null +++ b/3rdparty/cub/rocprim/thread/thread_store.hpp @@ -0,0 +1,146 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef ROCPRIM_THREAD_THREAD_STORE_HPP_ +#define ROCPRIM_THREAD_THREAD_STORE_HPP_ + + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +enum cache_store_modifier +{ + store_default, ///< Default (no modifier) + store_wb, ///< Cache write-back all coherent levels + store_cg, ///< Cache at global level + store_cs, ///< Cache streaming (likely to be accessed once) + store_wt, ///< Cache write-through (to system memory) + store_volatile, ///< Volatile shared (any memory space) +}; + +namespace detail +{ + +template +ROCPRIM_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val) +{ + __builtin_memcpy(ptr, &val, sizeof(T)); +} + +#if ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS == 1 + +// NOTE: the reason there is an interim_type is because of a bug for 8bit types. +// TODO fix flat_store_ubyte and flat_store_sbyte issues + +// Important for syncing. Check section 9.2.2 or 7.3 in the following document +// http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf +#define ROCPRIM_ASM_THREAD_STORE(cache_modifier, \ + llvm_cache_modifier, \ + type, \ + interim_type, \ + asm_operator, \ + output_modifier, \ + wait_cmd) \ + template<> \ + ROCPRIM_DEVICE __forceinline__ void AsmThreadStore(void * ptr, type val) \ + { \ + interim_type temp_val = val; \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : : "v"(ptr), #output_modifier(temp_val)); \ + asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \ + } + +// TODO fix flat_store_ubyte and flat_store_sbyte issues +// TODO Add specialization for custom larger data types +#define ROCPRIM_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_cmd); \ + ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_cmd); + +// [HIP-CPU] MSVC: erronous inline assembly specification (Triggers error C2059: syntax error: 'volatile') +#ifndef __HIP_CPU_RT__ +ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "glc", ""); +ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "glc slc", ""); +ROCPRIM_ASM_THREAD_STORE_GROUP(store_wt, "glc", "vmcnt"); +ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "glc", "vmcnt"); + +// TODO find correct modifiers to match these +ROCPRIM_ASM_THREAD_STORE_GROUP(store_cs, "", ""); +#endif // __HIP_CPU_RT__ + +#endif + +} + +/// \brief Store data using the default load instruction. No support for cache modified stores yet +/// \tparam MODIFIER - Value in enum for determine which type of cache store modifier to be used +/// \tparam OutputIteratorT - Type of Output Iterator +/// \tparam T - Type of Data to be stored +/// \param itr [in] - Iterator to location where data is to be stored +/// \param val [in] - Data to be stored +template < + cache_store_modifier MODIFIER = store_default, + typename OutputIteratorT, + typename T +> +ROCPRIM_DEVICE ROCPRIM_INLINE void thread_store( + OutputIteratorT itr, + T val) +{ + thread_store(&(*itr), val); +} + +/// \brief Store data using the default load instruction. No support for cache modified stores yet +/// \tparam MODIFIER - Value in enum for determine which type of cache store modifier to be used +/// \tparam T - Type of Data to be stored +/// \param ptr [in] - Pointer to location where data is to be stored +/// \param val [in] - Data to be stored +template < + cache_store_modifier MODIFIER = store_default, + typename T +> +ROCPRIM_DEVICE ROCPRIM_INLINE void thread_store( + T *ptr, + T val) +{ +#ifndef __HIP_CPU_RT__ + detail::AsmThreadStore(ptr, val); +#else + std::memcpy(ptr, &val, sizeof(T)); +#endif +} + +END_ROCPRIM_NAMESPACE + +#endif diff --git a/3rdparty/cub/rocprim/type_traits.hpp b/3rdparty/cub/rocprim/type_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b6afb6d11216e3ac9d705e18bd140bd7391395e --- /dev/null +++ b/3rdparty/cub/rocprim/type_traits.hpp @@ -0,0 +1,200 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPE_TRAITS_HPP_ +#define ROCPRIM_TYPE_TRAITS_HPP_ + +#include + +// Meta configuration for rocPRIM +#include "config.hpp" +#include "types.hpp" + +/// \addtogroup utilsmodule_typetraits +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Behaves like std::is_floating_point, but also includes half-precision and bfloat16-precision +/// floating point type (rocprim::half). +template +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point::value || + std::is_same<::rocprim::half, typename std::remove_cv::type>::value || + std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + > {}; + +/// \brief Alias for std::is_integral. +template +using is_integral = std::is_integral; + +/// \brief Behaves like std::is_arithmetic, but also includes half-precision and bfloat16-precision +/// floating point type (\ref rocprim::half). +template +struct is_arithmetic + : std::integral_constant< + bool, + std::is_arithmetic::value || + std::is_same<::rocprim::half, typename std::remove_cv::type>::value || + std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + > {}; + +/// \brief Behaves like std::is_fundamental, but also includes half-precision and bfloat16-precision +/// floating point type (\ref rocprim::half). +template +struct is_fundamental + : std::integral_constant< + bool, + std::is_fundamental::value || + std::is_same<::rocprim::half, typename std::remove_cv::type>::value || + std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value +> {}; + +/// \brief Alias for std::is_unsigned. +template +using is_unsigned = std::is_unsigned; + +/// \brief Behaves like std::is_signed, but also includes half-precision and bfloat16-precision +/// floating point type (\ref rocprim::half). +template +struct is_signed + : std::integral_constant< + bool, + std::is_signed::value || + std::is_same<::rocprim::half, typename std::remove_cv::type>::value || + std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + > {}; + +/// \brief Behaves like std::is_scalar, but also includes half-precision and bfloat16-precision +/// floating point type (\ref rocprim::half). +template +struct is_scalar + : std::integral_constant< + bool, + std::is_scalar::value || + std::is_same<::rocprim::half, typename std::remove_cv::type>::value || + std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value + > {}; + +/// \brief Behaves like std::is_compound, but also supports half-precision +/// floating point type (\ref rocprim::half). `value` for \ref rocprim::half is `false`. +template +struct is_compound + : std::integral_constant< + bool, + !is_fundamental::value + > {}; + +template +struct get_unsigned_bits_type +{ + typedef typename get_unsigned_bits_type::unsigned_type unsigned_type; +}; + +template +struct get_unsigned_bits_type +{ + typedef uint8_t unsigned_type; +}; + + +template +struct get_unsigned_bits_type +{ + typedef uint16_t unsigned_type; +}; + + +template +struct get_unsigned_bits_type +{ + typedef uint32_t unsigned_type; +}; + + +template +struct get_unsigned_bits_type +{ + typedef uint64_t unsigned_type; +}; + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleIn(UnsignedBits key) + -> typename std::enable_if::value, UnsignedBits>::type +{ + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; + return key ^ mask; +} + +template +static ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleIn(UnsignedBits key) + -> typename std::enable_if::value, UnsignedBits>::type +{ + return key ; +}; + +template +static ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleIn(UnsignedBits key) + -> typename std::enable_if::value && is_signed::value, UnsignedBits>::type +{ + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + return key ^ HIGH_BIT; +}; + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleOut(UnsignedBits key) + -> typename std::enable_if::value, UnsignedBits>::type +{ + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); + return key ^ mask; +} + +template +static ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleOut(UnsignedBits key) + -> typename std::enable_if::value, UnsignedBits>::type +{ + return key; +}; + +template +static ROCPRIM_DEVICE ROCPRIM_INLINE +auto TwiddleOut(UnsignedBits key) + -> typename std::enable_if::value && is_signed::value, UnsignedBits>::type +{ + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + return key ^ HIGH_BIT; +}; + + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule_typetraits + +#endif // ROCPRIM_TYPE_TRAITS_HPP_ diff --git a/3rdparty/cub/rocprim/types.hpp b/3rdparty/cub/rocprim/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0733c9b7c79b8c36f7c87986a655f6bf6c22cebd --- /dev/null +++ b/3rdparty/cub/rocprim/types.hpp @@ -0,0 +1,179 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_HPP_ +#define ROCPRIM_TYPES_HPP_ + +#include + +// Meta configuration for rocPRIM +#include "config.hpp" + +#include "types/future_value.hpp" +#include "types/double_buffer.hpp" +#include "types/integer_sequence.hpp" +#include "types/key_value_pair.hpp" +#include "types/tuple.hpp" + +/// \addtogroup utilsmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ +// Define vector types that will be used by rocPRIM internally. +// We don't use HIP vector types because they don't generate correct +// load/store operations, see https://github.com/RadeonOpenCompute/ROCm/issues/341 +#ifndef _MSC_VER +#define DEFINE_VECTOR_TYPE(name, base) \ +\ +struct alignas(sizeof(base) * 2) name##2 \ +{ \ + typedef base vector_value_type __attribute__((ext_vector_type(2))); \ + union { \ + vector_value_type data; \ + struct { base x, y; }; \ + }; \ +}; \ +\ +struct alignas(sizeof(base) * 4) name##4 \ +{ \ + typedef base vector_value_type __attribute__((ext_vector_type(4))); \ + union { \ + vector_value_type data; \ + struct { base x, y, w, z; }; \ + }; \ +}; +#else +#define DEFINE_VECTOR_TYPE(name, base) \ +\ +struct alignas(sizeof(base) * 2) name##2 \ +{ \ + typedef base vector_value_type; \ + union { \ + vector_value_type data; \ + struct { base x, y; }; \ + }; \ +}; \ +\ +struct alignas(sizeof(base) * 4) name##4 \ +{ \ + typedef base vector_value_type; \ + union { \ + vector_value_type data; \ + struct { base x, y, w, z; }; \ + }; \ +}; +#endif + +#ifdef _MSC_VER +#pragma warning( push ) +#pragma warning( disable : 4201 ) // nonstandard extension used: nameless struct/union +#endif +DEFINE_VECTOR_TYPE(char, char); +DEFINE_VECTOR_TYPE(short, short); +DEFINE_VECTOR_TYPE(int, int); +DEFINE_VECTOR_TYPE(longlong, long long); +#ifdef _MSC_VER +#pragma warning( pop ) +#endif +// Takes a scalar type T and matches to a vector type based on NumElements. +template +struct make_vector_type +{ + using type = void; +}; + +#define DEFINE_MAKE_VECTOR_N_TYPE(name, base, suffix) \ +template<> \ +struct make_vector_type \ +{ \ + using type = name##suffix; \ +}; + +#define DEFINE_MAKE_VECTOR_TYPE(name, base) \ +\ +template <> \ +struct make_vector_type \ +{ \ + using type = base; \ +}; \ +DEFINE_MAKE_VECTOR_N_TYPE(name, base, 2) \ +DEFINE_MAKE_VECTOR_N_TYPE(name, base, 4) + +DEFINE_MAKE_VECTOR_TYPE(char, char); +DEFINE_MAKE_VECTOR_TYPE(short, short); +DEFINE_MAKE_VECTOR_TYPE(int, int); +DEFINE_MAKE_VECTOR_TYPE(longlong, long long); + +#undef DEFINE_VECTOR_TYPE +#undef DEFINE_MAKE_VECTOR_TYPE +#undef DEFINE_MAKE_VECTOR_N_TYPE + +} // end namespace detail + +/// \brief Empty type used as a placeholder, usually used to flag that given +/// template parameter should not be used. +struct empty_type {}; + +/// \brief Binary operator that takes two instances of empty_type, usually used +/// as nop replacement for the HIP-CPU back-end +struct empty_binary_op +{ + constexpr empty_type operator()(const empty_type&, const empty_type&) const { return empty_type{}; } +}; + +/// \brief Half-precision floating point type +using half = ::__half; +/// \brief bfloat16 floating point type +using bfloat16 = ::cuda_bfloat16; + +// The lane_mask_type only exist at device side +#ifndef __AMDGCN_WAVEFRONT_SIZE +// When not compiling with hipcc, we're compiling with HIP-CPU +// TODO: introduce a ROCPRIM-specific macro to query this +#define __AMDGCN_WAVEFRONT_SIZE 64 +#endif +#if __AMDGCN_WAVEFRONT_SIZE == 32 +using lane_mask_type = unsigned int; +#elif __AMDGCN_WAVEFRONT_SIZE == 64 +using lane_mask_type = unsigned long long int; +#endif + +#ifdef __HIP_CPU_RT__ +using native_half = half; +#else +using native_half = _Float16; +#endif + +#ifdef __HIP_CPU_RT__ +// TODO: Find a better type +using native_bfloat16 = bfloat16; +#else +using native_bfloat16 = bfloat16; +#endif + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule + +#endif // ROCPRIM_TYPES_HPP_ diff --git a/3rdparty/cub/rocprim/types/double_buffer.hpp b/3rdparty/cub/rocprim/types/double_buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2185dd8c1242199cfb844ed5c654b6d42f96b6cc --- /dev/null +++ b/3rdparty/cub/rocprim/types/double_buffer.hpp @@ -0,0 +1,80 @@ +// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_DOUBLE_BUFFER_HPP_ +#define ROCPRIM_TYPES_DOUBLE_BUFFER_HPP_ + +#include "../config.hpp" + +/// \addtogroup utilsmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +template +class double_buffer +{ + T * buffers[2]; + + unsigned int selector; + +public: + + ROCPRIM_HOST_DEVICE inline + double_buffer() + { + selector = 0; + buffers[0] = nullptr; + buffers[1] = nullptr; + } + + ROCPRIM_HOST_DEVICE inline + double_buffer(T * current, T * alternate) + { + selector = 0; + buffers[0] = current; + buffers[1] = alternate; + } + + ROCPRIM_HOST_DEVICE inline + T * current() const + { + return buffers[selector]; + } + + ROCPRIM_HOST_DEVICE inline + T * alternate() const + { + return buffers[selector ^ 1]; + } + + ROCPRIM_HOST_DEVICE inline + void swap() + { + selector ^= 1; + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule + +#endif // ROCPRIM_TYPES_DOUBLE_BUFFER_HPP_ diff --git a/3rdparty/cub/rocprim/types/future_value.hpp b/3rdparty/cub/rocprim/types/future_value.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ddc93dd1d8ac531794a63c3842c051cae2b96225 --- /dev/null +++ b/3rdparty/cub/rocprim/types/future_value.hpp @@ -0,0 +1,118 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_FUTURE_VALUE_HPP_ +#define ROCPRIM_TYPES_FUTURE_VALUE_HPP_ + +#include "../config.hpp" + +/// \addtogroup utilsmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/** + * \brief Allows passing values that are not yet known at launch time as paramters to device algorithms. + * + * \note It is the users responsibility to ensure that value is available when the algorithm executes. + * This can be guaranteed with stream dependencies or explicit external synchronization. + * + * \code + * int* intermediate_result = nullptr; + * cudaMalloc(reinterpret_cast(&intermediate_result), sizeof(intermediate_result)); + * hipLaunchKernelGGL(compute_intermediate, blocks, threads, 0, stream, arg1, arg2, itermediate_result); + * const auto initial_value = rocprim::future_value{intermediate_result}; + * rocprim::exclusive_scan(temporary_storage, + * storage_size, + * input, + * output, + * initial_value, + * size); + * hipFree(intermediate_result) + * \endcode + * + * \tparam T + * \tparam Iter + */ +template +class future_value +{ +public: + using value_type = T; + using iterator_type = Iter; + + explicit ROCPRIM_HOST_DEVICE future_value(const Iter iter) + : iter_ {iter} + { + } + + ROCPRIM_HOST_DEVICE operator T() + { + return *iter_; + } + + ROCPRIM_HOST_DEVICE operator T() const + { + return *iter_; + } +private: + Iter iter_; +}; + +namespace detail +{ + /// \brief Used for unpacking a future_value, basically just a cast but its more explicit + /// this way. + template + ROCPRIM_HOST_DEVICE T get_input_value(const T value) + { + return value; + } + + template + ROCPRIM_HOST_DEVICE T get_input_value(::rocprim::future_value future) { + return future; + } + + template + struct input_value_traits { + using value_type = T; + }; + + template + struct input_value_traits<::rocprim::future_value> + { + using value_type = T; + using iterator_type = Iter; + }; + + template + using input_type_t = typename input_value_traits::value_type; + + template + using input_iterator_t = typename input_value_traits::iterator_type; +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule + +#endif diff --git a/3rdparty/cub/rocprim/types/integer_sequence.hpp b/3rdparty/cub/rocprim/types/integer_sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ece3d689f38b0a97411858734b8d430a8ab56182 --- /dev/null +++ b/3rdparty/cub/rocprim/types/integer_sequence.hpp @@ -0,0 +1,94 @@ +// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_INTEGER_SEQUENCE_HPP_ +#define ROCPRIM_TYPES_INTEGER_SEQUENCE_HPP_ + +#include + +#include "../config.hpp" + +BEGIN_ROCPRIM_NAMESPACE +#if defined(__cpp_lib_integer_sequence) && !defined(DOXYGEN_SHOULD_SKIP_THIS) +// For C++14 or newer we just use standard implementation +using std::integer_sequence; +using std::index_sequence; +using std::make_integer_sequence; +using std::make_index_sequence; +using std::index_sequence_for; +#else +/// \brief Compile-time sequence of integers +/// +/// Implements std::integer_sequence for C++11. When C++14 is supported +/// it is just an alias for std::integer_sequence. +template +class integer_sequence +{ + using value_type = T; + + static inline constexpr size_t size() noexcept + { + return sizeof...(Ints); + } +}; + +template +using index_sequence = integer_sequence; + +// DETAILS +namespace detail +{ + +template +struct integer_sequence_cat; + +template +struct integer_sequence_cat> +{ + using type = typename ::rocprim::integer_sequence; +}; + +template +struct make_integer_sequence_impl : + integer_sequence_cat::type> +{ +}; + +template +struct make_integer_sequence_impl +{ + using type = ::rocprim::integer_sequence; +}; + +} // end detail namespace + +template +using make_integer_sequence = typename detail::make_integer_sequence_impl::type; + +template +using make_index_sequence = make_integer_sequence; + +template +using index_sequence_for = make_index_sequence; +#endif + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_TYPES_INTEGER_SEQUENCE_HPP_ diff --git a/3rdparty/cub/rocprim/types/key_value_pair.hpp b/3rdparty/cub/rocprim/types/key_value_pair.hpp new file mode 100644 index 0000000000000000000000000000000000000000..469a0c7d69c6fcb55fd5597b6b201484cddd8071 --- /dev/null +++ b/3rdparty/cub/rocprim/types/key_value_pair.hpp @@ -0,0 +1,81 @@ +// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_KEY_VALUE_PAIR_HPP_ +#define ROCPRIM_TYPES_KEY_VALUE_PAIR_HPP_ + +#include "../config.hpp" + +/// \addtogroup utilsmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +template< + class Key_, + class Value_ +> +struct key_value_pair +{ + #ifndef DOXYGEN_SHOULD_SKIP_THIS + using Key = Key_; + using Value = Value_; + #endif + + using key_type = Key_; + using value_type = Value_; + + key_type key; + value_type value; + + ROCPRIM_HOST_DEVICE inline + key_value_pair() = default; + + ROCPRIM_HOST_DEVICE inline + ~key_value_pair() = default; + + ROCPRIM_HOST_DEVICE inline + key_value_pair(const key_type key, const value_type value) : key(key), value(value) + { + } + + #if __hcc_major__ < 1 || __hcc_major__ == 1 && __hcc_minor__ < 2 + ROCPRIM_HOST_DEVICE inline + key_value_pair& operator =(const key_value_pair& kvb) + { + key = kvb.key; + value = kvb.value; + return *this; + } + #endif + + ROCPRIM_HOST_DEVICE inline + bool operator !=(const key_value_pair& kvb) + { + return (key != kvb.key) || (value != kvb.value); + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule + +#endif // ROCPRIM_TYPES_KEY_VALUE_PAIR_HPP_ diff --git a/3rdparty/cub/rocprim/types/tuple.hpp b/3rdparty/cub/rocprim/types/tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e2b951cbbec58f4d29f89d46a2e4b4affd2b4fc3 --- /dev/null +++ b/3rdparty/cub/rocprim/types/tuple.hpp @@ -0,0 +1,1127 @@ +// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPES_TUPLE_HPP_ +#define ROCPRIM_TYPES_TUPLE_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/all_true.hpp" + +#include "integer_sequence.hpp" + +/// \addtogroup utilsmodule_tuple +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +// //////////////////////// +// tuple (FORWARD DECLARATION) +// //////////////////////// +template +class tuple; + +// //////////////////////// +// tuple_size +// //////////////////////// + +/// \brief Provides access to the number of elements in a tuple as a compile-time constant expression. +/// +/// tuple_size is undefined for types \p T that are not tuples. +template +class tuple_size; + +/// \brief For \p T that is tuple, \p tuple_size::value is the +/// the number of elements in a tuple (equal to sizeof...(Types)). +/// +/// \see std::integral_constant +template +class tuple_size<::rocprim::tuple> : public std::integral_constant +{ + // All member functions of std::integral_constant are constexpr, so it should work + // without problems on HIP +}; +/// const T specialization of \ref tuple_size +template +class tuple_size + : public std::integral_constant::value> +{ + +}; +/// volatile T specialization of \ref tuple_size +template +class tuple_size + : public std::integral_constant::value> +{ + +}; +/// const volatile T specialization of \ref tuple_size +template +class tuple_size + : public std::integral_constant::value> +{ + +}; + +// //////////////////////// +// tuple_element +// //////////////////////// + +/// \brief Provides compile-time indexed access to the types of the elements of the tuple. +/// +/// tuple_element is undefined for types \p T that are not tuples. +template +struct tuple_element; // rocprim::tuple_size is defined only for rocprim::tuple + +namespace detail +{ + +template +struct tuple_element_impl; + +template +struct tuple_element_impl> + : tuple_element_impl> +{ + +}; + +template +struct tuple_element_impl<0, ::rocprim::tuple> +{ + using type = T; +}; + +template +struct tuple_element_impl> +{ + static_assert(I != I, "tuple_element index out of range"); +}; + +} // end detail namespace + +/// \brief For \p T that is tuple, \p tuple_element::type is the +/// type of Ith element of that tuple. +template +struct tuple_element> +{ + /// \brief The type of Ith element of the tuple, where \p I is in [0, sizeof...(Types)) + #ifndef DOXYGEN_SHOULD_SKIP_THIS + using type = typename detail::tuple_element_impl>::type; + #else + typedef type; + #endif +}; +/// const T specialization of \ref tuple_element +template +struct tuple_element +{ + /// \brief The type of Ith element of the tuple, where \p I is in [0, sizeof...(Types)) + using type = typename std::add_const::type>::type; +}; +/// volatile T specialization of \ref tuple_element +template +struct tuple_element +{ + /// \brief The type of Ith element of the tuple, where \p I is in [0, sizeof...(Types)) + using type = typename std::add_volatile::type>::type; +}; +/// const volatile T specialization of \ref tuple_element +template +struct tuple_element +{ + /// \brief The type of Ith element of the tuple, where \p I is in [0, sizeof...(Types)) + using type = typename std::add_cv::type>::type; +}; + +template +using tuple_element_t = typename tuple_element::type; + +// get forward declaration +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +ROCPRIM_HOST_DEVICE +const tuple_element_t>& get(const tuple&) noexcept; + +template +ROCPRIM_HOST_DEVICE +tuple_element_t>& get(tuple&) noexcept; + +template +ROCPRIM_HOST_DEVICE +tuple_element_t>&& get(tuple&&) noexcept; +#endif + +// //////////////////////// +// tuple +// //////////////////////// + +namespace detail +{ + + template + ROCPRIM_HOST_DEVICE + inline T&& custom_forward(typename std::remove_reference::type& t) noexcept + { + return static_cast(t); + } + + template + ROCPRIM_HOST_DEVICE + inline T&& custom_forward(typename std::remove_reference::type&& t) noexcept + { + static_assert(!std::is_lvalue_reference::value, + "Can not forward an rvalue as an lvalue."); + return static_cast(t); + } + + +#ifdef __cpp_lib_is_final + template + using is_final = std::is_final; +#elif defined(__HCC__) // use clang extention + template + using is_final = std::integral_constant; +#else + template + struct is_final : std::false_type + { + }; +#endif + +// tuple_value - represents single element in a tuple +template< + size_t I, + class T, + bool /* Empty base optimization switch */ = std::is_empty::value && !is_final::value +> +struct tuple_value +{ + T value; + + ROCPRIM_HOST_DEVICE inline + constexpr tuple_value() noexcept : value() + { + static_assert(!std::is_reference::value, "can't default construct a reference element in a tuple" ); + } + + ROCPRIM_HOST_DEVICE inline + tuple_value(const tuple_value&) = default; + + ROCPRIM_HOST_DEVICE inline + tuple_value(tuple_value&&) = default; + + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(T value) noexcept + : value(value) + { + // This is workaround for hcc which fails during linking without + // this constructor with undefine reference errors when U from ctors + // below is exactly T. Example: + // rocprim::tuple t(1, 2, 3); + // Produced error: + // undefined reference to `rocprim::detail::tuple_value<0ul, int>::tuple_value(int) + } + + template< + class U, + typename = typename std::enable_if< + !std::is_same::type, tuple_value>::value + >::type, + typename = typename std::enable_if< + std::is_constructible::value + >::type + > + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(const U& v) noexcept : value(v) + { + } + + template< + class U, + typename = typename std::enable_if< + // So U can't be tuple_value + !std::is_same::type, tuple_value>::value + >::type, + typename = typename std::enable_if< + std::is_constructible::value + >::type + > + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(U&& v) noexcept : value(::rocprim::detail::custom_forward(v)) + { + } + + ROCPRIM_HOST_DEVICE inline + ~tuple_value() = default; + + template + ROCPRIM_HOST_DEVICE inline + tuple_value& operator=(U&& v) noexcept + { + value = ::rocprim::detail::custom_forward(v); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + void swap(tuple_value& v) noexcept + { + auto tmp = std::move(v.value); + v.value = std::move(this->value); + this->value = std::move(tmp); + } + + ROCPRIM_HOST_DEVICE inline + T& get() noexcept + { + return value; + } + + ROCPRIM_HOST_DEVICE inline + const T& get() const noexcept + { + return value; + } +}; + +// Specialization for empty base optimization +template +struct tuple_value : private T +{ + ROCPRIM_HOST_DEVICE inline + constexpr tuple_value() noexcept : T() + { + static_assert(!std::is_reference::value, "can't default construct a reference element in a tuple" ); + } + + ROCPRIM_HOST_DEVICE inline + tuple_value(const tuple_value&) = default; + + ROCPRIM_HOST_DEVICE inline + tuple_value(tuple_value&&) = default; + + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(T value) noexcept + : T(value) + { + // This is workaround for hcc which fails during linking without + // this constructor with undefine reference errors when U from ctors + // below is exactly T. Example: + // rocprim::tuple t(1, 2, 3); + // Produced error: + // undefined reference to `rocprim::detail::tuple_value<0ul, int>::tuple_value(int) + } + + template< + class U, + typename = typename std::enable_if< + !std::is_same::type, tuple_value>::value + >::type, + typename = typename std::enable_if< + std::is_constructible::value + >::type + > + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(const U& v) noexcept : T(v) + { + } + + template< + class U, + typename = typename std::enable_if< + // So U can't be tuple_value + !std::is_same::type, tuple_value>::value + >::type, + typename = typename std::enable_if< + std::is_constructible::value + >::type + > + ROCPRIM_HOST_DEVICE inline + explicit tuple_value(U&& v) noexcept : T(::rocprim::detail::custom_forward(v)) + { + } + + ROCPRIM_HOST_DEVICE inline + ~tuple_value() = default; + + template + ROCPRIM_HOST_DEVICE inline + tuple_value& operator=(U&& v) noexcept + { + T::operator=(::rocprim::detail::custom_forward(v)); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + void swap(tuple_value& v) noexcept + { + auto tmp = std::move(v); + v = std::move(*this); + *this = std::move(tmp); + } + + ROCPRIM_HOST_DEVICE inline + T& get() noexcept + { + return static_cast(*this); + } + + ROCPRIM_HOST_DEVICE inline + const T& get() const noexcept + { + return static_cast(*this); + } +}; + +template +ROCPRIM_HOST_DEVICE inline +void swallow(Types&&...) noexcept {} + +template +struct tuple_impl; + +template +struct tuple_impl<::rocprim::index_sequence, Types...> + : tuple_value... +{ + ROCPRIM_HOST_DEVICE inline + constexpr tuple_impl() = default; + + ROCPRIM_HOST_DEVICE inline + tuple_impl(const tuple_impl&) = default; + + ROCPRIM_HOST_DEVICE inline + tuple_impl(tuple_impl&&) = default; + + ROCPRIM_HOST_DEVICE inline + explicit tuple_impl(Types... values) + : tuple_value(values)... + { + // This is workaround for hcc which fails during linking without + // this constructor with undefine reference errors when UTypes + // are exactly Types (see constructor below). Example: + // rocprim::tuple t(1, 2, 3); + // Produced error: + // undefined reference to `rocprim::detail::tuple_impl< + // rocprim::integer_sequence, int, int, int + // >::tuple_impl(int, int, int)' + } + + template< + class... UTypes, + typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type + > + ROCPRIM_HOST_DEVICE inline + explicit tuple_impl(UTypes&&... values) + : tuple_value(::rocprim::detail::custom_forward(values))... + { + } + + template< + class... UTypes, + typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type + > + ROCPRIM_HOST_DEVICE inline + tuple_impl(::rocprim::tuple&& other) + : tuple_value(::rocprim::detail::custom_forward(::rocprim::get(other)))... + { + } + + template< + class... UTypes, + typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type + > + ROCPRIM_HOST_DEVICE inline + tuple_impl(const ::rocprim::tuple& other) + : tuple_value(::rocprim::get(other))... + { + } + + ROCPRIM_HOST_DEVICE inline + ~tuple_impl() = default; + + ROCPRIM_HOST_DEVICE inline + tuple_impl& operator=(const tuple_impl& other) noexcept + { + swallow( + tuple_value::operator=( + static_cast&>(other).get() + )... + ); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + tuple_impl& operator=(tuple_impl&& other) noexcept + { + swallow( + tuple_value::operator=( + static_cast&>(other).get() + )... + ); + return *this; + } + + template + ROCPRIM_HOST_DEVICE inline + tuple_impl& operator=(const ::rocprim::tuple& other) noexcept + { + swallow(tuple_value::operator=(::rocprim::get(other))...); + return *this; + } + + template + ROCPRIM_HOST_DEVICE inline + tuple_impl& operator=(::rocprim::tuple&& other) noexcept + { + swallow( + tuple_value::operator=( + ::rocprim::get(std::move(other)) + )... + ); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + tuple_impl& swap(tuple_impl& other) noexcept + { + swallow( + (static_cast&>(*this).swap( + static_cast&>(other) + ), 0)... + ); + return *this; + } +}; + +template +using tuple_base = + tuple_impl< + typename ::rocprim::index_sequence_for, + Types... + >; + +} // end detail namespace + +/// \brief Fixed-size collection of heterogeneous values. +/// +/// \tparam Types... the types (zero or more) of the elements that the tuple stores. +/// +/// \pre +/// * For all types in \p Types... following operations should not throw exceptions: +/// construction, copy and move assignment, and swapping. +/// +/// \see std::tuple +template +class tuple +{ + using base_type = detail::tuple_base; + // tuple_impl + base_type base; + + template + struct check_constructor + { + template + static constexpr bool enable_default() + { + return detail::all_true::value...>::value; + } + + template + static constexpr bool enable_copy() + { + return detail::all_true::value...>::value; + } + }; + + #ifndef DOXYGEN_SHOULD_SKIP_THIS + template + ROCPRIM_HOST_DEVICE + friend const tuple_element_t>& get(const tuple&) noexcept; + + template + ROCPRIM_HOST_DEVICE + friend tuple_element_t>& get(tuple&) noexcept; + + template + ROCPRIM_HOST_DEVICE + friend tuple_element_t>&& get(tuple&&) noexcept; + #endif + +public: + /// \brief Default constructor. Performs value-initialization of all elements. + /// + /// This overload only participates in overload resolution if: + /// * std::is_default_constructible::value is \p true for all \p i. + #ifndef DOXYGEN_SHOULD_SKIP_THIS + template< + class Dummy = void, + typename = typename std::enable_if< + check_constructor::template enable_default() + >::type + > + #endif + ROCPRIM_HOST_DEVICE inline + constexpr tuple() noexcept : base() {}; + + /// \brief Implicitly-defined copy constructor. + ROCPRIM_HOST_DEVICE inline + tuple(const tuple&) = default; + + /// \brief Implicitly-defined move constructor. + ROCPRIM_HOST_DEVICE inline + tuple(tuple&&) = default; + + #ifndef DOXYGEN_SHOULD_SKIP_THIS + ROCPRIM_HOST_DEVICE inline + explicit tuple(Types... values) noexcept + : base(values...) + { + // Workaround for HCC compiler, without this we get undefined reference + // errors during linking. Example: + // rocprim::tuple t1(1, 2) + // Produces error: + // 'undefined reference to `rocprim::tuple::tuple(int, double)' + } + #endif + + /// \brief Direct constructor. Initializes each element of the tuple with + /// the corresponding input value. + /// + /// This overload only participates in overload resolution if: + /// * std::is_copy_constructible::value is \p true for all \p i. + #ifndef DOXYGEN_SHOULD_SKIP_THIS + template< + class Dummy = void, + typename = typename std::enable_if< + check_constructor::template enable_copy() + >::type + > + #endif + ROCPRIM_HOST_DEVICE inline + explicit tuple(const Types&... values) + : base(values...) + { + } + + /// \brief Converting constructor. Initializes each element of the tuple + /// with the corresponding value in \p ::rocprim::detail::custom_forward(values). + /// + /// This overload only participates in overload resolution if: + /// * sizeof...(Types) == sizeof...(UTypes), + /// * sizeof...(Types) >= 1, and + /// * std::is_constructible::value is \p true for all \p i. + template< + class... UTypes + #ifndef DOXYGEN_SHOULD_SKIP_THIS + ,typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type, + typename = typename std::enable_if< + detail::all_true::value...>::value + >::type + #endif + > + ROCPRIM_HOST_DEVICE inline + explicit tuple(UTypes&&... values) noexcept + : base(::rocprim::detail::custom_forward(values)...) + { + } + + /// \brief Converting copy constructor. Initializes each element of the tuple + /// with the corresponding value from \p other. + /// + /// This overload only participates in overload resolution if: + /// * sizeof...(Types) == sizeof...(UTypes), + /// * sizeof...(Types) >= 1, and + /// * std::is_constructible::value is \p true for all \p i. + template< + class... UTypes, + #ifndef DOXYGEN_SHOULD_SKIP_THIS + typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type, + typename = typename std::enable_if< + detail::all_true::value...>::value + >::type + #endif + > + ROCPRIM_HOST_DEVICE inline + tuple(const tuple& other) noexcept + : base(other) + { + } + + /// \brief Converting move constructor. Initializes each element of the tuple + /// with the corresponding value from \p other. + /// + /// This overload only participates in overload resolution if: + /// * sizeof...(Types) == sizeof...(UTypes), + /// * sizeof...(Types) >= 1, and + /// * std::is_constructible::value is \p true for all \p i. + template< + class... UTypes, + #ifndef DOXYGEN_SHOULD_SKIP_THIS + typename = typename std::enable_if< + sizeof...(UTypes) == sizeof...(Types) + >::type, + typename = typename std::enable_if< + sizeof...(Types) >= 1 + >::type, + typename = typename std::enable_if< + detail::all_true::value...>::value + >::type + #endif + > + ROCPRIM_HOST_DEVICE inline + tuple(tuple&& other) noexcept + : base(::rocprim::detail::custom_forward>(other)) + { + } + + /// \brief Implicitly-defined destructor. + ROCPRIM_HOST_DEVICE inline + ~tuple() noexcept = default; + + #ifndef DOXYGEN_SHOULD_SKIP_THIS + template< + class T, + typename = typename std::enable_if< + std::is_assignable::value + >::type + > + ROCPRIM_HOST_DEVICE inline + tuple& operator=(T&& v) noexcept + { + base = ::rocprim::detail::custom_forward(v); + return *this; + } + + ROCPRIM_HOST_DEVICE inline + tuple& operator=(const tuple& other) noexcept + { + base = other.base; + return *this; + } + #else // For documentation + /// \brief Copy assignment operator. + /// \param other tuple to replace the contents of this tuple + tuple& operator=(const tuple& other) noexcept; + /// \brief Move assignment operator. + /// \param other tuple to replace the contents of this tuple + tuple& operator=(tuple&& other) noexcept; + /// \brief For all \p i, assigns \p rocprim::get(other) to \p rocprim::get(*this). + /// \param other tuple to replace the contents of this tuple + template + tuple& operator=(const tuple& other) noexcept; + /// \brief For all \p i, assigns \p ::rocprim::detail::custom_forward(get(other)) to \p rocprim::get(*this). + /// \param other tuple to replace the contents of this tuple + template + tuple& operator=(tuple&& other) noexcept; + #endif + + /// \brief Swaps the content of the tuple (\p *this) with the content \p other + /// \param other tuple of values to swap + void swap(tuple& other) noexcept + { + base.swap(other.base); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template<> +class tuple<> +{ +public: + ROCPRIM_HOST_DEVICE inline + constexpr tuple() noexcept + { + } + + ROCPRIM_HOST_DEVICE inline + ~tuple() = default; + + ROCPRIM_HOST_DEVICE inline + void swap(tuple&) noexcept + { + } +}; +#endif + +namespace detail +{ + +template +struct tuple_equal_to +{ + template + ROCPRIM_HOST_DEVICE inline + bool operator()(const T& lhs, const U& rhs) const + { + return tuple_equal_to()(lhs, rhs) && get(lhs) == get(rhs); + } +}; + +template<> +struct tuple_equal_to<0> +{ + template + ROCPRIM_HOST_DEVICE inline + bool operator()(const T&, const U&) const + { + return true; + } +}; + +template +struct tuple_less_than +{ + template + ROCPRIM_HOST_DEVICE inline + bool operator()(const T& lhs, const U& rhs) const + { + constexpr size_t idx = tuple_size::value - I; + if(get(lhs) < get(rhs)) + return true; + if(get(rhs) < get(lhs)) + return false; + return tuple_less_than()(lhs, rhs); + } +}; + +template<> +struct tuple_less_than<0> +{ + template + ROCPRIM_HOST_DEVICE inline + bool operator()(const T&, const U&) const + { + return false; + } +}; + +} // end namespace detail + +/// \brief Equal to operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares every element of the tuple lhs with the corresponding element +/// of the tuple rhs, and returns \p true if all are equal. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return \p true if rocprim::get(lhs) == rocprim::get(rhs) for all +/// \p i in [0, sizeof...(TTypes)); otherwise - \p false. Comparing two +/// empty tuples returns \p true. +template< + class... TTypes, + class... UTypes, + typename = typename std::enable_if< + sizeof...(TTypes) == sizeof...(UTypes) + >::type +> +ROCPRIM_HOST_DEVICE inline +bool operator==(const tuple& lhs, const tuple& rhs) +{ + return detail::tuple_equal_to()(lhs, rhs); +} + +/// \brief Not equal to operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares every element of the tuple lhs with the corresponding element +/// of the tuple rhs, and returns \p true if at least one of such pairs is +/// not equal. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return !(lhr == rhs) +template +ROCPRIM_HOST_DEVICE inline +bool operator!=(const tuple& lhs, const tuple& rhs) +{ + return !(lhs == rhs); +} + +/// \brief Less than operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares lhs and rhs lexicographically. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return (bool)(rocprim::get<0>(lhs) < rocprim::get<0>(rhs)) || +/// (!(bool)(rocprim::get<0>(rhs) < rocprim::get<0>(lhs)) && lhstail < rhstail), where +/// \p lhstail is \p lhs without its first element, and \p rhstail is \p rhs without its first +/// element. For two empty tuples, it returns \p false. +template< + class... TTypes, + class... UTypes, + typename = typename std::enable_if< + sizeof...(TTypes) == sizeof...(UTypes) + >::type +> +ROCPRIM_HOST_DEVICE inline +bool operator<(const tuple& lhs, const tuple& rhs) +{ + return detail::tuple_less_than()(lhs, rhs); +} + +/// \brief Greater than operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares lhs and rhs lexicographically. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return rhs < lhs +template +ROCPRIM_HOST_DEVICE inline +bool operator>(const tuple& lhs, const tuple& rhs) +{ + return rhs < lhs; +} + +/// \brief Less than or equal to operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares lhs and rhs lexicographically. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return !(rhs < lhs) +template +ROCPRIM_HOST_DEVICE inline +bool operator<=(const tuple& lhs, const tuple& rhs) +{ + return !(rhs < lhs); +} + +/// \brief Greater than or equal to operator for tuples. +/// +/// \tparam TTypes... - the element types of \p lhs tuple. +/// \tparam UTypes... - the element types of \p rhs tuple. +/// +/// Compares lhs and rhs lexicographically. +/// +/// \param lhs tuple to compare with \p rhs +/// \param rhs tuple to compare with \p lhs +/// \return !(lhs < rhs) +template +ROCPRIM_HOST_DEVICE inline +bool operator>=(const tuple& lhs, const tuple& rhs) +{ + return !(lhs < rhs); +} + +// //////////////////////// +// swap +// //////////////////////// + +/// \brief Swaps the content of \p lhs tuple with the content \p rhs +/// \param lhs,rhs tuples whose contents to swap +template +ROCPRIM_HOST_DEVICE inline +void swap(tuple& lhs, tuple& rhs) noexcept +{ + lhs.swap(rhs); +} + +// //////////////////////// +// get +// //////////////////////// + +/// \brief Extracts the I-th element from the tuple, where \p I is +/// an integer value from range [0, sizeof...(Types)). +/// \param t tuple whose contents to extract +/// \return constant refernce to the selected element of input tuple \p t. +template +ROCPRIM_HOST_DEVICE inline +const tuple_element_t>& get(const tuple& t) noexcept +{ + using type = detail::tuple_value>>; + return static_cast(t.base).get(); +} + +/// \brief Extracts the I-th element from the tuple, where \p I is +/// an integer value from range [0, sizeof...(Types)). +/// \param t tuple whose contents to extract +/// \return refernce to the selected element of input tuple \p t. +template +ROCPRIM_HOST_DEVICE inline +tuple_element_t>& get(tuple& t) noexcept +{ + using type = detail::tuple_value>>; + return static_cast(t.base).get(); +} + +/// \brief Extracts the I-th element from the tuple, where \p I is +/// an integer value from range [0, sizeof...(Types)). +/// \param t tuple whose contents to extract +/// \return rvalue refernce to the selected element of input tuple \p t. +template +ROCPRIM_HOST_DEVICE inline +tuple_element_t>&& get(tuple&& t) noexcept +{ + using value_type = tuple_element_t>; + using type = detail::tuple_value>>; + return static_cast(static_cast(t.base).get()); +} + +// //////////////////////// +// make_tuple +// //////////////////////// + +namespace detail +{ + +template +struct make_tuple_return +{ + using type = T; +}; + +template +struct make_tuple_return> +{ + using type = T&; +}; + +template +using make_tuple_return_t = typename make_tuple_return::type>::type; + +} // end detail namespace + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +ROCPRIM_HOST_DEVICE inline +tuple...> make_tuple(Types&&... args) noexcept +{ + return tuple...>(::rocprim::detail::custom_forward(args)...); +} +#else +/// \brief Creates a tuple, returned tuple type is deduced from the types of arguments. +/// +/// Returned tuple type \p tuple is deduced like this: For each \p Ti in +/// \p Types..., the corresponding type \p Vi in \p VTypes... is \p std::decay::type +/// unless \p std::decay::type results in \p std::reference_wrapper for some type U, +/// in which case the deduced type is U&. +/// +/// \param args - zero or more arguments to create tuple from +/// +/// \see std::tuple +template +tuple make_tuple(Types&&... args); +#endif + +// //////////////////////// +// ignore +// //////////////////////// + +namespace detail +{ + +struct ignore_t +{ + ROCPRIM_HOST_DEVICE inline + ignore_t() = default; + + ROCPRIM_HOST_DEVICE inline + ~ignore_t() = default; + + template + ROCPRIM_HOST_DEVICE inline + const ignore_t& operator=(const T&) const + { + return *this; + } +}; + +} +#ifndef DOXYGEN_SHOULD_SKIP_THIS +using ignore_type = detail::ignore_t; +#else +struct ignore_type; +#endif +/// \brief Assigning value to ignore object has no effect. +/// +/// Intended for use with \ref rocprim::tie when unpacking a \ref tuple, +/// as a placeholder for the arguments that are not used. +/// +/// \see std::ignore +const ignore_type ignore; + +// //////////////////////// +// tie +// //////////////////////// + +/// \brief Creates a tuple of lvalue references to its arguments \p args or instances +/// of \ref rocprim::ignore. +/// +/// \param args - zero or more input lvalue references used to create tuple +/// +/// \see std::tie +template +ROCPRIM_HOST_DEVICE inline +tuple tie(Types&... args) noexcept +{ + return ::rocprim::tuple(args...); +} + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule_tuple + +#endif // ROCPRIM_TYPES_TUPLE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_reduce_crosslane.hpp b/3rdparty/cub/rocprim/warp/detail/warp_reduce_crosslane.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9cba59bf23ec2ebf573c431caa94b9d3c6c6be1 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_reduce_crosslane.hpp @@ -0,0 +1,53 @@ +// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_CROSSLANE_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_REDUCE_CROSSLANE_HPP_ + +#include + +#include "../../config.hpp" + +#include "warp_reduce_dpp.hpp" +#include "warp_reduce_shuffle.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize, + bool UseAllReduce, + bool UseDPP = ROCPRIM_DETAIL_USE_DPP +> +using warp_reduce_crosslane = + typename std::conditional< + UseDPP, + warp_reduce_dpp, + warp_reduce_shuffle + >::type; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_CROSSLANE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_reduce_dpp.hpp b/3rdparty/cub/rocprim/warp/detail/warp_reduce_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8d1bc20f2399e6125b73db32eb61cf1e1abf4291 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_reduce_dpp.hpp @@ -0,0 +1,167 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_DPP_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_REDUCE_DPP_HPP_ + +#include + +#include "../../config.hpp" +#include "../../intrinsics.hpp" +#include "../../types.hpp" +#include "../../detail/various.hpp" + +#include "warp_reduce_shuffle.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize, + bool UseAllReduce +> +class warp_reduce_dpp +{ +public: + static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + + using storage_type = detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, BinaryFunction reduce_op) + { + output = input; + + if(WarpSize > 1) + { + // quad_perm:[1,0,3,2] -> 10110001 + output = reduce_op(warp_move_dpp(output), output); + } + if(WarpSize > 2) + { + // quad_perm:[2,3,0,1] -> 01001110 + output = reduce_op(warp_move_dpp(output), output); + } + if(WarpSize > 4) + { + // row_shr:4 + output = reduce_op(warp_move_dpp(output), output); + } + if(WarpSize > 8) + { + // row_shr:8 + output = reduce_op(warp_move_dpp(output), output); + } +#if ROCPRIM_NAVI + if(WarpSize > 16) + { + // row_bcast:15 + output = reduce_op(warp_swizzle(output), output); + } +#else + if(WarpSize > 16) + { + // row_bcast:15 + output = reduce_op(warp_move_dpp(output), output); + } + if(WarpSize > 32) + { + // row_bcast:31 + output = reduce_op(warp_move_dpp(output), output); + } +#endif + // Read the result from the last lane of the logical warp + output = warp_shuffle(output, WarpSize - 1, WarpSize); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; // disables unused parameter warning + this->reduce(input, output, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, unsigned int valid_items, BinaryFunction reduce_op) + { + // Fallback to shuffle-based implementation + warp_reduce_shuffle() + .reduce(input, output, valid_items, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, unsigned int valid_items, + storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; // disables unused parameter warning + this->reduce(input, output, valid_items, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void head_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) + { + // Fallback to shuffle-based implementation + warp_reduce_shuffle() + .head_segmented_reduce(input, output, flag, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void tail_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) + { + // Fallback to shuffle-based implementation + warp_reduce_shuffle() + .tail_segmented_reduce(input, output, flag, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void head_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + // Fallback to shuffle-based implementation + warp_reduce_shuffle() + .head_segmented_reduce(input, output, flag, storage, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void tail_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + // Fallback to shuffle-based implementation + warp_reduce_shuffle() + .tail_segmented_reduce(input, output, flag, storage, reduce_op); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_DPP_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_reduce_shared_mem.hpp b/3rdparty/cub/rocprim/warp/detail/warp_reduce_shared_mem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6c7d4e09dd546e1dea7eb75cf79be69d7c71a2f --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_reduce_shared_mem.hpp @@ -0,0 +1,167 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_ + +#include + +#include "../../config.hpp" +#include "../../intrinsics.hpp" +#include "../../types.hpp" +#include "../../detail/various.hpp" + +#include "warp_segment_bounds.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize, + bool UseAllReduce +> +class warp_reduce_shared_mem +{ + struct storage_type_ + { + T values[WarpSize]; + }; + +public: + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) + { + constexpr unsigned int ceiling = next_power_of_two(WarpSize); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + output = input; + store_volatile(&storage_.values[lid], output); + ROCPRIM_UNROLL + for(unsigned int i = ceiling >> 1; i > 0; i >>= 1) + { + if (lid + i < WarpSize && lid < i) + { + output = load_volatile(&storage_.values[lid]); + T other = load_volatile(&storage_.values[lid + i]); + output = reduce_op(output, other); + store_volatile(&storage_.values[lid], output); + } + } + set_output(output, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, unsigned int valid_items, + storage_type& storage, BinaryFunction reduce_op) + { + constexpr unsigned int ceiling = next_power_of_two(WarpSize); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + output = input; + store_volatile(&storage_.values[lid], output); + ROCPRIM_UNROLL + for(unsigned int i = ceiling >> 1; i > 0; i >>= 1) + { + if((lid + i) < WarpSize && lid < i && (lid + i) < valid_items) + { + output = load_volatile(&storage_.values[lid]); + T other = load_volatile(&storage_.values[lid + i]); + output = reduce_op(output, other); + store_volatile(&storage_.values[lid], output); + } + } + set_output(output, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void head_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + this->segmented_reduce(input, output, flag, storage, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void tail_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + this->segmented_reduce(input, output, flag, storage, reduce_op); + } + +private: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + const unsigned int lid = detail::logical_lane_id(); + constexpr unsigned int ceiling = next_power_of_two(WarpSize); + storage_type_& storage_ = storage.get(); + // Get logical lane id of the last valid value in the segment + auto last = last_in_warp_segment(flag); + + output = input; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < ceiling; i *= 2) + { + store_volatile(&storage_.values[lid], output); + if((lid + i) <= last) + { + T other = load_volatile(&storage_.values[lid + i]); + output = reduce_op(output, other); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Switch == false)>::type + set_output(T& output, storage_type& storage) + { + (void) output; + (void) storage; + // output already set correctly + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Switch == true)>::type + set_output(T& output, storage_type& storage) + { + storage_type_& storage_ = storage.get(); + output = load_volatile(&storage_.values[0]); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHARED_MEM_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_reduce_shuffle.hpp b/3rdparty/cub/rocprim/warp/detail/warp_reduce_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a7bfd38379c3b0ceeac8dff4405044c6abf79cc --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_reduce_shuffle.hpp @@ -0,0 +1,165 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_ + +#include + +#include "../../config.hpp" +#include "../../intrinsics.hpp" +#include "../../types.hpp" +#include "../../detail/various.hpp" + +#include "warp_segment_bounds.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize, + bool UseAllReduce +> +class warp_reduce_shuffle +{ +public: + static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + + using storage_type = detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, BinaryFunction reduce_op) + { + output = input; + + T value; + ROCPRIM_UNROLL + for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + { + value = warp_shuffle_down(output, offset, WarpSize); + output = reduce_op(output, value); + } + set_output(output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; // disables unused parameter warning + this->reduce(input, output, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, unsigned int valid_items, BinaryFunction reduce_op) + { + output = input; + + T value; + ROCPRIM_UNROLL + for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + { + value = warp_shuffle_down(output, offset, WarpSize); + unsigned int id = detail::logical_lane_id(); + if (id + offset < valid_items) output = reduce_op(output, value); + } + set_output(output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void reduce(T input, T& output, unsigned int valid_items, + storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; // disables unused parameter warning + this->reduce(input, output, valid_items, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void head_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) + { + this->segmented_reduce(input, output, flag, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void tail_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) + { + this->segmented_reduce(input, output, flag, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void head_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; + this->segmented_reduce(input, output, flag, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void tail_segmented_reduce(T input, T& output, Flag flag, + storage_type& storage, BinaryFunction reduce_op) + { + (void) storage; + this->segmented_reduce(input, output, flag, reduce_op); + } + +private: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) + { + // Get logical lane id of the last valid value in the segment, + // and convert it to number of valid values in segment. + auto valid_items_in_segment = last_in_warp_segment(flag) + 1U; + this->reduce(input, output, valid_items_in_segment, reduce_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Switch == false)>::type + set_output(T& output) + { + (void) output; + // output already set correctly + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(Switch == true)>::type + set_output(T& output) + { + output = warp_shuffle(output, 0, WarpSize); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_REDUCE_SHUFFLE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_scan_crosslane.hpp b/3rdparty/cub/rocprim/warp/detail/warp_scan_crosslane.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acf7bbb7d6cfc08ac62e474fac48c19f88170d72 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_scan_crosslane.hpp @@ -0,0 +1,52 @@ +// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_CROSSLANE_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SCAN_CROSSLANE_HPP_ + +#include + +#include "../../config.hpp" + +#include "warp_scan_dpp.hpp" +#include "warp_scan_shuffle.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize, + bool UseDPP = ROCPRIM_DETAIL_USE_DPP +> +using warp_scan_crosslane = + typename std::conditional< + UseDPP, + warp_scan_dpp, + warp_scan_shuffle + >::type; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_CROSSLANE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_scan_dpp.hpp b/3rdparty/cub/rocprim/warp/detail/warp_scan_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbe13674f1f9f46e0d1249b97954b527b76d14f9 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_scan_dpp.hpp @@ -0,0 +1,270 @@ +// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize +> +class warp_scan_dpp +{ +public: + static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + + using storage_type = detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, BinaryFunction scan_op) + { + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int row_lane_id = lane_id % ::rocprim::min(16u, WarpSize); + + output = input; + + if(WarpSize > 1) + { + T t = scan_op(warp_move_dpp(output), output); // row_shr:1 + if(row_lane_id >= 1) output = t; + } + if(WarpSize > 2) + { + T t = scan_op(warp_move_dpp(output), output); // row_shr:2 + if(row_lane_id >= 2) output = t; + } + if(WarpSize > 4) + { + T t = scan_op(warp_move_dpp(output), output); // row_shr:4 + if(row_lane_id >= 4) output = t; + } + if(WarpSize > 8) + { + T t = scan_op(warp_move_dpp(output), output); // row_shr:8 + if(row_lane_id >= 8) output = t; + } +#if ROCPRIM_NAVI + if(WarpSize > 16) + { + T t = scan_op(warp_swizzle(output), output); // row_bcast:15 + if(lane_id % 32 >= 16) output = t; + } +#else + if(WarpSize > 16) + { + T t = scan_op(warp_move_dpp(output), output); // row_bcast:15 + if(lane_id % 32 >= 16) output = t; + } + if(WarpSize > 32) + { + T t = scan_op(warp_move_dpp(output), output); // row_bcast:31 + if(lane_id >= 32) output = t; + } +#endif + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, output, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize-1, WarpSize); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + inclusive_scan(input, output, reduction, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(output, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + exclusive_scan(input, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + exclusive_scan(input, output, init, reduction, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, + BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + scan(input, inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, inclusive_output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(inclusive_output, WarpSize-1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + scan(input, inclusive_output, exclusive_output, init, reduction, scan_op); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T broadcast(T input, const unsigned int src_lane, storage_type& storage) + { + (void) storage; + return warp_shuffle(input, src_lane, WarpSize); + } + +protected: + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + { + (void) storage; + return to_exclusive(inclusive_input, exclusive_output); + } + +private: + // Changes inclusive scan results to exclusive scan results + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output, T init, + BinaryFunction scan_op) + { + // include init value in scan results + exclusive_output = scan_op(init, inclusive_input); + // get exclusive results + exclusive_output = warp_shuffle_up(exclusive_output, 1, WarpSize); + if(detail::logical_lane_id() == 0) + { + exclusive_output = init; + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output) + { + // shift to get exclusive results + exclusive_output = warp_shuffle_up(inclusive_input, 1, WarpSize); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_scan_shared_mem.hpp b/3rdparty/cub/rocprim/warp/detail/warp_scan_shared_mem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cced3e80ea3d6b68b61413098c1ac55400c42850 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_scan_shared_mem.hpp @@ -0,0 +1,191 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize +> +class warp_scan_shared_mem +{ + struct storage_type_ + { + T threads[WarpSize]; + }; +public: + using storage_type = detail::raw_storage; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + T me = input; + store_volatile(&storage_.threads[lid], me); + for(unsigned int i = 1; i < WarpSize; i *= 2) + { + if(lid >= i) + { + T other = load_volatile(&storage_.threads[lid - i]); + me = scan_op(other, me); + store_volatile(&storage_.threads[lid], me); + } + } + output = me; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + inclusive_scan(input, output, storage, scan_op); + reduction = load_volatile(&storage_.threads[WarpSize - 1]); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, + storage_type& storage, BinaryFunction scan_op) + { + inclusive_scan(input, output, storage, scan_op); + to_exclusive(output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + inclusive_scan(input, output, storage, scan_op); + to_exclusive(output, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + inclusive_scan(input, output, storage, scan_op); + reduction = load_volatile(&storage_.threads[WarpSize - 1]); + to_exclusive(output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, + storage_type& storage, BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, storage, scan_op); + to_exclusive(exclusive_output, init, storage, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, + storage_type& storage, BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, storage, scan_op); + to_exclusive(exclusive_output, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + storage_type_& storage_ = storage.get(); + inclusive_scan(input, inclusive_output, storage, scan_op); + reduction = load_volatile(&storage_.threads[WarpSize - 1]); + to_exclusive(exclusive_output, init, storage, scan_op); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T broadcast(T input, const unsigned int src_lane, storage_type& storage) + { + storage_type_& storage_ = storage.get(); + if(src_lane == detail::logical_lane_id()) + { + store_volatile(&storage_.threads[src_lane], input); + } + return load_volatile(&storage_.threads[src_lane]); + } + +protected: + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + { + (void) inclusive_input; + return to_exclusive(exclusive_output, storage); + } + +private: + // Calculate exclusive results base on inclusive scan results in storage.threads[]. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T& exclusive_output, T init, + storage_type& storage, BinaryFunction scan_op) + { + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + exclusive_output = init; + if(lid != 0) + { + exclusive_output = scan_op(init, load_volatile(&storage_.threads[lid-1])); + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T& exclusive_output, storage_type& storage) + { + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + if(lid != 0) + { + exclusive_output = load_volatile(&storage_.threads[lid-1]); + } + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_scan_shuffle.hpp b/3rdparty/cub/rocprim/warp/detail/warp_scan_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..457297528e89593931f6844916ee715679436f93 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_scan_shuffle.hpp @@ -0,0 +1,237 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_SHUFFLE_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SCAN_SHUFFLE_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../types.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class T, + unsigned int WarpSize +> +class warp_scan_shuffle +{ +public: + static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + + using storage_type = detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, BinaryFunction scan_op) + { + output = input; + + T value; + const unsigned int id = detail::logical_lane_id(); + ROCPRIM_UNROLL + for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + { + value = warp_shuffle_up(output, offset, WarpSize); + if(id >= offset) output = scan_op(value, output); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, output, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize-1, WarpSize); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void inclusive_scan(T input, T& output, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + inclusive_scan(input, output, reduction, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(output, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + exclusive_scan(input, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + exclusive_scan(input, output, init, reduction, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, + BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + scan(input, inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; // disables unused parameter warning + inclusive_scan(input, inclusive_output, scan_op); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + BinaryFunction scan_op) + { + inclusive_scan(input, inclusive_output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(inclusive_output, WarpSize-1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(inclusive_output, exclusive_output, init, scan_op); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + storage_type& storage, BinaryFunction scan_op) + { + (void) storage; + scan(input, inclusive_output, exclusive_output, init, reduction, scan_op); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + T broadcast(T input, const unsigned int src_lane, storage_type& storage) + { + (void) storage; + return warp_shuffle(input, src_lane, WarpSize); + } + +protected: + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + { + (void) storage; + return to_exclusive(inclusive_input, exclusive_output); + } + +private: + // Changes inclusive scan results to exclusive scan results + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output, T init, + BinaryFunction scan_op) + { + // include init value in scan results + exclusive_output = scan_op(init, inclusive_input); + // get exclusive results + exclusive_output = warp_shuffle_up(exclusive_output, 1, WarpSize); + if(detail::logical_lane_id() == 0) + { + exclusive_output = init; + } + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void to_exclusive(T inclusive_input, T& exclusive_output) + { + // shift to get exclusive results + exclusive_output = warp_shuffle_up(inclusive_input, 1, WarpSize); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_SHUFFLE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_segment_bounds.hpp b/3rdparty/cub/rocprim/warp/detail/warp_segment_bounds.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf1c142c5a1206747ea89b87248b495989b845a0 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_segment_bounds.hpp @@ -0,0 +1,80 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SEGMENT_BOUNDS_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SEGMENT_BOUNDS_HPP_ + +#include + +#include "../../config.hpp" +#include "../../intrinsics.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Returns logical warp id of the last thread in thread's segment +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto last_in_warp_segment(Flag flag) + -> typename std::enable_if<(WarpSize <= __AMDGCN_WAVEFRONT_SIZE), unsigned int>::type +{ + // Get flags (now every thread know where the flags are) + lane_mask_type warp_flags = ::rocprim::ballot(flag); + + // In case of head flags change them to tail flags + if(HeadSegmented) + { + warp_flags >>= 1; + } + const auto lane_id = ::rocprim::lane_id(); + // Zero bits from thread with lower lane id + warp_flags &= lane_mask_type(-1) ^ ((lane_mask_type(1) << lane_id) - 1U); + // Ignore bits from thread from other (previous) logical warps + warp_flags >>= (lane_id / WarpSize) * WarpSize; + // Make sure last item in logical warp is marked as a tail + warp_flags |= lane_mask_type(1) << (WarpSize - 1U); + // Calculate logical lane id of the last valid value in the segment +#ifndef __HIP_CPU_RT__ + #if __AMDGCN_WAVEFRONT_SIZE == 32 + return ::__ffs(warp_flags) - 1; + #else + return ::__ffsll(warp_flags) - 1; + #endif +#else +#if _MSC_VER + // TODO: verify correctness + unsigned long tmp = 0; + _BitScanReverse64(&tmp, warp_flags); + return 1u << tmp; +#elif __GNUC__ + return __builtin_ctzl(warp_flags); +#else + static_assert(false, "Look for GCC/Clang implementation"); +#endif +#endif +} + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SEGMENT_BOUNDS_HPP_ diff --git a/3rdparty/cub/rocprim/warp/detail/warp_sort_shuffle.hpp b/3rdparty/cub/rocprim/warp/detail/warp_sort_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ace5441c1900c03cffb033428485c3d2cab3ef80 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/detail/warp_sort_shuffle.hpp @@ -0,0 +1,511 @@ +// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ +#define ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ + +#include + +#include "../../config.hpp" +#include "../../detail/various.hpp" + +#include "../../intrinsics.hpp" +#include "../../functional.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template< + class Key, + unsigned int WarpSize, + class Value +> +class warp_sort_shuffle +{ +private: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type + swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function) + { + (void) k; + (void) v; + (void) mask; + (void) dir; + (void) compare_function; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(WarpSize > warp)>::type + swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function) + { + Key k1 = warp_shuffle_xor(k, mask, WarpSize); + //V v1 = warp_shuffle_xor(v, mask, WarpSize); + bool swap = compare_function(dir ? k : k1, dir ? k1 : k); + if (swap) + { + k = k1; + v = warp_shuffle_xor(v, mask, WarpSize); + } + } + + template< + int warp, + class V, + class BinaryFunction, + unsigned int ItemsPerThread + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type + swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + int mask, + bool dir, + BinaryFunction compare_function) + { + (void) k; + (void) v; + (void) mask; + (void) dir; + (void) compare_function; + } + + template< + int warp, + class V, + class BinaryFunction, + unsigned int ItemsPerThread + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(WarpSize > warp)>::type + swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + int mask, + bool dir, + BinaryFunction compare_function) + { + Key k1[ItemsPerThread]; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + k1[item]= warp_shuffle_xor(k[item], mask, WarpSize); + //V v1 = warp_shuffle_xor(v, mask, WarpSize); + bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); + if (swap) + { + k[item] = k1[item]; + v[item] = warp_shuffle_xor(v[item], mask, WarpSize); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type + swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + { + (void) k; + (void) mask; + (void) dir; + (void) compare_function; + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(WarpSize > warp)>::type + swap(Key& k, int mask, bool dir, BinaryFunction compare_function) + { + Key k1 = warp_shuffle_xor(k, mask, WarpSize); + bool swap = compare_function(dir ? k : k1, dir ? k1 : k); + if (swap) + { + k = k1; + } + } + + template< + int warp, + class BinaryFunction, + unsigned int ItemsPerThread + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type + swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + { + (void) k; + (void) mask; + (void) dir; + (void) compare_function; + } + + template< + int warp, + class BinaryFunction, + unsigned int ItemsPerThread + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(WarpSize > warp)>::type + swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function) + { + Key k1[ItemsPerThread]; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + k1[item]= warp_shuffle_xor(k[item], mask, WarpSize); + bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); + if (swap) + { + k[item] = k1[item]; + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_swap(Key (&k)[ItemsPerThread], + unsigned int i, + unsigned int j, + bool dir, + BinaryFunction compare_function) + { + if(compare_function(k[i], k[j]) == dir) + { + Key temp = k[i]; + k[i] = k[j]; + k[j] = temp; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + unsigned int i, + unsigned int j, + bool dir, + BinaryFunction compare_function) + { + if(compare_function(k[i], k[j]) == dir) + { + Key k_temp = k[i]; + k[i] = k[j]; + k[j] = k_temp; + V v_temp = v[i]; + v[i] = v[j]; + v[j] = v_temp; + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_shuffle(unsigned int group_size, + unsigned int offset, + bool dir, + BinaryFunction compare_function, + KeyValue&... kv) + { + ROCPRIM_UNROLL + for(unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) { + // The local direction must change every group_size items + // and is flipped if dir is true + const bool local_dir = ((base & group_size) > 0) != dir; + + ROCPRIM_UNROLL + for(unsigned i = 0; i < offset; ++i) { + thread_swap(kv..., base + i, base + i + offset, local_dir, compare_function); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void thread_sort(bool dir, BinaryFunction compare_function, KeyValue&... kv) + { + ROCPRIM_UNROLL + for(unsigned int k = 2; k <= ItemsPerThread; k *= 2) + { + ROCPRIM_UNROLL + for(unsigned int j = k / 2; j > 0; j /= 2) + { + thread_shuffle(k, j, dir, compare_function, kv...); + } + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(WarpSize > warp)>::type + thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) + { + ROCPRIM_UNROLL + for(unsigned int j = ItemsPerThread / 2; j > 0; j /= 2) + { + thread_shuffle(ItemsPerThread, j, dir, compare_function, kv...); + } + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type + thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) + { + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv) + { + static_assert( + sizeof...(KeyValue) < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" + ); + + unsigned int id = detail::logical_lane_id(); + swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function); + + swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function); + swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function); + + swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function); + swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function); + swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function); + + swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function); + swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function); + swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function); + swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function); + + swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function); + swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function); + swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function); + swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function); + swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function); + + swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function); + swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function); + swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function); + swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function); + swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function); + swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function); + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction, + class... KeyValue + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv) + { + static_assert( + sizeof...(KeyValue) < 3, + "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" + ); + + static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be power of 2"); + + unsigned int id = detail::logical_lane_id(); + thread_sort(get_bit(id, 0) != 0, compare_function, kv...); + + swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function); + thread_merge<2, ItemsPerThread>(get_bit(id, 1) != 0, compare_function, kv...); + + swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function); + swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function); + thread_merge<4, ItemsPerThread>(get_bit(id, 2) != 0, compare_function, kv...); + + swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function); + swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function); + swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function); + thread_merge<8, ItemsPerThread>(get_bit(id, 3) != 0, compare_function, kv...); + + swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function); + swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function); + swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function); + swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function); + thread_merge<16, ItemsPerThread>(get_bit(id, 4) != 0, compare_function, kv...); + + swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function); + swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function); + swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function); + swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function); + swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function); + thread_merge<32, ItemsPerThread>(get_bit(id, 5) != 0, compare_function, kv...); + + swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function); + swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function); + swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function); + swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function); + swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function); + swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function); + thread_merge<1, ItemsPerThread>(false, compare_function, kv...); + } + +public: + static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + + using storage_type = ::rocprim::detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_value, BinaryFunction compare_function) + { + // sort by value only + bitonic_sort(compare_function, thread_value); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_value, storage_type& storage, + BinaryFunction compare_function) + { + (void) storage; + sort(thread_value, compare_function); + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) + { + // sort by value only + bitonic_sort(compare_function, thread_values); + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function) + { + (void) storage; + sort(thread_values, compare_function); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(sizeof(V) <= sizeof(int))>::type + sort(Key& thread_key, Value& thread_value, + BinaryFunction compare_function) + { + bitonic_sort(compare_function, thread_key, thread_value); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if::type + sort(Key& thread_key, Value& thread_value, + BinaryFunction compare_function) + { + // Instead of passing large values between lanes we pass indices and gather values after sorting. + unsigned int v = detail::logical_lane_id(); + bitonic_sort(compare_function, thread_key, v); + thread_value = warp_shuffle(thread_value, v, WarpSize); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key& thread_key, Value& thread_value, + storage_type& storage, BinaryFunction compare_function) + { + (void) storage; + sort(compare_function, thread_key, thread_value); + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction, + class V = Value + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(sizeof(V) <= sizeof(int))>::type + sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) + { + bitonic_sort(compare_function, thread_keys, thread_values); + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction, + class V = Value + > + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if::type + sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function) + { + // Instead of passing large values between lanes we pass indices and gather values after sorting. + unsigned int v[ItemsPerThread]; + ROCPRIM_UNROLL + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + v[item] = ItemsPerThread * detail::logical_lane_id() + item; + } + + bitonic_sort(compare_function, thread_keys, v); + + V copy[ItemsPerThread]; + ROCPRIM_UNROLL + for(unsigned item = 0; item < ItemsPerThread; ++item) { + copy[item] = thread_values[item]; + } + + ROCPRIM_UNROLL + for(unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) { + ROCPRIM_UNROLL + for(unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) { + V temp = warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize); + if(v[dst_item] % ItemsPerThread == src_item) + thread_values[dst_item] = temp; + } + } + } + + template< + unsigned int ItemsPerThread, + class BinaryFunction + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& storage, BinaryFunction compare_function) + { + (void) storage; + sort(thread_keys, thread_values, compare_function); + } +}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_exchange.hpp b/3rdparty/cub/rocprim/warp/warp_exchange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99581091bb58efb4f5840c12d0d44e77f9d37d80 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_exchange.hpp @@ -0,0 +1,420 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_EXCHANGE_HPP_ +#define ROCPRIM_WARP_WARP_EXCHANGE_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../intrinsics/warp_shuffle.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The \p warp_exchange class is a warp level parallel primitive which provides +/// methods for rearranging items partitioned across threads in a warp. +/// +/// \tparam T - the input type. +/// \tparam ItemsPerThread - the number of items contributed by each thread. +/// \tparam WarpSize - the number of threads in a warp. +/// +/// \par Overview +/// * The \p warp_exchange class supports the following rearrangement methods: +/// * Transposing a blocked arrangement to a striped arrangement. +/// * Transposing a striped arrangement to a blocked arrangement. +/// +/// \par Examples +/// \parblock +/// In the example an exchange operation is performed on a warp of 8 threads, using type +/// \p int with 4 items per thread. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// constexpr unsigned int threads_per_block = 128; +/// constexpr unsigned int threads_per_warp = 8; +/// constexpr unsigned int items_per_thread = 4; +/// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; +/// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; +/// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread +/// using warp_exchange_int = rocprim::warp_exchange; +/// // allocate storage in shared memory +/// __shared__ warp_exchange_int::storage_type storage[warps_per_block]; +/// +/// int items[items_per_thread]; +/// ... +/// warp_exchange_int w_exchange; +/// w_exchange.blocked_to_striped(items, items, storage[warp_id]); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize = ::rocprim::device_warp_size() +> +class warp_exchange +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + + // Struct used for creating a raw_storage object for this primitive's temporary storage. + struct storage_type_ + { + T buffer[WarpSize * ItemsPerThread]; + }; + +public: + + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by the related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = detail::raw_storage; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Transposes a blocked arrangement of items to a striped arrangement + /// across the warp, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// constexpr unsigned int threads_per_block = 128; + /// constexpr unsigned int threads_per_warp = 8; + /// constexpr unsigned int items_per_thread = 4; + /// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; + /// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; + /// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread + /// using warp_exchange_int = rocprim::warp_exchange; + /// // allocate storage in shared memory + /// __shared__ warp_exchange_int::storage_type storage[warps_per_block]; + /// + /// int items[items_per_thread]; + /// ... + /// warp_exchange_int w_exchange; + /// w_exchange.blocked_to_striped(items, items, storage[warp_id]); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[flat_id * ItemsPerThread + i] = input[i]; + } + ::rocprim::wave_barrier(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[i * WarpSize + flat_id]; + } + } + + /// \brief Transposes a blocked arrangement of items to a striped arrangement + /// across the warp, using warp shuffle operations. + /// Caution: this API is experimental. Performance might not be consistent. + /// ItemsPerThread must be a divisor of WarpSize. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// constexpr unsigned int threads_per_block = 128; + /// constexpr unsigned int threads_per_warp = 8; + /// constexpr unsigned int items_per_thread = 4; + /// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; + /// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; + /// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread + /// using warp_exchange_int = rocprim::warp_exchange; + /// + /// int items[items_per_thread]; + /// ... + /// warp_exchange_int w_exchange; + /// w_exchange.blocked_to_striped_shuffle(items, items); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_striped_shuffle(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + static_assert(WarpSize % ItemsPerThread == 0, + "ItemsPerThread must be a divisor of WarpSize to use blocked_to_striped_shuffle"); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + U work_array[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + { + ROCPRIM_UNROLL + for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) + { + const auto value = ::rocprim::warp_shuffle( + input[src_idx], + flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread) + ); + if(src_idx == flat_id % ItemsPerThread) + { + work_array[dst_idx] = value; + } + } + } + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = work_array[i]; + } + } + + /// \brief Transposes a striped arrangement of items to a blocked arrangement + /// across the warp, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// constexpr unsigned int threads_per_block = 128; + /// constexpr unsigned int threads_per_warp = 8; + /// constexpr unsigned int items_per_thread = 4; + /// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; + /// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; + /// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread + /// using warp_exchange_int = rocprim::warp_exchange; + /// // allocate storage in shared memory + /// __shared__ warp_exchange_int::storage_type storage[warps_per_block]; + /// + /// int items[items_per_thread]; + /// ... + /// warp_exchange_int w_exchange; + /// w_exchange.striped_to_blocked(items, items, storage[warp_id]); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void striped_to_blocked(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[i * WarpSize + flat_id] = input[i]; + } + ::rocprim::wave_barrier(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_.buffer[flat_id * ItemsPerThread + i]; + } + } + + /// \brief Transposes a striped arrangement of items to a blocked arrangement + /// across the warp, using warp shuffle operations. + /// Caution: this API is experimental. Performance might not be consistent. + /// ItemsPerThread must be a divisor of WarpSize. + /// + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// constexpr unsigned int threads_per_block = 128; + /// constexpr unsigned int threads_per_warp = 8; + /// constexpr unsigned int items_per_thread = 4; + /// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; + /// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; + /// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread + /// using warp_exchange_int = rocprim::warp_exchange; + /// + /// int items[items_per_thread]; + /// ... + /// warp_exchange_int w_exchange; + /// w_exchange.striped_to_blocked_shuffle(items, items); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void striped_to_blocked_shuffle(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) + { + static_assert(WarpSize % ItemsPerThread == 0, + "ItemsPerThread must be a divisor of WarpSize to use striped_to_blocked_shuffle"); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + U work_array[ItemsPerThread]; + + ROCPRIM_UNROLL + for(unsigned int dst_idx = 0; dst_idx < ItemsPerThread; dst_idx++) + { + ROCPRIM_UNROLL + for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) + { + const auto value = ::rocprim::warp_shuffle( + input[src_idx], + (ItemsPerThread * flat_id + dst_idx) % WarpSize + ); + if(flat_id / (WarpSize / ItemsPerThread) == src_idx) + { + work_array[dst_idx] = value; + } + } + } + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = work_array[i]; + } + } + + /// \brief Orders \p input values according to ranks using temporary storage, + /// then writes the values to \p output in a striped manner. + /// No values in \p ranks should exists that exceed \p WarpSize*ItemsPerThread-1 . + /// \tparam U - [inferred] the output type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [in] ranks - array containing the positions. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// constexpr unsigned int threads_per_block = 128; + /// constexpr unsigned int threads_per_warp = 8; + /// constexpr unsigned int items_per_thread = 4; + /// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; + /// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; + /// // specialize warp_exchange for int, warp of 8 threads and 4 items per thread + /// using warp_exchange_int = rocprim::warp_exchange; + /// // allocate storage in shared memory + /// __shared__ warp_exchange_int::storage_type storage[warps_per_block]; + /// + /// int items[items_per_thread]; + /// + /// // data-type of `ranks` should be able to contain warp_size*items_per_thread unique elements + /// // unsigned short is sufficient for up to 1024*64 elements + /// unsigned short ranks[items_per_thread]; + /// ... + /// warp_exchange_int w_exchange; + /// w_exchange.scatter_to_striped(items, items, ranks, storage[warp_id]); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_striped( + const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const OffsetT (&ranks)[ItemsPerThread], + storage_type& storage) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + + ROCPRIM_UNROLL + for (unsigned int i = 0; i < ItemsPerThread; i++) + { + storage_.buffer[ranks[i]] = input[i]; + } + ::rocprim::wave_barrier(); + + ROCPRIM_UNROLL + for (unsigned int i = 0; i < ItemsPerThread; i++) + { + unsigned int item_offset = (i * WarpSize) + flat_id; + output[i] = storage_.buffer[item_offset]; + } + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group warpmodule + +#endif // ROCPRIM_WARP_WARP_EXCHANGE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_load.hpp b/3rdparty/cub/rocprim/warp/warp_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..14347ef92186a1292f9eb973135d1ea64dd84574 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_load.hpp @@ -0,0 +1,458 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_LOAD_HPP_ +#define ROCPRIM_WARP_WARP_LOAD_HPP_ + +#include "../config.hpp" +#include "../intrinsics.hpp" +#include "../detail/various.hpp" + +#include "warp_exchange.hpp" +#include "../block/block_load_func.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief \p warp_load_method enumerates the methods available to load data +/// from continuous memory into a blocked/striped arrangement of items across the warp +enum class warp_load_method +{ + /// Data from continuous memory is loaded into a blocked arrangement of items. + /// \par Performance Notes: + /// * Performance decreases with increasing number of items per thread (stride + /// between reads), because of reduced memory coalescing. + warp_load_direct, + + /// A striped arrangement of data is read directly from memory. + warp_load_striped, + + /// Data from continuous memory is loaded into a blocked arrangement of items + /// using vectorization as an optimization. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, provided that + /// vectorization requirements are fulfilled. Otherwise, performance will default + /// to \p warp_load_direct. + /// \par Requirements: + /// * The input offset (\p block_input) must be quad-item aligned. + /// * The following conditions will prevent vectorization and switch to default + /// \p warp_load_direct: + /// * \p ItemsPerThread is odd. + /// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, + /// int4, etc. + warp_load_vectorize, + + /// A striped arrangement of data from continuous memory is locally transposed + /// into a blocked arrangement of items. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p warp_load_direct and + /// \p warp_load_vectorize due to reordering on local memory. + warp_load_transpose, + + /// Defaults to \p warp_load_direct + default_method = warp_load_direct +}; + +/// \brief The \p warp_load class is a warp level parallel primitive which provides methods +/// for loading data from continuous memory into a blocked arrangement of items across a warp. +/// +/// \tparam T - the input/output type. +/// \tparam ItemsPerThread - the number of items to be processed by +/// each thread. +/// \tparam WarpSize - the number of threads in the warp. It must be a divisor of the +/// kernel block size. +/// \tparam Method - the method to load data. +/// +/// \par Overview +/// * The \p warp_load class has a number of different methods to load data: +/// * [warp_load_direct](\ref ::warp_load_method::warp_load_direct) +/// * [warp_load_striped](\ref ::warp_load_method::warp_load_striped) +/// * [warp_load_vectorize](\ref ::warp_load_method::warp_load_vectorize) +/// * [warp_load_transpose](\ref ::warp_load_method::warp_load_transpose) +/// +/// \par Example: +/// \parblock +/// In the example a load operation is performed on a warp of 8 threads, using type +/// \p int and 4 items per thread. +/// +/// \code{.cpp} +/// __global__ void example_kernel(int * input, ...) +/// { +/// constexpr unsigned int threads_per_block = 128; +/// constexpr unsigned int threads_per_warp = 8; +/// constexpr unsigned int items_per_thread = 4; +/// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; +/// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; +/// const int offset = blockIdx.x * threads_per_block * items_per_thread +/// + warp_id * threads_per_warp * items_per_thread; +/// int items[items_per_thread]; +/// rocprim::warp_load warp_load; +/// warp_load.load(input + offset, items); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize = ::rocprim::device_warp_size(), + warp_load_method Method = warp_load_method::warp_load_direct +> +class warp_load +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords \p __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// warp. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] input - the input iterator to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// warp. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] input - the input iterator to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items, valid); + } + + /// \brief Loads data from continuous memory into an arrangement of items across the + /// warp. + /// + /// \tparam InputIterator - [inferred] an iterator type for input (can be a simple + /// pointer. + /// + /// \param [in] input - the input iterator to load from. + /// \param [out] items - array that data is loaded to. + /// \param [in] valid - maximum range of valid numbers to load. + /// \param [in] out_of_bounds - default value assigned to out-of-bound items. + /// \param [in] storage - temporary storage for inputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p InputIterator + /// can be dereferenced and then implicitly converted to \p T. + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items, valid, + out_of_bounds); + } +}; + +/// @} +// end of group warpmodule + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_load +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +public: + using storage_type = typename ::rocprim::detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid, + out_of_bounds); + } +}; + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_load +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +public: + using storage_type = typename ::rocprim::detail::empty_storage_type; + + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(T* input, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked_vectorized(flat_id, input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items, valid); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_blocked(flat_id, input, items, valid, + out_of_bounds); + } +}; + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_load +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +private: + using exchange_type = ::rocprim::warp_exchange; + +public: + using storage_type = typename exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items); + exchange_type().striped_to_blocked(items, items, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid); + exchange_type().striped_to_blocked(items, items, storage); + } + + template< + class InputIterator, + class Default + > + ROCPRIM_DEVICE ROCPRIM_INLINE + void load(InputIterator input, + T (&items)[ItemsPerThread], + unsigned int valid, + Default out_of_bounds, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type InputIterator " + "can be dereferenced and then implicitly converted to T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid, + out_of_bounds); + exchange_type().striped_to_blocked(items, items, storage); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_WARP_LOAD_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_reduce.hpp b/3rdparty/cub/rocprim/warp/warp_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4cd04d578e12ad23e16a068efd2213c38326dde9 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_reduce.hpp @@ -0,0 +1,377 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_REDUCE_HPP_ +#define ROCPRIM_WARP_WARP_REDUCE_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "detail/warp_reduce_crosslane.hpp" +#include "detail/warp_reduce_shared_mem.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Select warp_reduce implementation based WarpSize +template +struct select_warp_reduce_impl +{ + typedef typename std::conditional< + // can we use crosslane (DPP or shuffle-based) implementation? + detail::is_warpsize_shuffleable::value, + detail::warp_reduce_crosslane, // yes + detail::warp_reduce_shared_mem // no + >::type type; +}; + +} // end namespace detail + +/// \brief The warp_reduce class is a warp level parallel primitive which provides methods +/// for performing reduction operations on items partitioned across threads in a hardware +/// warp. +/// +/// \tparam T - the input/output type. +/// \tparam WarpSize - the size of logical warp size, which can be equal to or less than +/// the size of hardware warp (see rocprim::device_warp_size()). Reduce operations are performed +/// separately within groups determined by WarpSize. +/// \tparam UseAllReduce - input parameter to determine whether to broadcast final reduction +/// value to all threads (default is false). +/// +/// \par Overview +/// * \p WarpSize must be equal to or less than the size of hardware warp (see +/// rocprim::device_warp_size()). If it is less, reduce is performed separately within groups +/// determined by WarpSize. \n +/// For example, if \p WarpSize is 4, hardware warp is 64, reduction will be performed in logical +/// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` +/// (thread is represented here by its id within hardware warp). +/// * Logical warp is a group of \p WarpSize consecutive threads from the same hardware warp. +/// * Supports non-commutative reduce operators. However, a reduce operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Number of threads executing warp_reduce's function must be a multiple of \p WarpSize; +/// * All threads from a logical warp must be in the same hardware warp. +/// +/// \par Examples +/// \parblock +/// In the examples reduce operation is performed on groups of 16 threads, each provides +/// one \p int value, result is returned using the same variable as for input. Hardware +/// warp size is 64. Block (tile) size is 64. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize warp_reduce for int and logical warp of 16 threads +/// using warp_reduce_int = rocprim::warp_reduce; +/// // allocate storage in shared memory +/// __shared__ warp_reduce_int::storage_type temp[4]; +/// +/// int logical_warp_id = threadIdx.x/16; +/// int value = ...; +/// // execute reduce +/// warp_reduce_int().reduce( +/// value, // input +/// value, // output +/// temp[logical_warp_id] +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int WarpSize = device_warp_size(), + bool UseAllReduce = false +> +class warp_reduce +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_warp_reduce_impl::type +#endif +{ + using base_type = typename detail::select_warp_reduce_impl::type; + + // Check if WarpSize is valid for the targets + static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, "WarpSize can't be greater than hardware warp size."); + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Performs reduction across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples reduce operation is performed on groups of 16 threads, each provides + /// one \p int value, result is returned using the same variable as for input. Hardware + /// warp size is 64. Block (tile) size is 64. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize warp_reduce for int and logical warp of 16 threads + /// using warp_reduce_int = rocprim::warp_reduce; + /// // allocate storage in shared memory + /// __shared__ warp_reduce_int::storage_type temp[4]; + /// + /// int logical_warp_id = threadIdx.x/16; + /// int value = ...; + /// // execute reduction + /// warp_reduce_int().reduce( + /// value, // input + /// value, // output + /// temp[logical_warp_id], + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce(T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::reduce(input, output, storage, reduce_op); + } + + /// \brief Performs reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce(T , + T& , + storage_type& , + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) reduce_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Performs reduction across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] valid_items - number of items that will be reduced in the warp. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// In the examples reduce operation is performed on groups of 16 threads, each provides + /// one \p int value, result is returned using the same variable as for input. Hardware + /// warp size is 64. Block (tile) size is 64. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize warp_reduce for int and logical warp of 16 threads + /// using warp_reduce_int = rocprim::warp_reduce; + /// // allocate storage in shared memory + /// __shared__ warp_reduce_int::storage_type temp[4]; + /// + /// int logical_warp_id = threadIdx.x/16; + /// int value = ...; + /// int valid_items = 4; + /// // execute reduction + /// warp_reduce_int().reduce( + /// value, // input + /// value, // output + /// valid_items, + /// temp[logical_warp_id] + /// ); + /// ... + /// } + /// \endcode + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce(T input, + T& output, + int valid_items, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::reduce(input, output, valid_items, storage, reduce_op); + } + + /// \brief Performs reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce(T , + T& , + int , + storage_type& , + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) reduce_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Performs head-segmented reduction across threads in a logical warp. + /// + /// \tparam Flag - type of head flags. Must be contextually convertible to \p bool. + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] flag - thread head flag, \p true flags mark beginnings of segments. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto head_segmented_reduce(T input, + T& output, + Flag flag, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::head_segmented_reduce(input, output, flag, storage, reduce_op); + } + + /// \brief Performs head-segmented reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto head_segmented_reduce(T , + T& , + Flag , + storage_type& , + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) reduce_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Performs tail-segmented reduction across threads in a logical warp. + /// + /// \tparam Flag - type of tail flags. Must be contextually convertible to \p bool. + /// \tparam BinaryFunction - type of binary function used for reduce. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] flag - thread tail flag, \p true flags mark ends of segments. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] reduce_op - binary operation function object that will be used for reduce. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto tail_segmented_reduce(T input, + T& output, + Flag flag, + storage_type& storage, + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::tail_segmented_reduce(input, output, flag, storage, reduce_op); + } + + /// \brief Performs tail-segmented reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto tail_segmented_reduce(T , + T& , + Flag , + storage_type& , + BinaryFunction reduce_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) reduce_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group warpmodule + +#endif // ROCPRIM_WARP_WARP_REDUCE_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_scan.hpp b/3rdparty/cub/rocprim/warp/warp_scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..116d10360297d520b961fc8e83bb17810d44643b --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_scan.hpp @@ -0,0 +1,686 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_SCAN_HPP_ +#define ROCPRIM_WARP_WARP_SCAN_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" +#include "../types.hpp" + +#include "detail/warp_scan_crosslane.hpp" +#include "detail/warp_scan_shared_mem.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// Select warp_scan implementation based WarpSize +template +struct select_warp_scan_impl +{ + typedef typename std::conditional< + // can we use crosslane (DPP or shuffle-based) implementation? + detail::is_warpsize_shuffleable::value, + detail::warp_scan_crosslane, // yes + detail::warp_scan_shared_mem // no + >::type type; +}; + +} // end namespace detail + +/// \brief The warp_scan class is a warp level parallel primitive which provides methods +/// for performing inclusive and exclusive scan operations of items partitioned across +/// threads in a hardware warp. +/// +/// \tparam T - the input/output type. +/// \tparam WarpSize - the size of logical warp size, which can be equal to or less than +/// the size of hardware warp (see rocprim::device_warp_size()). Scan operations are performed +/// separately within groups determined by WarpSize. +/// +/// \par Overview +/// * \p WarpSize must be equal to or less than the size of hardware warp (see +/// rocprim::device_warp_size()). If it is less, scan is performed separately within groups +/// determined by WarpSize. \n +/// For example, if \p WarpSize is 4, hardware warp is 64, scan will be performed in logical +/// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` +/// (thread is represented here by its id within hardware warp). +/// * Logical warp is a group of \p WarpSize consecutive threads from the same hardware warp. +/// * Supports non-commutative scan operators. However, a scan operator should be +/// associative. When used with non-associative functions the results may be non-deterministic +/// and/or vary in precision. +/// * Number of threads executing warp_scan's function must be a multiple of \p WarpSize; +/// * All threads from a logical warp must be in the same hardware warp. +/// +/// \par Examples +/// \parblock +/// In the examples scan operation is performed on groups of 16 threads, each provides +/// one \p int value, result is returned using the same variable as for input. Hardware +/// warp size is 64. Block (tile) size is 64. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// // specialize warp_scan for int and logical warp of 16 threads +/// using warp_scan_int = rocprim::warp_scan; +/// // allocate storage in shared memory +/// __shared__ warp_scan_int::storage_type temp[4]; +/// +/// int logical_warp_id = threadIdx.x/16; +/// int value = ...; +/// // execute inclusive scan +/// warp_scan_int().inclusive_scan( +/// value, // input +/// value, // output +/// temp[logical_warp_id] +/// ); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int WarpSize = device_warp_size() +> +class warp_scan +#ifndef DOXYGEN_SHOULD_SKIP_THIS + : private detail::select_warp_scan_impl::type +#endif +{ + using base_type = typename detail::select_warp_scan_impl::type; + + // Check if WarpSize is valid for the targets + static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, "WarpSize can't be greater than hardware warp size."); + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union type with other storage types + /// to increase shared memory reusability. + using storage_type = typename base_type::storage_type; + + /// \brief Performs inclusive scan across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive min scan operations performed on groups of 32 threads, + /// each provides one \p float value, result is returned using the same variable as for input. + /// Hardware warp size is 64. Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for float and logical warp of 32 threads + /// using warp_scan_f = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_float::storage_type temp[8]; // 256/32 = 8 + /// + /// int logical_warp_id = threadIdx.x/32; + /// float value = ...; + /// // execute inclusive min scan + /// warp_scan_float().inclusive_scan( + /// value, // input + /// value, // output + /// temp[logical_warp_id], + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the input values across threads in a block/tile are {1, -2, 3, -4, ..., 255, -256}, then + /// output values in the first logical warp will be {1, -2, -2, -4, ..., -32}, in the second: + /// {33, -34, -34, -36, ..., -64} etc. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::inclusive_scan(input, output, storage, scan_op); + } + + /// \brief Performs inclusive scan across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan(T , + T& , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Performs inclusive scan and reduction across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [out] reduction - result of reducing of all \p input values in logical warp. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive prefix sum operations performed on groups of 64 threads, + /// each thread provides one \p int value. Hardware warp size is 64. Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for int and logical warp of 64 threads + /// using warp_scan_int = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_int::storage_type temp[4]; // 256/64 = 4 + /// + /// int logical_warp_id = threadIdx.x/64; + /// int input = ...; + /// int output, reduction; + /// // inclusive prefix sum + /// warp_scan_int().inclusive_scan( + /// input, + /// output, + /// reduction, + /// temp[logical_warp_id] + /// ); + /// ... + /// } + /// \endcode + /// + /// If the \p input values across threads in a block/tile are {1, 1, 1, 1, ..., 1, 1}, then + /// \p output values in the every logical warp will be {1, 2, 3, 4, ..., 64}. + /// The \p reduction will be equal \p 64. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan(T input, + T& output, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::inclusive_scan(input, output, reduction, storage, scan_op); + } + + /// \brief Performs inclusive scan and reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan(T , + T& , + T& , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Performs exclusive scan across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a logical warp. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive min scan operations performed on groups of 32 threads, + /// each provides one \p float value, result is returned using the same variable as for input. + /// Hardware warp size is 64. Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for float and logical warp of 32 threads + /// using warp_scan_f = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_float::storage_type temp[8]; // 256/32 = 8 + /// + /// int logical_warp_id = threadIdx.x/32; + /// float value = ...; + /// // execute exclusive min scan + /// warp_scan_float().exclusive_scan( + /// value, // input + /// value, // output + /// 100.0f, // init + /// temp[logical_warp_id], + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the initial value is \p 100 and input values across threads in a block/tile are + /// {1, -2, 3, -4, ..., 255, -256}, then output values in the first logical + /// warp will be {100, 1, -2, -2, -4, ..., -30}, in the second: + /// {100, 33, -34, -34, -36, ..., -62} etc. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan(T input, + T& output, + T init, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::exclusive_scan(input, output, init, storage, scan_op); + } + + /// \brief Performs exclusive scan across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan(T , + T& , + T , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Performs exclusive scan and reduction across threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] output - reference to a thread output value. May be aliased with \p input. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a logical warp. + /// \param [out] reduction - result of reducing of all \p input values in logical warp. + /// \p init value is not included in the reduction. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present exclusive prefix sum operations performed on groups of 64 threads, + /// each thread provides one \p int value. Hardware warp size is 64. Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for int and logical warp of 64 threads + /// using warp_scan_int = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_int::storage_type temp[4]; // 256/64 = 4 + /// + /// int logical_warp_id = threadIdx.x/64; + /// int input = ...; + /// int output, reduction; + /// // exclusive prefix sum + /// warp_scan_int().exclusive_scan( + /// input, + /// output, + /// 10, // init + /// reduction, + /// temp[logical_warp_id] + /// ); + /// ... + /// } + /// \endcode + /// + /// If the initial value is \p 10 and \p input values across threads in a block/tile are + /// {1, 1, ..., 1, 1}, then \p output values in every logical warp will be + /// {10, 11, 12, 13, ..., 73}. The \p reduction will be 64. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan(T input, + T& output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); + } + + /// \brief Performs exclusive scan and reduction across threads in a logical warp. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan(T , + T& , + T , + T& , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Performs inclusive and exclusive scan operations across threads + /// in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] inclusive_output - reference to a thread inclusive-scan output value. + /// \param [out] exclusive_output - reference to a thread exclusive-scan output value. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a logical warp. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present min inclusive and exclusive scan operations performed on groups of 32 threads, + /// each provides one \p float value, result is returned using the same variable as for input. + /// Hardware warp size is 64. Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for float and logical warp of 32 threads + /// using warp_scan_f = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_float::storage_type temp[8]; // 256/32 = 8 + /// + /// int logical_warp_id = threadIdx.x/32; + /// float input = ...; + /// float ex_output, in_output; + /// // execute exclusive min scan + /// warp_scan_float().scan( + /// input, + /// in_output, + /// ex_output, + /// 100.0f, // init + /// temp[logical_warp_id], + /// rocprim::minimum() + /// ); + /// ... + /// } + /// \endcode + /// + /// If the initial value is \p 100 and input values across threads in a block/tile are + /// {1, -2, 3, -4, ..., 255, -256}, then \p in_output values in the first logical + /// warp will be {1, -2, -2, -4, ..., -32}, in the second: + /// {33, -34, -34, -36, ..., -64} and so forth, \p ex_output values in the first + /// logical warp will be {100, 1, -2, -2, -4, ..., -30}, in the second: + /// {100, 33, -34, -34, -36, ..., -62} etc. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::scan(input, inclusive_output, exclusive_output, init, storage, scan_op); + } + + /// \brief Performs inclusive and exclusive scan operations across threads + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scan(T , + T& , + T& , + T , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Performs inclusive and exclusive scan operations, and reduction across + /// threads in a logical warp. + /// + /// \tparam BinaryFunction - type of binary function used for scan. Default type + /// is rocprim::plus. + /// + /// \param [in] input - thread input value. + /// \param [out] inclusive_output - reference to a thread inclusive-scan output value. + /// \param [out] exclusive_output - reference to a thread exclusive-scan output value. + /// \param [in] init - initial value used to start the exclusive scan. Should be the same + /// for all threads in a logical warp. + /// \param [out] reduction - result of reducing of all \p input values in logical warp. + /// \p init value is not included in the reduction. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// \param [in] scan_op - binary operation function object that will be used for scan. + /// The signature of the function should be equivalent to the following: + /// T f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Examples + /// \parblock + /// The examples present inclusive and exclusive prefix sum operations performed on groups + /// of 64 threads, each thread provides one \p int value. Hardware warp size is 64. + /// Block (tile) size is 256. + /// + /// \code{.cpp} + /// __global__ void example_kernel(...) // blockDim.x = 256 + /// { + /// // specialize warp_scan for int and logical warp of 64 threads + /// using warp_scan_int = rocprim::warp_scan; + /// // allocate storage in shared memory + /// __shared__ warp_scan_int::storage_type temp[4]; // 256/64 = 4 + /// + /// int logical_warp_id = threadIdx.x/64; + /// int input = ...; + /// int in_output, ex_output, reduction; + /// // inclusive and exclusive prefix sum + /// warp_scan_int().scan( + /// input, + /// in_output, + /// ex_output, + /// init, + /// reduction, + /// temp[logical_warp_id] + /// ); + /// ... + /// } + /// \endcode + /// + /// If the initial value is \p 10 and \p input values across threads in a block/tile are + /// {1, 1, ..., 1, 1}, then \p in_output values in every logical warp will be + /// {1, 2, 3, 4, ..., 63, 64}, and \p ex_output values in every logical warp will + /// be {10, 11, 12, 13, ..., 73}. The \p reduction will be 64. + /// \endparblock + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::scan( + input, inclusive_output, exclusive_output, init, reduction, + storage, scan_op + ); + } + + /// \brief Performs inclusive and exclusive scan operations across threads + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scan(T , + T& , + T& , + T , + T& , + storage_type& , + BinaryFunction scan_op = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) scan_op; + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size . Aborting warp sort."); + return; + } + + /// \brief Broadcasts value from one thread to all threads in logical warp. + /// + /// \param [in] input - value to broadcast. + /// \param [in] src_lane - id of the thread whose value should be broadcasted + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto broadcast(T input, + const unsigned int src_lane, + storage_type& storage) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), T>::type + { + return base_type::broadcast(input, src_lane, storage); + } + + /// \brief Broadcasts value from one thread to all threads in logical warp. + /// Invalid Warp Size + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto broadcast(T , + const unsigned int , + storage_type& ) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), T>::type + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return T(); + } + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +protected: + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + return base_type::to_exclusive(inclusive_input, exclusive_output, storage); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto to_exclusive(T , T& , storage_type&) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } +#endif +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group warpmodule + +#endif // ROCPRIM_WARP_WARP_SCAN_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_sort.hpp b/3rdparty/cub/rocprim/warp/warp_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a4b60396c324e831e049dc5186b0267fe2d40745 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_sort.hpp @@ -0,0 +1,523 @@ +// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_SORT_HPP_ +#define ROCPRIM_WARP_WARP_SORT_HPP_ + +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "../intrinsics.hpp" +#include "../functional.hpp" + +#include "detail/warp_sort_shuffle.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief The warp_sort class provides warp-wide methods for computing a parallel +/// sort of items across thread warps. This class currently implements parallel +/// bitonic sort, and only accepts warp sizes that are powers of two. +/// +/// \tparam Key Data type for parameter Key +/// \tparam WarpSize [optional] The number of threads in a warp +/// \tparam Value [optional] Data type for parameter Value. By default, it's empty_type +/// +/// \par Overview +/// * \p WarpSize must be power of two. +/// * \p WarpSize must be equal to or less than the size of hardware warp (see +/// rocprim::device_warp_size()). If it is less, sort is performed separately within groups +/// determined by WarpSize. +/// For example, if \p WarpSize is 4, hardware warp is 64, sort will be performed in logical +/// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` +/// (thread is represented here by its id within hardware warp). +/// * Accepts custom compare_functions for sorting across a warp. +/// * Number of threads executing warp_sort's function must be a multiple of \p WarpSize. +/// +/// \par Example: +/// \parblock +/// Every thread within the warp uses the warp_sort class by first specializing the +/// warp_sort type, and instantiating an object that will be used to invoke a +/// member function. +/// +/// \code{.cpp} +/// __global__ void example_kernel(...) +/// { +/// const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; +/// +/// int value = input[i]; +/// rocprim::warp_sort wsort; +/// wsort.sort(value); +/// input[i] = value; +/// } +/// \endcode +/// +/// Below is a snippet demonstrating how to pass a custom compare function: +/// \code{.cpp} +/// __device__ bool customCompare(const int& a, const int& b) +/// { +/// return a < b; +/// } +/// ... +/// __global__ void example_kernel(...) +/// { +/// const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; +/// +/// int value = input[i]; +/// rocprim::warp_sort wsort; +/// wsort.sort(value, customCompare); +/// input[i] = value; +/// } +/// \endcode +/// \endparblock +template< + class Key, + unsigned int WarpSize = device_warp_size(), + class Value = empty_type +> +class warp_sort : detail::warp_sort_shuffle +{ + typedef typename detail::warp_sort_shuffle base_type; + + // Check if WarpSize is valid for the targets + static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, "WarpSize can't be greater than hardware warp size."); + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords \p __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union with other storage types + /// to increase shared memory reusability. + typedef typename base_type::storage_type storage_type; + + /// \brief Warp sort for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_key - input/output to pass to other threads + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& thread_key, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort(thread_key, compare_function); + } + + /// \brief Warp sort for any data type. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Warp sort for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_keys - input/output keys to pass to other threads + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort(thread_keys, compare_function); + } + + /// \brief Warp sort for any data type. + /// Invalid Warp Size + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) thread_keys; // disables unused parameter warning + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Warp sort for any data type using temporary storage. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_key - input/output to pass to other threads + /// \param storage - temporary storage for inputs + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int value = ...; + /// using warp_sort_int = rp::warp_sort; + /// warp_sort_int wsort; + /// __shared__ typename warp_sort_int::storage_type storage; + /// wsort.sort(value, storage); + /// ... + /// } + /// \endcode + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& thread_key, + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_key, storage, compare_function + ); + } + + /// \brief Warp sort for any data type using temporary storage. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& , + storage_type& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + + /// \brief Warp sort for any data type using temporary storage. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_keys - input/output keys to pass to other threads + /// \param storage - temporary storage for inputs + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int value = ...; + /// using warp_sort_int = rp::warp_sort; + /// warp_sort_int wsort; + /// __shared__ typename warp_sort_int::storage_type storage; + /// wsort.sort(value, storage); + /// ... + /// } + /// \endcode + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_keys, storage, compare_function + ); + } + + /// \brief Warp sort for any data type using temporary storage. + /// Invalid Warp Size + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + storage_type& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) thread_keys; // disables unused parameter warning + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Warp sort by key for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_key - input/output key to pass to other threads + /// \param thread_value - input/output value to pass to other threads + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& thread_key, + Value& thread_value, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_key, thread_value, compare_function + ); + } + + /// \brief Warp sort by key for any data type. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& , + Value& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Warp sort by key for any data type. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_keys - input/output keys to pass to other threads + /// \param thread_values - input/outputs values to pass to other threads + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_keys, thread_values, compare_function + ); + } + + /// \brief Warp sort by key for any data type. + /// Invalid Warp Size + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) thread_keys; // disables unused parameter warning + (void) thread_values; // disables unused parameter warning + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + /// \brief Warp sort by key for any data type using temporary storage. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_key - input/output key to pass to other threads + /// \param thread_value - input/output value to pass to other threads + /// \param storage - temporary storage for inputs + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int value = ...; + /// using warp_sort_int = rp::warp_sort; + /// warp_sort_int wsort; + /// __shared__ typename warp_sort_int::storage_type storage; + /// wsort.sort(key, value, storage); + /// ... + /// } + /// \endcode + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& thread_key, + Value& thread_value, + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_key, thread_value, storage, compare_function + ); + } + + /// \brief Warp sort by key for any data type using temporary storage. + /// Invalid Warp Size + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key& , + Value& , + storage_type& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } + + + /// \brief Warp sort by key for any data type using temporary storage. + /// + /// \tparam BinaryFunction - type of binary function used for sort. Default type + /// is rocprim::less. + /// + /// \param thread_keys - input/output keys to pass to other threads + /// \param thread_values - input/output values to pass to other threads + /// \param storage - temporary storage for inputs + /// \param compare_function - binary operation function object that will be used for sort. + /// The signature of the function should be equivalent to the following: + /// bool f(const T &a, const T &b);. The signature does not need to have + /// const &, but function object must not modify the objects passed to it. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// int value = ...; + /// using warp_sort_int = rp::warp_sort; + /// warp_sort_int wsort; + /// __shared__ typename warp_sort_int::storage_type storage; + /// wsort.sort(key, value, storage); + /// ... + /// } + /// \endcode + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& storage, + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + { + base_type::sort( + thread_keys, thread_values, storage, compare_function + ); + } + + /// \brief Warp sort by key for any data type using temporary storage. + /// Invalid Warp Size + template< + unsigned int ItemsPerThread, + class BinaryFunction = ::rocprim::less, + unsigned int FunctionWarpSize = WarpSize + > + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Key (&thread_keys)[ItemsPerThread], + Value (&thread_values)[ItemsPerThread], + storage_type& , + BinaryFunction compare_function = BinaryFunction()) + -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + { + (void) thread_keys; // disables unused parameter warning + (void) thread_values; // disables unused parameter warning + (void) compare_function; // disables unused parameter warning + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); + return; + } +}; + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group warpmodule + +#endif // ROCPRIM_WARP_WARP_SORT_HPP_ diff --git a/3rdparty/cub/rocprim/warp/warp_store.hpp b/3rdparty/cub/rocprim/warp/warp_store.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4155aa2894c7a6fa22e8cb33c74cfc239d61acc2 --- /dev/null +++ b/3rdparty/cub/rocprim/warp/warp_store.hpp @@ -0,0 +1,373 @@ +// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_WARP_WARP_STORE_HPP_ +#define ROCPRIM_WARP_WARP_STORE_HPP_ + +#include "../config.hpp" +#include "../intrinsics.hpp" +#include "../detail/various.hpp" + +#include "warp_exchange.hpp" +#include "../block/block_store_func.hpp" + +/// \addtogroup warpmodule +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief \p warp_store_method enumerates the methods available to store a blocked/striped +/// arrangement of items into a blocked/striped arrangement in continuous memory +enum class warp_store_method +{ + /// A blocked arrangement of items is stored into a blocked arrangement on continuous + /// memory. + /// \par Performance Notes: + /// * Performance decreases with increasing number of items per thread (stride + /// between reads), because of reduced memory coalescing. + warp_store_direct, + + /// A striped arrangement of items is stored into a blocked arrangement on continuous + /// memory. + warp_store_striped, + + /// A blocked arrangement of items is stored into a blocked arrangement on continuous + /// memory using vectorization as an optimization. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, provided that + /// vectorization requirements are fulfilled. Otherwise, performance will default + /// to \p warp_store_direct. + /// \par Requirements: + /// * The output offset (\p block_output) must be quad-item aligned. + /// * The following conditions will prevent vectorization and switch to default + /// \p warp_store_direct: + /// * \p ItemsPerThread is odd. + /// * The datatype \p T is not a primitive or a HIP vector type (e.g. int2, + /// int4, etc. + warp_store_vectorize, + + /// A blocked arrangement of items is locally transposed and stored as a striped + /// arrangement of data on continuous memory. + /// \par Performance Notes: + /// * Performance remains high due to increased memory coalescing, regardless of the + /// number of items per thread. + /// * Performance may be better compared to \p warp_store_direct and + /// \p warp_store_vectorize due to reordering on local memory. + warp_store_transpose, + + /// Defaults to \p warp_store_direct + default_method = warp_store_direct +}; + +/// \brief The \p warp_store class is a warp level parallel primitive which provides methods +/// for storing an arrangement of items into a blocked/striped arrangement on continous memory. +/// +/// \tparam T - the output/output type. +/// \tparam ItemsPerThread - the number of items to be processed by +/// each thread. +/// \tparam WarpSize - the number of threads in a warp. It must be a divisor of the +/// kernel block size. +/// \tparam Method - the method to store data. +/// +/// \par Overview +/// * The \p warp_store class has a number of different methods to store data: +/// * [warp_store_direct](\ref ::warp_store_method::warp_store_direct) +/// * [warp_store_striped](\ref ::warp_store_method::warp_store_striped) +/// * [warp_store_vectorize](\ref ::warp_store_method::warp_store_vectorize) +/// * [warp_store_transpose](\ref ::warp_store_method::warp_store_transpose) +/// +/// \par Example: +/// \parblock +/// In the example a store operation is performed on a warp of 8 threads, using type +/// \p int and 4 items per thread. +/// +/// \code{.cpp} +/// __global__ void example_kernel(int * output, ...) +/// { +/// constexpr unsigned int threads_per_block = 128; +/// constexpr unsigned int threads_per_warp = 8; +/// constexpr unsigned int items_per_thread = 4; +/// constexpr unsigned int warps_per_block = threads_per_block / threads_per_warp; +/// const unsigned int warp_id = hipThreadIdx_x / threads_per_warp; +/// const int offset = blockIdx.x * threads_per_block * items_per_thread +/// + warp_id * threads_per_warp * items_per_thread; +/// int items[items_per_thread]; +/// rocprim::warp_store warp_store; +/// warp_store.store(output + offset, items); +/// ... +/// } +/// \endcode +/// \endparblock +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize = ::rocprim::device_warp_size(), + warp_store_method Method = warp_store_method::warp_store_direct +> +class warp_store +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +private: + using storage_type_ = typename ::rocprim::detail::empty_storage_type; + +public: + /// \brief Struct used to allocate a temporary memory that is required for thread + /// communication during operations provided by related parallel primitive. + /// + /// Depending on the implemention the operations exposed by parallel primitive may + /// require a temporary storage for thread communication. The storage should be allocated + /// using keywords \p __shared__. It can be aliased to + /// an externally allocated memory, or be a part of a union with other storage types + /// to increase shared memory reusability. + #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen + using storage_type = typename ::rocprim::detail::empty_storage_type; + #else + using storage_type = storage_type_; // only for Doxygen + #endif + + /// \brief Stores an arrangement of items from across the warp into an + /// arrangement on continuous memory. + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator to store to. + /// \param [in] items - array that data is read from. + /// \param [in] storage - temporary storage for outputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p OutputIterator + /// can be dereferenced and then implicitly assigned from \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_blocked(flat_id, output, items); + } + + /// \brief Stores an arrangement of items from across the warp into an + /// arrangement on continuous memory, which is guarded by range \p valid, + /// using temporary storage + /// + /// \tparam OutputIterator - [inferred] an iterator type for output (can be a simple + /// pointer. + /// + /// \param [out] block_output - the output iterator to store to. + /// \param [in] items - array that data is read from. + /// \param [in] valid - maximum range of valid numbers to read. + /// \param [in] storage - temporary storage for outputs. + /// + /// \par Overview + /// * The type \p T must be such that an object of type \p OutputIterator + /// can be dereferenced and then implicitly assigned from \p T. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_blocked(flat_id, output, items, valid); + } +}; + +/// @} +// end of group warpmodule + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_store +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +public: + using storage_type = typename ::rocprim::detail::empty_storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items, valid); + } +}; + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_store +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +public: + using storage_type = typename ::rocprim::detail::empty_storage_type; + + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(T* output, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_blocked_vectorized(flat_id, output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_blocked(flat_id, output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& /*storage*/) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_blocked(flat_id, output, items, valid); + } +}; + +template< + class T, + unsigned int ItemsPerThread, + unsigned int WarpSize +> +class warp_store +{ + static_assert(::rocprim::detail::is_power_of_two(WarpSize), + "Logical warp size must be a power of two."); + static_assert(WarpSize <= ::rocprim::device_warp_size(), + "Logical warp size cannot be larger than physical warp size."); + +private: + using exchange_type = ::rocprim::warp_exchange; + +public: + using storage_type = typename exchange_type::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + exchange_type().blocked_to_striped(items, items, storage); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(OutputIterator output, + T (&items)[ItemsPerThread], + unsigned int valid, + storage_type& storage) + { + using value_type = typename std::iterator_traits::value_type; + static_assert(std::is_convertible::value, + "The type T must be such that an object of type OutputIterator " + "can be dereferenced and then implicitly assigned from T."); + exchange_type().blocked_to_striped(items, items, storage); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items, valid); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_WARP_WARP_STORE_HPP_ diff --git a/3rdparty/cub/thread/thread_load.cuh b/3rdparty/cub/thread/thread_load.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e71d796325ad6dfb2a9ecf6b8af1a11832bf59f8 --- /dev/null +++ b/3rdparty/cub/thread/thread_load.cuh @@ -0,0 +1,118 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_THREAD_THREAD_LOAD_HPP_ +#define HIPCUB_ROCPRIM_THREAD_THREAD_LOAD_HPP_ + +#include "../config.hpp" +#include "../util_type.cuh" + +BEGIN_HIPCUB_NAMESPACE + +enum CacheLoadModifier : int32_t +{ + LOAD_DEFAULT, ///< Default (no modifier) + LOAD_CA, ///< Cache at all levels + LOAD_CG, ///< Cache at global level + LOAD_CS, ///< Cache streaming (likely to be accessed once) + LOAD_CV, ///< Cache as volatile (including cached system lines) + LOAD_LDG, ///< Cache as texture + LOAD_VOLATILE, ///< Volatile (any memory space) +}; + +template +HIPCUB_DEVICE __forceinline__ T AsmThreadLoad(void * ptr) +{ + T retval = 0; + __builtin_memcpy(&retval, ptr, sizeof(T)); + return retval; +} + +#if HIPCUB_THREAD_LOAD_USE_CACHE_MODIFIERS == 1 + +// Important for syncing. Check section 9.2.2 or 7.3 in the following document +// http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf +#define HIPCUB_ASM_THREAD_LOAD(cache_modifier, \ + llvm_cache_modifier, \ + type, \ + interim_type, \ + asm_operator, \ + output_modifier, \ + wait_cmd) \ + template<> \ + HIPCUB_DEVICE __forceinline__ type AsmThreadLoad(void * ptr) \ + { \ + interim_type retval; \ + asm volatile( \ + #asm_operator " %0, %1 " llvm_cache_modifier "\n" \ + "\ts_waitcnt " wait_cmd "(0)" : "=" #output_modifier(retval) : "v"(ptr) \ + ); \ + return retval; \ + } + +// TODO Add specialization for custom larger data types +#define HIPCUB_ASM_THREAD_LOAD_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_load_sbyte, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_load_sshort, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_load_ubyte, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_load_ushort, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_load_dword, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_load_dword, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_cmd); \ + HIPCUB_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_cmd); + +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CA, "glc", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CG, "glc slc", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CV, "glc", "vmcnt"); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_VOLATILE, "glc", "vmcnt"); + +// TODO find correct modifiers to match these +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_LDG, "", ""); +HIPCUB_ASM_THREAD_LOAD_GROUP(LOAD_CS, "", ""); + +#endif + +template +HIPCUB_DEVICE __forceinline__ +typename std::iterator_traits::value_type ThreadLoad(InputIteratorT itr) +{ + using T = typename std::iterator_traits::value_type; + T retval = ThreadLoad(&(*itr)); + return retval; +} + +template +HIPCUB_DEVICE __forceinline__ T +ThreadLoad(T * ptr) +{ + return AsmThreadLoad(ptr); +} + +END_HIPCUB_NAMESPACE +#endif diff --git a/3rdparty/cub/thread/thread_operators.cuh b/3rdparty/cub/thread/thread_operators.cuh new file mode 100644 index 0000000000000000000000000000000000000000..79deacccb0f30375b6229281907678bfdc8ec6d0 --- /dev/null +++ b/3rdparty/cub/thread/thread_operators.cuh @@ -0,0 +1,341 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ +#define HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" + +BEGIN_HIPCUB_NAMESPACE + +struct Equality +{ + template + HIPCUB_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a == b; + } +}; + +struct Inequality +{ + template + HIPCUB_HOST_DEVICE inline + constexpr bool operator()(const T& a, const T& b) const + { + return a != b; + } +}; + +template +struct InequalityWrapper +{ + EqualityOp op; + + HIPCUB_HOST_DEVICE inline + InequalityWrapper(EqualityOp op) : op(op) {} + + template + HIPCUB_HOST_DEVICE inline + bool operator()(const T &a, const T &b) + { + return !op(a, b); + } +}; + +struct Sum +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a + b; + } +}; + +struct Difference +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a - b; + } +}; + +struct Division +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a / b; + } +}; + +struct Max +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a < b ? b : a; + } +}; + +struct Min +{ + template + HIPCUB_HOST_DEVICE inline + constexpr T operator()(const T &a, const T &b) const + { + return a < b ? a : b; + } +}; + +struct ArgMax +{ + template< + class Key, + class Value + > + HIPCUB_HOST_DEVICE inline + constexpr KeyValuePair + operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + } +}; + +struct ArgMin +{ + template< + class Key, + class Value + > + HIPCUB_HOST_DEVICE inline + constexpr KeyValuePair + operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + } +}; + +template +struct CastOp +{ + template + HIPCUB_HOST_DEVICE inline + B operator()(const A &a) const + { + return (B)a; + } +}; + +template +class SwizzleScanOp +{ +private: + ScanOp scan_op; + +public: + HIPCUB_HOST_DEVICE inline + SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) + { + } + + template + HIPCUB_HOST_DEVICE inline + T operator()(const T &a, const T &b) + { + T _a(a); + T _b(b); + + return scan_op(_b, _a); + } +}; + +template +struct ReduceBySegmentOp +{ + ReductionOpT op; + + HIPCUB_HOST_DEVICE inline + ReduceBySegmentOp() + { + } + + HIPCUB_HOST_DEVICE inline + ReduceBySegmentOp(ReductionOpT op) : op(op) + { + } + + template + HIPCUB_HOST_DEVICE inline + KeyValuePairT operator()( + const KeyValuePairT &first, + const KeyValuePairT &second) + { + KeyValuePairT retval; + retval.key = first.key + second.key; + retval.value = (second.key) ? + second.value : + op(first.value, second.value); + return retval; + } +}; + +template +struct ReduceByKeyOp +{ + ReductionOpT op; + + HIPCUB_HOST_DEVICE inline + ReduceByKeyOp() + { + } + + HIPCUB_HOST_DEVICE inline + ReduceByKeyOp(ReductionOpT op) : op(op) + { + } + + template + HIPCUB_HOST_DEVICE inline + KeyValuePairT operator()( + const KeyValuePairT &first, + const KeyValuePairT &second) + { + KeyValuePairT retval = second; + + if (first.key == second.key) + { + retval.value = op(first.value, retval.value); + } + return retval; + } +}; + +template +struct BinaryFlip +{ + BinaryOpT binary_op; + + HIPCUB_HOST_DEVICE + explicit BinaryFlip(BinaryOpT binary_op) : binary_op(binary_op) + { + } + + template + HIPCUB_DEVICE auto + operator()(T &&t, U &&u) -> decltype(binary_op(std::forward(u), + std::forward(t))) + { + return binary_op(std::forward(u), std::forward(t)); + } +}; + +template +HIPCUB_HOST_DEVICE +BinaryFlip MakeBinaryFlip(BinaryOpT binary_op) +{ + return BinaryFlip(binary_op); +} + +namespace detail +{ + +// CUB uses value_type of OutputIteratorT (if not void) as a type of intermediate results in reduce, +// for example: +// +// /// The output value type +// typedef typename If<(Equals::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? +// typename std::iterator_traits::value_type, // ... then the input iterator's value type, +// typename std::iterator_traits::value_type>::Type OutputT; // ... else the output iterator's value type +// +// rocPRIM (as well as Thrust) uses result type of BinaryFunction instead (if not void): +// +// using input_type = typename std::iterator_traits::value_type; +// using result_type = typename ::rocprim::detail::match_result_type< +// input_type, BinaryFunction +// >::type; +// +// For short -> float using Sum() +// CUB: float Sum(float, float) +// rocPRIM: short Sum(short, short) +// +// This wrapper allows to have compatibility with CUB in hipCUB. +template< + class InputIteratorT, + class OutputIteratorT, + class BinaryFunction +> +struct convert_result_type_wrapper +{ + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + using result_type = + typename std::conditional< + std::is_void::value, input_type, output_type + >::type; + + convert_result_type_wrapper(BinaryFunction op) : op(op) {} + + template + HIPCUB_HOST_DEVICE inline + constexpr result_type operator()(const T &a, const T &b) const + { + return static_cast(op(a, b)); + } + + BinaryFunction op; +}; + +template< + class InputIteratorT, + class OutputIteratorT, + class BinaryFunction +> +inline +convert_result_type_wrapper +convert_result_type(BinaryFunction op) +{ + return convert_result_type_wrapper(op); +} + +} // end detail namespace + +END_HIPCUB_NAMESPACE + +#endif // HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_ diff --git a/3rdparty/cub/thread/thread_reduce.cuh b/3rdparty/cub/thread/thread_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7d3674f5f82fe4a7661b35654631aa593e4f614d --- /dev/null +++ b/3rdparty/cub/thread/thread_reduce.cuh @@ -0,0 +1,88 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_THREAD_THREAD_REDUCE_HPP_ +#define HIPCUB_ROCPRIM_THREAD_THREAD_REDUCE_HPP_ + +BEGIN_HIPCUB_NAMESPACE + +/// Internal namespace (to prevent ADL mishaps between static functions when mixing different CUB installations) +namespace internal { + +template < + int LENGTH, + typename T, + typename ReductionOp, + bool NoPrefix = false> +__device__ __forceinline__ T ThreadReduce( + T* input, + ReductionOp reduction_op, + T prefix = T(0)) +{ + T retval; + if(NoPrefix) + retval = input[0]; + else + retval = prefix; + + #pragma unroll + for (int i = 0 + NoPrefix; i < LENGTH; ++i) + retval = reduction_op(retval, input[i]); + + return retval; +} + +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T (&input)[LENGTH], + ReductionOp reduction_op, + T prefix) +{ + return ThreadReduce((T*)input, reduction_op, prefix); +} + +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReduce( + T (&input)[LENGTH], + ReductionOp reduction_op) +{ + return ThreadReduce((T*)input, reduction_op); +} + +} + +END_HIPCUB_NAMESPACE + +#endif diff --git a/3rdparty/cub/thread/thread_scan.cuh b/3rdparty/cub/thread/thread_scan.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cea96fb65eb276eca5a81dcd4bfe862b95f513fd --- /dev/null +++ b/3rdparty/cub/thread/thread_scan.cuh @@ -0,0 +1,255 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIBCUB_ROCPRIM_THREAD_THREAD_SCAN_HPP_ +#define HIBCUB_ROCPRIM_THREAD_THREAD_SCAN_HPP_ + + +#include "../config.hpp" +#include "../util_type.cuh" + +BEGIN_HIPCUB_NAMESPACE + +/// Internal namespace (to prevent ADL mishaps between static functions when mixing different CUB installations) +namespace internal { + + /** + * \addtogroup UtilModule + * @{ + */ + + /** + * \name Sequential prefix scan over statically-sized array types + * @{ + */ + + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanExclusive( + T inclusive, + T exclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) + { + #pragma unroll + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + + return inclusive; + } + + #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /** + * \brief Perform a sequential exclusive prefix scan over \p LENGTH elements of the \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanExclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. If not, the first output element is undefined. (Handy for preventing thread-0 from applying a prefix.) + { + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = prefix; + T exclusive = inclusive; + + return ThreadScanExclusive(inclusive, exclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + /** + * \brief Perform a sequential exclusive prefix scan over the statically-sized \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanExclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + { + return ThreadScanExclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); + } + + #endif + + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanInclusive( + T inclusive, + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + Int2Type /*length*/) + { + #pragma unroll + for (int i = 0; i < LENGTH; ++i) + { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } + + return inclusive; + } + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + + /** + * \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanInclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + T inclusive = input[0]; + output[0] = inclusive; + + // Continue scan + return ThreadScanInclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + /** + * \brief Perform a sequential inclusive prefix scan over the statically-sized \p input array. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanInclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op) ///< [in] Binary scan operator + { + return ThreadScanInclusive((T*) input, (T*) output, scan_op); + } + + /** + * \brief Perform a sequential inclusive prefix scan over \p LENGTH elements of the \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanInclusive( + T *input, ///< [in] Input array + T *output, ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + { + T inclusive = input[0]; + if (apply_prefix) + { + inclusive = scan_op(prefix, inclusive); + } + output[0] = inclusive; + + // Continue scan + return ThreadScanInclusive(inclusive, input + 1, output + 1, scan_op, Int2Type()); + } + + /** + * \brief Perform a sequential inclusive prefix scan over the statically-sized \p input array, seeded with the specified \p prefix. The aggregate is returned. + * + * \tparam LENGTH [inferred] LengthT of \p input and \p output arrays + * \tparam T [inferred] The data type to be scanned. + * \tparam ScanOp [inferred] Binary scan operator type having member T operator()(const T &a, const T &b) + */ + template < + int LENGTH, + typename T, + typename ScanOp> + __device__ __forceinline__ T ThreadScanInclusive( + T (&input)[LENGTH], ///< [in] Input array + T (&output)[LENGTH], ///< [out] Output array (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + T prefix, ///< [in] Prefix to seed scan with + bool apply_prefix = true) ///< [in] Whether or not the calling thread should apply its prefix. (Handy for preventing thread-0 from applying a prefix.) + { + return ThreadScanInclusive((T*) input, (T*) output, scan_op, prefix, apply_prefix); + } + + #endif + + //@} end member group + + /** @} */ // end group UtilModule + + + } // internal namespace + + END_HIPCUB_NAMESPACE + + #endif // HIBCUB_ROCPRIM_THREAD_THREAD_SCAN_HPP_ diff --git a/3rdparty/cub/thread/thread_search.cuh b/3rdparty/cub/thread/thread_search.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ea3a9460a7e33c3bff227b327d0790fb05fc9b94 --- /dev/null +++ b/3rdparty/cub/thread/thread_search.cuh @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + + #ifndef HIBCUB_ROCPRIM_THREAD_THREAD_SEARCH_HPP_ + #define HIBCUB_ROCPRIM_THREAD_THREAD_SEARCH_HPP_ + + #include + + BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * \brief Computes the begin offsets into A and B for the specific diagonal + * + */ +template < + typename AIteratorT, + typename BIteratorT, + typename OffsetT, + typename CoordinateT> +__host__ __device__ __forceinline__ void MergePathSearch( + OffsetT diagonal, + AIteratorT a, + BIteratorT b, + OffsetT a_len, + OffsetT b_len, + CoordinateT& path_coordinate) +{ + OffsetT split_min = CUB_MAX(diagonal - b_len, 0); + OffsetT split_max = CUB_MIN(diagonal, a_len); + + while (split_min < split_max) + { + OffsetT split_pivot = (split_min + split_max) >> 1; + if (a[split_pivot] <= b[diagonal - split_pivot - 1]) + { + // Move candidate split range up A, down B + split_min = split_pivot + 1; + } + else + { + // Move candidate split range up B, down A + split_max = split_pivot; + } + } + + path_coordinate.x = CUB_MIN(split_min, a_len); + path_coordinate.y = diagonal - split_min; +} + + + +/** + * \brief Returns the offset of the first value within \p input which does not compare less than \p val + */ +template < + typename InputIteratorT, + typename OffsetT, + typename T> +__device__ __forceinline__ OffsetT LowerBound( + InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (input[retval + half] < val) + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + else + { + num_items = half; + } + } + + return retval; +} + + +/** + * \brief Returns the offset of the first value within \p input which compares greater than \p val + */ +template < + typename InputIteratorT, + typename OffsetT, + typename T> +__device__ __forceinline__ OffsetT UpperBound( + InputIteratorT input, ///< [in] Input sequence + OffsetT num_items, ///< [in] Input sequence length + T val) ///< [in] Search key +{ + OffsetT retval = 0; + while (num_items > 0) + { + OffsetT half = num_items >> 1; + if (val < input[retval + half]) + { + num_items = half; + } + else + { + retval = retval + (half + 1); + num_items = num_items - (half + 1); + } + } + + return retval; +} + +#endif + +END_HIPCUB_NAMESPACE + +#endif // HIBCUB_ROCPRIM_THREAD_THREAD_SCAN_HPP_ diff --git a/3rdparty/cub/thread/thread_sort.hpp b/3rdparty/cub/thread/thread_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f7108e60657acbdeb80838199398c08970dea4f8 --- /dev/null +++ b/3rdparty/cub/thread/thread_sort.hpp @@ -0,0 +1,112 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_THREAD_SORT_HPP_ +#define HIPCUB_ROCPRIM_THREAD_SORT_HPP_ + +#include "../config.hpp" + +#include "../util_ptx.cuh" +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + + +template +HIPCUB_DEVICE __forceinline__ void Swap(T &lhs, T &rhs) +{ + T temp = lhs; + lhs = rhs; + rhs = temp; +} + + +/** + * @brief Sorts data using odd-even sort method + * + * The sorting method is stable. Further details can be found in: + * A. Nico Habermann. Parallel neighbor sort (or the glory of the induction + * principle). Technical Report AD-759 248, Carnegie Mellon University, 1972. + * + * @tparam KeyT + * Key type + * + * @tparam ValueT + * Value type. If `hipcub::NullType` is used as `ValueT`, only keys are sorted. + * + * @tparam CompareOp + * functor type having member `bool operator()(KeyT lhs, KeyT rhs)` + * + * @tparam ITEMS_PER_THREAD + * The number of items per thread + * + * @param[in,out] keys + * Keys to sort + * + * @param[in,out] items + * Values to sort + * + * @param[in] compare_op + * Comparison function object which returns true if the first argument is + * ordered before the second + */ +template +HIPCUB_DEVICE __forceinline__ void +StableOddEvenSort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op) +{ + constexpr bool KEYS_ONLY = ::rocprim::Equals::VALUE; + + #pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + { + #pragma unroll + for (int j = 1 & i; j < ITEMS_PER_THREAD - 1; j += 2) + { + if (compare_op(keys[j + 1], keys[j])) + { + Swap(keys[j], keys[j + 1]); + if (!KEYS_ONLY) + { + Swap(items[j], items[j + 1]); + } + } + } // inner loop + } // outer loop +} + + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_THREAD_SORT_HPP_ diff --git a/3rdparty/cub/thread/thread_store.cuh b/3rdparty/cub/thread/thread_store.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fd31db0be10883d11c5160dfb4b18f52c131f51b --- /dev/null +++ b/3rdparty/cub/thread/thread_store.cuh @@ -0,0 +1,109 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_THREAD_THREAD_STORE_HPP_ +#define HIPCUB_ROCPRIM_THREAD_THREAD_STORE_HPP_ +BEGIN_HIPCUB_NAMESPACE + +enum CacheStoreModifier +{ + STORE_DEFAULT, ///< Default (no modifier) + STORE_WB, ///< Cache write-back all coherent levels + STORE_CG, ///< Cache at global level + STORE_CS, ///< Cache streaming (likely to be accessed once) + STORE_WT, ///< Cache write-through (to system memory) + STORE_VOLATILE, ///< Volatile shared (any memory space) +}; + +// TODO add to detail namespace +// TODO cleanup +template +HIPCUB_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val) +{ + __builtin_memcpy(ptr, &val, sizeof(T)); +} + +#if HIPCUB_THREAD_STORE_USE_CACHE_MODIFIERS == 1 + +// NOTE: the reason there is an interim_type is because of a bug for 8bit types. +// TODO fix flat_store_ubyte and flat_store_sbyte issues + +// Important for syncing. Check section 9.2.2 or 7.3 in the following document +// http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf +#define HIPCUB_ASM_THREAD_STORE(cache_modifier, \ + llvm_cache_modifier, \ + type, \ + interim_type, \ + asm_operator, \ + output_modifier, \ + wait_cmd) \ + template<> \ + HIPCUB_DEVICE __forceinline__ void AsmThreadStore(void * ptr, type val) \ + { \ + interim_type temp_val = val; \ + asm volatile(#asm_operator " %0, %1 " llvm_cache_modifier : : "v"(ptr), #output_modifier(temp_val)); \ + asm volatile("s_waitcnt " wait_cmd "(%0)" : : "I"(0x00)); \ + } + +// TODO fix flat_store_ubyte and flat_store_sbyte issues +// TODO Add specialization for custom larger data types +#define HIPCUB_ASM_THREAD_STORE_GROUP(cache_modifier, llvm_cache_modifier, wait_cmd) \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int8_t, int16_t, flat_store_byte, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, int16_t, int16_t, flat_store_short, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint8_t, uint16_t, flat_store_byte, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint16_t, uint16_t, flat_store_short, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint32_t, uint32_t, flat_store_dword, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, float, uint32_t, flat_store_dword, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_cmd); \ + HIPCUB_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_cmd); + +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WB, "glc", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CG, "glc slc", ""); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_WT, "glc", "vmcnt"); +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_VOLATILE, "glc", "vmcnt"); + +// TODO find correct modifiers to match these +HIPCUB_ASM_THREAD_STORE_GROUP(STORE_CS, "", ""); + +#endif + +template +__device__ __forceinline__ void ThreadStore(OutputIteratorT itr, T val) +{ + ThreadStore(&(*itr), val); +} + +template +__device__ __forceinline__ void ThreadStore(T * ptr, T val) +{ + AsmThreadStore(ptr, val); +} + +END_HIPCUB_NAMESPACE +#endif diff --git a/3rdparty/cub/util_allocator.cuh b/3rdparty/cub/util_allocator.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7c425309f32e192b39c0bf8e74426ca59661e7a6 --- /dev/null +++ b/3rdparty/cub/util_allocator.cuh @@ -0,0 +1,647 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_UTIL_ALLOCATOR_HPP_ +#define HIPCUB_ROCPRIM_UTIL_ALLOCATOR_HPP_ + +#include "config.hpp" + +#include +#include +#include + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +#define _HipcubLog(format, ...) printf(format, __VA_ARGS__); + +// Hipified version of cub/util_allocator.cuh + +struct CachingDeviceAllocator +{ + //--------------------------------------------------------------------- + // Constants + //--------------------------------------------------------------------- + + /// Out-of-bounds bin + static const unsigned int INVALID_BIN = (unsigned int) -1; + + /// Invalid size + static const size_t INVALID_SIZE = (size_t) -1; + + /// Invalid device ordinal + static const int INVALID_DEVICE_ORDINAL = -1; + + //--------------------------------------------------------------------- + // Type definitions and helper types + //--------------------------------------------------------------------- + + /** + * Descriptor for device memory allocations + */ + struct BlockDescriptor + { + void* d_ptr; // Device pointer + size_t bytes; // Size of allocation in bytes + unsigned int bin; // Bin enumeration + int device; // device ordinal + cudaStream_t associated_stream; // Associated associated_stream + cudaEvent_t ready_event; // Signal when associated stream has run to the point at which this block was freed + + // Constructor (suitable for searching maps for a specific block, given its pointer and device) + BlockDescriptor(void *d_ptr, int device) : + d_ptr(d_ptr), + bytes(0), + bin(INVALID_BIN), + device(device), + associated_stream(0), + ready_event(0) + {} + + // Constructor (suitable for searching maps for a range of suitable blocks, given a device) + BlockDescriptor(int device) : + d_ptr(NULL), + bytes(0), + bin(INVALID_BIN), + device(device), + associated_stream(0), + ready_event(0) + {} + + // Comparison functor for comparing device pointers + static bool PtrCompare(const BlockDescriptor &a, const BlockDescriptor &b) + { + if (a.device == b.device) + return (a.d_ptr < b.d_ptr); + else + return (a.device < b.device); + } + + // Comparison functor for comparing allocation sizes + static bool SizeCompare(const BlockDescriptor &a, const BlockDescriptor &b) + { + if (a.device == b.device) + return (a.bytes < b.bytes); + else + return (a.device < b.device); + } + }; + + /// BlockDescriptor comparator function interface + typedef bool (*Compare)(const BlockDescriptor &, const BlockDescriptor &); + + class TotalBytes { + public: + size_t free; + size_t live; + TotalBytes() { free = live = 0; } + }; + + /// Set type for cached blocks (ordered by size) + typedef std::multiset CachedBlocks; + + /// Set type for live blocks (ordered by ptr) + typedef std::multiset BusyBlocks; + + /// Map type of device ordinals to the number of cached bytes cached by each device + typedef std::map GpuCachedBytes; + + + //--------------------------------------------------------------------- + // Utility functions + //--------------------------------------------------------------------- + + /** + * Integer pow function for unsigned base and exponent + */ + static unsigned int IntPow( + unsigned int base, + unsigned int exp) + { + unsigned int retval = 1; + while (exp > 0) + { + if (exp & 1) { + retval = retval * base; // multiply the result by the current base + } + base = base * base; // square the base + exp = exp >> 1; // divide the exponent in half + } + return retval; + } + + + /** + * Round up to the nearest power-of + */ + void NearestPowerOf( + unsigned int &power, + size_t &rounded_bytes, + unsigned int base, + size_t value) + { + power = 0; + rounded_bytes = 1; + + if (value * base < value) + { + // Overflow + power = sizeof(size_t) * 8; + rounded_bytes = size_t(0) - 1; + return; + } + + while (rounded_bytes < value) + { + rounded_bytes *= base; + power++; + } + } + + + //--------------------------------------------------------------------- + // Fields + //--------------------------------------------------------------------- + + std::mutex mutex; /// Mutex for thread-safety + + unsigned int bin_growth; /// Geometric growth factor for bin-sizes + unsigned int min_bin; /// Minimum bin enumeration + unsigned int max_bin; /// Maximum bin enumeration + + size_t min_bin_bytes; /// Minimum bin size + size_t max_bin_bytes; /// Maximum bin size + size_t max_cached_bytes; /// Maximum aggregate cached bytes per device + + const bool skip_cleanup; /// Whether or not to skip a call to FreeAllCached() when destructor is called. (The CUDA runtime may have already shut down for statically declared allocators) + bool debug; /// Whether or not to print (de)allocation events to stdout + + GpuCachedBytes cached_bytes; /// Map of device ordinal to aggregate cached bytes on that device + CachedBlocks cached_blocks; /// Set of cached device allocations available for reuse + BusyBlocks live_blocks; /// Set of live device allocations currently in use + + //--------------------------------------------------------------------- + // Methods + //--------------------------------------------------------------------- + + /** + * \brief Constructor. + */ + CachingDeviceAllocator( + unsigned int bin_growth, ///< Geometric growth factor for bin-sizes + unsigned int min_bin = 1, ///< Minimum bin (default is bin_growth ^ 1) + unsigned int max_bin = INVALID_BIN, ///< Maximum bin (default is no max bin) + size_t max_cached_bytes = INVALID_SIZE, ///< Maximum aggregate cached bytes per device (default is no limit) + bool skip_cleanup = false, ///< Whether or not to skip a call to \p FreeAllCached() when the destructor is called (default is to deallocate) + bool debug = false) ///< Whether or not to print (de)allocation events to stdout (default is no stderr output) + : + bin_growth(bin_growth), + min_bin(min_bin), + max_bin(max_bin), + min_bin_bytes(IntPow(bin_growth, min_bin)), + max_bin_bytes(IntPow(bin_growth, max_bin)), + max_cached_bytes(max_cached_bytes), + skip_cleanup(skip_cleanup), + debug(debug), + cached_blocks(BlockDescriptor::SizeCompare), + live_blocks(BlockDescriptor::PtrCompare) + {} + + + /** + * \brief Default constructor. + * + * Configured with: + * \par + * - \p bin_growth = 8 + * - \p min_bin = 3 + * - \p max_bin = 7 + * - \p max_cached_bytes = (\p bin_growth ^ \p max_bin) * 3) - 1 = 6,291,455 bytes + * + * which delineates five bin-sizes: 512B, 4KB, 32KB, 256KB, and 2MB and + * sets a maximum of 6,291,455 cached bytes per device + */ + CachingDeviceAllocator( + bool skip_cleanup = false, + bool debug = false) + : + bin_growth(8), + min_bin(3), + max_bin(7), + min_bin_bytes(IntPow(bin_growth, min_bin)), + max_bin_bytes(IntPow(bin_growth, max_bin)), + max_cached_bytes((max_bin_bytes * 3) - 1), + skip_cleanup(skip_cleanup), + debug(debug), + cached_blocks(BlockDescriptor::SizeCompare), + live_blocks(BlockDescriptor::PtrCompare) + {} + + + /** + * \brief Sets the limit on the number bytes this allocator is allowed to cache per device. + * + * Changing the ceiling of cached bytes does not cause any allocations (in-use or + * cached-in-reserve) to be freed. See \p FreeAllCached(). + */ + cudaError_t SetMaxCachedBytes( + size_t max_cached_bytes) + { + // Lock + mutex.lock(); + + if (debug) _HipcubLog("Changing max_cached_bytes (%lld -> %lld)\n", (long long) this->max_cached_bytes, (long long) max_cached_bytes); + + this->max_cached_bytes = max_cached_bytes; + + // Unlock + mutex.unlock(); + + return cudaSuccess; + } + + + /** + * \brief Provides a suitable allocation of device memory for the given size on the specified device. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceAllocate( + int device, ///< [in] Device on which to place the allocation + void **d_ptr, ///< [out] Reference to pointer to the allocation + size_t bytes, ///< [in] Minimum number of bytes for the allocation + cudaStream_t active_stream = 0) ///< [in] The stream to be associated with this allocation + { + *d_ptr = NULL; + int entrypoint_device = INVALID_DEVICE_ORDINAL; + cudaError_t error = cudaSuccess; + + if (device == INVALID_DEVICE_ORDINAL) + { + if (cubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + device = entrypoint_device; + } + + // Create a block descriptor for the requested allocation + bool found = false; + BlockDescriptor search_key(device); + search_key.associated_stream = active_stream; + NearestPowerOf(search_key.bin, search_key.bytes, bin_growth, bytes); + + if (search_key.bin > max_bin) + { + // Bin is greater than our maximum bin: allocate the request + // exactly and give out-of-bounds bin. It will not be cached + // for reuse when returned. + search_key.bin = INVALID_BIN; + search_key.bytes = bytes; + } + else + { + // Search for a suitable cached allocation: lock + mutex.lock(); + + if (search_key.bin < min_bin) + { + // Bin is less than minimum bin: round up + search_key.bin = min_bin; + search_key.bytes = min_bin_bytes; + } + + // Iterate through the range of cached blocks on the same device in the same bin + CachedBlocks::iterator block_itr = cached_blocks.lower_bound(search_key); + while ((block_itr != cached_blocks.end()) + && (block_itr->device == device) + && (block_itr->bin == search_key.bin)) + { + // To prevent races with reusing blocks returned by the host but still + // in use by the device, only consider cached blocks that are + // either (from the active stream) or (from an idle stream) + if ((active_stream == block_itr->associated_stream) || + (cudaEventQuery(block_itr->ready_event) != cudaErrorNotReady)) + { + // Reuse existing cache block. Insert into live blocks. + found = true; + search_key = *block_itr; + search_key.associated_stream = active_stream; + live_blocks.insert(search_key); + + // Remove from free blocks + cached_bytes[device].free -= search_key.bytes; + cached_bytes[device].live += search_key.bytes; + + if (debug) _HipcubLog("\tDevice %d reused cached block at %p (%lld bytes) for stream %lld (previously associated with stream %lld).\n", + device, search_key.d_ptr, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) block_itr->associated_stream); + + cached_blocks.erase(block_itr); + + break; + } + block_itr++; + } + + // Done searching: unlock + mutex.unlock(); + } + + // Allocate the block if necessary + if (!found) + { + // Set runtime's current device to specified device (entrypoint may not be set) + if (device != entrypoint_device) + { + if (cubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + if (cubDebug(error = cudaSetDevice(device))) return error; + } + + // Attempt to allocate + if (cubDebug(error = cudaMalloc(&search_key.d_ptr, search_key.bytes)) == cudaErrorMemoryAllocation) + { + // The allocation attempt failed: free all cached blocks on device and retry + if (debug) _HipcubLog("\tDevice %d failed to allocate %lld bytes for stream %lld, retrying after freeing cached allocations", + device, (long long) search_key.bytes, (long long) search_key.associated_stream); + + error = cudaGetLastError(); // Reset error + error = cudaSuccess; // Reset the error we will return + + // Lock + mutex.lock(); + + // Iterate the range of free blocks on the same device + BlockDescriptor free_key(device); + CachedBlocks::iterator block_itr = cached_blocks.lower_bound(free_key); + + while ((block_itr != cached_blocks.end()) && (block_itr->device == device)) + { + // No need to worry about synchronization with the device: hipFree is + // blocking and will synchronize across all kernels executing + // on the current device + + // Free device memory and destroy stream event. + if (cubDebug(error = cudaFree(block_itr->d_ptr))) break; + if (cubDebug(error = cudaEventDestroy(block_itr->ready_event))) break; + + // Reduce balance and erase entry + cached_bytes[device].free -= block_itr->bytes; + + if (debug) _HipcubLog("\tDevice %d freed %lld bytes.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + device, (long long) block_itr->bytes, (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + + cached_blocks.erase(block_itr); + + block_itr++; + } + + // Unlock + mutex.unlock(); + + // Return under error + if (error) return error; + + // Try to allocate again + if (cubDebug(error = cudaMalloc(&search_key.d_ptr, search_key.bytes))) return error; + } + + // Create ready event + if (cubDebug(error = cudaEventCreateWithFlags(&search_key.ready_event, cudaEventDisableTiming))) + return error; + + // Insert into live blocks + mutex.lock(); + live_blocks.insert(search_key); + cached_bytes[device].live += search_key.bytes; + mutex.unlock(); + + if (debug) _HipcubLog("\tDevice %d allocated new device block at %p (%lld bytes associated with stream %lld).\n", + device, search_key.d_ptr, (long long) search_key.bytes, (long long) search_key.associated_stream); + + // Attempt to revert back to previous device if necessary + if ((entrypoint_device != INVALID_DEVICE_ORDINAL) && (entrypoint_device != device)) + { + if (cubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + } + + // Copy device pointer to output parameter + *d_ptr = search_key.d_ptr; + + if (debug) _HipcubLog("\t\t%lld available blocks cached (%lld bytes), %lld live blocks outstanding(%lld bytes).\n", + (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + + return error; + } + + + /** + * \brief Provides a suitable allocation of device memory for the given size on the current device. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceAllocate( + void **d_ptr, ///< [out] Reference to pointer to the allocation + size_t bytes, ///< [in] Minimum number of bytes for the allocation + cudaStream_t active_stream = 0) ///< [in] The stream to be associated with this allocation + { + return DeviceAllocate(INVALID_DEVICE_ORDINAL, d_ptr, bytes, active_stream); + } + + + /** + * \brief Frees a live allocation of device memory on the specified device, returning it to the allocator. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceFree( + int device, + void* d_ptr) + { + int entrypoint_device = INVALID_DEVICE_ORDINAL; + cudaError_t error = cudaSuccess; + + if (device == INVALID_DEVICE_ORDINAL) + { + if (cubDebug(error = cudaGetDevice(&entrypoint_device))) + return error; + device = entrypoint_device; + } + + // Lock + mutex.lock(); + + // Find corresponding block descriptor + bool recached = false; + BlockDescriptor search_key(d_ptr, device); + BusyBlocks::iterator block_itr = live_blocks.find(search_key); + if (block_itr != live_blocks.end()) + { + // Remove from live blocks + search_key = *block_itr; + live_blocks.erase(block_itr); + cached_bytes[device].live -= search_key.bytes; + + // Keep the returned allocation if bin is valid and we won't exceed the max cached threshold + if ((search_key.bin != INVALID_BIN) && (cached_bytes[device].free + search_key.bytes <= max_cached_bytes)) + { + // Insert returned allocation into free blocks + recached = true; + cached_blocks.insert(search_key); + cached_bytes[device].free += search_key.bytes; + + if (debug) _HipcubLog("\tDevice %d returned %lld bytes from associated stream %lld.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks outstanding. (%lld bytes)\n", + device, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) cached_blocks.size(), + (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + } + } + + // First set to specified device (entrypoint may not be set) + if (device != entrypoint_device) + { + if (cubDebug(error = cudaGetDevice(&entrypoint_device))) return error; + if (cubDebug(error = cudaSetDevice(device))) return error; + } + + if (recached) + { + // Insert the ready event in the associated stream (must have current device set properly) + if (cubDebug(error = cudaEventRecord(search_key.ready_event, search_key.associated_stream))) return error; + } + + // Unlock + mutex.unlock(); + + if (!recached) + { + // Free the allocation from the runtime and cleanup the event. + if (cubDebug(error = cudaFree(d_ptr))) return error; + if (cubDebug(error = cudaEventDestroy(search_key.ready_event))) return error; + + if (debug) _HipcubLog("\tDevice %d freed %lld bytes from associated stream %lld.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + device, (long long) search_key.bytes, (long long) search_key.associated_stream, (long long) cached_blocks.size(), (long long) cached_bytes[device].free, (long long) live_blocks.size(), (long long) cached_bytes[device].live); + } + + // Reset device + if ((entrypoint_device != INVALID_DEVICE_ORDINAL) && (entrypoint_device != device)) + { + if (cubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + + return error; + } + + + /** + * \brief Frees a live allocation of device memory on the current device, returning it to the allocator. + * + * Once freed, the allocation becomes available immediately for reuse within the \p active_stream + * with which it was associated with during allocation, and it becomes available for reuse within other + * streams when all prior work submitted to \p active_stream has completed. + */ + cudaError_t DeviceFree( + void* d_ptr) + { + return DeviceFree(INVALID_DEVICE_ORDINAL, d_ptr); + } + + + /** + * \brief Frees all cached device allocations on all devices + */ + cudaError_t FreeAllCached() + { + cudaError_t error = cudaSuccess; + int entrypoint_device = INVALID_DEVICE_ORDINAL; + int current_device = INVALID_DEVICE_ORDINAL; + + mutex.lock(); + + while (!cached_blocks.empty()) + { + // Get first block + CachedBlocks::iterator begin = cached_blocks.begin(); + + // Get entry-point device ordinal if necessary + if (entrypoint_device == INVALID_DEVICE_ORDINAL) + { + if (cubDebug(error = cudaGetDevice(&entrypoint_device))) break; + } + + // Set current device ordinal if necessary + if (begin->device != current_device) + { + if (cubDebug(error = cudaSetDevice(begin->device))) break; + current_device = begin->device; + } + + // Free device memory + if (cubDebug(error = cudaFree(begin->d_ptr))) break; + if (cubDebug(error = cudaEventDestroy(begin->ready_event))) break; + + // Reduce balance and erase entry + cached_bytes[current_device].free -= begin->bytes; + + if (debug) _HipcubLog("\tDevice %d freed %lld bytes.\n\t\t %lld available blocks cached (%lld bytes), %lld live blocks (%lld bytes) outstanding.\n", + current_device, (long long) begin->bytes, (long long) cached_blocks.size(), (long long) cached_bytes[current_device].free, (long long) live_blocks.size(), (long long) cached_bytes[current_device].live); + + cached_blocks.erase(begin); + } + + mutex.unlock(); + + // Attempt to revert back to entry-point device if necessary + if (entrypoint_device != INVALID_DEVICE_ORDINAL) + { + if (cubDebug(error = cudaSetDevice(entrypoint_device))) return error; + } + + return error; + } + + + /** + * \brief Destructor + */ + virtual ~CachingDeviceAllocator() + { + if (!skip_cleanup) + FreeAllCached(); + } + +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_UTIL_ALLOCATOR_HPP_ diff --git a/3rdparty/cub/util_math.cuh b/3rdparty/cub/util_math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1d466cd89397c86c9dca28e32be893ffa4ef0da5 --- /dev/null +++ b/3rdparty/cub/util_math.cuh @@ -0,0 +1,54 @@ +/****************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_UTIL_MATH_HPP_ +#define HIPCUB_ROCPRIM_UTIL_MATH_HPP_ + +/** + * \file + * Define helper math functions. + */ + +BEGIN_HIPCUB_NAMESPACE + +/** + * \brief Computes the midpoint of the integers + * + * Extra operation is performed in order to prevent overflow. + * + * \return Half the sum of \p begin and \p end + */ +template +constexpr __device__ __host__ T MidPoint(T begin, T end) +{ + return begin + (end - begin) / 2; +} + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_UTIL_MATH_HPP_ diff --git a/3rdparty/cub/util_ptx.cuh b/3rdparty/cub/util_ptx.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bcc11ba1a8f26c49b2f708b658a21f64a3945542 --- /dev/null +++ b/3rdparty/cub/util_ptx.cuh @@ -0,0 +1,325 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_ +#define HIPCUB_ROCPRIM_UTIL_PTX_HPP_ + +#include +#include + +#include "config.hpp" + +#include + +#define HIPCUB_WARP_THREADS ::rocprim::warp_size() +#define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size() +#define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size() +#define HIPCUB_ARCH 1 // ignored with rocPRIM backend + + +BEGIN_HIPCUB_NAMESPACE + +// Missing compared to CUB: +// * ThreadExit - not supported +// * ThreadTrap - not supported +// * FFMA_RZ, FMUL_RZ - not in CUB public API +// * WARP_SYNC - not supported, not CUB public API +// * CTA_SYNC_AND - not supported, not CUB public API +// * MatchAny - not in CUB public API +// +// Differences: +// * Warp thread masks (when used) are 64-bit unsigned integers +// * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs +// * Arguments first_lane, last_lane, and member_mask are ignored +// in Shuffle* funcs +// * count in BAR is ignored, BAR works like CTA_SYNC + +// ID functions etc. + +HIPCUB_DEVICE inline +int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z) +{ + return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) + + ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) + + threadIdx.x; +} + +HIPCUB_DEVICE inline +unsigned int LaneId() +{ + return ::rocprim::lane_id(); +} + +HIPCUB_DEVICE inline +unsigned int WarpId() +{ + return ::rocprim::warp_id(); +} + +template +HIPCUB_DEVICE inline +uint64_t WarpMask(unsigned int warp_id) { + constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS); + constexpr bool is_arch_warp = + LOGICAL_WARP_THREADS == ::rocprim::device_warp_size(); + + uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS); + + if (is_pow_of_two && !is_arch_warp) { + member_mask <<= warp_id * LOGICAL_WARP_THREADS; + } + + return member_mask; +} + +// Returns the warp lane mask of all lanes less than the calling thread +HIPCUB_DEVICE inline +uint64_t LaneMaskLt() +{ + return (uint64_t(1) << LaneId()) - 1; +} + +// Returns the warp lane mask of all lanes less than or equal to the calling thread +HIPCUB_DEVICE inline +uint64_t LaneMaskLe() +{ + return ((uint64_t(1) << LaneId()) << 1) - 1; +} + +// Returns the warp lane mask of all lanes greater than the calling thread +HIPCUB_DEVICE inline +uint64_t LaneMaskGt() +{ + return uint64_t(-1)^LaneMaskLe(); +} + +// Returns the warp lane mask of all lanes greater than or equal to the calling thread +HIPCUB_DEVICE inline +uint64_t LaneMaskGe() +{ + return uint64_t(-1)^LaneMaskLt(); +} + +// Shuffle funcs + +template < + int LOGICAL_WARP_THREADS, + typename T +> +HIPCUB_DEVICE inline +T ShuffleUp(T input, + int src_offset, + int first_thread, + unsigned int member_mask) +{ + // Not supported in rocPRIM. + (void) first_thread; + // Member mask is not supported in rocPRIM, because it's + // not supported in ROCm. + (void) member_mask; + return ::rocprim::warp_shuffle_up( + input, src_offset, LOGICAL_WARP_THREADS + ); +} + +template < + int LOGICAL_WARP_THREADS, + typename T +> +HIPCUB_DEVICE inline +T ShuffleDown(T input, + int src_offset, + int last_thread, + unsigned int member_mask) +{ + // Not supported in rocPRIM. + (void) last_thread; + // Member mask is not supported in rocPRIM, because it's + // not supported in ROCm. + (void) member_mask; + return ::rocprim::warp_shuffle_down( + input, src_offset, LOGICAL_WARP_THREADS + ); +} + +template < + int LOGICAL_WARP_THREADS, + typename T +> +HIPCUB_DEVICE inline +T ShuffleIndex(T input, + int src_lane, + unsigned int member_mask) +{ + // Member mask is not supported in rocPRIM, because it's + // not supported in ROCm. + (void) member_mask; + return ::rocprim::warp_shuffle( + input, src_lane, LOGICAL_WARP_THREADS + ); +} + +// Other + +HIPCUB_DEVICE inline +unsigned int SHR_ADD(unsigned int x, + unsigned int shift, + unsigned int addend) +{ + return (x >> shift) + addend; +} + +HIPCUB_DEVICE inline +unsigned int SHL_ADD(unsigned int x, + unsigned int shift, + unsigned int addend) +{ + return (x << shift) + addend; +} + +namespace detail { + +template +HIPCUB_DEVICE inline +auto unsigned_bit_extract(UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits) + -> typename std::enable_if::type +{ + #ifdef __CUDACC__ + return __bitextract_u64(source, bit_start, num_bits); + #else + return (source << (64 - bit_start - num_bits)) >> (64 - num_bits); + #endif // __HIP_PLATFORM_AMD__ +} + +template +HIPCUB_DEVICE inline +auto unsigned_bit_extract(UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits) + -> typename std::enable_if::type +{ + #ifdef __CUDACC__ + return __bitextract_u32(source, bit_start, num_bits); + #else + return (static_cast(source) << (32 - bit_start - num_bits)) >> (32 - num_bits); + #endif // __HIP_PLATFORM_AMD__ +} + +} // end namespace detail + +// Bitfield-extract. +// Extracts \p num_bits from \p source starting at bit-offset \p bit_start. +// The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type. +template +HIPCUB_DEVICE inline +unsigned int BFE(UnsignedBits source, + unsigned int bit_start, + unsigned int num_bits) +{ + static_assert(std::is_unsigned::value, "UnsignedBits must be unsigned"); + return detail::unsigned_bit_extract(source, bit_start, num_bits); +} + +// Bitfield insert. +// Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start. +HIPCUB_DEVICE inline +void BFI(unsigned int &ret, + unsigned int x, + unsigned int y, + unsigned int bit_start, + unsigned int num_bits) +{ + #ifdef __CUDACC__ + ret = __bitinsert_u32(x, y, bit_start, num_bits); + #else + x <<= bit_start; + unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start; + unsigned int MASK_Y = ~MASK_X; + ret = (y & MASK_Y) | (x & MASK_X); + #endif // __HIP_PLATFORM_AMD__ +} + +HIPCUB_DEVICE inline +unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z) +{ + return x + y + z; +} + +HIPCUB_DEVICE inline +int PRMT(unsigned int a, unsigned int b, unsigned int index) +{ + return ::__byte_perm(a, b, index); +} + +HIPCUB_DEVICE inline +void BAR(int count) +{ + (void) count; + __syncthreads(); +} + +HIPCUB_DEVICE inline +void CTA_SYNC() +{ + __syncthreads(); +} + +HIPCUB_DEVICE inline +void WARP_SYNC(unsigned int member_mask) +{ + (void) member_mask; + ::rocprim::wave_barrier(); +} + +HIPCUB_DEVICE inline +int WARP_ANY(int predicate, uint64_t member_mask) +{ + (void) member_mask; + return ::__any(predicate); +} + +HIPCUB_DEVICE inline +int WARP_ALL(int predicate, uint64_t member_mask) +{ + (void) member_mask; + return ::__all(predicate); +} + +HIPCUB_DEVICE inline +int64_t WARP_BALLOT(int predicate, uint64_t member_mask) +{ + (void) member_mask; + return __ballot(predicate); +} + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_ diff --git a/3rdparty/cub/util_type.cuh b/3rdparty/cub/util_type.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2c8672e084e9d436464b10693d0a62dee84bdc18 --- /dev/null +++ b/3rdparty/cub/util_type.cuh @@ -0,0 +1,645 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_UTIL_TYPE_HPP_ +#define HIPCUB_ROCPRIM_UTIL_TYPE_HPP_ + +#include +#include + +#include "config.hpp" + +#include +#include + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +using NullType = ::rocprim::empty_type; + +#endif + +template struct +[[deprecated("[Since 1.16] If is deprecated use std::conditional instead.")]] If +{ + using Type = typename std::conditional::type; +}; + +template struct +[[deprecated("[Since 1.16] IsPointer is deprecated use std::is_pointer instead.")]] IsPointer +{ + static constexpr bool VALUE = std::is_pointer::value; +}; + +template struct +[[deprecated("[Since 1.16] IsVolatile is deprecated use std::is_volatile instead.")]] IsVolatile +{ + static constexpr bool VALUE = std::is_volatile::value; +}; + +template struct +[[deprecated("[Since 1.16] RemoveQualifiers is deprecated use std::remove_cv instead.")]] RemoveQualifiers +{ + using Type = typename std::remove_cv::type; +}; + +template +struct PowerOfTwo +{ + static constexpr bool VALUE = ::rocprim::detail::is_power_of_two(N); +}; + +namespace detail +{ + +template +struct Log2Impl +{ + static constexpr int VALUE = Log2Impl> 1), COUNT + 1>::VALUE; +}; + +template +struct Log2Impl +{ + static constexpr int VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1; +}; + +} // end of detail namespace + +template +struct Log2 +{ + static_assert(N != 0, "The logarithm of zero is undefined"); + static constexpr int VALUE = detail::Log2Impl::VALUE; +}; + +template +struct DoubleBuffer +{ + T * d_buffers[2]; + + int selector; + + HIPCUB_HOST_DEVICE inline + DoubleBuffer() + { + selector = 0; + d_buffers[0] = nullptr; + d_buffers[1] = nullptr; + } + + HIPCUB_HOST_DEVICE inline + DoubleBuffer(T * d_current, T * d_alternate) + { + selector = 0; + d_buffers[0] = d_current; + d_buffers[1] = d_alternate; + } + + HIPCUB_HOST_DEVICE inline + T * Current() + { + return d_buffers[selector]; + } + + HIPCUB_HOST_DEVICE inline + T * Alternate() + { + return d_buffers[selector ^ 1]; + } +}; + +template +struct Int2Type +{ + enum {VALUE = A}; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template< + class Key, + class Value +> +using KeyValuePair = ::rocprim::key_value_pair; + +#endif + +template +using FutureValue = ::rocprim::future_value; + +namespace detail +{ + +template +inline +::rocprim::double_buffer to_double_buffer(DoubleBuffer& source) +{ + return ::rocprim::double_buffer(source.Current(), source.Alternate()); +} + +template +inline +void update_double_buffer(DoubleBuffer& target, ::rocprim::double_buffer& source) +{ + if(target.Current() != source.current()) + { + target.selector ^= 1; + } +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +template +using is_integral_or_enum = + std::integral_constant::value || std::is_enum::value>; + +#endif + +} + +template +HIPCUB_HOST_DEVICE __forceinline__ constexpr NumeratorT +DivideAndRoundUp(NumeratorT n, DenominatorT d) +{ + static_assert(cub::detail::is_integral_or_enum::value && + cub::detail::is_integral_or_enum::value, + "DivideAndRoundUp is only intended for integral types."); + + // Static cast to undo integral promotion. + return static_cast(n / d + (n % d != 0 ? 1 : 0)); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/****************************************************************************** + * Size and alignment + ******************************************************************************/ + +/// Structure alignment +template +struct AlignBytes +{ + struct Pad + { + T val; + char byte; + }; + + enum + { + /// The "true CUDA" alignment of T in bytes + ALIGN_BYTES = sizeof(Pad) - sizeof(T) + }; + + /// The "truly aligned" type + typedef T Type; +}; + +// Specializations where host C++ compilers (e.g., 32-bit Windows) may disagree +// with device C++ compilers (EDG) on types passed as template parameters through +// kernel functions + +#define __HIPCUB_ALIGN_BYTES(t, b) \ + template <> struct AlignBytes \ + { enum { ALIGN_BYTES = b }; typedef __align__(b) t Type; }; + +__HIPCUB_ALIGN_BYTES(short4, 8) +__HIPCUB_ALIGN_BYTES(ushort4, 8) +__HIPCUB_ALIGN_BYTES(int2, 8) +__HIPCUB_ALIGN_BYTES(uint2, 8) +__HIPCUB_ALIGN_BYTES(long long, 8) +__HIPCUB_ALIGN_BYTES(unsigned long long, 8) +__HIPCUB_ALIGN_BYTES(float2, 8) +__HIPCUB_ALIGN_BYTES(double, 8) +#ifdef _WIN32 + __HIPCUB_ALIGN_BYTES(long2, 8) + __HIPCUB_ALIGN_BYTES(ulong2, 8) +#else + __HIPCUB_ALIGN_BYTES(long2, 16) + __HIPCUB_ALIGN_BYTES(ulong2, 16) +#endif +__HIPCUB_ALIGN_BYTES(int4, 16) +__HIPCUB_ALIGN_BYTES(uint4, 16) +__HIPCUB_ALIGN_BYTES(float4, 16) +__HIPCUB_ALIGN_BYTES(long4, 16) +__HIPCUB_ALIGN_BYTES(ulong4, 16) +__HIPCUB_ALIGN_BYTES(longlong2, 16) +__HIPCUB_ALIGN_BYTES(ulonglong2, 16) +__HIPCUB_ALIGN_BYTES(double2, 16) +__HIPCUB_ALIGN_BYTES(longlong4, 16) +__HIPCUB_ALIGN_BYTES(ulonglong4, 16) +__HIPCUB_ALIGN_BYTES(double4, 16) + +template struct AlignBytes : AlignBytes {}; +template struct AlignBytes : AlignBytes {}; +template struct AlignBytes : AlignBytes {}; + + +/// Unit-words of data movement +template +struct UnitWord +{ + enum { + ALIGN_BYTES = AlignBytes::ALIGN_BYTES + }; + + template + struct IsMultiple + { + enum { + UNIT_ALIGN_BYTES = AlignBytes::ALIGN_BYTES, + IS_MULTIPLE = (sizeof(T) % sizeof(Unit) == 0) && (int(ALIGN_BYTES) % int(UNIT_ALIGN_BYTES) == 0) + }; + }; + + /// Biggest shuffle word that T is a whole multiple of and is not larger than the alignment of T + typedef typename std::conditional::IS_MULTIPLE, + unsigned int, + typename std::conditional::IS_MULTIPLE, + unsigned short, + unsigned char>::type>::type ShuffleWord; + + /// Biggest volatile word that T is a whole multiple of and is not larger than the alignment of T + typedef typename std::conditional::IS_MULTIPLE, + unsigned long long, + ShuffleWord>::type VolatileWord; + + /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T + typedef typename std::conditional::IS_MULTIPLE, + ulonglong2, + VolatileWord>::type DeviceWord; + + /// Biggest texture reference word that T is a whole multiple of and is not larger than the alignment of T + typedef typename std::conditional::IS_MULTIPLE, + uint4, + typename std::conditional::IS_MULTIPLE, + uint2, + ShuffleWord>::type>::type TextureWord; +}; + + +// float2 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef int ShuffleWord; + typedef unsigned long long VolatileWord; + typedef unsigned long long DeviceWord; + typedef float2 TextureWord; +}; + +// float4 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef int ShuffleWord; + typedef unsigned long long VolatileWord; + typedef ulonglong2 DeviceWord; + typedef float4 TextureWord; +}; + + +// char2 specialization workaround (for SM10-SM13) +template <> +struct UnitWord +{ + typedef unsigned short ShuffleWord; + typedef unsigned short VolatileWord; + typedef unsigned short DeviceWord; + typedef unsigned short TextureWord; +}; + + +template struct UnitWord : UnitWord {}; +template struct UnitWord : UnitWord {}; +template struct UnitWord : UnitWord {}; + + +#endif // DOXYGEN_SHOULD_SKIP_THIS + + + + +/****************************************************************************** + * Wrapper types + ******************************************************************************/ + +/** + * \brief A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions + */ +template +struct Uninitialized +{ + /// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T + typedef typename UnitWord::DeviceWord DeviceWord; + + static constexpr std::size_t DATA_SIZE = sizeof(T); + static constexpr std::size_t WORD_SIZE = sizeof(DeviceWord); + static constexpr std::size_t WORDS = DATA_SIZE / WORD_SIZE; + + /// Backing storage + DeviceWord storage[WORDS]; + + /// Alias + HIPCUB_HOST_DEVICE __forceinline__ T& Alias() + { + return reinterpret_cast(*this); + } +}; + + +/****************************************************************************** + * Simple type traits utilities. + * + * For example: + * Traits::CATEGORY // SIGNED_INTEGER + * Traits::NULL_TYPE // true + * Traits::CATEGORY // NOT_A_NUMBER + * Traits::PRIMITIVE; // false + * + ******************************************************************************/ + + #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document + +/** + * \brief Basic type traits categories + */ +enum Category +{ + NOT_A_NUMBER, + SIGNED_INTEGER, + UNSIGNED_INTEGER, + FLOATING_POINT +}; + + +/** + * \brief Basic type traits + */ +template +struct BaseTraits +{ + /// Category + static const Category CATEGORY = _CATEGORY; + enum + { + PRIMITIVE = _PRIMITIVE, + NULL_TYPE = _NULL_TYPE, + }; +}; + + +/** + * Basic type traits (unsigned primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = UNSIGNED_INTEGER; + static const UnsignedBits LOWEST_KEY = UnsignedBits(0); + static const UnsignedBits MAX_KEY = UnsignedBits(-1); + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + return key; + } + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + return key; + } + + static HIPCUB_HOST_DEVICE __forceinline__ T Max() + { + UnsignedBits retval_bits = MAX_KEY; + T retval; + memcpy(&retval, &retval_bits, sizeof(T)); + return retval; + } + + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() + { + UnsignedBits retval_bits = LOWEST_KEY; + T retval; + memcpy(&retval, &retval_bits, sizeof(T)); + return retval; + } +}; + + +/** + * Basic type traits (signed primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = SIGNED_INTEGER; + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + static const UnsignedBits LOWEST_KEY = HIGH_BIT; + static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT; + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + return key ^ HIGH_BIT; + }; + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + return key ^ HIGH_BIT; + }; + + static HIPCUB_HOST_DEVICE __forceinline__ T Max() + { + UnsignedBits retval = MAX_KEY; + return reinterpret_cast(retval); + } + + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() + { + UnsignedBits retval = LOWEST_KEY; + return reinterpret_cast(retval); + } +}; + +template +struct FpLimits; + +template <> +struct FpLimits +{ + static HIPCUB_HOST_DEVICE __forceinline__ float Max() { + return std::numeric_limits::max(); + } + + static HIPCUB_HOST_DEVICE __forceinline__ float Lowest() { + return std::numeric_limits::max() * float(-1); + } +}; + +template <> +struct FpLimits +{ + static HIPCUB_HOST_DEVICE __forceinline__ double Max() { + return std::numeric_limits::max(); + } + + static HIPCUB_HOST_DEVICE __forceinline__ double Lowest() { + return std::numeric_limits::max() * double(-1); + } +}; + +template <> +struct FpLimits<__half> +{ + static HIPCUB_HOST_DEVICE __forceinline__ __half Max() { + unsigned short max_word = 0x7BFF; + return reinterpret_cast<__half&>(max_word); + } + + static HIPCUB_HOST_DEVICE __forceinline__ __half Lowest() { + unsigned short lowest_word = 0xFBFF; + return reinterpret_cast<__half&>(lowest_word); + } +}; + +template <> +struct FpLimits +{ + static HIPCUB_HOST_DEVICE __forceinline__ cuda_bfloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast(max_word); + } + + static HIPCUB_HOST_DEVICE __forceinline__ cuda_bfloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast(lowest_word); + } +}; + +/** + * Basic type traits (fp primitive specialization) + */ +template +struct BaseTraits +{ + typedef _UnsignedBits UnsignedBits; + + static const Category CATEGORY = FLOATING_POINT; + static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); + static const UnsignedBits LOWEST_KEY = UnsignedBits(-1); + static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT; + + enum + { + PRIMITIVE = true, + NULL_TYPE = false, + }; + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) + { + UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; + return key ^ mask; + }; + + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) + { + UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); + return key ^ mask; + }; + + static HIPCUB_HOST_DEVICE __forceinline__ T Max() { + return FpLimits::Max(); + } + + static HIPCUB_HOST_DEVICE __forceinline__ T Lowest() { + return FpLimits::Lowest(); + } +}; + + +/** + * \brief Numeric type traits + */ +template struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits<(std::numeric_limits::is_signed) ? SIGNED_INTEGER : UNSIGNED_INTEGER, true, false, unsigned char, char> {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; +template <> struct NumericTraits<__half> : BaseTraits {}; +template <> struct NumericTraits : BaseTraits {}; + +template <> struct NumericTraits : BaseTraits::VolatileWord, bool> {}; + +/** + * \brief Type traits + */ +template +struct Traits : NumericTraits::type> {}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_UTIL_TYPE_HPP_ diff --git a/3rdparty/cub/version.cuh b/3rdparty/cub/version.cuh new file mode 100644 index 0000000000000000000000000000000000000000..56d63c7e8db035f0e522735b020aeb81dd5d8080 --- /dev/null +++ b/3rdparty/cub/version.cuh @@ -0,0 +1,21 @@ +#ifndef HIPCUB_VERSION_HPP_ +#define HIPCUB_VERSION_HPP_ + +/// \def HIPCUB_VERSION +/// \brief HIPCUB library version +/// +/// Version number may not be visible in the documentation. +/// +/// HIPCUB_VERSION % 100 is the patch level, +/// HIPCUB_VERSION / 100 % 1000 is the minor version, +/// HIPCUB_VERSION / 100000 is the major version. +/// +/// For example, if HIPCUB_VERSION is 100500, then the major version is 1, +/// the minor version is 5, and the patch level is 0. +#define HIPCUB_VERSION 2 * 100000 + 10 * 100 + 12 + +#define HIPCUB_VERSION_MAJOR 2 +#define HIPCUB_VERSION_MINOR 10 +#define HIPCUB_VERSION_PATCH 12 + +#endif // HIPCUB_VERSION_HPP_ \ No newline at end of file diff --git a/3rdparty/cub/warp/warp_exchange.hpp b/3rdparty/cub/warp/warp_exchange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d83d65c8849fa04e1c902a0bc837cdf6b87b1311 --- /dev/null +++ b/3rdparty/cub/warp/warp_exchange.hpp @@ -0,0 +1,109 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_EXCHANGE_HPP_ +#define HIPCUB_ROCPRIM_WARP_WARP_EXCHANGE_HPP_ + +#include "../config.hpp" +#include "../util_type.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template < + typename InputT, + int ITEMS_PER_THREAD, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + int ARCH = HIPCUB_ARCH +> +class WarpExchange +{ + using base_type = typename rocprim::warp_exchange; + +public: + using TempStorage = typename base_type::storage_type; + +private: + TempStorage &temp_storage; + +public: + WarpExchange() = delete; + + explicit HIPCUB_DEVICE __forceinline__ + WarpExchange(TempStorage &temp_storage) : + temp_storage(temp_storage) + { + } + + template + HIPCUB_DEVICE __forceinline__ + void BlockedToStriped( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type rocprim_warp_exchange; + rocprim_warp_exchange.blocked_to_striped(input_items, output_items, temp_storage); + } + + template + HIPCUB_DEVICE __forceinline__ + void StripedToBlocked( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD]) + { + base_type rocprim_warp_exchange; + rocprim_warp_exchange.striped_to_blocked(input_items, output_items, temp_storage); + } + + template + HIPCUB_DEVICE __forceinline__ + void ScatterToStriped( + InputT (&items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD]) + { + ScatterToStriped(items, items, ranks); + } + + template + HIPCUB_DEVICE __forceinline__ + void ScatterToStriped( + const InputT (&input_items)[ITEMS_PER_THREAD], + OutputT (&output_items)[ITEMS_PER_THREAD], + OffsetT (&ranks)[ITEMS_PER_THREAD]) + { + base_type rocprim_warp_exchange; + rocprim_warp_exchange.scatter_to_striped(input_items, output_items, ranks, temp_storage); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_EXCHANGE_HPP_ diff --git a/3rdparty/cub/warp/warp_load.hpp b/3rdparty/cub/warp/warp_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7d059efd28332fac9db637aa304bcd4f046594ab --- /dev/null +++ b/3rdparty/cub/warp/warp_load.hpp @@ -0,0 +1,412 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_ +#define HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" +#include "../iterator/cache_modified_input_iterator.cuh" +#include "./warp_exchange.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +enum WarpLoadAlgorithm +{ + WARP_LOAD_DIRECT, + WARP_LOAD_STRIPED, + WARP_LOAD_VECTORIZE, + WARP_LOAD_TRANSPOSE +}; + +template< + class InputT, + int ITEMS_PER_THREAD, + WarpLoadAlgorithm ALGORITHM = WARP_LOAD_DIRECT, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + int ARCH = HIPCUB_ARCH +> +class WarpLoad +{ +private: + constexpr static bool IS_ARCH_WARP + = static_cast(LOGICAL_WARP_THREADS) == HIPCUB_DEVICE_WARP_THREADS; + + template + struct LoadInternal; + + template <> + struct LoadInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ + LoadInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_blocked( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_load_direct_blocked( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + ::rocprim::block_load_direct_blocked( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items), + oob_default + ); + } + }; + + template <> + struct LoadInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ + LoadInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items), + oob_default + ); + } + }; + + template <> + struct LoadInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ LoadInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputT *block_ptr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_blocked_vectorized( + static_cast(linear_tid), + block_ptr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + const InputT *block_ptr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_blocked_vectorized( + static_cast(linear_tid), + block_ptr, + items + ); + } + + template< + CacheLoadModifier MODIFIER, + typename ValueType, + typename OffsetT + > + HIPCUB_DEVICE __forceinline__ void Load( + CacheModifiedInputIterator block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_blocked_vectorized( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + _InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_blocked_vectorized( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_load_direct_blocked_vectorized( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + // vectorized overload does not exist + // fall back to direct blocked + ::rocprim::block_load_direct_blocked( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items), + oob_default + ); + } + }; + + template <> + struct LoadInternal + { + using WarpExchangeT = WarpExchange< + InputT, + ITEMS_PER_THREAD, + LOGICAL_WARP_THREADS, + ARCH + >; + using TempStorage = typename WarpExchangeT::TempStorage; + TempStorage& temp_storage; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ LoadInternal( + TempStorage &temp_storage, + int linear_tid) : + temp_storage(temp_storage), + linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items + ); + WarpExchangeT(temp_storage).StripedToBlocked(items, items); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + WarpExchangeT(temp_storage).StripedToBlocked(items, items); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + ::rocprim::block_load_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items), + oob_default + ); + WarpExchangeT(temp_storage).StripedToBlocked(items, items); + } + }; + + using InternalLoad = LoadInternal; + + using _TempStorage = typename InternalLoad::TempStorage; + + HIPCUB_DEVICE __forceinline__ _TempStorage &PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + _TempStorage &temp_storage; + int linear_tid; + +public: + struct TempStorage : Uninitialized<_TempStorage> + { + }; + + HIPCUB_DEVICE __forceinline__ + WarpLoad() : + temp_storage(PrivateStorage()), + linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS)) + { + } + + HIPCUB_DEVICE __forceinline__ + WarpLoad(TempStorage &temp_storage) : + temp_storage(temp_storage.Alias()), + linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS)) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD]) + { + InternalLoad(temp_storage, linear_tid) + .Load(block_itr, items); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items) + { + InternalLoad(temp_storage, linear_tid) + .Load(block_itr, items, valid_items); + } + + template + HIPCUB_DEVICE __forceinline__ void Load( + InputIteratorT block_itr, + InputT (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + InternalLoad(temp_storage, linear_tid) + .Load(block_itr, items, valid_items, oob_default); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_ diff --git a/3rdparty/cub/warp/warp_merge_sort.hpp b/3rdparty/cub/warp/warp_merge_sort.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c8973aaaceb488c4c22a783f76c9a84583b7ac7 --- /dev/null +++ b/3rdparty/cub/warp/warp_merge_sort.hpp @@ -0,0 +1,179 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All + * rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_MERGE_SORT_ +#define HIPCUB_ROCPRIM_WARP_WARP_MERGE_SORT_ + +#include "../config.hpp" + +#include "../block/block_merge_sort.hpp" +#include "../util_ptx.cuh" +#include "../util_type.cuh" + +#include +#include + +BEGIN_HIPCUB_NAMESPACE + +/** + * @brief The WarpMergeSort class provides methods for sorting items partitioned + * across a CUDA warp using a merge sorting method. + * @ingroup WarpModule + * + * @tparam KeyT + * Key type + * + * @tparam ITEMS_PER_THREAD + * The number of items per thread + * + * @tparam LOGICAL_WARP_THREADS + * [optional] The number of threads per "logical" warp (may be less + * than the number of hardware warp threads). Default is the warp size of the + * targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a + * power of two. + * + * @tparam ValueT + * [optional] Value type (default: cub::NullType, which indicates a + * keys-only sort) + * + * @tparam PTX_ARCH + * [optional] \ptxversion + * + * @par Overview + * WarpMergeSort arranges items into ascending order using a comparison + * functor with less-than semantics. Merge sort can handle arbitrary types + * and comparison functors. + * + * @par A Simple Example + * @par + * The code snippet below illustrates a sort of 64 integer keys that are + * partitioned across 16 threads where each thread owns 4 consecutive items. + * @par + * @code + * #include // or equivalently + * + * struct CustomLess + * { + * template + * __device__ bool operator()(const DataType &lhs, const DataType &rhs) + * { + * return lhs < rhs; + * } + * }; + * + * __global__ void ExampleKernel(...) + * { + * constexpr int warp_threads = 16; + * constexpr int block_threads = 256; + * constexpr int items_per_thread = 4; + * constexpr int warps_per_block = block_threads / warp_threads; + * const int warp_id = static_cast(threadIdx.x) / warp_threads; + * + * // Specialize WarpMergeSort for a virtual warp of 16 threads + * // owning 4 integer items each + * using WarpMergeSortT = + * cub::WarpMergeSort; + * + * // Allocate shared memory for WarpMergeSort + * __shared__ typename WarpMergeSort::TempStorage temp_storage[warps_per_block]; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[items_per_thread]; + * // ... + * + * WarpMergeSort(temp_storage[warp_id]).Sort(thread_keys, CustomLess()); + * // ... + * } + * @endcode + * @par + * Suppose the set of input @p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + * The corresponding output @p thread_keys in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + */ +template < + typename KeyT, + int ITEMS_PER_THREAD, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + typename ValueT = NullType, + int PTX_ARCH = HIPCUB_ARCH> +class WarpMergeSort + : public BlockMergeSortStrategy< + KeyT, + ValueT, + LOGICAL_WARP_THREADS, + ITEMS_PER_THREAD, + WarpMergeSort> +{ +private: + constexpr static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == HIPCUB_DEVICE_WARP_THREADS; + constexpr static bool KEYS_ONLY = ::rocprim::Equals::VALUE; + constexpr static int TILE_SIZE = ITEMS_PER_THREAD * LOGICAL_WARP_THREADS; + + using BlockMergeSortStrategyT = BlockMergeSortStrategy; + + const unsigned int warp_id; + const uint64_t member_mask; + +public: + WarpMergeSort() = delete; + + HIPCUB_DEVICE __forceinline__ + WarpMergeSort(typename BlockMergeSortStrategyT::TempStorage &temp_storage) + : BlockMergeSortStrategyT(temp_storage, + IS_ARCH_WARP + ? LaneId() + : (LaneId() % LOGICAL_WARP_THREADS)) + , warp_id(IS_ARCH_WARP ? 0 : (LaneId() / LOGICAL_WARP_THREADS)) + , member_mask(WarpMask(warp_id)) + { + } + + HIPCUB_DEVICE __forceinline__ uint64_t get_member_mask() const + { + return member_mask; + } + +private: + HIPCUB_DEVICE __forceinline__ void SyncImplementation() const + { + WARP_SYNC(member_mask); + } + + friend BlockMergeSortStrategyT; +}; + + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_MERGE_SORT_ diff --git a/3rdparty/cub/warp/warp_reduce.cuh b/3rdparty/cub/warp/warp_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e1c8518914700caea1588d696d8ff8205518e2a7 --- /dev/null +++ b/3rdparty/cub/warp/warp_reduce.cuh @@ -0,0 +1,130 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_REDUCE_HPP_ +#define HIPCUB_ROCPRIM_WARP_WARP_REDUCE_HPP_ + +#include "../config.hpp" + +#include "../util_ptx.cuh" +#include "../thread/thread_operators.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + int ARCH = HIPCUB_ARCH> +class WarpReduce : private ::rocprim::warp_reduce +{ + static_assert(LOGICAL_WARP_THREADS > 0, "LOGICAL_WARP_THREADS must be greater than 0"); + using base_type = typename ::rocprim::warp_reduce; + + typename base_type::storage_type &temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + WarpReduce(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + HIPCUB_DEVICE inline + T Sum(T input) + { + base_type::reduce(input, input, temp_storage_); + return input; + } + + HIPCUB_DEVICE inline + T Sum(T input, int valid_items) + { + base_type::reduce(input, input, valid_items, temp_storage_); + return input; + } + + template + HIPCUB_DEVICE inline + T HeadSegmentedSum(T input, FlagT head_flag) + { + base_type::head_segmented_reduce(input, input, head_flag, temp_storage_); + return input; + } + + template + HIPCUB_DEVICE inline + T TailSegmentedSum(T input, FlagT tail_flag) + { + base_type::tail_segmented_reduce(input, input, tail_flag, temp_storage_); + return input; + } + + template + HIPCUB_DEVICE inline + T Reduce(T input, ReduceOp reduce_op) + { + base_type::reduce(input, input, temp_storage_, reduce_op); + return input; + } + + template + HIPCUB_DEVICE inline + T Reduce(T input, ReduceOp reduce_op, int valid_items) + { + base_type::reduce(input, input, valid_items, temp_storage_, reduce_op); + return input; + } + + template + HIPCUB_DEVICE inline + T HeadSegmentedReduce(T input, FlagT head_flag, ReduceOp reduce_op) + { + base_type::head_segmented_reduce( + input, input, head_flag, temp_storage_, reduce_op + ); + return input; + } + + template + HIPCUB_DEVICE inline + T TailSegmentedReduce(T input, FlagT tail_flag, ReduceOp reduce_op) + { + base_type::tail_segmented_reduce( + input, input, tail_flag, temp_storage_, reduce_op + ); + return input; + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_REDUCE_HPP_ diff --git a/3rdparty/cub/warp/warp_scan.cuh b/3rdparty/cub/warp/warp_scan.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cd0960f32eca45d8185e2ea0a8ddbd87686d270c --- /dev/null +++ b/3rdparty/cub/warp/warp_scan.cuh @@ -0,0 +1,172 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_SCAN_HPP_ +#define HIPCUB_ROCPRIM_WARP_WARP_SCAN_HPP_ + +#include "../config.hpp" + +#include "../util_ptx.cuh" +#include "../thread/thread_operators.cuh" + +#include + +BEGIN_HIPCUB_NAMESPACE + +template< + typename T, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + int ARCH = HIPCUB_ARCH> +class WarpScan : private ::rocprim::warp_scan +{ + static_assert(LOGICAL_WARP_THREADS > 0, "LOGICAL_WARP_THREADS must be greater than 0"); + using base_type = typename ::rocprim::warp_scan; + + typename base_type::storage_type &temp_storage_; + +public: + using TempStorage = typename base_type::storage_type; + + HIPCUB_DEVICE inline + WarpScan(TempStorage& temp_storage) : temp_storage_(temp_storage) + { + } + + HIPCUB_DEVICE inline + void InclusiveSum(T input, T& inclusive_output) + { + base_type::inclusive_scan(input, inclusive_output, temp_storage_); + } + + HIPCUB_DEVICE inline + void InclusiveSum(T input, T& inclusive_output, T& warp_aggregate) + { + base_type::inclusive_scan(input, inclusive_output, warp_aggregate, temp_storage_); + } + + HIPCUB_DEVICE inline + void ExclusiveSum(T input, T& exclusive_output) + { + base_type::exclusive_scan(input, exclusive_output, T(0), temp_storage_); + } + + HIPCUB_DEVICE inline + void ExclusiveSum(T input, T& exclusive_output, T& warp_aggregate) + { + base_type::exclusive_scan(input, exclusive_output, T(0), warp_aggregate, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T input, T& inclusive_output, ScanOp scan_op) + { + base_type::inclusive_scan(input, inclusive_output, temp_storage_, scan_op); + } + + template + HIPCUB_DEVICE inline + void InclusiveScan(T input, T& inclusive_output, ScanOp scan_op, T& warp_aggregate) + { + base_type::inclusive_scan( + input, inclusive_output, warp_aggregate, + temp_storage_, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& exclusive_output, ScanOp scan_op) + { + base_type::inclusive_scan(input, exclusive_output, temp_storage_, scan_op); + base_type::to_exclusive(exclusive_output, exclusive_output, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& exclusive_output, T initial_value, ScanOp scan_op) + { + base_type::exclusive_scan( + input, exclusive_output, initial_value, + temp_storage_, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& exclusive_output, ScanOp scan_op, T& warp_aggregate) + { + base_type::inclusive_scan( + input, exclusive_output, warp_aggregate, temp_storage_, scan_op + ); + base_type::to_exclusive(exclusive_output, exclusive_output, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void ExclusiveScan(T input, T& exclusive_output, T initial_value, ScanOp scan_op, T& warp_aggregate) + { + base_type::exclusive_scan( + input, exclusive_output, initial_value, warp_aggregate, + temp_storage_, scan_op + ); + } + + template + HIPCUB_DEVICE inline + void Scan(T input, T& inclusive_output, T& exclusive_output, ScanOp scan_op) + { + base_type::inclusive_scan(input, inclusive_output, temp_storage_, scan_op); + base_type::to_exclusive(inclusive_output, exclusive_output, temp_storage_); + } + + template + HIPCUB_DEVICE inline + void Scan(T input, T& inclusive_output, T& exclusive_output, T initial_value, ScanOp scan_op) + { + base_type::scan( + input, inclusive_output, exclusive_output, initial_value, + temp_storage_, scan_op + ); + // In CUB documentation it's unclear if inclusive_output should include initial_value, + // however,the implementation includes initial_value in inclusive_output in WarpScan::Scan(). + // In rocPRIM it's not included, and this is a fix to match CUB implementation. + // After confirmation from CUB's developers we will most probably change rocPRIM too. + inclusive_output = scan_op(initial_value, inclusive_output); + } + + HIPCUB_DEVICE inline + T Broadcast(T input, unsigned int src_lane) + { + return base_type::broadcast(input, src_lane, temp_storage_); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_SCAN_HPP_ diff --git a/3rdparty/cub/warp/warp_store.hpp b/3rdparty/cub/warp/warp_store.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c0064feb3dfd9b04c9994ac78ccf104abf279724 --- /dev/null +++ b/3rdparty/cub/warp/warp_store.hpp @@ -0,0 +1,317 @@ +/****************************************************************************** + * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#ifndef HIPCUB_ROCPRIM_WARP_WARP_STORE_HPP_ +#define HIPCUB_ROCPRIM_WARP_WARP_STORE_HPP_ + +#include "../config.hpp" + +#include "../util_type.cuh" +#include "./warp_exchange.hpp" + +#include + +BEGIN_HIPCUB_NAMESPACE + +enum WarpStoreAlgorithm +{ + WARP_STORE_DIRECT, + WARP_STORE_STRIPED, + WARP_STORE_VECTORIZE, + WARP_STORE_TRANSPOSE +}; + +template< + class T, + int ITEMS_PER_THREAD, + WarpStoreAlgorithm ALGORITHM = WARP_STORE_DIRECT, + int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS, + int ARCH = HIPCUB_ARCH +> +class WarpStore +{ +private: + constexpr static bool IS_ARCH_WARP + = static_cast(LOGICAL_WARP_THREADS) == HIPCUB_DEVICE_WARP_THREADS; + + template + struct StoreInternal; + + template <> + struct StoreInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ StoreInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_store_direct_blocked( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_store_direct_blocked( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + }; + + template <> + struct StoreInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ StoreInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_store_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + ::rocprim::block_store_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + }; + + template <> + struct StoreInternal + { + using TempStorage = NullType; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ StoreInternal( + TempStorage & /*temp_storage*/, + int linear_tid) + : linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + T *block_ptr, + T (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_store_direct_blocked_vectorized( + static_cast(linear_tid), + block_ptr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + _OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD]) + { + ::rocprim::block_store_direct_blocked_vectorized( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + // vectorized overload does not exist + // fall back to direct blocked + ::rocprim::block_store_direct_blocked( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + } + }; + + template <> + struct StoreInternal + { + using WarpExchangeT = WarpExchange< + T, + ITEMS_PER_THREAD, + LOGICAL_WARP_THREADS, + ARCH + >; + using TempStorage = typename WarpExchangeT::TempStorage; + TempStorage& temp_storage; + int linear_tid; + + HIPCUB_DEVICE __forceinline__ StoreInternal( + TempStorage &temp_storage, + int linear_tid) : + temp_storage(temp_storage), + linear_tid(linear_tid) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD]) + { + WarpExchangeT(temp_storage).BlockedToStriped(items, items); + ::rocprim::block_store_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items + ); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + WarpExchangeT(temp_storage).BlockedToStriped(items, items); + ::rocprim::block_store_direct_warp_striped( + static_cast(linear_tid), + block_itr, + items, + static_cast(valid_items) + ); + + } + }; + + using InternalStore = StoreInternal; + + using _TempStorage = typename InternalStore::TempStorage; + + HIPCUB_DEVICE __forceinline__ _TempStorage &PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + _TempStorage &temp_storage; + int linear_tid; + +public: + struct TempStorage : Uninitialized<_TempStorage> + { + }; + + HIPCUB_DEVICE __forceinline__ + WarpStore() : + temp_storage(PrivateStorage()), + linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS)) + { + } + + HIPCUB_DEVICE __forceinline__ + WarpStore(TempStorage &temp_storage) : + temp_storage(temp_storage.Alias()), + linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS)) + { + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD]) + { + InternalStore(temp_storage, linear_tid) + .Store(block_itr, items); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items) + { + InternalStore(temp_storage, linear_tid) + .Store(block_itr, items, valid_items); + } + + template + HIPCUB_DEVICE __forceinline__ void Store( + OutputIteratorT block_itr, + T (&items)[ITEMS_PER_THREAD], + int valid_items, + DefaultT oob_default) + { + InternalStore(temp_storage, linear_tid) + .Store(block_itr, items, valid_items, oob_default); + } +}; + +END_HIPCUB_NAMESPACE + +#endif // HIPCUB_ROCPRIM_WARP_WARP_STORE_HPP_ diff --git a/README.md b/README.md index d638978ae661aa38e144f54187e0e9134a3a4740..727bb0c5d75bd0efcc9653d5525f472894038ba0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# GLM130B_FT +# GLM130B_fastertransformer ## 论文 @@ -68,6 +68,10 @@ nvcc CMakeFiles/test_logprob_kernels.dir/test_logprob_kernels.cu.o -o ../../bin/ ``` +## 数据集 + +无 + ## 推理 ### 原版模型下载与转换 @@ -122,7 +126,7 @@ python ../examples/pytorch/glm/glm_tokenize.py [ModelZoo / GLM130B_FT · GitLab (hpccube.com)](https://developer.hpccube.com/codes/modelzoo/glm130b_ft) -## 参考 +## 参考资料 [THUDM/GLM-130B: GLM-130B: An Open Bilingual Pre-Trained Model (ICLR 2023) (github.com)](https://github.com/THUDM/GLM-130B)