"mmdet3d/structures/ops/iou3d_calculator.py" did not exist on "806802743ff395823594abff42f55c1ed2929bd3"
Commit f8a481f8 authored by zhouxiang's avatar zhouxiang
Browse files

添加dtk中的cub头文件

parent 7b7c64c5
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ,
unsigned int ItemsPerThread,
unsigned int Bins
>
class block_histogram_atomic
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
static_assert(
std::is_convertible<T, unsigned int>::value,
"T must be convertible to unsigned int"
);
public:
using storage_type = typename ::rocprim::detail::empty_storage_type;
template<class Counter>
ROCPRIM_DEVICE ROCPRIM_INLINE
void composite(T (&input)[ItemsPerThread],
Counter hist[Bins])
{
static_assert(
std::is_same<Counter, unsigned int>::value || std::is_same<Counter, int>::value ||
std::is_same<Counter, float>::value || std::is_same<Counter, unsigned long long>::value,
"Counter must be type that is supported by atomics (float, int, unsigned int, unsigned long long)"
);
ROCPRIM_UNROLL
for (unsigned int i = 0; i < ItemsPerThread; ++i)
{
::rocprim::detail::atomic_add(&hist[static_cast<unsigned int>(input[i])], Counter(1));
}
::rocprim::syncthreads();
}
template<class Counter>
ROCPRIM_DEVICE ROCPRIM_INLINE
void composite(T (&input)[ItemsPerThread],
Counter hist[Bins],
storage_type& storage)
{
(void) storage;
this->composite(input, hist);
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_ATOMIC_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../block_radix_sort.hpp"
#include "../block_discontinuity.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ,
unsigned int ItemsPerThread,
unsigned int Bins
>
class block_histogram_sort
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
static_assert(
std::is_convertible<T, unsigned int>::value,
"T must be convertible to unsigned int"
);
private:
using radix_sort = block_radix_sort<T, BlockSizeX, ItemsPerThread, empty_type, BlockSizeY, BlockSizeZ>;
using discontinuity = block_discontinuity<T, BlockSizeX, BlockSizeY, BlockSizeZ>;
public:
union storage_type_
{
typename radix_sort::storage_type sort;
struct
{
typename discontinuity::storage_type flag;
unsigned int start[Bins];
unsigned int end[Bins];
};
};
using storage_type = detail::raw_storage<storage_type_>;
template<class Counter>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void composite(T (&input)[ItemsPerThread],
Counter hist[Bins])
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->composite(input, hist, storage);
}
template<class Counter>
ROCPRIM_DEVICE ROCPRIM_INLINE
void composite(T (&input)[ItemsPerThread],
Counter hist[Bins],
storage_type& storage)
{
// TODO: Check, MSVC rejects the code with the static assertion, yet compiles fine for all tested types. Predicate likely too strict
//static_assert(
// std::is_convertible<unsigned int, Counter>::value,
// "unsigned int must be convertible to Counter"
//);
constexpr auto tile_size = BlockSize * ItemsPerThread;
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
unsigned int head_flags[ItemsPerThread];
discontinuity_op flags_op(storage);
storage_type_& storage_ = storage.get();
radix_sort().sort(input, storage_.sort);
::rocprim::syncthreads(); // Fix race condition that appeared on Vega10 hardware, storage LDS is reused below.
ROCPRIM_UNROLL
for(unsigned int offset = 0; offset < Bins; offset += BlockSize)
{
const unsigned int offset_tid = offset + flat_tid;
if(offset_tid < Bins)
{
storage_.start[offset_tid] = tile_size;
storage_.end[offset_tid] = tile_size;
}
}
::rocprim::syncthreads();
discontinuity().flag_heads(head_flags, input, flags_op, storage_.flag);
::rocprim::syncthreads();
// The start of the first bin is not overwritten since the input is sorted
// and the starts are based on the second item.
// The very first item is never used as `b` in the operator
// This means that this should not need synchromization, but in practice it does.
if(flat_tid == 0)
{
storage_.start[static_cast<unsigned int>(input[0])] = 0;
}
::rocprim::syncthreads();
ROCPRIM_UNROLL
for(unsigned int offset = 0; offset < Bins; offset += BlockSize)
{
const unsigned int offset_tid = offset + flat_tid;
if(offset_tid < Bins)
{
Counter count = static_cast<Counter>(storage_.end[offset_tid] - storage_.start[offset_tid]);
hist[offset_tid] += count;
}
}
}
private:
struct discontinuity_op
{
storage_type &storage;
ROCPRIM_DEVICE ROCPRIM_INLINE
discontinuity_op(storage_type &storage) : storage(storage)
{
}
ROCPRIM_DEVICE ROCPRIM_INLINE
bool operator()(const T& a, const T& b, unsigned int b_index) const
{
storage_type_& storage_ = storage.get();
if(static_cast<unsigned int>(a) != static_cast<unsigned int>(b))
{
storage_.start[static_cast<unsigned int>(b)] = b_index;
storage_.end[static_cast<unsigned int>(a)] = b_index;
return true;
}
else
{
return false;
}
}
};
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../warp/warp_reduce.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ,
bool CommutativeOnly = false
>
class block_reduce_raking_reduce
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
// Number of items to reduce per thread
static constexpr unsigned int thread_reduction_size_ =
(BlockSize + ::rocprim::device_warp_size() - 1)/ ::rocprim::device_warp_size();
// Warp reduce, warp_reduce_crosslane does not require shared memory (storage), but
// logical warp size must be a power of two.
static constexpr unsigned int warp_size_ =
detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
static constexpr bool commutative_only_ = CommutativeOnly && ((BlockSize % warp_size_ == 0) && (BlockSize > warp_size_));
static constexpr unsigned int sharing_threads_ = ::rocprim::max<int>(1, BlockSize - warp_size_);
static constexpr unsigned int segment_length_ = sharing_threads_ / warp_size_;
// BlockSize is multiple of hardware warp
static constexpr bool block_size_smaller_than_warp_size_ = (BlockSize < warp_size_);
using warp_reduce_prefix_type = ::rocprim::detail::warp_reduce_crosslane<T, warp_size_, false>;
struct storage_type_
{
T threads[BlockSize];
};
public:
using storage_type = detail::raw_storage<storage_type_>;
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input to be reduced
/// \param output [out] Variable containing reduction output
/// \param storage [in] Temporary Storage used for the Reduction
/// \param reduce_op [in] Binary reduction operator
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T input,
T& output,
storage_type& storage,
BinaryFunction reduce_op)
{
this->reduce_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, storage, reduce_op
);
}
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input to be reduced
/// \param output [out] Variable containing reduction output
/// \param reduce_op [in] Binary reduction operator
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T input,
T& output,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, storage, reduce_op);
}
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input array to be reduced
/// \param output [out] Variable containing reduction output
/// \param storage [in] Temporary Storage used for the Reduction
/// \param reduce_op [in] Binary reduction operator
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T (&input)[ItemsPerThread],
T& output,
storage_type& storage,
BinaryFunction reduce_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = reduce_op(thread_input, input[i]);
}
// Reduction of reduced values to get partials
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->reduce_impl(
flat_tid,
thread_input, output, // input, output
storage,
reduce_op
);
}
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input array to be reduced
/// \param output [out] Variable containing reduction output
/// \param reduce_op [in] Binary reduction operator
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T (&input)[ItemsPerThread],
T& output,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, storage, reduce_op);
}
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input partial reductions
/// \param output [out] Variable containing reduction output
/// \param valid_items [in] Number of valid elements (may be less than BlockSize)
/// \param storage [in] Temporary Storage used for reduction
/// \param reduce_op [in] Binary reduction operator
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T input,
T& output,
unsigned int valid_items,
storage_type& storage,
BinaryFunction reduce_op)
{
this->reduce_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, valid_items, storage, reduce_op
);
}
/// \brief Computes a thread block-wide reduction using addition (+) as the reduction operator. The first num_valid threads each contribute one reduction partial. The return value is only valid for thread<sub>0</sub>.
/// \param input [in] Calling thread's input partial reductions
/// \param output [out] Variable containing reduction output
/// \param valid_items [in] Number of valid elements (may be less than BlockSize)
/// \param reduce_op [in] Binary reduction operator
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T input,
T& output,
unsigned int valid_items,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, valid_items, storage, reduce_op);
}
private:
template<class BinaryFunction, bool FunctionCommutativeOnly = commutative_only_>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto reduce_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction reduce_op)
-> typename std::enable_if<(!FunctionCommutativeOnly), void>::type
{
storage_type_& storage_ = storage.get();
storage_.threads[flat_tid] = input;
::rocprim::syncthreads();
if (flat_tid < warp_size_)
{
T thread_reduction = storage_.threads[flat_tid];
for(unsigned int i = warp_size_ + flat_tid; i < BlockSize; i += warp_size_)
{
thread_reduction = reduce_op(
thread_reduction, storage_.threads[i]
);
}
warp_reduce<block_size_smaller_than_warp_size_, warp_reduce_prefix_type>(
thread_reduction, output, BlockSize, reduce_op
);
}
}
template<class BinaryFunction, bool FunctionCommutativeOnly = commutative_only_>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto reduce_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction reduce_op)
-> typename std::enable_if<(FunctionCommutativeOnly), void>::type
{
storage_type_& storage_ = storage.get();
if (flat_tid >= warp_size_)
storage_.threads[flat_tid - warp_size_] = input;
::rocprim::syncthreads();
if (flat_tid < warp_size_)
{
T thread_reduction = input;
T* storage_pointer = &storage_.threads[flat_tid * segment_length_];
#pragma unroll
for( unsigned int i = 0; i < segment_length_; i++ )
{
thread_reduction = reduce_op(
thread_reduction, storage_pointer[i]
);
}
warp_reduce<block_size_smaller_than_warp_size_, warp_reduce_prefix_type>(
thread_reduction, output, BlockSize, reduce_op
);
}
}
template<bool UseValid, class WarpReduce, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto warp_reduce(T input,
T& output,
const unsigned int valid_items,
BinaryFunction reduce_op)
-> typename std::enable_if<UseValid>::type
{
WarpReduce().reduce(
input, output, valid_items, reduce_op
);
}
template<bool UseValid, class WarpReduce, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto warp_reduce(T input,
T& output,
const unsigned int valid_items,
BinaryFunction reduce_op)
-> typename std::enable_if<!UseValid>::type
{
(void) valid_items;
WarpReduce().reduce(
input, output, reduce_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce_impl(const unsigned int flat_tid,
T input,
T& output,
const unsigned int valid_items,
storage_type& storage,
BinaryFunction reduce_op)
{
storage_type_& storage_ = storage.get();
storage_.threads[flat_tid] = input;
::rocprim::syncthreads();
if (flat_tid < warp_size_)
{
T thread_reduction = storage_.threads[flat_tid];
for(unsigned int i = warp_size_ + flat_tid; i < BlockSize; i += warp_size_)
{
if(i < valid_items)
{
thread_reduction = reduce_op(thread_reduction, storage_.threads[i]);
}
}
warp_reduce_prefix_type().reduce(thread_reduction, output, valid_items, reduce_op);
}
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_RAKING_REDUCE_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../warp/warp_reduce.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ
>
class block_reduce_warp_reduce
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
// Select warp size
static constexpr unsigned int warp_size_ =
detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
// Number of warps in block
static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_;
// Check if we have to pass number of valid items into warp reduction primitive
static constexpr bool block_size_is_warp_multiple_ = ((BlockSize % warp_size_) == 0);
static constexpr bool warps_no_is_pow_of_two_ = detail::is_power_of_two(warps_no_);
// typedef of warp_reduce primitive that will be used to perform warp-level
// reduce operation on input values.
// warp_reduce_crosslane is an implementation of warp_reduce that does not need storage,
// but requires logical warp size to be a power of two.
using warp_reduce_input_type = ::rocprim::detail::warp_reduce_crosslane<T, warp_size_, false>;
// typedef of warp_reduce primitive that will be used to perform reduction
// of results of warp-level reduction.
using warp_reduce_output_type = ::rocprim::detail::warp_reduce_crosslane<
T, detail::next_power_of_two(warps_no_), false
>;
struct storage_type_
{
T warp_partials[warps_no_];
};
public:
using storage_type = detail::raw_storage<storage_type_>;
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T input,
T& output,
storage_type& storage,
BinaryFunction reduce_op)
{
this->reduce_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, storage, reduce_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T input,
T& output,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, storage, reduce_op);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T (&input)[ItemsPerThread],
T& output,
storage_type& storage,
BinaryFunction reduce_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = reduce_op(thread_input, input[i]);
}
// Reduction of reduced values to get partials
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->reduce_impl(
flat_tid,
thread_input, output, // input, output
storage,
reduce_op
);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T (&input)[ItemsPerThread],
T& output,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, storage, reduce_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce(T input,
T& output,
unsigned int valid_items,
storage_type& storage,
BinaryFunction reduce_op)
{
this->reduce_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, valid_items, storage, reduce_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void reduce(T input,
T& output,
unsigned int valid_items,
BinaryFunction reduce_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->reduce(input, output, valid_items, storage, reduce_op);
}
private:
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction reduce_op)
{
const auto warp_id = ::rocprim::warp_id(flat_tid);
const auto lane_id = ::rocprim::lane_id();
const unsigned int warp_offset = warp_id * warp_size_;
const unsigned int num_valid =
(warp_offset < BlockSize) ? BlockSize - warp_offset : 0;
storage_type_& storage_ = storage.get();
// Perform warp reduce
warp_reduce<!block_size_is_warp_multiple_, warp_reduce_input_type>(
input, output, num_valid, reduce_op
);
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();
if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];
warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
warp_partial, output, warps_no_, reduce_op
);
}
}
template<bool UseValid, class WarpReduce, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto warp_reduce(T input,
T& output,
const unsigned int valid_items,
BinaryFunction reduce_op)
-> typename std::enable_if<UseValid>::type
{
WarpReduce().reduce(
input, output, valid_items, reduce_op
);
}
template<bool UseValid, class WarpReduce, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto warp_reduce(T input,
T& output,
const unsigned int valid_items,
BinaryFunction reduce_op)
-> typename std::enable_if<!UseValid>::type
{
(void) valid_items;
WarpReduce().reduce(
input, output, reduce_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void reduce_impl(const unsigned int flat_tid,
T input,
T& output,
const unsigned int valid_items,
storage_type& storage,
BinaryFunction reduce_op)
{
const auto warp_id = ::rocprim::warp_id(flat_tid);
const auto lane_id = ::rocprim::lane_id();
const unsigned int warp_offset = warp_id * warp_size_;
const unsigned int num_valid =
(warp_offset < valid_items) ? valid_items - warp_offset : 0;
storage_type_& storage_ = storage.get();
// Perform warp reduce
warp_reduce_input_type().reduce(
input, output, num_valid, reduce_op
);
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();
if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];
unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_;
warp_reduce_output_type().reduce(
warp_partial, output, valid_warps_no, reduce_op
);
}
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_REDUCE_WARP_REDUCE_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../warp/warp_scan.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ
>
class block_scan_reduce_then_scan
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
// Number of items to reduce per thread
static constexpr unsigned int thread_reduction_size_ =
(BlockSize + ::rocprim::device_warp_size() - 1)/ ::rocprim::device_warp_size();
// Warp scan, warp_scan_crosslane does not require shared memory (storage), but
// logical warp size must be a power of two.
static constexpr unsigned int warp_size_ =
detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
// Minimize LDS bank conflicts
static constexpr unsigned int banks_no_ = ::rocprim::detail::get_lds_banks_no();
static constexpr bool has_bank_conflicts_ =
::rocprim::detail::is_power_of_two(thread_reduction_size_) && thread_reduction_size_ > 1;
static constexpr unsigned int bank_conflicts_padding =
has_bank_conflicts_ ? (warp_size_ * thread_reduction_size_ / banks_no_) : 0;
struct storage_type_
{
T threads[warp_size_ * thread_reduction_size_ + bank_conflicts_padding];
};
public:
using storage_type = detail::raw_storage<storage_type_>;
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T input,
T& output,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, storage, scan_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->inclusive_scan(input, output, storage, scan_op);
reduction = storage_.threads[index(BlockSize - 1)];
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T input,
T& output,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, reduction, storage, scan_op);
}
template<class PrefixCallback, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
const auto warp_id = ::rocprim::warp_id(flat_tid);
storage_type_& storage_ = storage.get();
this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
// Include block prefix (this operation overwrites storage_.threads[0])
T block_prefix = this->get_block_prefix(
flat_tid, warp_id,
storage_.threads[index(BlockSize - 1)], // block reduction
prefix_callback_op, storage
);
output = scan_op(block_prefix, output);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
BinaryFunction scan_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// Include prefix (first thread does not have prefix)
output[0] = input[0];
if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
// Final thread-local scan
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
output[i] = scan_op(output[i-1], input[i]);
}
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, storage, scan_op);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->inclusive_scan(input, output, storage, scan_op);
// Save reduction result
reduction = storage_.threads[index(BlockSize - 1)];
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, reduction, storage, scan_op);
}
template<
class PrefixCallback,
unsigned int ItemsPerThread,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// this operation overwrites storage_.threads[0]
T block_prefix = this->get_block_prefix(
flat_tid, ::rocprim::warp_id(flat_tid),
storage_.threads[index(BlockSize - 1)], // block reduction
prefix_callback_op, storage
);
// Include prefix (first thread does not have prefix)
output[0] = input[0];
if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
// Include block prefix
output[0] = scan_op(block_prefix, output[0]);
// Final thread-local scan
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
output[i] = scan_op(output[i-1], input[i]);
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
T init,
storage_type& storage,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(flat_tid, input, output, init, storage, scan_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T input,
T& output,
T init,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, storage, scan_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
T init,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
storage_type_& storage_ = storage.get();
this->exclusive_scan_impl(
flat_tid, input, output, init, storage, scan_op
);
// Save reduction result
reduction = storage_.threads[index(BlockSize - 1)];
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T input,
T& output,
T init,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, reduction, storage, scan_op);
}
template<class PrefixCallback, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
const auto warp_id = ::rocprim::warp_id(flat_tid);
storage_type_& storage_ = storage.get();
this->exclusive_scan_impl(
flat_tid, input, output, storage, scan_op
);
// Get reduction result
T reduction = storage_.threads[index(BlockSize - 1)];
// Include block prefix (this operation overwrites storage_.threads[0])
T block_prefix = this->get_block_prefix(
flat_tid, warp_id, reduction,
prefix_callback_op, storage
);
output = scan_op(block_prefix, output);
if(flat_tid == 0) output = block_prefix;
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
storage_type& storage,
BinaryFunction scan_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
init,
storage,
scan_op
);
// Include init value
T prev = input[0];
T exclusive = init;
if(flat_tid != 0)
{
exclusive = thread_input;
}
output[0] = exclusive;
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
exclusive = scan_op(exclusive, prev);
prev = input[i];
output[i] = exclusive;
}
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, storage, scan_op);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->exclusive_scan(input, output, init, storage, scan_op);
// Save reduction result
reduction = storage_.threads[index(BlockSize - 1)];
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, reduction, storage, scan_op);
}
template<
class PrefixCallback,
unsigned int ItemsPerThread,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// this operation overwrites storage_.warp_prefixes[0]
T block_prefix = this->get_block_prefix(
flat_tid, ::rocprim::warp_id(flat_tid),
storage_.threads[index(BlockSize - 1)], // block reduction
prefix_callback_op, storage
);
// Include init value and block prefix
T prev = input[0];
T exclusive = block_prefix;
if(flat_tid != 0)
{
exclusive = scan_op(block_prefix, thread_input);
}
output[0] = exclusive;
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
exclusive = scan_op(exclusive, prev);
prev = input[i];
output[i] = exclusive;
}
}
private:
// Calculates inclusive scan results and stores them in storage_.threads,
// result for each thread is stored in storage_.threads[flat_tid], and sets
// output to storage_.threads[flat_tid]
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Calculate inclusive scan,
// result for each thread is stored in storage_.threads[flat_tid]
this->inclusive_scan_base(flat_tid, input, storage, scan_op);
output = storage_.threads[index(flat_tid)];
}
// Calculates inclusive scan results and stores them in storage_.threads,
// result for each thread is stored in storage_.threads[flat_tid]
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan_base(const unsigned int flat_tid,
T input,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
storage_.threads[index(flat_tid)] = input;
::rocprim::syncthreads();
if(flat_tid < warp_size_)
{
const unsigned int idx_start = index(flat_tid * thread_reduction_size_);
const unsigned int idx_end = idx_start + thread_reduction_size_;
T thread_reduction = storage_.threads[idx_start];
ROCPRIM_UNROLL
for(unsigned int i = idx_start + 1; i < idx_end; i++)
{
thread_reduction = scan_op(
thread_reduction, storage_.threads[i]
);
}
// Calculate warp prefixes
warp_scan_prefix_type().inclusive_scan(thread_reduction, thread_reduction, scan_op);
thread_reduction = warp_shuffle_up(thread_reduction, 1, warp_size_);
// Include warp prefix
thread_reduction = scan_op(thread_reduction, storage_.threads[idx_start]);
if(flat_tid == 0)
{
thread_reduction = input;
}
storage_.threads[idx_start] = thread_reduction;
ROCPRIM_UNROLL
for(unsigned int i = idx_start + 1; i < idx_end; i++)
{
thread_reduction = scan_op(
thread_reduction, storage_.threads[i]
);
storage_.threads[i] = thread_reduction;
}
}
::rocprim::syncthreads();
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
T init,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Calculates inclusive scan, result for each thread is stored in storage_.threads[flat_tid]
this->inclusive_scan_base(flat_tid, input, storage, scan_op);
output = init;
if(flat_tid != 0) output = scan_op(init, storage_.threads[index(flat_tid-1)]);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Calculates inclusive scan, result for each thread is stored in storage_.threads[flat_tid]
this->inclusive_scan_base(flat_tid, input, storage, scan_op);
if(flat_tid > 0)
{
output = storage_.threads[index(flat_tid-1)];
}
}
// OVERWRITES storage_.threads[0]
template<class PrefixCallback, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void include_block_prefix(const unsigned int flat_tid,
const unsigned int warp_id,
const T input,
T& output,
const T reduction,
PrefixCallback& prefix_callback_op,
storage_type& storage,
BinaryFunction scan_op)
{
T block_prefix = this->get_block_prefix(
flat_tid, warp_id, reduction,
prefix_callback_op, storage
);
output = scan_op(block_prefix, input);
}
// OVERWRITES storage_.threads[0]
template<class PrefixCallback>
ROCPRIM_DEVICE ROCPRIM_INLINE
T get_block_prefix(const unsigned int flat_tid,
const unsigned int warp_id,
const T reduction,
PrefixCallback& prefix_callback_op,
storage_type& storage)
{
storage_type_& storage_ = storage.get();
if(warp_id == 0)
{
T block_prefix = prefix_callback_op(reduction);
if(flat_tid == 0)
{
// Reuse storage_.threads[0] which should not be
// needed at that point.
storage_.threads[0] = block_prefix;
}
}
::rocprim::syncthreads();
return storage_.threads[0];
}
// Change index to minimize LDS bank conflicts if necessary
ROCPRIM_DEVICE ROCPRIM_INLINE
unsigned int index(unsigned int n) const
{
// Move every 32-bank wide "row" (32 banks * 4 bytes) by one item
return has_bank_conflicts_ ? (n + (n/banks_no_)) : n;
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../warp/warp_scan.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class T,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ
>
class block_scan_warp_scan
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
// Select warp size
static constexpr unsigned int warp_size_ =
detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
// Number of warps in block
static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_;
// typedef of warp_scan primitive that will be used to perform warp-level
// inclusive/exclusive scan operations on input values.
// warp_scan_crosslane is an implementation of warp_scan that does not need storage,
// but requires logical warp size to be a power of two.
using warp_scan_input_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
// typedef of warp_scan primitive that will be used to get prefix values for
// each warp (scanned carry-outs from warps before it).
using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, detail::next_power_of_two(warps_no_)>;
struct storage_type_
{
T warp_prefixes[warps_no_];
// ---------- Shared memory optimisation ----------
// Since warp_scan_input and warp_scan_prefix are typedef of warp_scan_crosslane,
// we don't need to allocate any temporary memory for them.
// If we just use warp_scan, we would need to add following union to this struct:
// union
// {
// typename warp_scan_input::storage_type wscan[warps_no_];
// typename warp_scan_prefix::storage_type wprefix_scan;
// };
// and use storage_.wscan[warp_id] and storage.wprefix_scan when calling
// warp_scan_input().inclusive_scan(..) and warp_scan_prefix().inclusive_scan(..).
};
public:
using storage_type = detail::raw_storage<storage_type_>;
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
{
this->inclusive_scan_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, storage, scan_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T input,
T& output,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, storage, scan_op);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->inclusive_scan(input, output, storage, scan_op);
// Save reduction result
reduction = storage_.warp_prefixes[warps_no_ - 1];
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T input,
T& output,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, reduction, storage, scan_op);
}
template<class PrefixCallback, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T input,
T& output,
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
const auto warp_id = ::rocprim::warp_id(flat_tid);
storage_type_& storage_ = storage.get();
this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
// Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1])
T block_prefix = this->get_block_prefix(
flat_tid, warp_id,
storage_.warp_prefixes[warps_no_ - 1], // block reduction
prefix_callback_op, storage
);
output = scan_op(block_prefix, output);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
BinaryFunction scan_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// Include prefix (first thread does not have prefix)
output[0] = input[0];
if(flat_tid != 0)
{
output[0] = scan_op(thread_input, input[0]);
}
// Final thread-local scan
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
output[i] = scan_op(output[i-1], input[i]);
}
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, storage, scan_op);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->inclusive_scan(input, output, storage, scan_op);
// Save reduction result
reduction = storage_.warp_prefixes[warps_no_ - 1];
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->inclusive_scan(input, output, reduction, storage, scan_op);
}
template<
class PrefixCallback,
unsigned int ItemsPerThread,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void inclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// this operation overwrites storage_.warp_prefixes[warps_no_ - 1]
T block_prefix = this->get_block_prefix(
flat_tid, ::rocprim::warp_id(flat_tid),
storage_.warp_prefixes[warps_no_ - 1], // block reduction
prefix_callback_op, storage
);
// Include prefix (first thread does not have prefix)
output[0] = input[0];
if(flat_tid != 0)
{
output[0] = scan_op(thread_input, input[0]);
}
// Include block prefix
output[0] = scan_op(block_prefix, output[0]);
// Final thread-local scan
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
output[i] = scan_op(output[i-1], input[i]);
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
T init,
storage_type& storage,
BinaryFunction scan_op)
{
this->exclusive_scan_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
input, output, init, storage, scan_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T input,
T& output,
T init,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(
input, output, init, storage, scan_op
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
T init,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->exclusive_scan(
input, output, init, storage, scan_op
);
// Save reduction result
reduction = storage_.warp_prefixes[warps_no_ - 1];
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T input,
T& output,
T init,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(
input, output, init, reduction, storage, scan_op
);
}
template<class PrefixCallback, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T input,
T& output,
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
const auto warp_id = ::rocprim::warp_id(flat_tid);
storage_type_& storage_ = storage.get();
this->exclusive_scan_impl(
flat_tid, input, output, storage, scan_op
);
// Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1])
T block_prefix = this->get_block_prefix(
flat_tid, warp_id,
storage_.warp_prefixes[warps_no_ - 1], // block reduction
prefix_callback_op, storage
);
output = scan_op(block_prefix, output);
if(flat_tid == 0) output = block_prefix;
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
storage_type& storage,
BinaryFunction scan_op)
{
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
init,
storage,
scan_op
);
// Include init value
T prev = input[0];
T exclusive = init;
if(flat_tid != 0)
{
exclusive = thread_input;
}
output[0] = exclusive;
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
exclusive = scan_op(exclusive, prev);
prev = input[i];
output[i] = exclusive;
}
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, storage, scan_op);
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
T& reduction,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
this->exclusive_scan(input, output, init, storage, scan_op);
// Save reduction result
reduction = storage_.warp_prefixes[warps_no_ - 1];
}
template<unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
T init,
T& reduction,
BinaryFunction scan_op)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->exclusive_scan(input, output, init, reduction, storage, scan_op);
}
template<
class PrefixCallback,
unsigned int ItemsPerThread,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void exclusive_scan(T (&input)[ItemsPerThread],
T (&output)[ItemsPerThread],
storage_type& storage,
PrefixCallback& prefix_callback_op,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Reduce thread items
T thread_input = input[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
thread_input = scan_op(thread_input, input[i]);
}
// Scan of reduced values to get prefixes
const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
this->exclusive_scan_impl(
flat_tid,
thread_input, thread_input, // input, output
storage,
scan_op
);
// this operation overwrites storage_.warp_prefixes[warps_no_ - 1]
T block_prefix = this->get_block_prefix(
flat_tid, ::rocprim::warp_id(flat_tid),
storage_.warp_prefixes[warps_no_ - 1], // block reduction
prefix_callback_op, storage
);
// Include init value and block prefix
T prev = input[0];
T exclusive = block_prefix;
if(flat_tid != 0)
{
exclusive = scan_op(block_prefix, thread_input);
}
output[0] = exclusive;
ROCPRIM_UNROLL
for(unsigned int i = 1; i < ItemsPerThread; i++)
{
exclusive = scan_op(exclusive, prev);
prev = input[i];
output[i] = exclusive;
}
}
private:
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto inclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
{
storage_type_& storage_ = storage.get();
// Perform warp scan
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
// i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
const auto warp_id = ::rocprim::warp_id(flat_tid);
this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
// Use warp prefix to calculate the final scan results for every thread
if(warp_id != 0)
{
auto warp_prefix = storage_.warp_prefixes[warp_id - 1];
output = scan_op(warp_prefix, output);
}
}
// When BlockSize is less than warp_size we dont need the extra prefix calculations.
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto inclusive_scan_impl(unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
{
(void) storage;
(void) flat_tid;
storage_type_& storage_ = storage.get();
// Perform warp scan
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
if(flat_tid == BlockSize_ - 1)
{
storage_.warp_prefixes[0] = output;
}
::rocprim::syncthreads();
}
// Exclusive scan with initial value when BlockSize is bigger than warp_size
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
T init,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
{
storage_type_& storage_ = storage.get();
// Perform warp scan on input values
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
// i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
const auto warp_id = ::rocprim::warp_id(flat_tid);
this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
// Include initial value in warp prefixes, and fix warp prefixes
// for exclusive scan (first warp prefix is init)
auto warp_prefix = init;
if(warp_id != 0)
{
warp_prefix = scan_op(init, storage_.warp_prefixes[warp_id-1]);
}
// Use warp prefix to calculate the final scan results for every thread
output = scan_op(warp_prefix, output); // include warp prefix in scan results
output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
if(::rocprim::lane_id() == 0)
{
output = warp_prefix;
}
}
// Exclusive scan with initial value when BlockSize is less than warp_size.
// When BlockSize is less than warp_size we dont need the extra prefix calculations.
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
T init,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
{
(void) flat_tid;
(void) storage;
(void) init;
storage_type_& storage_ = storage.get();
// Perform warp scan on input values
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
if(flat_tid == BlockSize_ - 1)
{
storage_.warp_prefixes[0] = output;
}
::rocprim::syncthreads();
// Use warp prefix to calculate the final scan results for every thread
output = scan_op(init, output); // include warp prefix in scan results
output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
if(::rocprim::lane_id() == 0)
{
output = init;
}
}
// Exclusive scan with unknown initial value
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
{
storage_type_& storage_ = storage.get();
// Perform warp scan on input values
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
// i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
const auto warp_id = ::rocprim::warp_id(flat_tid);
this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
// Use warp prefix to calculate the final scan results for every thread
T warp_prefix;
if(warp_id != 0)
{
warp_prefix = storage_.warp_prefixes[warp_id - 1];
output = scan_op(warp_prefix, output);
}
output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
if(::rocprim::lane_id() == 0)
{
output = warp_prefix;
}
}
// Exclusive scan with unknown initial value, when BlockSize less than warp_size.
// When BlockSize is less than warp_size we dont need the extra prefix calculations.
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto exclusive_scan_impl(const unsigned int flat_tid,
T input,
T& output,
storage_type& storage,
BinaryFunction scan_op)
-> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
{
(void) flat_tid;
(void) storage;
storage_type_& storage_ = storage.get();
// Perform warp scan on input values
warp_scan_input_type().inclusive_scan(
// not using shared mem, see note in storage_type
input, output, scan_op
);
if(flat_tid == BlockSize_ - 1)
{
storage_.warp_prefixes[0] = output;
}
::rocprim::syncthreads();
output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
}
// i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
ROCPRIM_DEVICE ROCPRIM_INLINE
void calculate_warp_prefixes(const unsigned int flat_tid,
const unsigned int warp_id,
T inclusive_input,
storage_type& storage,
BinaryFunction scan_op)
{
storage_type_& storage_ = storage.get();
// Save the warp reduction result, that is the scan result
// for last element in each warp
if(flat_tid == ::rocprim::min((warp_id+1) * warp_size_, BlockSize_) - 1)
{
storage_.warp_prefixes[warp_id] = inclusive_input;
}
::rocprim::syncthreads();
// Scan the warp reduction results and store in storage_.warp_prefixes
if(flat_tid < warps_no_)
{
auto warp_prefix = storage_.warp_prefixes[flat_tid];
warp_scan_prefix_type().inclusive_scan(
// not using shared mem, see note in storage_type
warp_prefix, warp_prefix, scan_op
);
storage_.warp_prefixes[flat_tid] = warp_prefix;
}
::rocprim::syncthreads();
}
// THIS OVERWRITES storage_.warp_prefixes[warps_no_ - 1]
template<class PrefixCallback>
ROCPRIM_DEVICE ROCPRIM_INLINE
T get_block_prefix(const unsigned int flat_tid,
const unsigned int warp_id,
const T reduction,
PrefixCallback& prefix_callback_op,
storage_type& storage)
{
storage_type_& storage_ = storage.get();
if(warp_id == 0)
{
T block_prefix = prefix_callback_op(reduction);
if(flat_tid == 0)
{
// Reuse storage_.warp_prefixes[warps_no_ - 1] to store block prefix
storage_.warp_prefixes[warps_no_ - 1] = block_prefix;
}
}
::rocprim::syncthreads();
return storage_.warp_prefixes[warps_no_ - 1];
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
#define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../warp/warp_sort.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class Key,
unsigned int BlockSizeX,
unsigned int BlockSizeY,
unsigned int BlockSizeZ,
unsigned int ItemsPerThread,
class Value
>
class block_sort_bitonic
{
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
template<class KeyType, class ValueType>
struct storage_type_
{
KeyType key[BlockSize * ItemsPerThread];
ValueType value[BlockSize * ItemsPerThread];
};
template<class KeyType>
struct storage_type_<KeyType, empty_type>
{
KeyType key[BlockSize * ItemsPerThread];
};
public:
using storage_type = detail::raw_storage<storage_type_<Key, Value>>;
static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be a power of two!");
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
storage_type& storage,
BinaryFunction compare_function)
{
this->sort_impl<BlockSize>(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
storage, compare_function,
thread_key
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
storage_type& storage,
BinaryFunction compare_function)
{
this->sort_impl<BlockSize>(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
storage, compare_function,
thread_keys
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key& thread_key,
BinaryFunction compare_function)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->sort(thread_key, storage, compare_function);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
BinaryFunction compare_function)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->sort(thread_keys, storage, compare_function);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
Value& thread_value,
storage_type& storage,
BinaryFunction compare_function)
{
this->sort_impl<BlockSize>(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
storage, compare_function,
thread_key, thread_value
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE inline
void sort(Key (&thread_keys)[ItemsPerThread],
Value (&thread_values)[ItemsPerThread],
storage_type& storage,
BinaryFunction compare_function)
{
this->sort_impl<BlockSize>(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
storage, compare_function,
thread_keys, thread_values
);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key& thread_key,
Value& thread_value,
BinaryFunction compare_function)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->sort(thread_key, thread_value, storage, compare_function);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
Value (&thread_values)[ItemsPerThread],
BinaryFunction compare_function)
{
ROCPRIM_SHARED_MEMORY storage_type storage;
this->sort(thread_keys, thread_values, storage, compare_function);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
storage_type& storage,
const unsigned int size,
BinaryFunction compare_function)
{
this->sort_impl(
::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(), size,
storage, compare_function,
thread_key
);
}
private:
ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(Key& k, const unsigned int flat_tid, storage_type& storage)
{
storage_type_<Key, Value>& storage_ = storage.get();
storage_.key[flat_tid] = k;
::rocprim::syncthreads();
}
ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, storage_type& storage) {
storage_type_<Key, Value>& storage_ = storage.get();
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
storage_.key[item * BlockSize + flat_tid] = k[item];
}
::rocprim::syncthreads();
}
ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(Key& k, Value& v, const unsigned int flat_tid, storage_type& storage)
{
storage_type_<Key, Value>& storage_ = storage.get();
storage_.key[flat_tid] = k;
storage_.value[flat_tid] = v;
::rocprim::syncthreads();
}
ROCPRIM_DEVICE ROCPRIM_INLINE
void copy_to_shared(Key (&k)[ItemsPerThread],
Value (&v)[ItemsPerThread],
const unsigned int flat_tid,
storage_type& storage)
{
storage_type_<Key, Value>& storage_ = storage.get();
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
storage_.key[item * BlockSize + flat_tid] = k[item];
storage_.value[item * BlockSize + flat_tid] = v[item];
}
::rocprim::syncthreads();
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void swap(Key& key,
const unsigned int flat_tid,
const unsigned int next_id,
const bool dir,
storage_type& storage,
BinaryFunction compare_function)
{
storage_type_<Key, Value>& storage_ = storage.get();
Key next_key = storage_.key[next_id];
bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key);
bool swap = compare ^ dir;
if(swap)
{
key = next_key;
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void swap(Key (&key)[ItemsPerThread],
const unsigned int flat_tid,
const unsigned int next_id,
const bool dir,
storage_type& storage,
BinaryFunction compare_function)
{
storage_type_<Key, Value>& storage_ = storage.get();
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
Key next_key = storage_.key[item * BlockSize + next_id];
bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]);
bool swap = compare ^ dir;
if(swap)
{
key[item] = next_key;
}
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void swap(Key& key,
Value& value,
const unsigned int flat_tid,
const unsigned int next_id,
const bool dir,
storage_type& storage,
BinaryFunction compare_function)
{
storage_type_<Key, Value>& storage_ = storage.get();
Key next_key = storage_.key[next_id];
bool b = next_id < flat_tid;
bool compare = compare_function(b ? key : next_key, b ? next_key : key);
bool swap = compare ^ dir;
if(swap)
{
key = next_key;
value = storage_.value[next_id];
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void swap(Key (&key)[ItemsPerThread],
Value (&value)[ItemsPerThread],
const unsigned int flat_tid,
const unsigned int next_id,
const bool dir,
storage_type& storage,
BinaryFunction compare_function)
{
storage_type_<Key, Value>& storage_ = storage.get();
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
Key next_key = storage_.key[item * BlockSize + next_id];
bool b = next_id < flat_tid;
bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]);
bool swap = compare ^ dir;
if(swap)
{
key[item] = next_key;
value[item] = storage_.value[item * BlockSize + next_id];
}
}
}
template<
unsigned int Size,
class BinaryFunction,
class... KeyValue
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type
sort_power_two(const unsigned int flat_tid,
storage_type& storage,
BinaryFunction compare_function,
KeyValue&... kv)
{
(void) flat_tid;
(void) storage;
::rocprim::warp_sort<Key, Size, Value> wsort;
wsort.sort(kv..., compare_function);
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void warp_swap(Key& k, Value& v, int mask, bool dir, BinaryFunction compare_function)
{
Key k1 = warp_shuffle_xor(k, mask);
bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
if (swap)
{
k = k1;
v = warp_shuffle_xor(v, mask);
}
}
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void warp_swap(Key (&k)[ItemsPerThread],
Value (&v)[ItemsPerThread],
int mask,
bool dir,
BinaryFunction compare_function)
{
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
Key k1 = warp_shuffle_xor(k[item], mask);
bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
if (swap)
{
k[item] = k1;
v[item] = warp_shuffle_xor(v[item], mask);
}
}
}
template<class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void warp_swap(Key& k, int mask, bool dir, BinaryFunction compare_function)
{
Key k1 = warp_shuffle_xor(k, mask);
bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
if (swap)
{
k = k1;
}
}
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void warp_swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function)
{
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
Key k1 = warp_shuffle_xor(k[item], mask);
bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
if (swap)
{
k[item] = k1;
}
}
}
template <class BinaryFunction, unsigned int Items = ItemsPerThread, class... KeyValue>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<(Items < 2)>::type
thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/)
{
}
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void thread_swap(Key (&k)[ItemsPerThread],
Value (&v)[ItemsPerThread],
bool dir,
unsigned int i,
unsigned int j,
BinaryFunction compare_function)
{
if(compare_function(k[i], k[j]) == dir)
{
Key k_temp = k[i];
k[i] = k[j];
k[j] = k_temp;
Value v_temp = v[i];
v[i] = v[j];
v[j] = v_temp;
}
}
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void thread_swap(Key (&k)[ItemsPerThread],
bool dir,
unsigned int i,
unsigned int j,
BinaryFunction compare_function)
{
if(compare_function(k[i], k[j]) == dir)
{
Key k_temp = k[i];
k[i] = k[j];
k[j] = k_temp;
}
}
template <class BinaryFunction, class... KeyValue>
ROCPRIM_DEVICE ROCPRIM_INLINE
void thread_shuffle(unsigned int offset, bool dir, BinaryFunction compare_function, KeyValue&... kv)
{
ROCPRIM_UNROLL
for(unsigned base = 0; base < ItemsPerThread; base += 2 * offset)
{
ROCPRIM_UNROLL
for(unsigned i = 0; i < offset; ++i)
{
thread_swap(kv..., dir, base + i, base + i + offset, compare_function);
}
}
}
template <class BinaryFunction, unsigned int Items = ItemsPerThread, class... KeyValue>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!(Items < 2)>::type
thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv)
{
ROCPRIM_UNROLL
for(unsigned int k = ItemsPerThread / 2; k > 0; k /= 2)
{
thread_shuffle(k, dir, compare_function, kv...);
}
}
template<
unsigned int Size,
class BinaryFunction,
class... KeyValue
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<(Size > ::rocprim::device_warp_size())>::type
sort_power_two(const unsigned int flat_tid,
storage_type& storage,
BinaryFunction compare_function,
KeyValue&... kv)
{
const auto warp_id_is_even = ((flat_tid / ::rocprim::device_warp_size()) % 2) == 0;
::rocprim::warp_sort<Key, ::rocprim::device_warp_size(), Value> wsort;
auto compare_function2 =
[compare_function, warp_id_is_even](const Key& a, const Key& b) mutable -> bool
{
auto r = compare_function(a, b);
if(warp_id_is_even)
return r;
return !r;
};
wsort.sort(kv..., compare_function2);
ROCPRIM_UNROLL
for(unsigned int length = ::rocprim::device_warp_size(); length < Size; length *= 2)
{
const bool dir = (flat_tid & (length * 2)) != 0;
ROCPRIM_UNROLL
for(unsigned int k = length; k > ::rocprim::device_warp_size() / 2; k /= 2)
{
copy_to_shared(kv..., flat_tid, storage);
swap(kv..., flat_tid, flat_tid ^ k, dir, storage, compare_function);
::rocprim::syncthreads();
}
ROCPRIM_UNROLL
for(unsigned int k = ::rocprim::device_warp_size() / 2; k > 0; k /= 2)
{
const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0;
const bool local_dir = length_even ? dir : !dir;
warp_swap(kv..., k, local_dir, compare_function);
}
thread_merge(dir, compare_function, kv...);
}
}
template<
unsigned int Size,
class BinaryFunction,
class... KeyValue
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<detail::is_power_of_two(Size)>::type
sort_impl(const unsigned int flat_tid,
storage_type& storage,
BinaryFunction compare_function,
KeyValue&... kv)
{
static constexpr unsigned int PairSize = sizeof...(KeyValue);
static_assert(
PairSize < 3,
"KeyValue parameter pack can 1 or 2 elements (key, or key and value)"
);
sort_power_two<Size, BinaryFunction>(flat_tid, storage, compare_function, kv...);
}
// In case BlockSize is not a power-of-two, the slower odd-even mergesort function is used
// instead of the bitonic sort function
template<
unsigned int Size,
class BinaryFunction,
class... KeyValue
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!detail::is_power_of_two(Size)>::type
sort_impl(const unsigned int flat_tid,
storage_type& storage,
BinaryFunction compare_function,
KeyValue&... kv)
{
static constexpr unsigned int PairSize = sizeof...(KeyValue);
static_assert(
PairSize < 3,
"KeyValue parameter pack can 1 or 2 elements (key, or key and value)"
);
copy_to_shared(kv..., flat_tid, storage);
bool is_even = (flat_tid % 2) == 0;
unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, Size - 1);
unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, Size - 1) : ::rocprim::max(flat_tid, 1u) - 1;
ROCPRIM_UNROLL
for(unsigned int length = 0; length < Size; length++)
{
unsigned int next_id = (length % 2) == 0 ? even_id : odd_id;
swap(kv..., flat_tid, next_id, 0, storage, compare_function);
::rocprim::syncthreads();
copy_to_shared(kv..., flat_tid, storage);
}
}
template<
class BinaryFunction,
class... KeyValue
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(const unsigned int flat_tid,
const unsigned int size,
storage_type& storage,
BinaryFunction compare_function,
KeyValue&... kv)
{
static constexpr unsigned int PairSize = sizeof...(KeyValue);
static_assert(
PairSize < 3,
"KeyValue parameter pack can 1 or 2 elements (key, or key and value)"
);
if(size > BlockSize)
{
return;
}
copy_to_shared(kv..., flat_tid, storage);
bool is_even = (flat_tid % 2 == 0);
unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, size - 1);
unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, size - 1) : ::rocprim::max(flat_tid, 1u) - 1;
for(unsigned int length = 0; length < size; length++)
{
unsigned int next_id = (length % 2 == 0) ? even_id : odd_id;
// Use only "valid" keys to ensure that compare_function will not use garbage keys
// for example, as indices of an array (a lookup table)
if(flat_tid < size)
{
swap(kv..., flat_tid, next_id, 0, storage, compare_function);
}
::rocprim::syncthreads();
copy_to_shared(kv..., flat_tid, storage);
}
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_CONFIG_HPP_
#define ROCPRIM_CONFIG_HPP_
#define BEGIN_ROCPRIM_NAMESPACE \
namespace rocprim {
#define END_ROCPRIM_NAMESPACE \
} /* rocprim */
#include <limits>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <thrust/system/cuda/cuda_bfloat16.h>
#ifndef ROCPRIM_DEVICE
#define ROCPRIM_DEVICE __device__
#define ROCPRIM_HOST __host__
#define ROCPRIM_HOST_DEVICE __host__ __device__
#define ROCPRIM_SHARED_MEMORY __shared__
#ifdef WIN32
#define ROCPRIM_KERNEL __global__ static
#else
#define ROCPRIM_KERNEL __global__
#endif
// TODO: These paremeters should be tuned for NAVI in the close future.
#ifndef ROCPRIM_DEFAULT_MAX_BLOCK_SIZE
#define ROCPRIM_DEFAULT_MAX_BLOCK_SIZE 256
#endif
#ifndef ROCPRIM_DEFAULT_MIN_WARPS_PER_EU
#define ROCPRIM_DEFAULT_MIN_WARPS_PER_EU 1
#endif
// Currently HIP on Windows has a bug involving inline device functions generating
// local memory/register allocation errors during compilation. Current workaround is to
// use __attribute__((always_inline)) for the affected functions
#ifdef WIN32
#define ROCPRIM_INLINE inline __attribute__((always_inline))
#else
#define ROCPRIM_INLINE inline
#endif
#define ROCPRIM_FORCE_INLINE __attribute__((always_inline))
#endif
#ifndef ROCPRIM_DISABLE_DPP
#define ROCPRIM_DETAIL_USE_DPP true
#endif
#ifdef ROCPRIM_DISABLE_LOOKBACK_SCAN
#define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN false
#else
#define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN true
#endif
#ifndef ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS
#define ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS 1
#endif
#ifndef ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS
#define ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS 1
#endif
// Defines targeted AMD architecture. Supported values:
// * 803 (gfx803)
// * 900 (gfx900)
// * 906 (gfx906)
// * 908 (gfx908)
// * 910 (gfx90a)
#ifndef ROCPRIM_TARGET_ARCH
#define ROCPRIM_TARGET_ARCH 0
#endif
#if (__gfx1010__ || __gfx1011__ || __gfx1012__ || __gfx1030__ || __gfx1031__)
#define ROCPRIM_NAVI 1
#else
#define ROCPRIM_NAVI 0
#endif
#define ROCPRIM_ARCH_90a 910
/// Supported warp sizes
#define ROCPRIM_WARP_SIZE_32 32u
#define ROCPRIM_WARP_SIZE_64 64u
#define ROCPRIM_MAX_WARP_SIZE ROCPRIM_WARP_SIZE_64
#if (defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__))
#define ROCPRIM_UNROLL
#define ROCPRIM_NO_UNROLL
#else
#define ROCPRIM_UNROLL _Pragma("unroll")
#define ROCPRIM_NO_UNROLL _Pragma("nounroll")
#endif
#ifndef ROCPRIM_GRID_SIZE_LIMIT
#define ROCPRIM_GRID_SIZE_LIMIT std::numeric_limits<unsigned int>::max()
#endif
#if __cpp_if_constexpr >= 201606
#define ROCPRIM_IF_CONSTEXPR constexpr
#else
#define ROCPRIM_IF_CONSTEXPR
#endif
#endif // ROCPRIM_CONFIG_HPP_
// Copyright (c) 2017-2019 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DETAIL_ALL_TRUE_HPP_
#define ROCPRIM_DETAIL_ALL_TRUE_HPP_
#include <type_traits>
#include "../config.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
// all_of
template<bool... values>
struct all_true : std::true_type
{
};
template<bool... values>
struct all_true<true, values...> : all_true<values...>
{
};
template<bool... values>
struct all_true<false, values...> : std::false_type
{
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DETAIL_ALL_TRUE_HPP_
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
#define ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
#include <type_traits>
#include "../config.hpp"
#include "../intrinsics.hpp"
#include "../types.hpp"
#include "../functional.hpp"
#include "../detail/various.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
class BinaryFunction,
class ResultType = typename BinaryFunction::result_type,
class InputType = typename BinaryFunction::input_type
>
struct reverse_binary_op_wrapper
{
using result_type = ResultType;
using input_type = InputType;
ROCPRIM_HOST_DEVICE inline
reverse_binary_op_wrapper() = default;
ROCPRIM_HOST_DEVICE inline
reverse_binary_op_wrapper(BinaryFunction binary_op)
: binary_op_(binary_op)
{
}
ROCPRIM_HOST_DEVICE inline
~reverse_binary_op_wrapper() = default;
ROCPRIM_HOST_DEVICE inline
result_type operator()(const input_type& t1, const input_type& t2)
{
return binary_op_(t2, t1);
}
private:
BinaryFunction binary_op_;
};
// Wrapper for performing head-flagged scan
template<class V, class F, class BinaryFunction>
struct headflag_scan_op_wrapper
{
static_assert(std::is_convertible<F, bool>::value, "F must be convertible to bool");
using result_type = rocprim::tuple<V, F>;
using input_type = result_type;
ROCPRIM_HOST_DEVICE inline
headflag_scan_op_wrapper() = default;
ROCPRIM_HOST_DEVICE inline
headflag_scan_op_wrapper(BinaryFunction scan_op)
: scan_op_(scan_op)
{
}
ROCPRIM_HOST_DEVICE inline
~headflag_scan_op_wrapper() = default;
ROCPRIM_HOST_DEVICE inline
result_type operator()(const input_type& t1, const input_type& t2)
{
return rocprim::make_tuple(!rocprim::get<1>(t2)
? scan_op_(rocprim::get<0>(t1), rocprim::get<0>(t2))
: rocprim::get<0>(t2),
F {rocprim::get<1>(t2) || rocprim::get<1>(t1)});
}
private:
BinaryFunction scan_op_;
};
template<class EqualityOp>
struct inequality_wrapper
{
using equality_op_type = EqualityOp;
ROCPRIM_HOST_DEVICE inline
inequality_wrapper() = default;
ROCPRIM_HOST_DEVICE inline
inequality_wrapper(equality_op_type equality_op)
: equality_op(equality_op)
{}
template<class T, class U>
ROCPRIM_DEVICE ROCPRIM_INLINE
bool operator()(const T &a, const U &b)
{
return !equality_op(a, b);
}
equality_op_type equality_op;
};
} // end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DETAIL_BINARY_OP_WRAPPERS_HPP_
// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_
#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_
#include <type_traits>
#include "../config.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
// invoke_result is based on std::invoke_result.
// The main difference is using ROCPRIM_HOST_DEVICE, this allows to
// use invoke_result with device-only lambdas/functors in host-only functions
// on HIP-clang.
template <class T>
struct is_reference_wrapper : std::false_type {};
template <class U>
struct is_reference_wrapper<std::reference_wrapper<U>> : std::true_type {};
template<class T>
struct invoke_impl {
template<class F, class... Args>
ROCPRIM_HOST_DEVICE
static auto call(F&& f, Args&&... args)
-> decltype(std::forward<F>(f)(std::forward<Args>(args)...));
};
template<class B, class MT>
struct invoke_impl<MT B::*>
{
template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<std::is_base_of<B, Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> T&&;
template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(t.get());
template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<!std::is_base_of<B, Td>::value>::type,
class = typename std::enable_if<!is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(*std::forward<T>(t));
template<class T, class... Args, class MT1,
class = typename std::enable_if<std::is_function<MT1>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto call(MT1 B::*pmf, T&& t, Args&&... args)
-> decltype((invoke_impl::get(std::forward<T>(t)).*pmf)(std::forward<Args>(args)...));
template<class T>
ROCPRIM_HOST_DEVICE
static auto call(MT B::*pmd, T&& t)
-> decltype(invoke_impl::get(std::forward<T>(t)).*pmd);
};
template<class F, class... Args, class Fd = typename std::decay<F>::type>
ROCPRIM_HOST_DEVICE
auto INVOKE(F&& f, Args&&... args)
-> decltype(invoke_impl<Fd>::call(std::forward<F>(f), std::forward<Args>(args)...));
// Conforming C++14 implementation (is also a valid C++11 implementation):
template <typename AlwaysVoid, typename, typename...>
struct invoke_result_impl { };
template <typename F, typename...Args>
struct invoke_result_impl<decltype(void(INVOKE(std::declval<F>(), std::declval<Args>()...))), F, Args...>
{
using type = decltype(INVOKE(std::declval<F>(), std::declval<Args>()...));
};
template <class F, class... ArgTypes>
struct invoke_result : invoke_result_impl<void, F, ArgTypes...> {};
template<class InputType, class BinaryFunction>
struct match_result_type
{
using type = typename invoke_result<BinaryFunction, InputType, InputType>::type;
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_
#define ROCPRIM_DETAIL_RADIX_SORT_HPP_
#include <type_traits>
#include "../config.hpp"
#include "../type_traits.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
// Encode and decode integral and floating point values for radix sort in such a way that preserves
// correct order of negative and positive keys (i.e. negative keys go before positive ones,
// which is not true for a simple reinterpetation of the key's bits).
// Digit extractor takes into account that (+0.0 == -0.0) is true for floats,
// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction.
// Maximum digit length is 32.
template<class Key, class BitKey, class Enable = void>
struct radix_key_codec_integral { };
template<class Key, class BitKey>
struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::is_unsigned<Key>::value>::type>
{
using bit_key_type = BitKey;
ROCPRIM_DEVICE ROCPRIM_INLINE
static bit_key_type encode(Key key)
{
return __builtin_bit_cast(bit_key_type, key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static Key decode(bit_key_type bit_key)
{
return __builtin_bit_cast(Key, bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};
template<class Key, class BitKey>
struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::is_signed<Key>::value>::type>
{
using bit_key_type = BitKey;
static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1);
ROCPRIM_DEVICE ROCPRIM_INLINE
static bit_key_type encode(Key key)
{
const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
return sign_bit ^ bit_key;
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static Key decode(bit_key_type bit_key)
{
bit_key ^= sign_bit;
return __builtin_bit_cast(Key, bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};
template<class Key>
struct float_bit_mask;
template<>
struct float_bit_mask<float>
{
static constexpr uint32_t sign_bit = 0x80000000;
static constexpr uint32_t exponent = 0x7F800000;
static constexpr uint32_t mantissa = 0x007FFFFF;
using bit_type = uint32_t;
};
template<>
struct float_bit_mask<double>
{
static constexpr uint64_t sign_bit = 0x8000000000000000;
static constexpr uint64_t exponent = 0x7FF0000000000000;
static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF;
using bit_type = uint64_t;
};
template<>
struct float_bit_mask<rocprim::bfloat16>
{
static constexpr uint16_t sign_bit = 0x8000;
static constexpr uint16_t exponent = 0x7F80;
static constexpr uint16_t mantissa = 0x007F;
using bit_type = uint16_t;
};
template<>
struct float_bit_mask<rocprim::half>
{
static constexpr uint16_t sign_bit = 0x8000;
static constexpr uint16_t exponent = 0x7C00;
static constexpr uint16_t mantissa = 0x03FF;
using bit_type = uint16_t;
};
template<class Key, class BitKey>
struct radix_key_codec_floating
{
using bit_key_type = BitKey;
static constexpr bit_key_type sign_bit = float_bit_mask<Key>::sign_bit;
ROCPRIM_DEVICE ROCPRIM_INLINE
static bit_key_type encode(Key key)
{
bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1);
return bit_key;
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static Key decode(bit_key_type bit_key)
{
bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit;
return __builtin_bit_cast(Key, bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
// -0.0 should be treated as +0.0 for stable sort
// -0.0 is encoded as inverted sign_bit, +0.0 as sign_bit
// or vice versa for descending sort
bit_key_type key = bit_key == sign_bit ? bit_key_type(~sign_bit) : bit_key;
return static_cast<unsigned int>(key >> start) & mask;
}
};
template<class Key, class Enable = void>
struct radix_key_codec_base
{
static_assert(sizeof(Key) == 0,
"Only integral and floating point types supported as radix sort keys");
};
template<class Key>
struct radix_key_codec_base<
Key,
typename std::enable_if<::rocprim::is_integral<Key>::value>::type
> : radix_key_codec_integral<Key, typename std::make_unsigned<Key>::type> { };
template<>
struct radix_key_codec_base<bool>
{
using bit_key_type = unsigned char;
ROCPRIM_DEVICE ROCPRIM_INLINE
static bit_key_type encode(bool key)
{
return static_cast<bit_key_type>(key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static bool decode(bit_key_type bit_key)
{
return static_cast<bool>(bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
{
unsigned int mask = (1u << length) - 1;
return static_cast<unsigned int>(bit_key >> start) & mask;
}
};
template<>
struct radix_key_codec_base<::rocprim::half> : radix_key_codec_floating<::rocprim::half, unsigned short> { };
template<>
struct radix_key_codec_base<::rocprim::bfloat16> : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> { };
template<>
struct radix_key_codec_base<float> : radix_key_codec_floating<float, unsigned int> { };
template<>
struct radix_key_codec_base<double> : radix_key_codec_floating<double, unsigned long long> { };
template<class Key, bool Descending = false>
class radix_key_codec : protected radix_key_codec_base<Key>
{
using base_type = radix_key_codec_base<Key>;
public:
using bit_key_type = typename base_type::bit_key_type;
ROCPRIM_DEVICE ROCPRIM_INLINE
static bit_key_type encode(Key key)
{
bit_key_type bit_key = base_type::encode(key);
return (Descending ? ~bit_key : bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static Key decode(bit_key_type bit_key)
{
bit_key = (Descending ? ~bit_key : bit_key);
return base_type::decode(bit_key);
}
ROCPRIM_DEVICE ROCPRIM_INLINE
static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits)
{
return base_type::extract_digit(bit_key, start, radix_bits);
}
};
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DETAIL_VARIOUS_HPP_
#define ROCPRIM_DETAIL_VARIOUS_HPP_
#include <type_traits>
#include "../config.hpp"
#include "../types.hpp"
#include "../type_traits.hpp"
// TODO: Refactor when it gets crowded
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
struct empty_storage_type
{
};
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr bool is_power_of_two(const T x)
{
static_assert(::rocprim::is_integral<T>::value, "T must be integer type");
return (x > 0) && ((x & (x - 1)) == 0);
}
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T next_power_of_two(const T x, const T acc = 1)
{
static_assert(::rocprim::is_unsigned<T>::value, "T must be unsigned type");
return acc >= x ? acc : next_power_of_two(x, 2 * acc);
}
template <
typename T,
typename U,
std::enable_if_t<::rocprim::is_integral<T>::value && ::rocprim::is_unsigned<U>::value, int> = 0>
ROCPRIM_HOST_DEVICE inline constexpr auto ceiling_div(const T a, const U b)
{
return a / b + (a % b > 0 ? 1 : 0);
}
ROCPRIM_HOST_DEVICE inline
size_t align_size(size_t size, size_t alignment = 256)
{
return ceiling_div(size, alignment) * alignment;
}
// TOOD: Put the block algorithms with warp size variables at device side with macro.
// Temporary workaround
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T warp_size_in_class(const T warp_size)
{
return warp_size;
}
// Select the minimal warp size for block of size block_size, it's
// useful for blocks smaller than maximal warp size.
template<class T>
ROCPRIM_HOST_DEVICE inline
constexpr T get_min_warp_size(const T block_size, const T max_warp_size)
{
static_assert(::rocprim::is_unsigned<T>::value, "T must be unsigned type");
return block_size >= max_warp_size ? max_warp_size : next_power_of_two(block_size);
}
template<unsigned int WarpSize>
struct is_warpsize_shuffleable {
static const bool value = detail::is_power_of_two(WarpSize);
};
// Selects an appropriate vector_type based on the input T and size N.
// The byte size is calculated and used to select an appropriate vector_type.
template<class T, unsigned int N>
struct match_vector_type
{
static constexpr unsigned int size = sizeof(T) * N;
using vector_base_type =
typename std::conditional<
sizeof(T) >= 4,
int,
typename std::conditional<
sizeof(T) >= 2,
short,
char
>::type
>::type;
using vector_4 = typename make_vector_type<vector_base_type, 4>::type;
using vector_2 = typename make_vector_type<vector_base_type, 2>::type;
using vector_1 = typename make_vector_type<vector_base_type, 1>::type;
using type =
typename std::conditional<
size % sizeof(vector_4) == 0,
vector_4,
typename std::conditional<
size % sizeof(vector_2) == 0,
vector_2,
vector_1
>::type
>::type;
};
// Checks if Items is odd and ensures that size of T is smaller than vector_type.
template<class T, unsigned int Items>
struct is_vectorizable : std::integral_constant<bool, (Items % 2 == 0) &&(sizeof(T) < sizeof(typename match_vector_type<T, Items>::type))> {};
// Returns the number of LDS (local data share) banks.
ROCPRIM_HOST_DEVICE
constexpr unsigned int get_lds_banks_no()
{
// Currently all devices supported by ROCm have 32 banks (4 bytes each)
return 32;
}
// Finds biggest fundamental type for type T that sizeof(T) is
// a multiple of that type's size.
template<class T>
struct match_fundamental_type
{
using type =
typename std::conditional<
sizeof(T)%8 == 0,
unsigned long long,
typename std::conditional<
sizeof(T)%4 == 0,
unsigned int,
typename std::conditional<
sizeof(T)%2 == 0,
unsigned short,
unsigned char
>::type
>::type
>::type;
};
template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto store_volatile(T * output, T value)
-> typename std::enable_if<std::is_fundamental<T>::value>::type
{
// TODO: check GCC
// error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers
#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__))
*const_cast<volatile T*>(output) = value;
#else
*output = value;
#endif
}
template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto store_volatile(T * output, T value)
-> typename std::enable_if<!std::is_fundamental<T>::value>::type
{
using fundamental_type = typename match_fundamental_type<T>::type;
constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type);
auto input_ptr = reinterpret_cast<volatile fundamental_type*>(&value);
auto output_ptr = reinterpret_cast<volatile fundamental_type*>(output);
ROCPRIM_UNROLL
for(unsigned int i = 0; i < n; i++)
{
output_ptr[i] = input_ptr[i];
}
}
template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto load_volatile(T * input)
-> typename std::enable_if<std::is_fundamental<T>::value, T>::type
{
// TODO: check GCC
// error: binding reference of type ‘const half_float::half&’ to ‘volatile half_float::half’ discards qualifiers
#if !(defined(__HIP_CPU_RT__ ) && defined(__GNUC__))
T retval = *const_cast<volatile T*>(input);
return retval;
#else
return *input;
#endif
}
template<class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto load_volatile(T * input)
-> typename std::enable_if<!std::is_fundamental<T>::value, T>::type
{
using fundamental_type = typename match_fundamental_type<T>::type;
constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type);
T retval;
auto output_ptr = reinterpret_cast<volatile fundamental_type*>(&retval);
auto input_ptr = reinterpret_cast<volatile fundamental_type*>(input);
ROCPRIM_UNROLL
for(unsigned int i = 0; i < n; i++)
{
output_ptr[i] = input_ptr[i];
}
return retval;
}
// A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions
template <typename T>
struct raw_storage
{
// Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T
typedef typename detail::match_fundamental_type<T>::type device_word;
// Backing storage
device_word storage[sizeof(T) / sizeof(device_word)];
// Alias
ROCPRIM_HOST_DEVICE T& get()
{
return reinterpret_cast<T&>(*this);
}
};
// Checks if two iterators have the same type and value
template<class Iterator1, class Iterator2>
inline
bool are_iterators_equal(Iterator1, Iterator2)
{
return false;
}
template<class Iterator>
inline
bool are_iterators_equal(Iterator iter1, Iterator iter2)
{
return iter1 == iter2;
}
template<class...>
using void_t = void;
template<typename T>
struct type_identity {
using type = T;
};
template<class T, class = void>
struct extract_type_impl : type_identity<T> { };
template<class T>
struct extract_type_impl<T, void_t<typename T::type> > : extract_type_impl<typename T::type> { };
template <typename T>
using extract_type = typename extract_type_impl<T>::type;
template<bool Value, class T>
struct select_type_case
{
static constexpr bool value = Value;
using type = T;
};
template<class Case, class... OtherCases>
struct select_type_impl
: std::conditional<
Case::value,
type_identity<extract_type<typename Case::type>>,
select_type_impl<OtherCases...>
>::type { };
template<class T>
struct select_type_impl<select_type_case<true, T>> : type_identity<extract_type<T>> { };
template<class T>
struct select_type_impl<select_type_case<false, T>>
{
static_assert(
sizeof(T) == 0,
"Cannot select any case. "
"The last case must have true condition or be a fallback type."
);
};
template<class Fallback>
struct select_type_impl<Fallback> : type_identity<extract_type<Fallback>> { };
template <typename... Cases>
using select_type = typename select_type_impl<Cases...>::type;
template <bool Value>
using bool_constant = std::integral_constant<bool, Value>;
} // end namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DETAIL_VARIOUS_HPP_
// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_CONFIG_TYPES_HPP_
#define ROCPRIM_DEVICE_CONFIG_TYPES_HPP_
#include <type_traits>
#include "../config.hpp"
#include "../intrinsics/thread.hpp"
#include "../detail/various.hpp"
/// \addtogroup primitivesmodule_deviceconfigs
/// @{
BEGIN_ROCPRIM_NAMESPACE
/// \brief Special type used to show that the given device-level operation
/// will be executed with optimal configuration dependent on types of the function's parameters
/// and the target device architecture specified by ROCPRIM_TARGET_ARCH.
struct default_config { };
/// \brief Configuration of particular kernels launched by device-level operation
///
/// \tparam BlockSize - number of threads in a block.
/// \tparam ItemsPerThread - number of items processed by each thread.
template <unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT>
struct kernel_config
{
/// \brief Number of threads in a block.
static constexpr unsigned int block_size = BlockSize;
/// \brief Number of items processed by each thread.
static constexpr unsigned int items_per_thread = ItemsPerThread;
/// \brief Number of items processed by a single kernel launch.
static constexpr unsigned int size_limit = SizeLimit;
};
namespace detail
{
template<
unsigned int MaxBlockSize,
unsigned int SharedMemoryPerThread,
// Most kernels require block sizes not smaller than warp
unsigned int MinBlockSize,
// Can fit in shared memory?
// Although GPUs have 64KiB, 32KiB is used here as a "soft" limit,
// because some additional memory may be required in kernels
bool = (MaxBlockSize * SharedMemoryPerThread <= (1u << 15))
>
struct limit_block_size
{
// No, then try to decrease block size
static constexpr unsigned int value =
limit_block_size<
detail::next_power_of_two(MaxBlockSize) / 2,
SharedMemoryPerThread,
MinBlockSize
>::value;
};
template<
unsigned int MaxBlockSize,
unsigned int SharedMemoryPerThread,
unsigned int MinBlockSize
>
struct limit_block_size<MaxBlockSize, SharedMemoryPerThread, MinBlockSize, true>
{
static_assert(MaxBlockSize >= MinBlockSize, "Data is too large, it cannot fit in shared memory");
static constexpr unsigned int value = MaxBlockSize;
};
template<unsigned int Arch, class T>
struct select_arch_case
{
static constexpr unsigned int arch = Arch;
using type = T;
};
template<unsigned int TargetArch, class Case, class... OtherCases>
struct select_arch
: std::conditional<
Case::arch == TargetArch,
extract_type<typename Case::type>,
select_arch<TargetArch, OtherCases...>
>::type { };
template<unsigned int TargetArch, class Universal>
struct select_arch<TargetArch, Universal> : extract_type<Universal> { };
template<class Config, class Default>
using default_or_custom_config =
typename std::conditional<
std::is_same<Config, default_config>::value,
Default,
Config
>::type;
} // end namespace detail
END_ROCPRIM_NAMESPACE
/// @}
// end of group primitivesmodule_deviceconfigs
#endif // ROCPRIM_DEVICE_CONFIG_TYPES_HPP_
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_
#include "../../block/block_adjacent_difference.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_store.hpp"
#include "../../detail/various.hpp"
#include "../../config.hpp"
#include <cuda_runtime.h>
#include <type_traits>
#include <cstdint>
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template <typename T, unsigned int BlockSize>
struct adjacent_diff_helper
{
using adjacent_diff_type = ::rocprim::block_adjacent_difference<T, BlockSize>;
using storage_type = typename adjacent_diff_type::storage_type;
template <unsigned int ItemsPerThread,
typename Output,
typename BinaryFunction,
typename InputIt,
bool InPlace>
ROCPRIM_DEVICE void dispatch(const T (&input)[ItemsPerThread],
Output (&output)[ItemsPerThread],
const BinaryFunction op,
const InputIt previous_values,
const unsigned int block_id,
const std::size_t starting_block,
const std::size_t num_blocks,
const std::size_t size,
storage_type& storage,
bool_constant<InPlace> /*in_place*/,
std::false_type /*right*/)
{
static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
// Not the first block, i.e. has a predecessor
if(starting_block + block_id != 0)
{
// `previous_values` needs to be accessed with a stride of `items_per_block` if the
// operation is out-of-place
const unsigned int block_offset = InPlace ? block_id : block_id * items_per_block;
const InputIt block_previous_values = previous_values + block_offset;
const T tile_predecessor = block_previous_values[-1];
// Not the last (i.e. full block)
if(starting_block + block_id != num_blocks - 1)
{
adjacent_diff_type {}.subtract_left(input, output, op, tile_predecessor, storage);
}
else
{
const unsigned int valid_items
= static_cast<unsigned int>(size - (num_blocks - 1) * items_per_block);
adjacent_diff_type {}.subtract_left_partial(
input, output, op, tile_predecessor, valid_items, storage);
}
}
else
{
// Not the last (i.e. full block)
if(starting_block + block_id != num_blocks - 1)
{
adjacent_diff_type {}.subtract_left(input, output, op, storage);
}
else
{
const unsigned int valid_items
= static_cast<unsigned int>(size - (num_blocks - 1) * items_per_block);
adjacent_diff_type {}.subtract_left_partial(
input, output, op, valid_items, storage);
}
}
}
template <unsigned int ItemsPerThread,
typename Output,
typename BinaryFunction,
typename InputIt,
bool InPlace>
ROCPRIM_DEVICE void dispatch(const T (&input)[ItemsPerThread],
Output (&output)[ItemsPerThread],
const BinaryFunction op,
const InputIt previous_values,
const unsigned int block_id,
const std::size_t starting_block,
const std::size_t num_blocks,
const std::size_t size,
storage_type& storage,
bool_constant<InPlace> /*in_place*/,
std::true_type /*right*/)
{
static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
// Not the last (i.e. full) block and has a successor
if(starting_block + block_id != num_blocks - 1)
{
// `previous_values` needs to be accessed with a stride of `items_per_block` if the
// operation is out-of-place
// When in-place, the first block does not save its value (since it won't be used)
// so the block values are shifted right one. This means that next block's first value
// is in the position `block_id`
const unsigned int block_offset = InPlace ? block_id : (block_id + 1) * items_per_block;
const InputIt next_block_values = previous_values + block_offset;
const T tile_successor = *next_block_values;
adjacent_diff_type {}.subtract_right(input, output, op, tile_successor, storage);
}
else
{
const unsigned int valid_items
= static_cast<unsigned int>(size - (num_blocks - 1) * items_per_block);
adjacent_diff_type {}.subtract_right_partial(input, output, op, valid_items, storage);
}
}
};
template <typename T, typename InputIterator>
ROCPRIM_DEVICE ROCPRIM_INLINE auto select_previous_values_iterator(T* previous_values,
InputIterator /*input*/,
std::true_type /*in_place*/)
{
return previous_values;
}
template <typename T, typename InputIterator>
ROCPRIM_DEVICE ROCPRIM_INLINE auto select_previous_values_iterator(T* /*previous_values*/,
InputIterator input,
std::false_type /*in_place*/)
{
return input;
}
template <typename Config,
bool InPlace,
bool Right,
typename InputIt,
typename OutputIt,
typename BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void adjacent_difference_kernel_impl(
const InputIt input,
const OutputIt output,
const std::size_t size,
const BinaryFunction op,
const typename std::iterator_traits<InputIt>::value_type* previous_values,
const std::size_t starting_block)
{
using input_type = typename std::iterator_traits<InputIt>::value_type;
using output_type = typename std::iterator_traits<OutputIt>::value_type;
static constexpr unsigned int block_size = Config::block_size;
static constexpr unsigned int items_per_thread = Config::items_per_thread;
static constexpr unsigned int items_per_block = block_size * items_per_thread;
using block_load_type
= ::rocprim::block_load<input_type, block_size, items_per_thread, Config::load_method>;
using block_store_type
= ::rocprim::block_store<output_type, block_size, items_per_thread, Config::store_method>;
using adjacent_helper = adjacent_diff_helper<input_type, block_size>;
ROCPRIM_SHARED_MEMORY struct
{
typename block_load_type::storage_type load;
typename adjacent_helper::storage_type adjacent_diff;
typename block_store_type::storage_type store;
} storage;
const unsigned int block_id = blockIdx.x;
const unsigned int block_offset = block_id * items_per_block;
const std::size_t num_blocks = ceiling_div(size, items_per_block);
input_type thread_input[items_per_thread];
if(starting_block + block_id < num_blocks - 1)
{
block_load_type {}.load(input + block_offset, thread_input, storage.load);
}
else
{
const unsigned int valid_items
= static_cast<unsigned int>(size - (num_blocks - 1) * items_per_block);
block_load_type {}.load(input + block_offset, thread_input, valid_items, storage.load);
}
::rocprim::syncthreads();
// Type tags for tag dispatch.
static constexpr auto in_place = bool_constant<InPlace> {};
static constexpr auto right = bool_constant<Right> {};
// When doing the operation in-place the last/first items of each block have been copied out
// in advance and written to the contiguos locations, since accessing them would be a data race
// with the writing of their new values. In this case `select_previous_values_iterator` returns
// a pointer to the copied values, and it should be addressed by block_id.
// Otherwise (when the transform is out-of-place) it just returns the input iterator, and the
// first/last values of the blocks can be accessed with a stride of `items_per_block`
const auto previous_values_it
= select_previous_values_iterator(previous_values, input, in_place);
output_type thread_output[items_per_thread];
// Do tag dispatch on `right` to select either `subtract_right` or `subtract_left`.
// Note that the function is overloaded on its last parameter.
adjacent_helper {}.dispatch(thread_input,
thread_output,
op,
previous_values_it,
block_id,
starting_block,
num_blocks,
size,
storage.adjacent_diff,
in_place,
right);
::rocprim::syncthreads();
if(starting_block + block_id < num_blocks - 1)
{
block_store_type {}.store(output + block_offset, thread_output, storage.store);
}
else
{
const unsigned int valid_items
= static_cast<unsigned int>(size - (num_blocks - 1) * items_per_block);
block_store_type {}.store(output + block_offset, thread_output, valid_items, storage.store);
}
}
} // namespace detail
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_DIFFERENCE_HPP_
\ No newline at end of file
// Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<class Size>
ROCPRIM_DEVICE ROCPRIM_INLINE
Size get_binary_search_middle(Size left, Size right)
{
const Size d = right - left;
return left + d / 2 + d / 64;
}
template<class RandomAccessIterator, class Size, class T, class BinaryPredicate>
ROCPRIM_DEVICE ROCPRIM_INLINE
Size lower_bound_n(RandomAccessIterator first,
Size size,
const T& value,
BinaryPredicate compare_op)
{
Size left = 0;
Size right = size;
while(left < right)
{
const Size mid = get_binary_search_middle(left, right);
if(compare_op(first[mid], value))
{
left = mid + 1;
}
else
{
right = mid;
}
}
return left;
}
template<class RandomAccessIterator, class Size, class T, class BinaryPredicate>
ROCPRIM_DEVICE ROCPRIM_INLINE
Size upper_bound_n(RandomAccessIterator first,
Size size,
const T& value,
BinaryPredicate compare_op)
{
Size left = 0;
Size right = size;
while(left < right)
{
const Size mid = get_binary_search_middle(left, right);
if(compare_op(value, first[mid]))
{
right = mid;
}
else
{
left = mid + 1;
}
}
return left;
}
struct lower_bound_search_op
{
template<class HaystackIterator, class CompareOp, class Size, class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const
{
return lower_bound_n(haystack, size, value, compare_op);
}
};
struct upper_bound_search_op
{
template<class HaystackIterator, class CompareOp, class Size, class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const
{
return upper_bound_n(haystack, size, value, compare_op);
}
};
struct binary_search_op
{
template<class HaystackIterator, class CompareOp, class Size, class T>
ROCPRIM_DEVICE ROCPRIM_INLINE
bool operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const
{
const Size n = lower_bound_n(haystack, size, value, compare_op);
return n != size && !compare_op(value, haystack[n]);
}
};
} // end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_
// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_
#define ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_
#include <type_traits>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../block/block_reduce.hpp"
#include "../config_types.hpp"
/// \addtogroup primitivesmodule_deviceconfigs
/// @{
BEGIN_ROCPRIM_NAMESPACE
/// \brief Configuration of device-level reduce primitives.
///
/// \tparam BlockSize - number of threads in a block.
/// \tparam ItemsPerThread - number of items processed by each thread.
/// \tparam BlockReduceMethod - algorithm for block reduce.
/// \tparam SizeLimit - limit on the number of items reduced by a single launch
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
::rocprim::block_reduce_algorithm BlockReduceMethod,
unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT
>
struct reduce_config
{
/// \brief Number of threads in a block.
static constexpr unsigned int block_size = BlockSize;
/// \brief Number of items processed by each thread.
static constexpr unsigned int items_per_thread = ItemsPerThread;
/// \brief Algorithm for block reduce.
static constexpr block_reduce_algorithm block_reduce_method = BlockReduceMethod;
/// \brief Limit on the number of items reduced by a single launch
static constexpr unsigned int size_limit = SizeLimit;
};
END_ROCPRIM_NAMESPACE
#endif //ROCPRIM_DEVICE_DETAIL_CONFIG_HELPER_HPP_
// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_
#include <cmath>
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../block/block_load.hpp"
#include "uint_fast_div.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
// Special wrapper for passing fixed-length arrays (i.e. T values[Size]) into kernels
template<class T, unsigned int Size>
class fixed_array
{
private:
T values[Size];
public:
ROCPRIM_HOST_DEVICE inline
fixed_array(const T values[Size])
{
for(unsigned int i = 0; i < Size; i++)
{
this->values[i] = values[i];
}
}
ROCPRIM_HOST_DEVICE inline
T& operator[](unsigned int index)
{
return values[index];
}
ROCPRIM_HOST_DEVICE inline
const T& operator[](unsigned int index) const
{
return values[index];
}
};
template<class Level, class Enable = void>
struct sample_to_bin_even
{
unsigned int bins;
Level lower_level;
Level upper_level;
Level scale;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even() = default;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level)
: bins(bins),
lower_level(lower_level),
upper_level(upper_level),
scale((upper_level - lower_level) / bins)
{}
template<class Sample>
ROCPRIM_HOST_DEVICE inline
bool operator()(Sample sample, unsigned int& bin) const
{
const Level s = static_cast<Level>(sample);
if(s >= lower_level && s < upper_level)
{
bin = static_cast<unsigned int>((s - lower_level) / scale);
return true;
}
return false;
}
};
// This specialization uses fast division (uint_fast_div) for integers smaller than 64 bit
template<class Level>
struct sample_to_bin_even<Level, typename std::enable_if<std::is_integral<Level>::value && (sizeof(Level) <= 4)>::type>
{
unsigned int bins;
Level lower_level;
Level upper_level;
uint_fast_div scale;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even() = default;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level)
: bins(bins),
lower_level(lower_level),
upper_level(upper_level),
scale((upper_level - lower_level) / bins)
{}
template<class Sample>
ROCPRIM_HOST_DEVICE inline
bool operator()(Sample sample, unsigned int& bin) const
{
const Level s = static_cast<Level>(sample);
if(s >= lower_level && s < upper_level)
{
bin = static_cast<unsigned int>(s - lower_level) / scale;
return true;
}
return false;
}
};
// This specialization uses multiplication by inv divisor for floats
template<class Level>
struct sample_to_bin_even<Level, typename std::enable_if<std::is_floating_point<Level>::value>::type>
{
unsigned int bins;
Level lower_level;
Level upper_level;
Level inv_scale;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even() = default;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_even(unsigned int bins, Level lower_level, Level upper_level)
: bins(bins),
lower_level(lower_level),
upper_level(upper_level),
inv_scale(bins / (upper_level - lower_level))
{}
template<class Sample>
ROCPRIM_HOST_DEVICE inline
bool operator()(Sample sample, unsigned int& bin) const
{
const Level s = static_cast<Level>(sample);
if(s >= lower_level && s < upper_level)
{
bin = static_cast<unsigned int>((s - lower_level) * inv_scale);
return true;
}
return false;
}
};
// Returns index of the first element in values that is greater than value, or count if no such element is found.
template<class T>
ROCPRIM_HOST_DEVICE inline
unsigned int upper_bound(const T * values, unsigned int count, T value)
{
unsigned int current = 0;
while(count > 0)
{
const unsigned int step = count / 2;
const unsigned int next = current + step;
if(value < values[next])
{
count = step;
}
else
{
current = next + 1;
count -= step + 1;
}
}
return current;
}
template<class Level>
struct sample_to_bin_range
{
unsigned int bins;
const Level * level_values;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_range() = default;
ROCPRIM_HOST_DEVICE inline
sample_to_bin_range(unsigned int bins, const Level * level_values)
: bins(bins), level_values(level_values)
{}
template<class Sample>
ROCPRIM_HOST_DEVICE inline
bool operator()(Sample sample, unsigned int& bin) const
{
const Level s = static_cast<Level>(sample);
bin = upper_bound(level_values, bins + 1, s) - 1;
return bin < bins;
}
};
template<class T, unsigned int Size>
struct sample_vector
{
T values[Size];
};
// Checks if it is possible to load 2 or 4 sample_vector<Sample, Channels> as one 32-bit value
template<
unsigned int ItemsPerThread,
unsigned int Channels,
class Sample
>
struct is_sample_vectorizable
: std::integral_constant<
bool,
((sizeof(Sample) * Channels == 1) || (sizeof(Sample) * Channels == 2)) &&
(sizeof(Sample) * Channels * ItemsPerThread % sizeof(int) == 0) &&
(sizeof(Sample) * Channels * ItemsPerThread / sizeof(int) > 0)
> { };
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
class Sample
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<is_sample_vectorizable<ItemsPerThread, Channels, Sample>::value>::type
load_samples(unsigned int flat_id,
Sample * samples,
sample_vector<Sample, Channels> (&values)[ItemsPerThread])
{
using packed_samples_type = int[sizeof(Sample) * Channels * ItemsPerThread / sizeof(int)];
if(reinterpret_cast<uintptr_t>(samples) % sizeof(int) == 0)
{
// the pointer is aligned by 4 bytes
block_load_direct_striped<BlockSize>(
flat_id,
reinterpret_cast<const int *>(samples),
reinterpret_cast<packed_samples_type&>(values)
);
}
else
{
block_load_direct_striped<BlockSize>(
flat_id,
reinterpret_cast<const sample_vector<Sample, Channels> *>(samples),
values
);
}
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
class Sample
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!is_sample_vectorizable<ItemsPerThread, Channels, Sample>::value>::type
load_samples(unsigned int flat_id,
Sample * samples,
sample_vector<Sample, Channels> (&values)[ItemsPerThread])
{
block_load_direct_striped<BlockSize>(
flat_id,
reinterpret_cast<const sample_vector<Sample, Channels> *>(samples),
values
);
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
class Sample,
class SampleIterator
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load_samples(unsigned int flat_id,
SampleIterator samples,
sample_vector<Sample, Channels> (&values)[ItemsPerThread])
{
Sample tmp[Channels * ItemsPerThread];
block_load_direct_blocked(
flat_id,
samples,
tmp
);
for(unsigned int i = 0; i < ItemsPerThread; i++)
{
for(unsigned int channel = 0; channel < Channels; channel++)
{
values[i].values[channel] = tmp[i * Channels + channel];
}
}
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
class Sample,
class SampleIterator
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load_samples(unsigned int flat_id,
SampleIterator samples,
sample_vector<Sample, Channels> (&values)[ItemsPerThread],
unsigned int valid_count)
{
Sample tmp[Channels * ItemsPerThread];
block_load_direct_blocked(
flat_id,
samples,
tmp,
valid_count * Channels
);
for(unsigned int i = 0; i < ItemsPerThread; i++)
{
for(unsigned int channel = 0; channel < Channels; channel++)
{
values[i].values[channel] = tmp[i * Channels + channel];
}
}
}
template<
unsigned int BlockSize,
unsigned int ActiveChannels,
class Counter
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void init_histogram(fixed_array<Counter *, ActiveChannels> histogram,
fixed_array<unsigned int, ActiveChannels> bins)
{
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int block_id = ::rocprim::detail::block_id<0>();
const unsigned int index = block_id * BlockSize + flat_id;
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
if(index < bins[channel])
{
histogram[channel][index] = 0;
}
}
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
unsigned int ActiveChannels,
class SampleIterator,
class Counter,
class SampleToBinOp
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void histogram_shared(SampleIterator samples,
unsigned int columns,
unsigned int rows,
unsigned int row_stride,
unsigned int rows_per_block,
fixed_array<Counter *, ActiveChannels> histogram,
fixed_array<SampleToBinOp, ActiveChannels> sample_to_bin_op,
fixed_array<unsigned int, ActiveChannels> bins,
unsigned int * block_histogram_start)
{
using sample_type = typename std::iterator_traits<SampleIterator>::value_type;
using sample_vector_type = sample_vector<sample_type, Channels>;
constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int block_id0 = ::rocprim::detail::block_id<0>();
const unsigned int block_id1 = ::rocprim::detail::block_id<1>();
const unsigned int grid_size0 = ::rocprim::detail::grid_size<0>();
unsigned int * block_histogram[ActiveChannels];
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
block_histogram[channel] = block_histogram_start;
block_histogram_start += bins[channel];
for(unsigned int bin = flat_id; bin < bins[channel]; bin += BlockSize)
{
block_histogram[channel][bin] = 0;
}
}
::rocprim::syncthreads();
const unsigned int start_row = block_id1 * rows_per_block;
const unsigned int end_row = ::rocprim::min(rows, start_row + rows_per_block);
for(unsigned int row = start_row; row < end_row; row++)
{
SampleIterator row_samples = samples + row * row_stride;
unsigned int block_offset = block_id0 * items_per_block;
while(block_offset < columns)
{
sample_vector_type values[ItemsPerThread];
if(block_offset + items_per_block <= columns)
{
load_samples<BlockSize>(flat_id, row_samples + Channels * block_offset, values);
for(unsigned int i = 0; i < ItemsPerThread; i++)
{
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
unsigned int bin;
if(sample_to_bin_op[channel](values[i].values[channel], bin))
{
::rocprim::detail::atomic_add(&block_histogram[channel][bin], 1);
}
}
}
}
else
{
const unsigned int valid_count = columns - block_offset;
load_samples<BlockSize>(flat_id, row_samples + Channels * block_offset, values, valid_count);
for(unsigned int i = 0; i < ItemsPerThread; i++)
{
if(flat_id * ItemsPerThread + i < valid_count)
{
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
unsigned int bin;
if(sample_to_bin_op[channel](values[i].values[channel], bin))
{
::rocprim::detail::atomic_add(&block_histogram[channel][bin], 1);
}
}
}
}
}
block_offset += grid_size0 * items_per_block;
}
}
::rocprim::syncthreads();
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
for(unsigned int bin = flat_id; bin < bins[channel]; bin += BlockSize)
{
if(block_histogram[channel][bin] > 0)
{
::rocprim::detail::atomic_add(&histogram[channel][bin], block_histogram[channel][bin]);
}
}
}
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int Channels,
unsigned int ActiveChannels,
class SampleIterator,
class Counter,
class SampleToBinOp
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void histogram_global(SampleIterator samples,
unsigned int columns,
unsigned int row_stride,
fixed_array<Counter *, ActiveChannels> histogram,
fixed_array<SampleToBinOp, ActiveChannels> sample_to_bin_op,
fixed_array<unsigned int, ActiveChannels> bins_bits)
{
using sample_type = typename std::iterator_traits<SampleIterator>::value_type;
using sample_vector_type = sample_vector<sample_type, Channels>;
constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int block_id0 = ::rocprim::detail::block_id<0>();
const unsigned int block_id1 = ::rocprim::detail::block_id<1>();
const unsigned int block_offset = block_id0 * items_per_block;
samples += block_id1 * row_stride + Channels * block_offset;
sample_vector_type values[ItemsPerThread];
unsigned int valid_count;
if(block_offset + items_per_block <= columns)
{
valid_count = items_per_block;
load_samples<BlockSize>(flat_id, samples, values);
}
else
{
valid_count = columns - block_offset;
load_samples<BlockSize>(flat_id, samples, values, valid_count);
}
for(unsigned int i = 0; i < ItemsPerThread; i++)
{
for(unsigned int channel = 0; channel < ActiveChannels; channel++)
{
unsigned int bin;
if(sample_to_bin_op[channel](values[i].values[channel], bin))
{
const unsigned int pos = flat_id * ItemsPerThread + i;
lane_mask_type same_bin_lanes_mask = ::rocprim::ballot(pos < valid_count);
for(unsigned int b = 0; b < bins_bits[channel]; b++)
{
const unsigned int bit_set = bin & (1u << b);
const lane_mask_type bit_set_mask = ::rocprim::ballot(bit_set);
same_bin_lanes_mask &= (bit_set ? bit_set_mask : ~bit_set_mask);
}
const unsigned int same_bin_count = ::rocprim::bit_count(same_bin_lanes_mask);
const unsigned int prev_same_bin_count = ::rocprim::masked_bit_count(same_bin_lanes_mask);
if(prev_same_bin_count == 0)
{
// Write the number of lanes having this bin,
// if the current lane is the first (and maybe only) lane with this bin.
::rocprim::detail::atomic_add(&histogram[channel][bin], same_bin_count);
}
}
}
}
}
} // end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_HISTOGRAM_HPP_
// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_store.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
struct range_t
{
unsigned int begin1;
unsigned int end1;
unsigned int begin2;
unsigned int end2;
ROCPRIM_DEVICE ROCPRIM_INLINE
constexpr unsigned int count1() const
{
return end1 - begin1;
}
ROCPRIM_DEVICE ROCPRIM_INLINE
constexpr unsigned int count2() const
{
return end2 - begin2;
}
};
ROCPRIM_DEVICE ROCPRIM_INLINE
range_t compute_range(const unsigned int id,
const unsigned int size1,
const unsigned int size2,
const unsigned int spacing,
const unsigned int p1,
const unsigned int p2)
{
unsigned int diag1 = id * spacing;
unsigned int diag2 = min(size1 + size2, diag1 + spacing);
return range_t{p1, p2, diag1 - p1, diag2 - p2};
}
template<class KeysInputIterator1, class KeysInputIterator2, class OffsetT, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
const OffsetT input1_size,
const OffsetT input2_size,
const OffsetT diag,
BinaryFunction compare_function)
{
using key_type_1 = typename std::iterator_traits<KeysInputIterator1>::value_type;
using key_type_2 = typename std::iterator_traits<KeysInputIterator2>::value_type;
OffsetT begin = diag < input2_size ? 0u : diag - input2_size;
OffsetT end = min(diag, input1_size);
while(begin < end)
{
OffsetT a = (begin + end) / 2;
OffsetT b = diag - 1 - a;
key_type_1 input_a = keys_input1[a];
key_type_2 input_b = keys_input2[b];
if(!compare_function(input_b, input_a))
{
begin = a + 1;
}
else
{
end = a;
}
}
return begin;
}
template<
class IndexIterator,
class KeysInputIterator1,
class KeysInputIterator2,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void partition_kernel_impl(IndexIterator indices,
KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
const size_t input1_size,
const size_t input2_size,
const unsigned int spacing,
BinaryFunction compare_function)
{
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
const unsigned int flat_block_size = ::rocprim::detail::block_size<0>();
unsigned int id = flat_block_id * flat_block_size + flat_id;
unsigned int partition_id = id * spacing;
size_t diag = min(static_cast<size_t>(partition_id), input1_size + input2_size);
unsigned int begin =
merge_path(
keys_input1,
keys_input2,
input1_size,
input2_size,
diag,
compare_function
);
indices[id] = begin;
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
class KeysInputIterator1,
class KeysInputIterator2,
class KeyType
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load(unsigned int flat_id,
KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
KeyType * keys_shared,
const size_t input1_size,
const size_t input2_size)
{
ROCPRIM_UNROLL
for(unsigned int i = 0; i < ItemsPerThread; ++i)
{
unsigned int index = BlockSize * i + flat_id;
if(index < input1_size)
{
keys_shared[index] = keys_input1[index];
}
else if(index < input1_size + input2_size)
{
keys_shared[index] = keys_input2[index - input1_size];
}
}
::rocprim::syncthreads();
}
template <class KeyType, unsigned int ItemsPerThread, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void serial_merge(KeyType * keys_shared,
KeyType (&inputs)[ItemsPerThread],
unsigned int (&index)[ItemsPerThread],
range_t range,
BinaryFunction compare_function)
{
KeyType a = keys_shared[range.begin1];
KeyType b = keys_shared[range.begin2];
ROCPRIM_UNROLL
for(unsigned int i = 0; i < ItemsPerThread; ++i)
{
bool compare = (range.begin2 >= range.end2) ||
((range.begin1 < range.end1) && !compare_function(b, a));
unsigned int x = compare ? range.begin1 : range.begin2;
inputs[i] = compare ? a : b;
index[i] = x;
KeyType c = keys_shared[++x];
if(compare)
{
a = c;
range.begin1 = x;
}
else
{
b = c;
range.begin2 = x;
}
}
::rocprim::syncthreads();
}
template<
unsigned int BlockSize,
class KeysInputIterator1,
class KeysInputIterator2,
class KeyType,
unsigned int ItemsPerThread,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void merge_keys(unsigned int flat_id,
KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
KeyType (&key_inputs)[ItemsPerThread],
unsigned int (&index)[ItemsPerThread],
KeyType * keys_shared,
range_t range,
BinaryFunction compare_function)
{
load<BlockSize, ItemsPerThread>(
flat_id, keys_input1 + range.begin1, keys_input2 + range.begin2,
keys_shared, range.count1(), range.count2()
);
range_t range_local =
range_t {
0, range.count1(), range.count1(),
(range.count1() + range.count2())
};
unsigned int diag = ItemsPerThread * flat_id;
unsigned int partition =
merge_path(
keys_shared + range_local.begin1,
keys_shared + range_local.begin2,
range_local.count1(),
range_local.count2(),
diag,
compare_function
);
range_t range_partition =
range_t {
range_local.begin1 + partition,
range_local.end1,
range_local.begin2 + diag - partition,
range_local.end2
};
serial_merge(
keys_shared, key_inputs, index, range_partition,
compare_function
);
}
template<
bool WithValues,
unsigned int BlockSize,
class ValuesInputIterator1,
class ValuesInputIterator2,
class ValuesOutputIterator,
unsigned int ItemsPerThread
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<WithValues>::type
merge_values(unsigned int flat_id,
ValuesInputIterator1 values_input1,
ValuesInputIterator2 values_input2,
ValuesOutputIterator values_output,
unsigned int (&index)[ItemsPerThread],
const size_t input1_size,
const size_t input2_size)
{
using value_type = typename std::iterator_traits<ValuesInputIterator1>::value_type;
unsigned int count = input1_size + input2_size;
value_type values[ItemsPerThread];
if(count >= ItemsPerThread * BlockSize)
{
ROCPRIM_UNROLL
for(unsigned int i = 0; i < ItemsPerThread; ++i)
{
values[i] = (index[i] < input1_size) ? values_input1[index[i]] :
values_input2[index[i] - input1_size];
}
}
else
{
ROCPRIM_UNROLL
for(unsigned int i = 0; i < ItemsPerThread; ++i)
{
if(flat_id * ItemsPerThread + i < count)
{
values[i] = (index[i] < input1_size) ? values_input1[index[i]] :
values_input2[index[i] - input1_size];
}
}
}
::rocprim::syncthreads();
block_store_direct_blocked(
flat_id,
values_output,
values,
count
);
}
template<
bool WithValues,
unsigned int BlockSize,
class ValuesInputIterator1,
class ValuesInputIterator2,
class ValuesOutputIterator,
unsigned int ItemsPerThread
>
ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!WithValues>::type
merge_values(unsigned int flat_id,
ValuesInputIterator1 values_input1,
ValuesInputIterator2 values_input2,
ValuesOutputIterator values_output,
unsigned int (&index)[ItemsPerThread],
const size_t input1_size,
const size_t input2_size)
{
(void) flat_id;
(void) values_input1;
(void) values_input2;
(void) values_output;
(void) index;
(void) input1_size;
(void) input2_size;
}
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
class IndexIterator,
class KeysInputIterator1,
class KeysInputIterator2,
class KeysOutputIterator,
class ValuesInputIterator1,
class ValuesInputIterator2,
class ValuesOutputIterator,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void merge_kernel_impl(IndexIterator indices,
KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
KeysOutputIterator keys_output,
ValuesInputIterator1 values_input1,
ValuesInputIterator2 values_input2,
ValuesOutputIterator values_output,
const size_t input1_size,
const size_t input2_size,
BinaryFunction compare_function)
{
using key_type = typename std::iterator_traits<KeysInputIterator1>::value_type;
using value_type = typename std::iterator_traits<ValuesInputIterator1>::value_type;
using keys_store_type = ::rocprim::block_store<
key_type, BlockSize, ItemsPerThread,
::rocprim::block_store_method::block_store_transpose
>;
constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
constexpr unsigned int input_block_size = BlockSize * ItemsPerThread + 1;
ROCPRIM_SHARED_MEMORY union
{
typename detail::raw_storage<key_type[input_block_size]> keys_shared;
typename keys_store_type::storage_type keys_store;
} storage;
key_type input[ItemsPerThread];
unsigned int index[ItemsPerThread];
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
const unsigned int block_offset = flat_block_id * items_per_block;
const unsigned int count = input1_size + input2_size;
const unsigned int valid_in_last_block = count - block_offset;
const bool is_incomplete_block = valid_in_last_block < items_per_block;
const unsigned int p1 = indices[flat_block_id];
const unsigned int p2 = indices[flat_block_id + 1];
range_t range =
compute_range(
flat_block_id, input1_size, input2_size, items_per_block,
p1, p2
);
merge_keys<BlockSize>(
flat_id, keys_input1, keys_input2, input, index,
storage.keys_shared.get(),
range, compare_function
);
::rocprim::syncthreads();
if(is_incomplete_block) // # elements in last block may not equal items_per_block for the last block
{
keys_store_type().store(
keys_output + block_offset,
input,
valid_in_last_block,
storage.keys_store
);
}
else
{
keys_store_type().store(
keys_output + block_offset,
input,
storage.keys_store
);
}
merge_values<with_values, BlockSize>(
flat_id, values_input1 + range.begin1, values_input2 + range.begin2,
values_output + block_offset, index,
range.count1(), range.count2()
);
}
} // end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR next
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR nextWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR next DEALINGS IN
// THE SOFTWARE.
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
#include <type_traits>
#include <iterator>
#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../types.hpp"
#include "../../block/block_load.hpp"
#include "../../block/block_sort.hpp"
#include "../../block/block_store.hpp"
BEGIN_ROCPRIM_NAMESPACE
namespace detail
{
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
class Key
>
struct block_load_keys_impl {
using block_load_type = ::rocprim::block_load<Key,
BlockSize,
ItemsPerThread,
rocprim::block_load_method::block_load_transpose>;
using storage_type = typename block_load_type::storage_type;
template <class KeysInputIterator, class OffsetT>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load(const OffsetT block_offset,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
KeysInputIterator keys_input,
Key (&keys)[ItemsPerThread],
storage_type& storage)
{
if(is_incomplete_block)
{
block_load_type().load(
keys_input + block_offset,
keys,
valid_in_last_block,
storage
);
}
else
{
block_load_type().load(
keys_input + block_offset,
keys,
storage
);
}
}
};
template <bool WithValues, unsigned int BlockSize, unsigned int ItemsPerThread, class Value>
struct block_load_values_impl
{
using storage_type = empty_storage_type;
template <class ValuesInputIterator, class OffsetT>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load(const unsigned int flat_id,
const unsigned int (&ranks)[ItemsPerThread],
const OffsetT block_offset,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
ValuesInputIterator values_input,
Value (&values)[ItemsPerThread],
storage_type& storage)
{
(void) flat_id;
(void) ranks;
(void) block_offset;
(void) valid_in_last_block;
(void) is_incomplete_block;
(void) values_input;
(void) values;
(void) storage;
}
};
template <unsigned int BlockSize, unsigned int ItemsPerThread, class Value>
struct block_load_values_impl<true, BlockSize, ItemsPerThread, Value>
{
using block_exchange = ::rocprim::block_exchange<Value, BlockSize, ItemsPerThread>;
using storage_type = typename block_exchange::storage_type;
template <class ValuesInputIterator, class OffsetT>
ROCPRIM_DEVICE ROCPRIM_INLINE
void load(const unsigned int flat_id,
const unsigned int (&ranks)[ItemsPerThread],
const OffsetT block_offset,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
ValuesInputIterator values_input,
Value (&values)[ItemsPerThread],
storage_type& storage)
{
if(is_incomplete_block)
{
block_load_direct_striped<BlockSize>(
flat_id,
values_input + block_offset,
values,
valid_in_last_block
);
}
else
{
block_load_direct_striped<BlockSize>(
flat_id,
values_input + block_offset,
values
);
}
// Synchronize before reusing shared memory
::rocprim::syncthreads();
block_exchange().gather_from_striped(values, values, ranks, storage);
}
};
template<
bool WithValues,
unsigned int BlockSize,
unsigned int ItemsPerThread,
class Key,
class Value
>
struct block_store_impl {
using block_store_type
= block_store<Key, BlockSize, ItemsPerThread, block_store_method::block_store_transpose>;
using storage_type = typename block_store_type::storage_type;
template <class KeysOutputIterator, class ValuesOutputIterator, class OffsetT>
ROCPRIM_DEVICE ROCPRIM_INLINE
void store(const OffsetT block_offset,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
KeysOutputIterator keys_output,
ValuesOutputIterator values_output,
Key (&keys)[ItemsPerThread],
Value (&values)[ItemsPerThread],
storage_type& storage)
{
(void) values_output;
(void) values;
// Synchronize before reusing shared memory
::rocprim::syncthreads();
if(is_incomplete_block)
{
block_store_type().store(
keys_output + block_offset,
keys,
valid_in_last_block,
storage
);
}
else
{
block_store_type().store(
keys_output + block_offset,
keys,
storage
);
}
}
};
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
class Key,
class Value
>
struct block_store_impl<true, BlockSize, ItemsPerThread, Key, Value> {
using block_store_key_type = block_store<Key, BlockSize, ItemsPerThread, block_store_method::block_store_transpose>;
using block_store_value_type = block_store<Value, BlockSize, ItemsPerThread, block_store_method::block_store_transpose>;
union storage_type {
typename block_store_key_type::storage_type keys;
typename block_store_value_type::storage_type values;
};
template <class KeysOutputIterator, class ValuesOutputIterator, class OffsetT>
ROCPRIM_DEVICE ROCPRIM_INLINE
void store(const OffsetT block_offset,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
KeysOutputIterator keys_output,
ValuesOutputIterator values_output,
Key (&keys)[ItemsPerThread],
Value (&values)[ItemsPerThread],
storage_type& storage)
{
// Synchronize before reusing shared memory
::rocprim::syncthreads();
if(is_incomplete_block)
{
block_store_key_type().store(
keys_output + block_offset,
keys,
valid_in_last_block,
storage.keys
);
::rocprim::syncthreads();
block_store_value_type().store(
values_output + block_offset,
values,
valid_in_last_block,
storage.values
);
}
else
{
block_store_key_type().store(
keys_output + block_offset,
keys,
storage.keys
);
::rocprim::syncthreads();
block_store_value_type().store(
values_output + block_offset,
values,
storage.values
);
}
}
};
template <unsigned int BlockSize, unsigned int ItemsPerThread, class Key>
struct block_sort_impl
{
using stable_key_type = rocprim::tuple<Key, unsigned int>;
using block_sort_type = ::rocprim::block_sort<stable_key_type, BlockSize, ItemsPerThread>;
using storage_type = typename block_sort_type::storage_type;
template <class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(stable_key_type (&keys)[ItemsPerThread],
storage_type& storage,
const unsigned int valid_in_last_block,
const bool is_incomplete_block,
BinaryFunction compare_function)
{
if(is_incomplete_block)
{
// Special comparison that sorts out of range values after any "valid" values
auto oor_compare
= [compare_function, valid_in_last_block](
const stable_key_type& lhs, const stable_key_type& rhs) mutable -> bool {
const bool left_oor = rocprim::get<1>(lhs) >= valid_in_last_block;
const bool right_oor = rocprim::get<1>(rhs) >= valid_in_last_block;
return (left_oor || right_oor) ? !left_oor : compare_function(lhs, rhs);
};
block_sort_type().sort(keys, // keys_input
storage,
oor_compare);
}
else
{
block_sort_type()
.sort(
keys, // keys_input
storage,
compare_function
);
}
}
};
template<
unsigned int BlockSize,
unsigned int ItemsPerThread,
class KeysInputIterator,
class KeysOutputIterator,
class ValuesInputIterator,
class ValuesOutputIterator,
class OffsetT,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void block_sort_kernel_impl(KeysInputIterator keys_input,
KeysOutputIterator keys_output,
ValuesInputIterator values_input,
ValuesOutputIterator values_output,
const OffsetT input_size,
BinaryFunction compare_function)
{
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
const unsigned int flat_id = block_thread_id<0>();
const unsigned int flat_block_id = block_id<0>();
constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
const OffsetT block_offset = flat_block_id * items_per_block;
const unsigned int valid_in_last_block = input_size - block_offset;
const bool is_incomplete_block = flat_block_id == (input_size / items_per_block);
key_type keys[ItemsPerThread];
value_type values[ItemsPerThread];
using block_load_keys_impl = block_load_keys_impl<BlockSize, ItemsPerThread, key_type>;
using block_sort_impl = block_sort_impl<BlockSize, ItemsPerThread, key_type>;
using block_load_values_impl = block_load_values_impl<with_values, BlockSize, ItemsPerThread, value_type>;
using block_store_impl = block_store_impl<with_values, BlockSize, ItemsPerThread, key_type, value_type>;
ROCPRIM_SHARED_MEMORY union {
typename block_load_keys_impl::storage_type load_keys;
typename block_sort_impl::storage_type sort;
typename block_load_values_impl::storage_type load_values;
typename block_store_impl::storage_type store;
} storage;
block_load_keys_impl().load(
block_offset,
valid_in_last_block,
is_incomplete_block,
keys_input,
keys,
storage.load_keys
);
using stable_key_type = typename block_sort_impl::stable_key_type;
// Special comparison that preserves relative order of equal keys
auto stable_compare_function = [compare_function](const stable_key_type& a, const stable_key_type& b) mutable -> bool
{
const bool ab = compare_function(rocprim::get<0>(a), rocprim::get<0>(b));
const bool ba = compare_function(rocprim::get<0>(b), rocprim::get<0>(a));
return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b)));
};
stable_key_type stable_keys[ItemsPerThread];
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
stable_keys[item] = rocprim::make_tuple(keys[item], ItemsPerThread * flat_id + item);
}
// Synchronize before reusing shared memory
::rocprim::syncthreads();
block_sort_impl().sort(
stable_keys,
storage.sort,
valid_in_last_block,
is_incomplete_block,
stable_compare_function
);
unsigned int ranks[ItemsPerThread];
ROCPRIM_UNROLL
for(unsigned int item = 0; item < ItemsPerThread; ++item) {
keys[item] = rocprim::get<0>(stable_keys[item]);
ranks[item] = rocprim::get<1>(stable_keys[item]);
}
// Load the values with the already sorted indices
block_load_values_impl().load(
flat_id,
ranks,
block_offset,
valid_in_last_block,
is_incomplete_block,
values_input,
values,
storage.load_values
);
block_store_impl().store(
block_offset,
valid_in_last_block,
is_incomplete_block,
keys_output,
values_output,
keys,
values,
storage.store
);
}
template<
unsigned int BlockSize,
class KeysInputIterator,
class KeysOutputIterator,
class ValuesInputIterator,
class ValuesOutputIterator,
class OffsetT,
class BinaryFunction
>
ROCPRIM_DEVICE ROCPRIM_INLINE
void block_merge_kernel_impl(KeysInputIterator keys_input,
KeysOutputIterator keys_output,
ValuesInputIterator values_input,
ValuesOutputIterator values_output,
const OffsetT input_size,
const unsigned int block_size,
BinaryFunction compare_function)
{
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
unsigned int id = (flat_block_id * BlockSize) + flat_id;
if (id >= input_size)
{
return;
}
key_type key;
value_type value;
key = keys_input[id];
if(with_values)
{
value = values_input[id];
}
const unsigned int block_id = id / block_size;
const bool block_id_is_odd = block_id & 1;
const unsigned int next_block_id = block_id_is_odd ? block_id - 1 :
block_id + 1;
const unsigned int block_start = min(block_id * block_size, (unsigned int) input_size);
const unsigned int next_block_start = min(next_block_id * block_size, (unsigned int) input_size);
const unsigned int next_block_end = min((next_block_id + 1) * block_size, (unsigned int) input_size);
if(next_block_start == input_size)
{
keys_output[id] = key;
if(with_values)
{
values_output[id] = value;
}
return;
}
unsigned int left_id = next_block_start;
unsigned int right_id = next_block_end;
while(left_id < right_id)
{
unsigned int mid_id = (left_id + right_id) / 2;
key_type mid_key = keys_input[mid_id];
bool smaller = compare_function(mid_key, key);
left_id = smaller ? mid_id + 1 : left_id;
right_id = smaller ? right_id : mid_id;
}
right_id = next_block_end;
if(block_id_is_odd && left_id != right_id)
{
key_type upper_key = keys_input[left_id];
while(!compare_function(upper_key, key) &&
!compare_function(key, upper_key) &&
left_id < right_id)
{
unsigned int mid_id = (left_id + right_id) / 2;
key_type mid_key = keys_input[mid_id];
bool equal = !compare_function(mid_key, key) &&
!compare_function(key, mid_key);
left_id = equal ? mid_id + 1 : left_id + 1;
right_id = equal ? right_id : mid_id;
upper_key = keys_input[left_id];
}
}
unsigned int offset = 0;
offset += id - block_start;
offset += left_id - next_block_start;
offset += min(block_start, next_block_start);
keys_output[offset] = key;
if(with_values)
{
values_output[offset] = value;
}
}
} // end of detail namespace
END_ROCPRIM_NAMESPACE
#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment